From 4e4068f1d29ddc1ffe0bde8e6f2cf3094fd6bd40 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 5 Sep 2018 12:41:47 -0500 Subject: [PATCH] Finer-grained inference of substitutions in incremental mode (#2403) --- src/expr/node_algorithm.cpp | 36 ++++++++++++++ src/expr/node_algorithm.h | 19 ++++++-- src/smt/smt_engine.cpp | 95 +++++++++++++++++++++++++++---------- 3 files changed, 121 insertions(+), 29 deletions(-) diff --git a/src/expr/node_algorithm.cpp b/src/expr/node_algorithm.cpp index 5443a3a2a..9240e4a8e 100644 --- a/src/expr/node_algorithm.cpp +++ b/src/expr/node_algorithm.cpp @@ -166,5 +166,41 @@ bool hasFreeVar(TNode n) return false; } +void getSymbols(TNode n, std::unordered_set& syms) +{ + std::unordered_set visited; + getSymbols(n, syms); +} + +void getSymbols(TNode n, + std::unordered_set& syms, + std::unordered_set& visited) +{ + std::vector visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + if (visited.find(cur) == visited.end()) + { + visited.insert(cur); + if (cur.isVar() && cur.getKind() != kind::BOUND_VARIABLE) + { + syms.insert(cur); + } + if (cur.hasOperator()) + { + visit.push_back(cur.getOperator()); + } + for (TNode cn : cur) + { + visit.push_back(cn); + } + } + } while (!visit.empty()); +} + } // namespace expr } // namespace CVC4 diff --git a/src/expr/node_algorithm.h b/src/expr/node_algorithm.h index 61e81c4c2..7453bc292 100644 --- a/src/expr/node_algorithm.h +++ b/src/expr/node_algorithm.h @@ -39,20 +39,33 @@ namespace expr { bool hasSubterm(TNode n, TNode t, bool strict = false); /** - * Returns true iff the node n contains a bound variable. This bound - * variable may or may not be free. + * Returns true iff the node n contains a bound variable, that is a node of + * kind BOUND_VARIABLE. This bound variable may or may not be free. * @param n The node under investigation * @return true iff this node contains a bound variable */ bool hasBoundVar(TNode n); /** - * Returns true iff the node n contains a free variable. + * Returns true iff the node n contains a free variable, that is, a node + * of kind BOUND_VARIABLE that is not bound in n. * @param n The node under investigation * @return true iff this node contains a free variable. */ bool hasFreeVar(TNode n); +/** + * For term n, this function collects the symbols that occur as a subterms + * of n. A symbol is a variable that does not have kind BOUND_VARIABLE. + * @param n The node under investigation + * @param syms The set which the symbols of n are added to + */ +void getSymbols(TNode n, std::unordered_set& syms); +/** Same as above, with a visited cache */ +void getSymbols(TNode n, + std::unordered_set& syms, + std::unordered_set& visited); + } // namespace expr } // namespace CVC4 diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index cdd5ab3e0..17edaad41 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -45,6 +45,7 @@ #include "expr/kind.h" #include "expr/metakind.h" #include "expr/node.h" +#include "expr/node_algorithm.h" #include "expr/node_builder.h" #include "expr/node_self_iterator.h" #include "options/arith_options.h" @@ -448,6 +449,7 @@ class SmtEnginePrivate : public NodeManagerListener { typedef unordered_map NodeToNodeHashMap; typedef unordered_map NodeToBoolHashMap; + typedef context::CDHashSet NodeSet; /** * Manager for limiting time and abstract resource usage. @@ -503,6 +505,13 @@ class SmtEnginePrivate : public NodeManagerListener { */ SubstitutionMap d_abstractValueMap; + /** + * The (user-context-dependent) set of symbols that occur in at least one + * assertion in the current user context. This is used by the + * nonClausalSimplify pass. + */ + NodeSet d_symsInAssertions; + /** * A mapping of all abstract values (actual value |-> abstract) that * we've handed out. This is necessary to ensure that we give the @@ -545,6 +554,13 @@ class SmtEnginePrivate : public NodeManagerListener { */ bool nonClausalSimplify(); + /** record symbols in assertions + * + * This method is called when a set of assertions is finalized. It adds + * the symbols to d_symsInAssertions that occur in assertions. + */ + void recordSymbolsInAssertions(const std::vector& assertions); + /** * Helper function to fix up assertion list to restore invariants needed after * ite removal. @@ -579,6 +595,7 @@ class SmtEnginePrivate : public NodeManagerListener { d_assertionsProcessed(smt.d_userContext, false), d_fakeContext(), d_abstractValueMap(&d_fakeContext), + d_symsInAssertions(smt.d_userContext), d_abstractValues(), d_simplifyAssertionsDepth(0), // d_needsExpandDefs(true), //TODO? @@ -833,7 +850,6 @@ class SmtEnginePrivate : public NodeManagerListener { } } //------------------------------- end expression names - };/* class SmtEnginePrivate */ }/* namespace CVC4::smt */ @@ -3126,35 +3142,42 @@ bool SmtEnginePrivate::nonClausalSimplify() { << assertion << endl; } - // If in incremental mode, add substitutions to the list of assertions - if (substs_index > 0) + // add substitutions to model, or as assertions if needed (when incremental) + TheoryModel* m = d_smt.d_theoryEngine->getModel(); + Assert(m != nullptr); + NodeManager* nm = NodeManager::currentNM(); + NodeBuilder<> substitutionsBuilder(kind::AND); + for (pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos) { - NodeBuilder<> substitutionsBuilder(kind::AND); - substitutionsBuilder << d_assertions[substs_index]; - pos = newSubstitutions.begin(); - for (; pos != newSubstitutions.end(); ++pos) { - // Add back this substitution as an assertion - TNode lhs = (*pos).first, rhs = newSubstitutions.apply((*pos).second); - Node n = NodeManager::currentNM()->mkNode(kind::EQUAL, lhs, rhs); - substitutionsBuilder << n; - Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): will notify SAT layer of substitution: " << n << endl; - } - if (substitutionsBuilder.getNumChildren() > 1) { - d_assertions.replace(substs_index, - Rewriter::rewrite(Node(substitutionsBuilder))); + Node lhs = (*pos).first; + Node rhs = newSubstitutions.apply((*pos).second); + // If using incremental, we must check whether this variable has occurred + // before now. If it hasn't we can add this as a substitution. + if (substs_index == 0 + || d_symsInAssertions.find(lhs) == d_symsInAssertions.end()) + { + Trace("simplify") + << "SmtEnginePrivate::nonClausalSimplify(): substitute: " << lhs + << " " << rhs << endl; + m->addSubstitution(lhs, rhs); } - } else { - // If not in incremental mode, must add substitutions to model - TheoryModel* m = d_smt.d_theoryEngine->getModel(); - if(m != NULL) { - for(pos = newSubstitutions.begin(); pos != newSubstitutions.end(); ++pos) { - Node n = (*pos).first; - Node v = newSubstitutions.apply((*pos).second); - Trace("model") << "Add substitution : " << n << " " << v << endl; - m->addSubstitution( n, v ); - } + else + { + // if it has, the substitution becomes an assertion + Node eq = nm->mkNode(kind::EQUAL, lhs, rhs); + Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " + "substitute: will notify SAT layer of substitution: " + << eq << endl; + substitutionsBuilder << eq; } } + // add to the last assertion if necessary + if (substitutionsBuilder.getNumChildren() > 0) + { + substitutionsBuilder << d_assertions[substs_index]; + d_assertions.replace(substs_index, + Rewriter::rewrite(Node(substitutionsBuilder))); + } NodeBuilder<> learnedBuilder(kind::AND); Assert(d_assertions.getRealAssertionsEnd() <= d_assertions.size()); @@ -3415,6 +3438,20 @@ void SmtEnginePrivate::collectSkolems(TNode n, set& skolemSet, unordered_ cache[n] = true; } +void SmtEnginePrivate::recordSymbolsInAssertions( + const std::vector& assertions) +{ + std::unordered_set visited; + std::unordered_set syms; + for (TNode cn : assertions) + { + expr::getSymbols(cn, syms, visited); + } + for (const Node& s : syms) + { + d_symsInAssertions.insert(s); + } +} bool SmtEnginePrivate::checkForBadSkolems(TNode n, TNode skolem, unordered_map& cache) { @@ -3831,6 +3868,12 @@ void SmtEnginePrivate::processAssertions() { Trace("smt-proc") << "SmtEnginePrivate::processAssertions() end" << endl; dumpAssertions("post-everything", d_assertions); + // if incremental, compute which variables are assigned + if (options::incrementalSolving()) + { + recordSymbolsInAssertions(d_assertions.ref()); + } + // Push the formula to SAT { Chat() << "converting to CNF..." << endl; -- 2.30.2