Finer-grained inference of substitutions in incremental mode (#2403)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 5 Sep 2018 17:41:47 +0000 (12:41 -0500)
committerGitHub <noreply@github.com>
Wed, 5 Sep 2018 17:41:47 +0000 (12:41 -0500)
src/expr/node_algorithm.cpp
src/expr/node_algorithm.h
src/smt/smt_engine.cpp

index 5443a3a2ae8f150beb7de48281a348a929a66202..9240e4a8e1d40dc708f6165c5662c5d2ff6bacc2 100644 (file)
@@ -166,5 +166,41 @@ bool hasFreeVar(TNode n)
   return false;
 }
 
+void getSymbols(TNode n, std::unordered_set<Node, NodeHashFunction>& syms)
+{
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+  getSymbols(n, syms);
+}
+
+void getSymbols(TNode n,
+                std::unordered_set<Node, NodeHashFunction>& syms,
+                std::unordered_set<TNode, TNodeHashFunction>& visited)
+{
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    if (visited.find(cur) == visited.end())
+    {
+      visited.insert(cur);
+      if (cur.isVar() && cur.getKind() != kind::BOUND_VARIABLE)
+      {
+        syms.insert(cur);
+      }
+      if (cur.hasOperator())
+      {
+        visit.push_back(cur.getOperator());
+      }
+      for (TNode cn : cur)
+      {
+        visit.push_back(cn);
+      }
+    }
+  } while (!visit.empty());
+}
+
 }  // namespace expr
 }  // namespace CVC4
index 61e81c4c2eea94e59f2626b1b3a3249b7c980c06..7453bc292e32520a536139705170b383665cfcda 100644 (file)
@@ -39,20 +39,33 @@ namespace expr {
 bool hasSubterm(TNode n, TNode t, bool strict = false);
 
 /**
- * Returns true iff the node n contains a bound variable. This bound
- * variable may or may not be free.
+ * Returns true iff the node n contains a bound variable, that is a node of
+ * kind BOUND_VARIABLE. This bound variable may or may not be free.
  * @param n The node under investigation
  * @return true iff this node contains a bound variable
  */
 bool hasBoundVar(TNode n);
 
 /**
- * Returns true iff the node n contains a free variable.
+ * Returns true iff the node n contains a free variable, that is, a node
+ * of kind BOUND_VARIABLE that is not bound in n.
  * @param n The node under investigation
  * @return true iff this node contains a free variable.
  */
 bool hasFreeVar(TNode n);
 
+/**
+ * For term n, this function collects the symbols that occur as a subterms
+ * of n. A symbol is a variable that does not have kind BOUND_VARIABLE.
+ * @param n The node under investigation
+ * @param syms The set which the symbols of n are added to
+ */
+void getSymbols(TNode n, std::unordered_set<Node, NodeHashFunction>& syms);
+/** Same as above, with a visited cache */
+void getSymbols(TNode n,
+                std::unordered_set<Node, NodeHashFunction>& syms,
+                std::unordered_set<TNode, TNodeHashFunction>& visited);
+
 }  // namespace expr
 }  // namespace CVC4
 
index cdd5ab3e09c2dcf8c91eb25bdc5f2052fff49808..17edaad416b1a45626e9d78a5240cb8a9aa577e9 100644 (file)
@@ -45,6 +45,7 @@
 #include "expr/kind.h"
 #include "expr/metakind.h"
 #include "expr/node.h"
+#include "expr/node_algorithm.h"
 #include "expr/node_builder.h"
 #include "expr/node_self_iterator.h"
 #include "options/arith_options.h"
@@ -448,6 +449,7 @@ class SmtEnginePrivate : public NodeManagerListener {
 
   typedef unordered_map<Node, Node, NodeHashFunction> NodeToNodeHashMap;
   typedef unordered_map<Node, bool, NodeHashFunction> NodeToBoolHashMap;
+  typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
 
   /**
    * Manager for limiting time and abstract resource usage.
@@ -503,6 +505,13 @@ class SmtEnginePrivate : public NodeManagerListener {
    */
   SubstitutionMap d_abstractValueMap;
 
+  /**
+   * The (user-context-dependent) set of symbols that occur in at least one
+   * assertion in the current user context. This is used by the
+   * nonClausalSimplify pass.
+   */
+  NodeSet d_symsInAssertions;
+
   /**
    * A mapping of all abstract values (actual value |-> abstract) that
    * we've handed out.  This is necessary to ensure that we give the
@@ -545,6 +554,13 @@ class SmtEnginePrivate : public NodeManagerListener {
    */
   bool nonClausalSimplify();
 
+  /** record symbols in assertions
+   *
+   * This method is called when a set of assertions is finalized. It adds
+   * the symbols to d_symsInAssertions that occur in assertions.
+   */
+  void recordSymbolsInAssertions(const std::vector<Node>& assertions);
+
   /**
    * Helper function to fix up assertion list to restore invariants needed after
    * ite removal.
@@ -579,6 +595,7 @@ class SmtEnginePrivate : public NodeManagerListener {
         d_assertionsProcessed(smt.d_userContext, false),
         d_fakeContext(),
         d_abstractValueMap(&d_fakeContext),
+        d_symsInAssertions(smt.d_userContext),
         d_abstractValues(),
         d_simplifyAssertionsDepth(0),
         // d_needsExpandDefs(true),  //TODO?
@@ -833,7 +850,6 @@ class SmtEnginePrivate : public NodeManagerListener {
     }
   }
   //------------------------------- end expression names
-
 };/* class SmtEnginePrivate */
 
 }/* namespace CVC4::smt */
