From: Alex Ozdemir Date: Sat, 11 Apr 2020 17:16:05 +0000 (-0700) Subject: Add skip predicate to node traversal. (#4222) X-Git-Tag: cvc5-1.0.0~3385 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=0073464e433e80311269ce07e0fa5de417b5eefc;p=cvc5.git Add skip predicate to node traversal. (#4222) Sometime you want to skip specific sub-DAGs when traversing a node. For example, you might be doing a transformation with a cache, and want to skip sub-DAGs that you've already processed. This PR would add a skipIf builder method to NodeDfsIterable, which allows the user to provide a predicate according to which nodes will be omitted from the subsequent traversal. --- diff --git a/src/expr/node_traversal.cpp b/src/expr/node_traversal.cpp index 9e7a82c24..ad1a9ec71 100644 --- a/src/expr/node_traversal.cpp +++ b/src/expr/node_traversal.cpp @@ -16,11 +16,14 @@ namespace CVC4 { -NodeDfsIterator::NodeDfsIterator(TNode n, bool postorder) +NodeDfsIterator::NodeDfsIterator(TNode n, + bool postorder, + std::function skipIf) : d_stack{n}, d_visited(), d_postorder(postorder), - d_current(TNode()) + d_current(TNode()), + d_skipIf(skipIf) { } @@ -28,7 +31,8 @@ NodeDfsIterator::NodeDfsIterator(bool postorder) : d_stack(), d_visited(), d_postorder(postorder), - d_current(TNode()) + d_current(TNode()), + d_skipIf([](TNode) { return false; }) { } @@ -64,7 +68,8 @@ bool NodeDfsIterator::operator==(const NodeDfsIterator& other) const // The stack and current node uniquely represent traversal state. We need not // use the scheduled node set. // - // Users should not compare iterators for traversals of different nodes. + // Users should not compare iterators for traversals of different nodes, or + // traversals with different skipIfs. Assert(d_postorder == other.d_postorder); return d_stack == other.d_stack && d_current == other.d_current; } @@ -84,6 +89,12 @@ void NodeDfsIterator::advanceToNextVisit() if (visitEntry == d_visited.end()) { // if we haven't pre-visited this node, pre-visit it + if (d_skipIf(back)) + { + // actually, skip it if the skip predicate says so... + d_stack.pop_back(); + continue; + } d_visited[back] = false; d_current = back; // Use integer underflow to reverse-iterate @@ -123,7 +134,10 @@ void NodeDfsIterator::initializeIfUninitialized() } } -NodeDfsIterable::NodeDfsIterable(TNode n) : d_node(n), d_postorder(true) {} +NodeDfsIterable::NodeDfsIterable(TNode n) + : d_node(n), d_postorder(true), d_skipIf([](TNode) { return false; }) +{ +} NodeDfsIterable& NodeDfsIterable::inPostorder() { @@ -137,9 +151,16 @@ NodeDfsIterable& NodeDfsIterable::inPreorder() return *this; } +NodeDfsIterable& NodeDfsIterable::skipIf( + std::function skipCondition) +{ + d_skipIf = skipCondition; + return *this; +} + NodeDfsIterator NodeDfsIterable::begin() const { - return NodeDfsIterator(d_node, d_postorder); + return NodeDfsIterator(d_node, d_postorder, d_skipIf); } NodeDfsIterator NodeDfsIterable::end() const diff --git a/src/expr/node_traversal.h b/src/expr/node_traversal.h index fffc1d746..1078f08c8 100644 --- a/src/expr/node_traversal.h +++ b/src/expr/node_traversal.h @@ -18,6 +18,7 @@ #define CVC4__EXPR__NODE_TRAVERSAL_H #include +#include #include #include #include @@ -39,7 +40,7 @@ class NodeDfsIterator using difference_type = std::ptrdiff_t; // Construct a traversal iterator beginning at `n` - NodeDfsIterator(TNode n, bool postorder); + NodeDfsIterator(TNode n, bool postorder, std::function skipIf); // Construct an end-of-traversal iterator NodeDfsIterator(bool postorder); @@ -97,6 +98,9 @@ class NodeDfsIterator // Current referent node. A valid node to visit if non-null. // Null after construction (but before first access) and at the end. TNode d_current; + + // When to omit a node and its descendants from the traversal + std::function d_skipIf; }; // Node wrapper that is iterable in DAG post-order @@ -111,6 +115,9 @@ class NodeDfsIterable // Modify this iterable to be in pre-order NodeDfsIterable& inPreorder(); + // Skip a node (and its descendants) if true. + NodeDfsIterable& skipIf(std::function skipCondition); + // Move/copy construction and assignment. Destructor. NodeDfsIterable(NodeDfsIterable&&) = default; NodeDfsIterable& operator=(NodeDfsIterable&&) = default; @@ -124,6 +131,7 @@ class NodeDfsIterable private: TNode d_node; bool d_postorder; + std::function d_skipIf; }; } // namespace CVC4 diff --git a/test/unit/expr/node_traversal_black.h b/test/unit/expr/node_traversal_black.h index b4a7c449c..b751a0999 100644 --- a/test/unit/expr/node_traversal_black.h +++ b/test/unit/expr/node_traversal_black.h @@ -161,6 +161,22 @@ class NodePostorderTraversalBlack : public CxxTest::TestSuite std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual)); TS_ASSERT_EQUALS(actual, expected); } + + void testSkipIf() + { + Node tb = d_nodeManager->mkConst(true); + Node eb = d_nodeManager->mkConst(false); + Node cnd = d_nodeManager->mkNode(XOR, tb, eb); + Node top = d_nodeManager->mkNode(XOR, cnd, cnd); + std::vector expected = {top}; + + auto traversal = NodeDfsIterable(top).inPostorder().skipIf( + [&cnd](TNode n) { return n == cnd; }); + + std::vector actual; + std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual)); + TS_ASSERT_EQUALS(actual, expected); + } }; class NodePreorderTraversalBlack : public CxxTest::TestSuite @@ -278,4 +294,20 @@ class NodePreorderTraversalBlack : public CxxTest::TestSuite std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual)); TS_ASSERT_EQUALS(actual, expected); } + + void testSkipIf() + { + Node tb = d_nodeManager->mkConst(true); + Node eb = d_nodeManager->mkConst(false); + Node cnd = d_nodeManager->mkNode(XOR, tb, eb); + Node top = d_nodeManager->mkNode(XOR, cnd, cnd); + std::vector expected = {top, cnd, eb}; + + auto traversal = NodeDfsIterable(top).inPreorder().skipIf( + [&tb](TNode n) { return n == tb; }); + + std::vector actual; + std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual)); + TS_ASSERT_EQUALS(actual, expected); + } };