Sometime you want to skip specific sub-DAGs when traversing a node. For example, you might be doing a transformation with a cache, and want to skip sub-DAGs that you've already processed.
This PR would add a skipIf builder method to NodeDfsIterable, which allows the user to provide a predicate according to which nodes will be omitted from the subsequent traversal.
namespace CVC4 {
-NodeDfsIterator::NodeDfsIterator(TNode n, bool postorder)
+NodeDfsIterator::NodeDfsIterator(TNode n,
+ bool postorder,
+ std::function<bool(TNode)> skipIf)
: d_stack{n},
d_visited(),
d_postorder(postorder),
- d_current(TNode())
+ d_current(TNode()),
+ d_skipIf(skipIf)
{
}
: d_stack(),
d_visited(),
d_postorder(postorder),
- d_current(TNode())
+ d_current(TNode()),
+ d_skipIf([](TNode) { return false; })
{
}
// The stack and current node uniquely represent traversal state. We need not
// use the scheduled node set.
//
- // Users should not compare iterators for traversals of different nodes.
+ // Users should not compare iterators for traversals of different nodes, or
+ // traversals with different skipIfs.
Assert(d_postorder == other.d_postorder);
return d_stack == other.d_stack && d_current == other.d_current;
}
if (visitEntry == d_visited.end())
{
// if we haven't pre-visited this node, pre-visit it
+ if (d_skipIf(back))
+ {
+ // actually, skip it if the skip predicate says so...
+ d_stack.pop_back();
+ continue;
+ }
d_visited[back] = false;
d_current = back;
// Use integer underflow to reverse-iterate
}
}
-NodeDfsIterable::NodeDfsIterable(TNode n) : d_node(n), d_postorder(true) {}
+NodeDfsIterable::NodeDfsIterable(TNode n)
+ : d_node(n), d_postorder(true), d_skipIf([](TNode) { return false; })
+{
+}
NodeDfsIterable& NodeDfsIterable::inPostorder()
{
return *this;
}
+NodeDfsIterable& NodeDfsIterable::skipIf(
+ std::function<bool(TNode)> skipCondition)
+{
+ d_skipIf = skipCondition;
+ return *this;
+}
+
NodeDfsIterator NodeDfsIterable::begin() const
{
- return NodeDfsIterator(d_node, d_postorder);
+ return NodeDfsIterator(d_node, d_postorder, d_skipIf);
}
NodeDfsIterator NodeDfsIterable::end() const
#define CVC4__EXPR__NODE_TRAVERSAL_H
#include <cstddef>
+#include <functional>
#include <iterator>
#include <unordered_map>
#include <vector>
using difference_type = std::ptrdiff_t;
// Construct a traversal iterator beginning at `n`
- NodeDfsIterator(TNode n, bool postorder);
+ NodeDfsIterator(TNode n, bool postorder, std::function<bool(TNode)> skipIf);
// Construct an end-of-traversal iterator
NodeDfsIterator(bool postorder);
// Current referent node. A valid node to visit if non-null.
// Null after construction (but before first access) and at the end.
TNode d_current;
+
+ // When to omit a node and its descendants from the traversal
+ std::function<bool(TNode)> d_skipIf;
};
// Node wrapper that is iterable in DAG post-order
// Modify this iterable to be in pre-order
NodeDfsIterable& inPreorder();
+ // Skip a node (and its descendants) if true.
+ NodeDfsIterable& skipIf(std::function<bool(TNode)> skipCondition);
+
// Move/copy construction and assignment. Destructor.
NodeDfsIterable(NodeDfsIterable&&) = default;
NodeDfsIterable& operator=(NodeDfsIterable&&) = default;
private:
TNode d_node;
bool d_postorder;
+ std::function<bool(TNode)> d_skipIf;
};
} // namespace CVC4
std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
TS_ASSERT_EQUALS(actual, expected);
}
+
+ void testSkipIf()
+ {
+ Node tb = d_nodeManager->mkConst(true);
+ Node eb = d_nodeManager->mkConst(false);
+ Node cnd = d_nodeManager->mkNode(XOR, tb, eb);
+ Node top = d_nodeManager->mkNode(XOR, cnd, cnd);
+ std::vector<TNode> expected = {top};
+
+ auto traversal = NodeDfsIterable(top).inPostorder().skipIf(
+ [&cnd](TNode n) { return n == cnd; });
+
+ std::vector<TNode> actual;
+ std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
+ TS_ASSERT_EQUALS(actual, expected);
+ }
};
class NodePreorderTraversalBlack : public CxxTest::TestSuite
std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
TS_ASSERT_EQUALS(actual, expected);
}
+
+ void testSkipIf()
+ {
+ Node tb = d_nodeManager->mkConst(true);
+ Node eb = d_nodeManager->mkConst(false);
+ Node cnd = d_nodeManager->mkNode(XOR, tb, eb);
+ Node top = d_nodeManager->mkNode(XOR, cnd, cnd);
+ std::vector<TNode> expected = {top, cnd, eb};
+
+ auto traversal = NodeDfsIterable(top).inPreorder().skipIf(
+ [&tb](TNode n) { return n == tb; });
+
+ std::vector<TNode> actual;
+ std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
+ TS_ASSERT_EQUALS(actual, expected);
+ }
};