@@ -3126,35 +3142,42 @@ bool SmtEnginePrivate::nonClausalSimplify() {
                       << assertion << endl;
   }
 
-  // If in incremental mode, add substitutions to the list of assertions
-  if (substs_index > 0)
+  // add substitutions to model, or as assertions if needed (when incremental)
+  TheoryModel* m = d_smt.d_theoryEngine->getModel();
+  Assert(m != nullptr);
+  NodeManager* nm = NodeManager::currentNM();
+  NodeBuilder<> substitutionsBuilder(kind::AND);
+  for (pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos)
   {
-    NodeBuilder<> substitutionsBuilder(kind::AND);
-    substitutionsBuilder << d_assertions[substs_index];
-    pos = newSubstitutions.begin();
-    for (; pos != newSubstitutions.end(); ++pos) {
-      // Add back this substitution as an assertion
-      TNode lhs = (*pos).first, rhs = newSubstitutions.apply((*pos).second);
-      Node n = NodeManager::currentNM()->mkNode(kind::EQUAL, lhs, rhs);
-      substitutionsBuilder << n;
-      Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): will notify SAT layer of substitution: " << n << endl;
-    }
-    if (substitutionsBuilder.getNumChildren() > 1) {
-      d_assertions.replace(substs_index,
-                           Rewriter::rewrite(Node(substitutionsBuilder)));
+    Node lhs = (*pos).first;
+    Node rhs = newSubstitutions.apply((*pos).second);
+    // If using incremental, we must check whether this variable has occurred
+    // before now. If it hasn't we can add this as a substitution.
+    if (substs_index == 0
+        || d_symsInAssertions.find(lhs) == d_symsInAssertions.end())
+    {
+      Trace("simplify")
+          << "SmtEnginePrivate::nonClausalSimplify(): substitute: " << lhs
+          << " " << rhs << endl;
+      m->addSubstitution(lhs, rhs);
     }
-  } else {
-    // If not in incremental mode, must add substitutions to model
-    TheoryModel* m = d_smt.d_theoryEngine->getModel();
-    if(m != NULL) {
-      for(pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos) {
-        Node n = (*pos).first;
-        Node v = newSubstitutions.apply((*pos).second);
-        Trace("model") << "Add substitution : " << n << " " << v << endl;
-        m->addSubstitution( n, v );
-      }
+    else
+    {
+      // if it has, the substitution becomes an assertion
+      Node eq = nm->mkNode(kind::EQUAL, lhs, rhs);
+      Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
+                           "substitute: will notify SAT layer of substitution: "
+                        << eq << endl;
+      substitutionsBuilder << eq;
     }
   }
+  // add to the last assertion if necessary
+  if (substitutionsBuilder.getNumChildren() > 0)
+  {
+    substitutionsBuilder << d_assertions[substs_index];
+    d_assertions.replace(substs_index,
+                         Rewriter::rewrite(Node(substitutionsBuilder)));
+  }
 
   NodeBuilder<> learnedBuilder(kind::AND);
   Assert(d_assertions.getRealAssertionsEnd() <= d_assertions.size());
@@ -3415,6 +3438,20 @@ void SmtEnginePrivate::collectSkolems(TNode n, set<TNode>& skolemSet, unordered_
   cache[n] = true;
 }
 
+void SmtEnginePrivate::recordSymbolsInAssertions(
+    const std::vector<Node>& assertions)
+{
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+  std::unordered_set<Node, NodeHashFunction> syms;
+  for (TNode cn : assertions)
+  {
+    expr::getSymbols(cn, syms, visited);
+  }
+  for (const Node& s : syms)
+  {
+    d_symsInAssertions.insert(s);
+  }
+}
 
 bool SmtEnginePrivate::checkForBadSkolems(TNode n, TNode skolem, unordered_map<Node, bool, NodeHashFunction>& cache)
 {
@@ -3831,6 +3868,12 @@ void SmtEnginePrivate::processAssertions() {
   Trace("smt-proc") << "SmtEnginePrivate::processAssertions() end" << endl;
   dumpAssertions("post-everything", d_assertions);
 
+  // if incremental, compute which variables are assigned
+  if (options::incrementalSolving())
+  {
+    recordSymbolsInAssertions(d_assertions.ref());
+  }
+
   // Push the formula to SAT
   {
     Chat() << "converting to CNF..." << endl;