Use new let binding for cvc printer (#5561)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 2 Dec 2020 16:17:09 +0000 (10:17 -0600)
committerGitHub <noreply@github.com>
Wed, 2 Dec 2020 16:17:09 +0000 (10:17 -0600)
Also changes names slightly to avoid accidental uses of toStream, which can lead to infinite loops if the wrong version is used.

src/printer/cvc/cvc_printer.cpp
src/printer/cvc/cvc_printer.h
test/unit/expr/node_black.h

index be530099b91716c90ecc62b479e6b23a8d0fd08d..c681a223f8e0b5f716c79db3debad9fca4b8aa28 100644 (file)
@@ -31,7 +31,7 @@
 #include "expr/sequence.h"
 #include "options/language.h"  // for LANG_AST
 #include "options/smt_options.h"
-#include "printer/dagification_visitor.h"
+#include "printer/let_binding.h"
 #include "smt/command.h"
 #include "smt/node_command.h"
 #include "smt/smt_engine.h"
@@ -51,31 +51,10 @@ void CvcPrinter::toStream(std::ostream& out,
                           size_t dag) const
 {
   if(dag != 0) {
-    DagificationVisitor dv(dag);
-    NodeVisitor<DagificationVisitor> visitor;
-    visitor.run(dv, n);
-    const theory::SubstitutionMap& lets = dv.getLets();
-    if(!lets.empty()) {
-      out << "LET ";
-      bool first = true;
-      for(theory::SubstitutionMap::const_iterator i = lets.begin();
-          i != lets.end();
-          ++i) {
-        if(! first) {
-          out << ", ";
-        } else {
-          first = false;
-        }
-        toStream(out, (*i).second, toDepth, false);
-        out << " = ";
-        toStream(out, (*i).first, toDepth, false);
-      }
-      out << " IN ";
-    }
-    Node body = dv.getDagifiedBody();
-    toStream(out, body, toDepth, false);
+    LetBinding lbind(dag + 1);
+    toStreamNodeWithLetify(out, n, toDepth, false, &lbind);
   } else {
-    toStream(out, n, toDepth, false);
+    toStreamNode(out, n, toDepth, false, nullptr);
   }
 }
 
@@ -93,10 +72,11 @@ void toStreamRational(std::ostream& out, Node n, bool forceRational)
   }
 }
 
