Use traversal iterators in IntToBv (#4169)
authorAlex Ozdemir <aozdemir@hmc.edu>
Sat, 20 Jun 2020 05:45:32 +0000 (22:45 -0700)
committerGitHub <noreply@github.com>
Sat, 20 Jun 2020 05:45:32 +0000 (22:45 -0700)
This commit rips the traversal machinery out of Int-to-Bv, replacing it with traversal iterators.

Also, cleaned `childrenTypesChanged` a bit.

While basically I just cut out some lines, the diff is rather messy (I think the diffing tool doesn't like indentation changes).

src/expr/node_traversal.cpp
src/expr/node_traversal.h
src/preprocessing/passes/int_to_bv.cpp
test/unit/expr/node_traversal_black.h

index 0d207b302824788cacbfede5994764cb436f34ef..5968c43488ded602edf560a8e92e9ed1797f9b3d 100644 (file)
@@ -63,8 +63,11 @@ TNode& NodeDfsIterator::operator*()
   return d_current;
 }
 
-bool NodeDfsIterator::operator==(const NodeDfsIterator& other) const
+bool NodeDfsIterator::operator==(NodeDfsIterator& other)
 {
+  // Unitialize this node, and the other, before comparing.
+  initializeIfUninitialized();
+  other.initializeIfUninitialized();
   // The stack and current node uniquely represent traversal state. We need not
   // use the scheduled node set.
   //
@@ -74,7 +77,7 @@ bool NodeDfsIterator::operator==(const NodeDfsIterator& other) const
   return d_stack == other.d_stack && d_current == other.d_current;
 }
 
-bool NodeDfsIterator::operator!=(const NodeDfsIterator& other) const
+bool NodeDfsIterator::operator!=(NodeDfsIterator& other)
 {
   return !(*this == other);
 }
index 1398f6485fd8f638782709c99b4c907728cd35e5..586cbb64db3bfcc8591bfcaf9d31ca7c9330dfb1 100644 (file)
@@ -67,9 +67,13 @@ class NodeDfsIterator
   // Dereference
   reference operator*();
   // Equals
-  bool operator==(const NodeDfsIterator&) const;
+  // It is not constant, because an unitilized node must be initialized before
+  // comparison
+  bool operator==(NodeDfsIterator&);
   // Not equals
-  bool operator!=(const NodeDfsIterator&) const;
+  // It is not constant, because an unitilized node must be initialized before
+  // comparison
+  bool operator!=(NodeDfsIterator&);
 
  private:
   // While we're not at an appropriate visit (see d_postorder), advance.
index 2740cb771e409411019a97e30ff081402e8806e4..e30c03f26308ba0b1e0efd0ccc109e28a153f722 100644 (file)
@@ -23,6 +23,7 @@
 #include <vector>
 
 #include "expr/node.h"
+#include "expr/node_traversal.h"
 #include "theory/rewriter.h"
 #include "theory/theory.h"
 
@@ -36,107 +37,58 @@ using NodeMap = std::unordered_map<Node, Node, NodeHashFunction>;
 
 namespace {
 
-// TODO: clean this up
-struct intToBV_stack_element
-{
-  TNode d_node;
-  bool d_children_added;
-  intToBV_stack_element(TNode node) : d_node(node), d_children_added(false) {}
-}; /* struct intToBV_stack_element */
-
 bool childrenTypesChanged(Node n, NodeMap& cache) {
-  bool result = false;
   for (Node child : n) {
     TypeNode originalType = child.getType();
     TypeNode newType = cache[child].getType();
     if (! newType.isSubtypeOf(originalType)) {
-      result = true;
-      break;
+      return true;
     }
   }
-  return result;
+  return false;
 }
 
 
 Node intToBVMakeBinary(TNode n, NodeMap& cache)
 {
-  // Do a topological sort of the subexpressions and substitute them
-  vector<intToBV_stack_element> toVisit;
-  toVisit.push_back(n);
-
-  while (!toVisit.empty())
+  for (TNode current : NodeDfsIterable(n, VisitOrder::POSTORDER,
+           [&cache](TNode nn) { return cache.count(nn) > 0; }))
   {
-    // The current node we are processing
-    intToBV_stack_element& stackHead = toVisit.back();
-    TNode current = stackHead.d_node;
-
-    NodeMap::iterator find = cache.find(current);
-    if (find != cache.end())
+    Node result;
+    NodeManager* nm = NodeManager::currentNM();
+    if (current.getNumChildren() == 0)
     {
-      toVisit.pop_back();
-      continue;
+      result = current;
     }
-    if (stackHead.d_children_added)
+    else if (current.getNumChildren() > 2
+             && (current.getKind() == kind::PLUS
+                 || current.getKind() == kind::MULT))
     {
-      // Children have been processed, so rebuild this node
-      Node result;
-      NodeManager* nm = NodeManager::currentNM();
-      if (current.getNumChildren() > 2
-          && (current.getKind() == kind::PLUS
-              || current.getKind() == kind::MULT))
-      {
-        Assert(cache.find(current[0]) != cache.end());
-        result = cache[current[0]];
-        for (unsigned i = 1; i < current.getNumChildren(); ++i)
-        {
-          Assert(cache.find(current[i]) != cache.end());
-          Node child = current[i];
-          Node childRes = cache[current[i]];
-          result = nm->mkNode(current.getKind(), result, childRes);
-        }
-      }
-      else
+      Assert(cache.find(current[0]) != cache.end());
+      result = cache[current[0]];
+      for (unsigned i = 1; i < current.getNumChildren(); ++i)
       {
-        NodeBuilder<> builder(current.getKind());
-        if (current.getMetaKind() == kind::metakind::PARAMETERIZED) {
-          builder << current.getOperator();
-        }
-
-        for (unsigned i = 0; i < current.getNumChildren(); ++i)
-        {
-          Assert(cache.find(current[i]) != cache.end());
-          builder << cache[current[i]];
-        }
-        result = builder;
+        Assert(cache.find(current[i]) != cache.end());
+        Node child = current[i];
+        Node childRes = cache[current[i]];
+        result = nm->mkNode(current.getKind(), result, childRes);
       }
-      cache[current] = result;
-      toVisit.pop_back();
     }
     else
     {
-      // Mark that we have added the children if any
-      if (current.getNumChildren() > 0)
-      {
-        stackHead.d_children_added = true;
-        // We need to add the children
-        for (TNode::iterator child_it = current.begin();
-             child_it != current.end();
-             ++child_it)
-        {
-          TNode childNode = *child_it;
-          NodeMap::iterator childFind = cache.find(childNode);
-          if (childFind == cache.end())
-          {
-            toVisit.push_back(childNode);
-          }
-        }
+      NodeBuilder<> builder(current.getKind());
+      if (current.getMetaKind() == kind::metakind::PARAMETERIZED) {
+        builder << current.getOperator();
       }
-      else
+
+      for (unsigned i = 0; i < current.getNumChildren(); ++i)
       {
-        cache[current] = current;
-        toVisit.pop_back();
+        Assert(cache.find(current[i]) != cache.end());
+        builder << cache[current[i]];
       }
+      result = builder;
     }
+    cache[current] = result;
   }
   return cache[n];
 }
@@ -147,30 +99,16 @@ Node intToBV(TNode n, NodeMap& cache)
   AlwaysAssert(size > 0);
   AlwaysAssert(!options::incrementalSolving());
 
-  vector<intToBV_stack_element> toVisit;
   NodeMap binaryCache;
   Node n_binary = intToBVMakeBinary(n, binaryCache);
-  toVisit.push_back(TNode(n_binary));
 
-  while (!toVisit.empty())
+  for (TNode current : NodeDfsIterable(n_binary, VisitOrder::POSTORDER,
+           [&cache](TNode nn) { return cache.count(nn) > 0; }))
   {
-    // The current node we are processing
-    intToBV_stack_element& stackHead = toVisit.back();
-    TNode current = stackHead.d_node;
-
-    // If node is already in the cache we're done, pop from the stack
-    NodeMap::iterator find = cache.find(current);
-    if (find != cache.end())
-    {
-      toVisit.pop_back();
-      continue;
-    }
-
-    // Not yet substituted, so process
     NodeManager* nm = NodeManager::currentNM();
-    if (stackHead.d_children_added)
+    if (current.getNumChildren() > 0)
     {
-      // Children have been processed, so rebuild this node
+      // Not a leaf
       vector<Node> children;
       unsigned max = 0;
       for (unsigned i = 0; i < current.getNumChildren(); ++i)
@@ -258,73 +196,51 @@ Node intToBV(TNode n, NodeMap& cache)
 
       result = Rewriter::rewrite(result);
       cache[current] = result;
-      toVisit.pop_back();
     }
     else
     {
-      // Mark that we have added the children if any
-      if (current.getNumChildren() > 0)
+      // It's a leaf: could be a variable or a numeral
+      Node result = current;
+      if (current.isVar())
       {
-        stackHead.d_children_added = true;
-        // We need to add the children
-        for (TNode::iterator child_it = current.begin();
-             child_it != current.end();
-             ++child_it)
+        if (current.getType() == nm->integerType())
         {
-          TNode childNode = *child_it;
-          NodeMap::iterator childFind = cache.find(childNode);
-          if (childFind == cache.end())
-          {
-            toVisit.push_back(childNode);
-          }
+          result = nm->mkSkolem("__intToBV_var",
+                                nm->mkBitVectorType(size),
+                                "Variable introduced in intToBV pass");
         }
       }
-      else
+      else if (current.isConst())
       {
-        // It's a leaf: could be a variable or a numeral
-        Node result = current;
-        if (current.isVar())
+        switch (current.getKind())
         {
-          if (current.getType() == nm->integerType())
+          case kind::CONST_RATIONAL:
           {
-            result = nm->mkSkolem("__intToBV_var",
-                                  nm->mkBitVectorType(size),
-                                  "Variable introduced in intToBV pass");
-          }
-        }
-        else if (current.isConst())
-        {
-          switch (current.getKind())
-          {
-            case kind::CONST_RATIONAL:
-            {
-              Rational constant = current.getConst<Rational>();
-              if (constant.isIntegral()) {
-                AlwaysAssert(constant >= 0);
-                BitVector bv(size, constant.getNumerator());
-                if (bv.toSignedInteger() != constant.getNumerator())
-                {
-                  throw TypeCheckingException(
-                      current.toExpr(),
-                      string("Not enough bits for constant in intToBV: ")
-                          + current.toString());
-                }
-                result = nm->mkConst(bv);
+            Rational constant = current.getConst<Rational>();
+            if (constant.isIntegral()) {
+              AlwaysAssert(constant >= 0);
+              BitVector bv(size, constant.getNumerator());
+              if (bv.toSignedInteger() != constant.getNumerator())
+              {
+                throw TypeCheckingException(
+                    current.toExpr(),
+                    string("Not enough bits for constant in intToBV: ")
+                        + current.toString());
               }
-              break;
+              result = nm->mkConst(bv);
             }
-            default: break;
+            break;
           }
+          default: break;
         }
-        else
-        {
-          throw TypeCheckingException(
-              current.toExpr(),
-              string("Cannot translate to BV: ") + current.toString());
-        }
-        cache[current] = result;
-        toVisit.pop_back();
       }
+      else
+      {
+        throw TypeCheckingException(
+            current.toExpr(),
+            string("Cannot translate to BV: ") + current.toString());
+      }
+      cache[current] = result;
     }
   }
   return cache[n_binary];
index 02f8db6a3473c7c7cd3514723e772a4fcfc36c19..6f5dfaf0438322c44a2d1c25520843ad47be2c0d 100644 (file)
@@ -177,6 +177,22 @@ class NodePostorderTraversalBlack : public CxxTest::TestSuite
     std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
     TS_ASSERT_EQUALS(actual, expected);
   }
+
+  void testSkipAll()
+  {
+    Node tb = d_nodeManager->mkConst(true);
+    Node eb = d_nodeManager->mkConst(false);
+    Node cnd = d_nodeManager->mkNode(XOR, tb, eb);
+    Node top = d_nodeManager->mkNode(XOR, cnd, cnd);
+    std::vector<TNode> expected = {};
+
+    auto traversal = NodeDfsIterable(top, VisitOrder::POSTORDER,
+        [](TNode n) { return true; });
+
+    std::vector<TNode> actual;
+    std::copy(traversal.begin(), traversal.end(), std::back_inserter(actual));
+    TS_ASSERT_EQUALS(actual, expected);
+  }
 };
 
 class NodePreorderTraversalBlack : public CxxTest::TestSuite