Add skip predicate to node traversal. (#4222)
authorAlex Ozdemir <aozdemir@hmc.edu>
Sat, 11 Apr 2020 17:16:05 +0000 (10:16 -0700)
committerGitHub <noreply@github.com>
Sat, 11 Apr 2020 17:16:05 +0000 (12:16 -0500)
Sometime you want to skip specific sub-DAGs when traversing a node. For example, you might be doing a transformation with a cache, and want to skip sub-DAGs that you've already processed.

This PR would add a skipIf builder method to NodeDfsIterable, which allows the user to provide a predicate according to which nodes will be omitted from the subsequent traversal.

src/expr/node_traversal.cpp
src/expr/node_traversal.h
test/unit/expr/node_traversal_black.h

index 9e7a82c2416d75d89a5b0e8e25e505f509b05941..ad1a9ec71bad368b010f2a01e458010a73dddd86 100644 (file)
 
 namespace CVC4 {
 
-NodeDfsIterator::NodeDfsIterator(TNode n, bool postorder)
+NodeDfsIterator::NodeDfsIterator(TNode n,
+                                 bool postorder,
+                                 std::function<bool(TNode)> skipIf)
     : d_stack{n},
       d_visited(),
       d_postorder(postorder),
-      d_current(TNode())
+      d_current(TNode()),
+      d_skipIf(skipIf)
 {
 }
 
@@ -28,7 +31,8 @@ NodeDfsIterator::NodeDfsIterator(bool postorder)
     : d_stack(),
       d_visited(),
       d_postorder(postorder),
-      d_current(TNode())
+      d_current(TNode()),
+      d_skipIf([](TNode) { return false; })
 {
 }
 
@@ -64,7 +68,8 @@ bool NodeDfsIterator::operator==(const NodeDfsIterator& other) const
   // The stack and current node uniquely represent traversal state. We need not
   // use the scheduled node set.
   //
-  // Users should not compare iterators for traversals of different nodes.
+  // Users should not compare iterators for traversals of different nodes, or
+  // traversals with different skipIfs.
   Assert(d_postorder == other.d_postorder);
   return d_stack == other.d_stack && d_current == other.d_current;
 }
@@ -84,6 +89,12 @@ void NodeDfsIterator::advanceToNextVisit()
     if (visitEntry == d_visited.end())
     {
       // if we haven't pre-visited this node, pre-visit it
+      if (d_skipIf(back))
+      {
+        // actually, skip it if the skip predicate says so...
+        d_stack.pop_back();
+        continue;
+      }
       d_visited[back] = false;
       d_current = back;
       // Use integer underflow to reverse-iterate
@@ -123,7 +134,10 @@ void NodeDfsIterator::initializeIfUninitialized()
   }
 }
 
-NodeDfsIterable::NodeDfsIterable(TNode n) : d_node(n), d_postorder(true) {}
+NodeDfsIterable::NodeDfsIterable(TNode n)
+    : d_node(n), d_postorder(true), d_skipIf([](TNode) { return false; })
+{
+}
 
 NodeDfsIterable& NodeDfsIterable::inPostorder()
 {
@@ -137,9 +151,16 @@ NodeDfsIterable& NodeDfsIterable::inPreorder()
   return *this;
 }
 
+NodeDfsIterable& NodeDfsIterable::skipIf(
+    std::function<bool(TNode)> skipCondition)
+{
+  d_skipIf = skipCondition;
+  return *this;
+}
+
 NodeDfsIterator NodeDfsIterable::begin() const
 {
-  return NodeDfsIterator(d_node, d_postorder);
+  return NodeDfsIterator(d_node, d_postorder, d_skipIf);
 }
 
 NodeDfsIterator NodeDfsIterable::end() const
index fffc1d746f1f43cbeae94cb0337dcfd52809e95a..1078f08c82c73cd91a028bfbd50870232cb9b037 100644 (file)
@@ -18,6 +18,7 @@
 #define CVC4__EXPR__NODE_TRAVERSAL_H
 
 #include <cstddef>
+#include <functional>
 #include <iterator>
 #include <unordered_map>
 #include <vector>
@@ -39,7 +40,7 @@ class NodeDfsIterator
   using difference_type = std::ptrdiff_t;
 
   // Construct a traversal iterator beginning at `n`
-  NodeDfsIterator(TNode n, bool postorder);
+  NodeDfsIterator(TNode n, bool postorder, std::function<bool(TNode)> skipIf);
   // Construct an end-of-traversal iterator
   NodeDfsIterator(bool postorder);
 
@@ -97,6 +98,9 @@ class NodeDfsIterator
   // Current referent node. A valid node to visit if non-null.
   // Null after construction (but before first access) and at the end.
   TNode d_current;
+
+  // When to omit a node and its descendants from the traversal
+  std::function<bool(TNode)> d_skipIf;
 };
 
 // Node wrapper that is iterable in DAG post-order
@@ -111,6 +115,9 @@ class NodeDfsIterable
   // Modify this iterable to be in pre-order
   NodeDfsIterable& inPreorder();
 
+  // Skip a node (and its descendants) if true.
+  NodeDfsIterable& skipIf(std::function<bool(TNode)> skipCondition);
+
   // Move/copy construction and assignment. Destructor.
   NodeDfsIterable(NodeDfsIterable&&) = default;
   NodeDfsIterable& operator=(NodeDfsIterable&&) = default;
@@ -124,6 +131,7 @@ class NodeDfsIterable
  private:
   TNode d_node;
   bool d_postorder;
+  std::function<bool(TNode)> d_skipIf;
 };
 
 }  // namespace CVC4
index b4a7c449c096bee566456ae84311c90ea4e9dffa..b751a099987bc52d9e44d565797bc7f743af5a90 100644 (file)
@@ -161,6 +161,22 @@ class NodePostorderTraversalBlack : public CxxTest::TestSuite
     std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
     TS_ASSERT_EQUALS(actual, expected);
   }
+
+  void testSkipIf()
+  {
+    Node tb = d_nodeManager->mkConst(true);
+    Node eb = d_nodeManager->mkConst(false);
+    Node cnd = d_nodeManager->mkNode(XOR, tb, eb);
+    Node top = d_nodeManager->mkNode(XOR, cnd, cnd);
+    std::vector<TNode> expected = {top};
+
+    auto traversal = NodeDfsIterable(top).inPostorder().skipIf(
+        [&cnd](TNode n) { return n == cnd; });
+
+    std::vector<TNode> actual;
+    std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
+    TS_ASSERT_EQUALS(actual, expected);
+  }
 };
 
 class NodePreorderTraversalBlack : public CxxTest::TestSuite
@@ -278,4 +294,20 @@ class NodePreorderTraversalBlack : public CxxTest::TestSuite
     std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
     TS_ASSERT_EQUALS(actual, expected);
   }
+
+  void testSkipIf()
+  {
+    Node tb = d_nodeManager->mkConst(true);
+    Node eb = d_nodeManager->mkConst(false);
+    Node cnd = d_nodeManager->mkNode(XOR, tb, eb);
+    Node top = d_nodeManager->mkNode(XOR, cnd, cnd);
+    std::vector<TNode> expected = {top, cnd, eb};
+
+    auto traversal = NodeDfsIterable(top).inPreorder().skipIf(
+        [&tb](TNode n) { return n == tb; });
+
+    std::vector<TNode> actual;
+    std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
+    TS_ASSERT_EQUALS(actual, expected);
+  }
 };