-void CvcPrinter::toStream(std::ostream& out,
-                          TNode n,
-                          int depth,
-                          bool bracket) const
+void CvcPrinter::toStreamNode(std::ostream& out,
+                              TNode n,
+                              int depth,
+                              bool bracket,
+                              LetBinding* lbind) const
 {
   if (depth == 0) {
     out << "(...)";
@@ -286,11 +266,11 @@ void CvcPrinter::toStream(std::ostream& out,
       break;
     case kind::ITE:
       out << "IF ";
-      toStream(out, n[0], depth, true);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << " THEN ";
-      toStream(out, n[1], depth, true);
+      toStreamNode(out, n[1], depth, true, lbind);
       out << " ELSE ";
-      toStream(out, n[2], depth, true);
+      toStreamNode(out, n[2], depth, true, lbind);
       out << " ENDIF";
       return;
       break;
@@ -300,7 +280,7 @@ void CvcPrinter::toStream(std::ostream& out,
         if (i > 0) {
           out << ", ";
         }
-        toStream(out, n[i], depth, false);
+        toStreamNode(out, n[i], depth, false, lbind);
       }
       out << ']';
       return;
@@ -310,22 +290,22 @@ void CvcPrinter::toStream(std::ostream& out,
       break;
     case kind::LAMBDA:
       out << "(LAMBDA";
-      toStream(out, n[0], depth, true);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << ": ";
-      toStream(out, n[1], depth, true);
+      toStreamNodeWithLetify(out, n[1], depth, true, lbind);
       out << ")";
       return;
       break;
     case kind::WITNESS:
       out << "(WITNESS";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << " : ";
-      toStream(out, n[1], depth, false);
+      toStreamNodeWithLetify(out, n[1], depth, true, lbind);
       out << ')';
       return;
     case kind::DISTINCT:
       // distinct not supported directly, blast it away with the rewriter
-      toStream(out, theory::Rewriter::rewrite(n), depth, true);
+      toStreamNode(out, theory::Rewriter::rewrite(n), depth, true, lbind);
       return;
     case kind::SORT_TYPE:
     {
@@ -360,7 +340,9 @@ void CvcPrinter::toStream(std::ostream& out,
       break;
 
     // UF
-    case kind::APPLY_UF: toStream(op, n.getOperator(), depth, false); break;
+    case kind::APPLY_UF:
+      toStreamNode(op, n.getOperator(), depth, false, lbind);
+      break;
     case kind::CARDINALITY_CONSTRAINT:
     case kind::COMBINED_CARDINALITY_CONSTRAINT:
       out << "CARDINALITY_CONSTRAINT";
@@ -375,14 +357,14 @@ void CvcPrinter::toStream(std::ostream& out,
           if (i > 1) {
             out << ", ";
           }
-          toStream(out, n[i - 1], depth, false);
+          toStreamNode(out, n[i - 1], depth, false, lbind);
         }
         if (n.getNumChildren() > 2) {
           out << ')';
         }
       }
       out << " -> ";
-      toStream(out, n[n.getNumChildren() - 1], depth, false);
+      toStreamNode(out, n[n.getNumChildren() - 1], depth, false, lbind);
       return;
       break;
 
@@ -404,7 +386,7 @@ void CvcPrinter::toStream(std::ostream& out,
       return;
       break;
     case kind::APPLY_TYPE_ASCRIPTION: {
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << "::";
       TypeNode t = n.getOperator().getConst<AscriptionType>().getType();
       out << (t.isFunctionLike() ? t.getRangeType() : t);
@@ -430,14 +412,14 @@ void CvcPrinter::toStream(std::ostream& out,
               out << ", ";
             }
             out << recCons[i].getName() << " := ";
-            toStream(out, n[i], depth, false);
+            toStreamNode(out, n[i], depth, false, lbind);
           }
           out << " #)";
           return;
         }
         else
         {
-          toStream(op, n.getOperator(), depth, false);
+          toStreamNode(op, n.getOperator(), depth, false, lbind);
           if (n.getNumChildren() == 0)
           {
             // for datatype constants d, we print "d" and not "d()"
@@ -453,11 +435,11 @@ void CvcPrinter::toStream(std::ostream& out,
         Node opn = n.getOperator();
         if (!t.isDatatype())
         {
-          toStream(op, opn, depth, false);
+          toStreamNode(op, opn, depth, false, lbind);
         }
         else if (t.isTuple() || t.isRecord())
         {
-          toStream(out, n[0], depth, true);
+          toStreamNode(out, n[0], depth, true, lbind);
           out << '.';
           const DType& dt = t.getDType();
           if (t.isTuple())
@@ -476,11 +458,11 @@ void CvcPrinter::toStream(std::ostream& out,
           }
           else
           {
-            toStream(out, opn, depth, false);
+            toStreamNode(out, opn, depth, false, lbind);
           }
           return;
         }else{
-          toStream(op, opn, depth, false);
+          toStreamNode(op, opn, depth, false, lbind);
         }
       }
       break;
@@ -489,7 +471,7 @@ void CvcPrinter::toStream(std::ostream& out,
       op << "is_";
       unsigned cindex = DType::indexOf(n.getOperator());
       const DType& dt = DType::datatypeOf(n.getOperator());
-      toStream(op, dt[cindex].getConstructor(), depth, false);
+      toStreamNode(op, dt[cindex].getConstructor(), depth, false, lbind);
     }
       break;
     case kind::CONSTRUCTOR_TYPE:
@@ -502,45 +484,45 @@ void CvcPrinter::toStream(std::ostream& out,
           if(i > 0) {
             out << ", ";
           }
-          toStream(out, n[i], depth, false);
+          toStreamNode(out, n[i], depth, false, lbind);
         }
         if(n.getNumChildren() > 2) {
           out << ')';
         }
         out << " -> ";
       }
-      toStream(out, n[n.getNumChildren() - 1], depth, false);
+      toStreamNode(out, n[n.getNumChildren() - 1], depth, false, lbind);
       return;
     case kind::TESTER_TYPE:
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << " -> BOOLEAN";
       return;
       break;
     case kind::TUPLE_UPDATE:
-      toStream(out, n[0], depth, true);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << " WITH ." << n.getOperator().getConst<TupleUpdate>().getIndex() << " := ";
-      toStream(out, n[1], depth, true);
+      toStreamNode(out, n[1], depth, true, lbind);
       return;
       break;
     case kind::RECORD_UPDATE:
-      toStream(out, n[0], depth, true);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << " WITH ." << n.getOperator().getConst<RecordUpdate>().getField() << " := ";
-      toStream(out, n[1], depth, true);
+      toStreamNode(out, n[1], depth, true, lbind);
       return;
       break;
 
     // ARRAYS
     case kind::ARRAY_TYPE:
       out << "ARRAY ";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << " OF ";
-      toStream(out, n[1], depth, false);
+      toStreamNode(out, n[1], depth, false, lbind);
       return;
       break;
     case kind::SELECT:
-      toStream(out, n[0], depth, true);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << '[';
-      toStream(out, n[1], depth, false);
+      toStreamNode(out, n[1], depth, false, lbind);
       out << ']';
       return;
       break;
@@ -554,18 +536,18 @@ void CvcPrinter::toStream(std::ostream& out,
         out << '(';
       }
       TNode x = stk.top();
-      toStream(out, x[0], depth, false);
+      toStreamNode(out, x[0], depth, false, lbind);
       out << " WITH [";
-      toStream(out, x[1], depth, false);
+      toStreamNode(out, x[1], depth, false, lbind);
       out << "] := ";
-      toStream(out, x[2], depth, false);
+      toStreamNode(out, x[2], depth, false, lbind);
       stk.pop();
       while(!stk.empty()) {
         x = stk.top();
         out << ", [";
-        toStream(out, x[1], depth, false);
+        toStreamNode(out, x[1], depth, false, lbind);
         out << "] := ";
-        toStream(out, x[2], depth, false);
+        toStreamNode(out, x[2], depth, false, lbind);
         stk.pop();
       }
       if (bracket) {
@@ -651,13 +633,13 @@ void CvcPrinter::toStream(std::ostream& out,
       else
       {
         // ignore, there is no to-real in CVC language
-        toStream(out, n[0], depth, false);
+        toStreamNode(out, n[0], depth, false, lbind);
       }
       return;
     }
     case kind::DIVISIBLE:
       out << "DIVISIBLE(";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ", " << n.getOperator().getConst<Divisible>().k << ")";
       return;
 
@@ -758,16 +740,16 @@ void CvcPrinter::toStream(std::ostream& out,
         out << "BVPLUS(";
         out << BitVectorType(n.getType().toType()).getSize();
         out << ',';
-        toStream(out, n[child], depth, false);
+        toStreamNode(out, n[child], depth, false, lbind);
         out << ',';
         ++child;
       }
       out << "BVPLUS(";
       out << BitVectorType(n.getType().toType()).getSize();
       out << ',';
-      toStream(out, n[child], depth, false);
+      toStreamNode(out, n[child], depth, false, lbind);
       out << ',';
-      toStream(out, n[child + 1], depth, false);
+      toStreamNode(out, n[child + 1], depth, false, lbind);
       while (child > 0) {
         out << ')';
         --child;
@@ -781,9 +763,9 @@ void CvcPrinter::toStream(std::ostream& out,
       Assert(n.getType().isBitVector());
       out << BitVectorType(n.getType().toType()).getSize();
       out << ',';
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ',';
-      toStream(out, n[1], depth, false);
+      toStreamNode(out, n[1], depth, false, lbind);
       out << ')';
       return;
       break;
@@ -795,16 +777,16 @@ void CvcPrinter::toStream(std::ostream& out,
         out << "BVMULT(";
         out << BitVectorType(n.getType().toType()).getSize();
         out << ',';
-        toStream(out, n[child], depth, false);
+        toStreamNode(out, n[child], depth, false, lbind);
         out << ',';
         ++child;
         }
       out << "BVMULT(";
       out << BitVectorType(n.getType().toType()).getSize();
       out << ',';
-      toStream(out, n[child], depth, false);
+      toStreamNode(out, n[child], depth, false, lbind);
       out << ',';
-      toStream(out, n[child + 1], depth, false);
+      toStreamNode(out, n[child + 1], depth, false, lbind);
       while (child > 0) {
         out << ')';
         --child;
@@ -823,31 +805,31 @@ void CvcPrinter::toStream(std::ostream& out,
       break;
     case kind::BITVECTOR_REPEAT:
       out << "BVREPEAT(";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ", " << n.getOperator().getConst<BitVectorRepeat>() << ')';
       return;
       break;
     case kind::BITVECTOR_ZERO_EXTEND:
       out << "BVZEROEXTEND(";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ", " << n.getOperator().getConst<BitVectorZeroExtend>() << ')';
       return;
       break;
     case kind::BITVECTOR_SIGN_EXTEND:
       out << "SX(";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ", " << BitVectorType(n.getType().toType()).getSize() << ')';
       return;
       break;
     case kind::BITVECTOR_ROTATE_LEFT:
       out << "BVROTL(";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ", " << n.getOperator().getConst<BitVectorRotateLeft>() << ')';
       return;
       break;
     case kind::BITVECTOR_ROTATE_RIGHT:
       out << "BVROTR(";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ", " << n.getOperator().getConst<BitVectorRotateRight>() << ')';
       return;
       break;
@@ -855,7 +837,7 @@ void CvcPrinter::toStream(std::ostream& out,
     // SETS
     case kind::SET_TYPE:
       out << "SET OF ";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       return;
       break;
     case kind::UNION:
@@ -908,7 +890,7 @@ void CvcPrinter::toStream(std::ostream& out,
       break;
     case kind::SINGLETON:
       out << "{";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << "}";
       return;
       break;
@@ -918,13 +900,13 @@ void CvcPrinter::toStream(std::ostream& out,
       }
       out << '{';
       size_t i = 0;
-      toStream(out, n[i++], depth, false);
+      toStreamNode(out, n[i++], depth, false, lbind);
       for(;i+1 < n.getNumChildren(); ++i) {
         out << ", ";
-        toStream(out, n[i], depth, false);
+        toStreamNode(out, n[i], depth, false, lbind);
       }
       out << "} | ";
-      toStream(out, n[i], depth, true);
+      toStreamNode(out, n[i], depth, true, lbind);
       if(bracket) {
         out << ')';
       }
@@ -933,7 +915,7 @@ void CvcPrinter::toStream(std::ostream& out,
     }
     case kind::CARD: {
       out << "CARD(";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, false, lbind);
       out << ")";
       return;
       break;
@@ -942,17 +924,17 @@ void CvcPrinter::toStream(std::ostream& out,
     // Quantifiers
     case kind::FORALL:
       out << "(FORALL";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << " : ";
-      toStream(out, n[1], depth, false);
+      toStreamNodeWithLetify(out, n[1], depth, true, lbind);
       out << ')';
       // TODO: user patterns?
       return;
     case kind::EXISTS:
       out << "(EXISTS";
-      toStream(out, n[0], depth, false);
+      toStreamNode(out, n[0], depth, true, lbind);
       out << " : ";
-      toStream(out, n[1], depth, false);
+      toStreamNodeWithLetify(out, n[1], depth, true, lbind);
       out << ')';
       // TODO: user patterns?
       return;
@@ -965,7 +947,7 @@ void CvcPrinter::toStream(std::ostream& out,
         if(i > 0) {
           out << ", ";
         }
-        toStream(out, n[i], -1, false);
+        toStreamNode(out, n[i], -1, false, lbind);
         out << ":";
         n[i].getType().toStream(out, language::output::LANG_CVC4);
       }
@@ -1021,7 +1003,7 @@ void CvcPrinter::toStream(std::ostream& out,
         out << ", ";
       }
     }
-    toStream(out, n[i], depth, opType == INFIX);
+    toStreamNode(out, n[i], depth, opType == INFIX, lbind);
   }
 
   switch (opType) {
@@ -1040,8 +1022,7 @@ void CvcPrinter::toStream(std::ostream& out,
       out << ')' << op.str();
       break;
   }
-
-}/* CvcPrinter::toStream(TNode) */
+}
 
 template <class T>
 static bool tryToStream(std::ostream& out, const Command* c, bool cvc3Mode);
@@ -1611,6 +1592,48 @@ static bool tryToStream(std::ostream& out,
   return false;
 }
 
+void CvcPrinter::toStreamNodeWithLetify(std::ostream& out,
+                                        Node n,
+                                        int toDepth,
+                                        bool bracket,
+                                        LetBinding* lbind) const
+{
+  if (lbind == nullptr)
+  {
+    toStreamNode(out, n, toDepth, bracket, nullptr);
+    return;
+  }
+  std::vector<Node> letList;
+  lbind->letify(n, letList);
+  if (!letList.empty())
+  {
+    std::map<Node, uint32_t>::const_iterator it;
+    out << "LET ";
+    bool first = true;
+    for (size_t i = 0, nlets = letList.size(); i < nlets; i++)
+    {
+      if (!first)
+      {
+        out << ", ";
+      }
+      else
+      {
+        first = false;
+      }
+      Node nl = letList[i];
+      uint32_t id = lbind->getId(nl);
+      out << "_let_" << id << " = ";
+      Node nlc = lbind->convert(nl, "_let_", false);
+      toStreamNode(out, nlc, toDepth, true, lbind);
+    }
+    out << " IN ";
+  }
+  Node nc = lbind->convert(n, "_let_");
+  // print the body, passing the lbind object
+  toStreamNode(out, nc, toDepth, bracket, lbind);
+  lbind->popScope();
+}
+
 }/* CVC4::printer::cvc namespace */
 }/* CVC4::printer namespace */
 }/* CVC4 namespace */
index b0328bc3cdfa5b2748d0c14f403be669be6c4e54..799bba8d2c6c7fd118b9b7366d1b7973549433a4 100644 (file)
@@ -24,6 +24,9 @@
 #include "printer/printer.h"
 
 namespace CVC4 {
+
+class LetBinding;
+
 namespace printer {
 namespace cvc {
 
@@ -163,7 +166,14 @@ class CvcPrinter : public CVC4::Printer
       std::ostream& out, const std::vector<Command*>& sequence) const override;
 
  private:
-  void toStream(std::ostream& out, TNode n, int toDepth, bool bracket) const;
+  /**
+   * The main method for printing Nodes.
+   */
+  void toStreamNode(std::ostream& out,
+                    TNode n,
+                    int toDepth,
+                    bool bracket,
+                    LetBinding* lbind) const;
   /**
    * To stream model sort. This prints the appropriate output for type
    * tn declared via declare-sort or declare-datatype.
@@ -179,6 +189,15 @@ class CvcPrinter : public CVC4::Printer
   void toStreamModelTerm(std::ostream& out,
                          const smt::Model& m,
                          Node n) const override;
+  /**
+   * To stream with let binding. This prints n, possibly in the scope
+   * of letification generated by this method based on lbind.
+   */
+  void toStreamNodeWithLetify(std::ostream& out,
+                              Node n,
+                              int toDepth,
+                              bool bracket,
+                              LetBinding* lbind) const;
 
   bool d_cvc3Mode;
 }; /* class CvcPrinter */
index 5b82e0a58b0af051ba3494f21d962ed5d005aaa4..4e4a39654048fbda20aec2e66f6e976e9b6a1272 100644 (file)
@@ -631,31 +631,6 @@ class NodeBlack : public CxxTest::TestSuite {
     TS_ASSERT(sstr.str() ==
               "(f(f(f(x))) = x) OR (f(f(f(x))) = y) OR (f(x) = g(x)) OR (x = "
               "y) OR (f(g(x)) = g(y))");
-
-    sstr.str(string());
-    sstr << Node::dag(true) << n;  // always dagify
-    TS_ASSERT(sstr.str() ==
-              "LET _let_0 = f(x), _let_1 = g(x), _let_2 = f(f(_let_0)) IN "
-              "(_let_2 = x) OR (_let_2 = y) OR (_let_0 = _let_1) OR (x = y) OR "
-              "(f(_let_1) = g(y))");
-
-    sstr.str(string());
-    sstr << Node::dag(2) << n;  // dagify subexprs occurring > 2 times
-    TS_ASSERT(sstr.str() ==
-              "LET _let_0 = f(x) IN (f(f(_let_0)) = x) OR (f(f(_let_0)) = y) "
-              "OR (_let_0 = g(x)) OR (x = y) OR (f(g(x)) = g(y))");
-
-    Warning() << Node::setdepth(-1)
-              << Node::setlanguage(language::output::LANG_CVC4) << Node::dag(2)
-              << n << std::endl;
-    sstr.str(string());
-    sstr << Node::dag(3) << n;  // dagify subexprs occurring > 3 times
-    TS_ASSERT(sstr.str() ==
-              "(f(f(f(x))) = x) OR (f(f(f(x))) = y) OR (f(x) = g(x)) OR (x = "
-              "y) OR (f(g(x)) = g(y))");
-    Warning() << Node::setdepth(-1)
-              << Node::setlanguage(language::output::LANG_CVC4) << Node::dag(2)
-              << n << std::endl;
   }
 
   void testForEachOverNodeAsNodes() {