From a81ea9670349abe03da96cf45fbaa115c4098325 Mon Sep 17 00:00:00 2001 From: Clark Barrett Date: Tue, 12 Jun 2012 03:01:41 +0000 Subject: [PATCH] Adding constant propagation code - needs more testing - off by default --- src/smt/smt_engine.cpp | 233 ++++++++++++++++++++++++++--------- src/smt/smt_engine.h | 2 + src/theory/substitutions.cpp | 127 +++++++++++++++++-- src/theory/substitutions.h | 21 ++++ 4 files changed, 318 insertions(+), 65 deletions(-) diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index e0d507475..5874692b5 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -140,6 +140,8 @@ class SmtEnginePrivate { * then nothing has been pushed out yet. */ context::CDO d_lastSubstitutionPos; + static const bool d_doConstantProp = false; + /** * Runs the nonclausal solver and tries to solve all the assigned * theory literals. @@ -256,6 +258,7 @@ SmtEngine::SmtEngine(ExprManager* em) throw(AssertionException) : d_private(new smt::SmtEnginePrivate(*this)), d_definitionExpansionTime("smt::SmtEngine::definitionExpansionTime"), d_nonclausalSimplificationTime("smt::SmtEngine::nonclausalSimplificationTime"), + d_numConstantProps("smt::SmtEngine::numConstantProps", 0), d_staticLearningTime("smt::SmtEngine::staticLearningTime"), d_simpITETime("smt::SmtEngine::simpITETime"), d_unconstrainedSimpTime("smt::SmtEngine::unconstrainedSimpTime"), @@ -270,6 +273,7 @@ SmtEngine::SmtEngine(ExprManager* em) throw(AssertionException) : StatisticsRegistry::registerStat(&d_definitionExpansionTime); StatisticsRegistry::registerStat(&d_nonclausalSimplificationTime); + StatisticsRegistry::registerStat(&d_numConstantProps); StatisticsRegistry::registerStat(&d_staticLearningTime); StatisticsRegistry::registerStat(&d_simpITETime); StatisticsRegistry::registerStat(&d_unconstrainedSimpTime); @@ -389,6 +393,7 @@ SmtEngine::~SmtEngine() throw() { StatisticsRegistry::unregisterStat(&d_definitionExpansionTime); StatisticsRegistry::unregisterStat(&d_nonclausalSimplificationTime); + StatisticsRegistry::unregisterStat(&d_numConstantProps); StatisticsRegistry::unregisterStat(&d_staticLearningTime); StatisticsRegistry::unregisterStat(&d_simpITETime); StatisticsRegistry::unregisterStat(&d_unconstrainedSimpTime); @@ -847,35 +852,61 @@ void SmtEnginePrivate::nonClausalSimplify() { d_assertionsToPreprocess.clear(); d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst(false)); return; - } else { - // No, conflict, go through the literals and solve them - unsigned j = 0; - for(unsigned i = 0, i_end = d_nonClausalLearnedLiterals.size(); i < i_end; ++ i) { - // Simplify the literal we learned wrt previous substitutions - Node learnedLiteral = - theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(d_nonClausalLearnedLiterals[i])); - // It might just simplify to a constant - if (learnedLiteral.isConst()) { - if (learnedLiteral.getConst()) { - // If the learned literal simplifies to true, it's redundant - continue; - } else { - // If the learned literal simplifies to false, we're in conflict - Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " - << "conflict with " - << d_nonClausalLearnedLiterals[i] << endl; - d_assertionsToPreprocess.clear(); - d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst(false)); - return; - } + } + + // No, conflict, go through the literals and solve them + theory::SubstitutionMap constantPropagations(d_smt.d_context); + unsigned j = 0; + for(unsigned i = 0, i_end = d_nonClausalLearnedLiterals.size(); i < i_end; ++ i) { + // Simplify the literal we learned wrt previous substitutions + Node learnedLiteral = d_nonClausalLearnedLiterals[i]; + Node learnedLiteralNew = d_topLevelSubstitutions.apply(learnedLiteral); + if (learnedLiteral != learnedLiteralNew) { + learnedLiteral = theory::Rewriter::rewrite(learnedLiteralNew); + } + for (;;) { + learnedLiteralNew = constantPropagations.apply(learnedLiteral); + if (learnedLiteralNew == learnedLiteral) { + break; + } + ++d_smt.d_numConstantProps; + learnedLiteral = theory::Rewriter::rewrite(learnedLiteralNew); + } + // It might just simplify to a constant + if (learnedLiteral.isConst()) { + if (learnedLiteral.getConst()) { + // If the learned literal simplifies to true, it's redundant + continue; + } else { + // If the learned literal simplifies to false, we're in conflict + Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " + << "conflict with " + << d_nonClausalLearnedLiterals[i] << endl; + d_assertionsToPreprocess.clear(); + d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst(false)); + return; } - // Solve it with the corresponding theory - Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " - << "solving " << learnedLiteral << endl; - Theory::PPAssertStatus solveStatus = - d_smt.d_theoryEngine->solve(learnedLiteral, d_topLevelSubstitutions); + } + // Solve it with the corresponding theory + Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " + << "solving " << learnedLiteral << endl; + Theory::PPAssertStatus solveStatus = + d_smt.d_theoryEngine->solve(learnedLiteral, d_topLevelSubstitutions); - switch (solveStatus) { + switch (solveStatus) { + case Theory::PP_ASSERT_STATUS_SOLVED: { + // The literal should rewrite to true + Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " + << "solved " << learnedLiteral << endl; + Assert(theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(learnedLiteral)).isConst()); + vector > equations; + constantPropagations.simplifyLHS(d_topLevelSubstitutions, equations, true); + if (equations.empty()) { + break; + } + Assert(equations[0].first.isConst() && equations[0].second.isConst() && equations[0].first != equations[0].second); + // else fall through + } case Theory::PP_ASSERT_STATUS_CONFLICT: // If in conflict, we return false Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " @@ -884,46 +915,109 @@ void SmtEnginePrivate::nonClausalSimplify() { d_assertionsToPreprocess.clear(); d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst(false)); return; - case Theory::PP_ASSERT_STATUS_SOLVED: - // The literal should rewrite to true - Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " - << "solved " << learnedLiteral << endl; - Assert(theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(learnedLiteral)).isConst()); - break; default: - // Keep the literal - d_nonClausalLearnedLiterals[j++] = d_nonClausalLearnedLiterals[i]; + if (d_doConstantProp && learnedLiteral.getKind() == kind::EQUAL && (learnedLiteral[0].isConst() || learnedLiteral[1].isConst())) { + // constant propagation + TNode t; + TNode c; + if (learnedLiteral[0].isConst()) { + t = learnedLiteral[1]; + c = learnedLiteral[0]; + } + else { + t = learnedLiteral[0]; + c = learnedLiteral[1]; + } + Assert(!t.isConst()); + Assert(constantPropagations.apply(t) == t); + Assert(d_topLevelSubstitutions.apply(t) == t); + vector > equations; + constantPropagations.simplifyLHS(t, c, equations, true); + if (!equations.empty()) { + Assert(equations[0].first.isConst() && equations[0].second.isConst() && equations[0].first != equations[0].second); + d_assertionsToPreprocess.clear(); + d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst(false)); + return; + } + d_topLevelSubstitutions.simplifyRHS(constantPropagations); + } + else { + // Keep the literal + d_nonClausalLearnedLiterals[j++] = d_nonClausalLearnedLiterals[i]; + } break; + } + + if( Options::current()->incrementalSolving || + Options::current()->simplificationMode == Options::SIMPLIFICATION_MODE_INCREMENTAL ) { + // Tell PropEngine about new substitutions + SubstitutionMap::iterator pos = d_lastSubstitutionPos; + if(pos == d_topLevelSubstitutions.end()) { + pos = d_topLevelSubstitutions.begin(); + } else { + ++pos; } - if( Options::current()->incrementalSolving || - Options::current()->simplificationMode == Options::SIMPLIFICATION_MODE_INCREMENTAL ) { - // Tell PropEngine about new substitutions - SubstitutionMap::iterator pos = d_lastSubstitutionPos; - if(pos == d_topLevelSubstitutions.end()) { - pos = d_topLevelSubstitutions.begin(); - } else { - ++pos; - } + while(pos != d_topLevelSubstitutions.end()) { + // Push out this substitution + TNode lhs = (*pos).first, rhs = (*pos).second; + Node n = NodeManager::currentNM()->mkNode(lhs.getType().isBoolean() ? kind::IFF : kind::EQUAL, lhs, rhs); + d_assertionsToCheck.push_back(n); + Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): will notify SAT layer of substitution: " << n << endl; + d_lastSubstitutionPos = pos; + ++pos; + } + } - while(pos != d_topLevelSubstitutions.end()) { - // Push out this substitution - TNode lhs = (*pos).first, rhs = (*pos).second; - Node n = NodeManager::currentNM()->mkNode(lhs.getType().isBoolean() ? kind::IFF : kind::EQUAL, lhs, rhs); - d_assertionsToCheck.push_back(n); - Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): will notify SAT layer of substitution: " << n << endl; - d_lastSubstitutionPos = pos; - ++pos; - } +#ifdef CVC4_ASSERTIONS + // Check data structure invariants: + // 1. for each lhs of d_topLevelSubstitutions, does not appear anywhere in rhs of d_topLevelSubstitutions or anywhere in constantPropagations + // 2. each lhs of constantPropagations rewrites to itself + // 3. if l -> r is a constant propagation and l is a subterm of l' with l' -> r' another constant propagation, then l'[l/r] -> r' should be a + // constant propagation too + // 4. each lhs of constantPropagations is different from each rhs + SubstitutionMap::iterator pos = d_topLevelSubstitutions.begin(); + for (; pos != d_topLevelSubstitutions.end(); ++pos) { + Assert((*pos).first.isVar()); + Assert(d_topLevelSubstitutions.apply((*pos).second) == (*pos).second); + } + for (pos = constantPropagations.begin(); pos != constantPropagations.end(); ++pos) { + Assert((*pos).second.isConst()); + Assert(Rewriter::rewrite((*pos).first) == (*pos).first); + Node newLeft = d_topLevelSubstitutions.apply((*pos).first); + if (newLeft != (*pos).first) { + newLeft = Rewriter::rewrite(newLeft); + Assert(newLeft == (*pos).second || + (constantPropagations.hasSubstitution(newLeft) && constantPropagations.apply(newLeft) == (*pos).second)); } + newLeft = constantPropagations.apply((*pos).first); + if (newLeft != (*pos).first) { + newLeft = Rewriter::rewrite(newLeft); + Assert(newLeft == (*pos).second || + (constantPropagations.hasSubstitution(newLeft) && constantPropagations.apply(newLeft) == (*pos).second)); + } + Assert(constantPropagations.apply((*pos).second) == (*pos).second); } - // Resize the learnt - d_nonClausalLearnedLiterals.resize(j); +#endif } + // Resize the learnt + d_nonClausalLearnedLiterals.resize(j); hash_set s; for (unsigned i = 0; i < d_assertionsToPreprocess.size(); ++ i) { - Node assertion = theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(d_assertionsToPreprocess[i])); + Node assertion = d_assertionsToPreprocess[i]; + Node assertionNew = d_topLevelSubstitutions.apply(assertion); + if (assertion != assertionNew) { + assertion = theory::Rewriter::rewrite(assertionNew); + } + for (;;) { + assertionNew = constantPropagations.apply(assertion); + if (assertionNew == assertion) { + break; + } + ++d_smt.d_numConstantProps; + assertion = theory::Rewriter::rewrite(assertionNew); + } s.insert(assertion); d_assertionsToCheck.push_back(assertion); Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " @@ -933,7 +1027,19 @@ void SmtEnginePrivate::nonClausalSimplify() { d_assertionsToPreprocess.clear(); for (unsigned i = 0; i < d_nonClausalLearnedLiterals.size(); ++ i) { - Node learned = theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(d_nonClausalLearnedLiterals[i])); + Node learned = d_nonClausalLearnedLiterals[i]; + Node learnedNew = d_topLevelSubstitutions.apply(learned); + if (learned != learnedNew) { + learned = theory::Rewriter::rewrite(learnedNew); + } + for (;;) { + learnedNew = constantPropagations.apply(learned); + if (learnedNew == learned) { + break; + } + ++d_smt.d_numConstantProps; + learned = theory::Rewriter::rewrite(learnedNew); + } if (s.find(learned) != s.end()) { continue; } @@ -945,8 +1051,21 @@ void SmtEnginePrivate::nonClausalSimplify() { } d_nonClausalLearnedLiterals.clear(); + SubstitutionMap::iterator pos = constantPropagations.begin(); + for (; pos != constantPropagations.end(); ++pos) { + Node cProp = (*pos).first.eqNode((*pos).second); + if (s.find(cProp) != s.end()) { + continue; + } + s.insert(cProp); + d_assertionsToCheck.push_back(cProp); + Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): " + << "non-clausal constant propagation : " + << d_assertionsToCheck.back() << endl; + } } + void SmtEnginePrivate::simpITE() { TimerStat::CodeTimer simpITETimer(d_smt.d_simpITETime); diff --git a/src/smt/smt_engine.h b/src/smt/smt_engine.h index 4c0fed74c..e0f544f9e 100644 --- a/src/smt/smt_engine.h +++ b/src/smt/smt_engine.h @@ -239,6 +239,8 @@ class CVC4_PUBLIC SmtEngine { TimerStat d_definitionExpansionTime; /** time spent in non-clausal simplification */ TimerStat d_nonclausalSimplificationTime; + /** Num of constant propagations found during nonclausal simp */ + IntStat d_numConstantProps; /** time spent in static learning */ TimerStat d_staticLearningTime; /** time spent in simplifying ITEs */ diff --git a/src/theory/substitutions.cpp b/src/theory/substitutions.cpp index df4c919d8..3f8a6d630 100644 --- a/src/theory/substitutions.cpp +++ b/src/theory/substitutions.cpp @@ -17,6 +17,7 @@ **/ #include "theory/substitutions.h" +#include "theory/rewriter.h" using namespace std; @@ -103,21 +104,127 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& substitutionCache) return substitutionCache[t]; } -void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache, bool backSub, bool forwardSub) { - Debug("substitution") << "SubstitutionMap::addSubstitution(" << x << ", " << t << ")" << std::endl; - Assert(d_substitutions.find(x) == d_substitutions.end()); - if (backSub) { - // Temporary substitution cache +void SubstitutionMap::simplifyRHS(const SubstitutionMap& subMap) +{ + // Put the new substitutions into the old ones + NodeMap::iterator it = d_substitutions.begin(); + NodeMap::iterator it_end = d_substitutions.end(); + for(; it != it_end; ++ it) { + d_substitutions[(*it).first] = subMap.apply((*it).second); + } +} + + +void SubstitutionMap::simplifyRHS(TNode x, TNode t) { + // Temporary substitution cache + NodeCache tempCache; + tempCache[x] = t; + + // Put the new substitution into the old ones + NodeMap::iterator it = d_substitutions.begin(); + NodeMap::iterator it_end = d_substitutions.end(); + for(; it != it_end; ++ it) { + d_substitutions[(*it).first] = internalSubstitute((*it).second, tempCache); + } +} + + +/* We use subMap to simplify the left-hand sides of the current substitution map. If rewrite is true, + * we also apply the rewriter to the result. + * We want to maintain the invariant that all lhs are distinct from each other and from all rhs. + * If for some l -> r, l reduces to l', we try to add a new rule l' -> r. There are two cases + * where this fails + * (i) if l' is equal to some ll (in a rule ll -> rr), then if r != rr we add (r,rr) to the equation list + * (i) if l' is equalto some rr (in a rule ll -> rr), then if r != rr we add (r,rr) to the equation list + */ +void SubstitutionMap::simplifyLHS(const SubstitutionMap& subMap, vector >& equalities, bool rewrite) +{ + Assert(d_worklist.empty()); + // First, apply subMap to every LHS in d_substitutions + NodeMap::iterator it = d_substitutions.begin(); + NodeMap::iterator it_end = d_substitutions.end(); + Node newLeft; + for(; it != it_end; ++ it) { + newLeft = subMap.apply((*it).first); + if (newLeft != (*it).first) { + if (rewrite) { + newLeft = Rewriter::rewrite(newLeft); + } + d_worklist.push_back(pair(newLeft, (*it).second)); + } + } + processWorklist(equalities, rewrite); + Assert(d_worklist.empty()); +} + + +void SubstitutionMap::simplifyLHS(TNode lhs, TNode rhs, vector >& equalities, bool rewrite) +{ + Assert(d_worklist.empty()); + d_worklist.push_back(pair(lhs,rhs)); + processWorklist(equalities, rewrite); + Assert(d_worklist.empty()); +} + + +void SubstitutionMap::processWorklist(vector >& equalities, bool rewrite) +{ + // Add each new rewrite rule, taking care not to invalidate invariants and looking + // for any new rewrite rules we can learn + Node newLeft, newRight; + while (!d_worklist.empty()) { + newLeft = d_worklist.back().first; + newRight = d_worklist.back().second; + d_worklist.pop_back(); + NodeCache tempCache; - tempCache[x] = t; + tempCache[newLeft] = newRight; - // Put in the new substitutions into the old ones + Node newLeft2; + unsigned size = d_worklist.size(); + bool addThisRewrite = true; NodeMap::iterator it = d_substitutions.begin(); NodeMap::iterator it_end = d_substitutions.end(); + for(; it != it_end; ++ it) { - d_substitutions[(*it).first] = internalSubstitute((*it).second, tempCache); + + // Check for invariant violation. If new rewrite is redundant, do nothing + // Otherwise, add an equality to the output equalities + // In either case undo any work done by this rewrite + if (newLeft == (*it).first || newLeft == (*it).second) { + if ((*it).second != newRight) { + equalities.push_back(pair((*it).second, newRight)); + } + while (d_worklist.size() > size) { + d_worklist.pop_back(); + } + addThisRewrite = false; + break; + } + + newLeft2 = internalSubstitute((*it).first, tempCache); + if (newLeft2 != (*it).first) { + if (rewrite) { + newLeft2 = Rewriter::rewrite(newLeft2); + } + d_worklist.push_back(pair(newLeft2, (*it).second)); + } } + if (addThisRewrite) { + d_substitutions[newLeft] = newRight; + d_cacheInvalidated = true; + } + } +} + + +void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache, bool backSub, bool forwardSub) { + Debug("substitution") << "SubstitutionMap::addSubstitution(" << x << ", " << t << ")" << std::endl; + Assert(d_substitutions.find(x) == d_substitutions.end()); + + if (backSub) { + simplifyRHS(x, t); } // Put the new substitution in @@ -181,6 +288,10 @@ void SubstitutionMap::print(ostream& out) const { } } +void SubstitutionMap::debugPrint() const { + print(std::cout); +} + }/* CVC4::theory namespace */ std::ostream& operator<<(std::ostream& out, const theory::SubstitutionMap::iterator& i) { diff --git a/src/theory/substitutions.h b/src/theory/substitutions.h index ee2a15f6f..32ed35074 100644 --- a/src/theory/substitutions.h +++ b/src/theory/substitutions.h @@ -91,6 +91,10 @@ private: */ CacheInvalidator d_cacheInvalidator; + // Helper list and method for simplifyLHS methods + std::vector > d_worklist; + void processWorklist(std::vector >& equalities, bool rewrite); + public: SubstitutionMap(context::Context* context) : @@ -158,10 +162,27 @@ public: // should best interact with cache invalidation on context // pops. + // Simplify right-hand sides of current map using the given substitutions + void simplifyRHS(const SubstitutionMap& subMap); + + // Simplify right-hand sides of current map with lhs -> rhs + void simplifyRHS(TNode lhs, TNode rhs); + + // Simplify left-hand sides of current map using the given substitutions + void simplifyLHS(const SubstitutionMap& subMap, + std::vector >& equalities, + bool rewrite = true); + + // Simplify left-hand sides of current map with lhs -> rhs and then add lhs -> rhs to the substitutions set + void simplifyLHS(TNode lhs, TNode rhs, + std::vector >& equalities, + bool rewrite = true); + /** * Print to the output stream */ void print(std::ostream& out) const; + void debugPrint() const; };/* class SubstitutionMap */ -- 2.30.2