From: Dejan Jovanović Date: Wed, 9 May 2012 21:25:17 +0000 (+0000) Subject: * simplifying equality engine interface X-Git-Tag: cvc5-1.0.0~8194 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=1ce0650dcf8ce30424b546deb540974cc510c215;p=cvc5.git * simplifying equality engine interface * notifications are now through the interface subclass instead of a template * notifications include constants being merged * changed contextNotifyObj::notify to contextNotifyObj::contextNotifyPop so it's more descriptive and doesn't clutter methods when subclassed * sat solver now has explicit methods to make true and false constants * 0-level literals are removed from explanations of propagations --- diff --git a/src/context/context.cpp b/src/context/context.cpp index abb1575d4..da60a5bc4 100644 --- a/src/context/context.cpp +++ b/src/context/context.cpp @@ -80,7 +80,7 @@ void Context::pop() { while(pCNO != NULL) { // pre-store the "next" pointer in case pCNO deletes itself on notify() ContextNotifyObj* next = pCNO->d_pCNOnext; - pCNO->notify(); + pCNO->contextNotifyPop(); pCNO = next; } @@ -101,7 +101,7 @@ void Context::pop() { while(pCNO != NULL) { // pre-store the "next" pointer in case pCNO deletes itself on notify() ContextNotifyObj* next = pCNO->d_pCNOnext; - pCNO->notify(); + pCNO->contextNotifyPop(); pCNO = next; } diff --git a/src/context/context.h b/src/context/context.h index f0dbff72b..165c35c58 100644 --- a/src/context/context.h +++ b/src/context/context.h @@ -658,6 +658,7 @@ public: * the ContextObj objects have been restored). */ class ContextNotifyObj { + /** * Context is our friend so that when the Context is deleted, any * remaining ContextNotifyObj can be removed from the Context list. @@ -686,6 +687,15 @@ class ContextNotifyObj { */ ContextNotifyObj**& prev() throw() { return d_ppCNOprev; } +protected: + + /** + * This is the method called to notify the object of a pop. It must be + * implemented by the subclass. It is protected since context is out + * friend. + */ + virtual void contextNotifyPop() = 0; + public: /** @@ -703,12 +713,6 @@ public: */ virtual ~ContextNotifyObj() throw(AssertionException); - /** - * This is the method called to notify the object of a pop. It must be - * implemented by the subclass. - */ - virtual void notify() = 0; - };/* class ContextNotifyObj */ inline void ContextObj::makeCurrent() throw(AssertionException) { diff --git a/src/context/stacking_map.h b/src/context/stacking_map.h index 2dec1845c..ba644596e 100644 --- a/src/context/stacking_map.h +++ b/src/context/stacking_map.h @@ -96,6 +96,14 @@ class StackingMap : context::ContextNotifyObj { /** Our current offset in the d_trace stack (context-dependent). */ context::CDO d_offset; +protected: + + /** + * Called by the Context when a pop occurs. Cancels everything to the + * current context level. Overrides ContextNotifyObj::contextNotifyPop(). + */ + void contextNotifyPop(); + public: typedef typename MapType::const_iterator const_iterator; @@ -128,12 +136,6 @@ public: */ void set(ArgType n, const ValueType& newValue); - /** - * Called by the Context when a pop occurs. Cancels everything to the - * current context level. Overrides ContextNotifyObj::notify(). - */ - void notify(); - };/* class StackingMap<> */ template @@ -146,7 +148,7 @@ void StackingMap::set(ArgType n, const ValueType& n } template -void StackingMap::notify() { +void StackingMap::contextNotifyPop() { Trace("sm") << "SM cancelling : " << d_offset << " < " << d_trace.size() << " ?" << std::endl; while(d_offset < d_trace.size()) { std::pair p = d_trace.back(); diff --git a/src/context/stacking_vector.h b/src/context/stacking_vector.h index 9987731d4..ed311b952 100644 --- a/src/context/stacking_vector.h +++ b/src/context/stacking_vector.h @@ -82,7 +82,7 @@ public: * Called by the Context when a pop occurs. Cancels everything to the * current context level. Overrides ContextNotifyObj::notify(). */ - void notify(); + void contextNotifyPop(); };/* class StackingVector<> */ @@ -99,7 +99,7 @@ void StackingVector::set(size_t n, const T& newValue) { } template -void StackingVector::notify() { +void StackingVector::contextNotifyPop() { Trace("sv") << "SV cancelling : " << d_offset << " < " << d_trace.size() << " ?" << std::endl; while(d_offset < d_trace.size()) { std::pair p = d_trace.back(); diff --git a/src/prop/bvminisat/bvminisat.cpp b/src/prop/bvminisat/bvminisat.cpp index 124fc35f1..4868db6f5 100644 --- a/src/prop/bvminisat/bvminisat.cpp +++ b/src/prop/bvminisat/bvminisat.cpp @@ -73,7 +73,7 @@ SatValue BVMinisatSatSolver::assertAssumption(SatLiteral lit, bool propagate) { return toSatLiteralValue(d_minisat->assertAssumption(toMinisatLit(lit), propagate)); } -void BVMinisatSatSolver::notify() { +void BVMinisatSatSolver::contextNotifyPop() { while (d_assertionsCount > d_assertionsRealCount) { popAssumption(); d_assertionsCount --; diff --git a/src/prop/bvminisat/bvminisat.h b/src/prop/bvminisat/bvminisat.h index cd2a2c6b9..60cdd1c28 100644 --- a/src/prop/bvminisat/bvminisat.h +++ b/src/prop/bvminisat/bvminisat.h @@ -54,6 +54,10 @@ private: context::CDO d_assertionsRealCount; context::CDO d_lastPropagation; +protected: + + void contextNotifyPop(); + public: BVMinisatSatSolver() : @@ -70,10 +74,12 @@ public: SatVariable newVar(bool theoryAtom = false); + SatVariable trueVar() { return d_minisat->trueVar(); } + SatVariable falseVar() { return d_minisat->falseVar(); } + void markUnremovable(SatLiteral lit); void interrupt(); - void notify(); SatValue solve(); SatValue solve(long unsigned int&); diff --git a/src/prop/bvminisat/core/Solver.cc b/src/prop/bvminisat/core/Solver.cc index e24fcac1a..c96b6e4b2 100644 --- a/src/prop/bvminisat/core/Solver.cc +++ b/src/prop/bvminisat/core/Solver.cc @@ -119,7 +119,15 @@ Solver::Solver(CVC4::context::Context* c) : , propagation_budget (-1) , asynch_interrupt (false) , clause_added(false) -{} +{ + // Create the constant variables + varTrue = newVar(true, false); + varFalse = newVar(false, false); + + // Assert the constants + uncheckedEnqueue(mkLit(varTrue, false)); + uncheckedEnqueue(mkLit(varFalse, true)); +} Solver::~Solver() diff --git a/src/prop/bvminisat/core/Solver.h b/src/prop/bvminisat/core/Solver.h index c323bfe2b..ae5efd81e 100644 --- a/src/prop/bvminisat/core/Solver.h +++ b/src/prop/bvminisat/core/Solver.h @@ -64,6 +64,12 @@ class Solver { /** Cvc4 context */ CVC4::context::Context* c; + /** True constant */ + Var varTrue; + + /** False constant */ + Var varFalse; + public: // Constructor/Destructor: @@ -76,6 +82,9 @@ public: // Problem specification: // Var newVar (bool polarity = true, bool dvar = true); // Add a new variable with parameters specifying variable mode. + Var trueVar() const { return varTrue; } + Var falseVar() const { return varFalse; } + bool addClause (const vec& ps); // Add a clause to the solver. bool addEmptyClause(); // Add the empty clause, making the solver contradictory. diff --git a/src/prop/bvminisat/simp/SimpSolver.cc b/src/prop/bvminisat/simp/SimpSolver.cc index c8ce13410..59820e9e3 100644 --- a/src/prop/bvminisat/simp/SimpSolver.cc +++ b/src/prop/bvminisat/simp/SimpSolver.cc @@ -63,11 +63,25 @@ SimpSolver::SimpSolver(CVC4::context::Context* c) : , bwdsub_assigns (0) , n_touched (0) { - CVC4::StatisticsRegistry::registerStat(&total_eliminate_time); + CVC4::StatisticsRegistry::registerStat(&total_eliminate_time); vec dummy(1,lit_Undef); ca.extra_clause_field = true; // NOTE: must happen before allocating the dummy clause below. bwdsub_tmpunit = ca.alloc(dummy); remove_satisfied = false; + + // add the initialization for all the internal variables + for (int i = frozen.size(); i < vardata.size(); ++ i) { + frozen .push(1); + eliminated.push(0); + if (use_simplification){ + n_occ .push(0); + n_occ .push(0); + occurs .init(i); + touched .push(0); + elim_heap .insert(i); + } + } + } diff --git a/src/prop/cnf_stream.cpp b/src/prop/cnf_stream.cpp index 3a4fa781a..d18ec6e69 100644 --- a/src/prop/cnf_stream.cpp +++ b/src/prop/cnf_stream.cpp @@ -175,7 +175,15 @@ SatLiteral CnfStream::newLiteral(TNode node, bool theoryLiteral) { SatLiteral lit; if (!hasLiteral(node)) { // If no literal, we'll make one - lit = SatLiteral(d_satSolver->newVar(theoryLiteral)); + if (node.getKind() == kind::CONST_BOOLEAN) { + if (node.getConst()) { + lit = SatLiteral(d_satSolver->trueVar()); + } else { + lit = SatLiteral(d_satSolver->falseVar()); + } + } else { + lit = SatLiteral(d_satSolver->newVar(theoryLiteral)); + } d_translationCache[node].literal = lit; d_translationCache[node.notNode()].literal = ~lit; } else { diff --git a/src/prop/minisat/core/Solver.cc b/src/prop/minisat/core/Solver.cc index 5e1b032a3..6ee508eba 100644 --- a/src/prop/minisat/core/Solver.cc +++ b/src/prop/minisat/core/Solver.cc @@ -126,6 +126,14 @@ Solver::Solver(CVC4::prop::TheoryProxy* proxy, CVC4::context::Context* context, , asynch_interrupt (false) { PROOF(ProofManager::initSatProof(this);) + + // Create the constant variables + varTrue = newVar(true, false, false); + varFalse = newVar(false, false, false); + + // Assert the constants + uncheckedEnqueue(mkLit(varTrue, false)); + uncheckedEnqueue(mkLit(varFalse, true)); } @@ -190,16 +198,26 @@ CRef Solver::reason(Var x) { // Compute the assertion level for this clause int explLevel = 0; - for (int i = 0; i < explanation.size(); ++ i) { + int i, j; + for (i = 0, j = 0; i < explanation.size(); ++ i) { int varLevel = intro_level(var(explanation[i])); if (varLevel > explLevel) { explLevel = varLevel; } Assert(value(explanation[i]) != l_Undef); Assert(i == 0 || trail_index(var(explanation[0])) > trail_index(var(explanation[i]))); + // ignore zero level literals + if (i == 0 || level(var(explanation[i])) > 0) { + explanation[j++] = explanation[i]; + } + } + explanation.shrink(i - j); + if (j == 1) { + // Add not TRUE to the clause + explanation.push(mkLit(varTrue, true)); } - // Construct the reason (level 0) + // Construct the reason CRef real_reason = ca.alloc(explLevel, explanation, true); vardata[x] = mkVarData(real_reason, level(x), intro_level(x), trail_index(x)); clauses_removable.push(real_reason); diff --git a/src/prop/minisat/core/Solver.h b/src/prop/minisat/core/Solver.h index cfeb06211..e677d7220 100644 --- a/src/prop/minisat/core/Solver.h +++ b/src/prop/minisat/core/Solver.h @@ -65,6 +65,13 @@ protected: /** The current assertion level (user) */ int assertionLevel; + + /** Variable representing true */ + Var varTrue; + + /** Variable representing false */ + Var varFalse; + public: /** Returns the current user assertion level */ int getAssertionLevel() const { return assertionLevel; } @@ -108,6 +115,8 @@ public: // Problem specification: // Var newVar (bool polarity = true, bool dvar = true, bool theoryAtom = false); // Add a new variable with parameters specifying variable mode. + Var trueVar() const { return varTrue; } + Var falseVar() const { return varFalse; } // Less than for literals in a lemma struct lemma_lt { diff --git a/src/prop/minisat/minisat.cpp b/src/prop/minisat/minisat.cpp index bed30d658..4f2a16670 100644 --- a/src/prop/minisat/minisat.cpp +++ b/src/prop/minisat/minisat.cpp @@ -121,7 +121,6 @@ SatVariable MinisatSatSolver::newVar(bool theoryAtom) { return d_minisat->newVar(true, true, theoryAtom); } - SatValue MinisatSatSolver::solve(unsigned long& resource) { Trace("limit") << "SatSolver::solve(): have limit of " << resource << " conflicts" << std::endl; if(resource == 0) { diff --git a/src/prop/minisat/minisat.h b/src/prop/minisat/minisat.h index 9cf75a12e..19ade8ffa 100644 --- a/src/prop/minisat/minisat.h +++ b/src/prop/minisat/minisat.h @@ -56,6 +56,8 @@ public: void addClause(SatClause& clause, bool removable); SatVariable newVar(bool theoryAtom = false); + SatVariable trueVar() { return d_minisat->trueVar(); } + SatVariable falseVar() { return d_minisat->falseVar(); } SatValue solve(); SatValue solve(long unsigned int&); diff --git a/src/prop/minisat/simp/SimpSolver.cc b/src/prop/minisat/simp/SimpSolver.cc index 2cacfbcc0..8da3856ff 100644 --- a/src/prop/minisat/simp/SimpSolver.cc +++ b/src/prop/minisat/simp/SimpSolver.cc @@ -67,6 +67,19 @@ SimpSolver::SimpSolver(CVC4::prop::TheoryProxy* proxy, CVC4::context::Context* c ca.extra_clause_field = true; // NOTE: must happen before allocating the dummy clause below. bwdsub_tmpunit = ca.alloc(0, dummy); remove_satisfied = false; + + // add the initialization for all the internal variables + for (int i = frozen.size(); i < vardata.size(); ++ i) { + frozen .push(1); + eliminated.push(0); + if (use_simplification){ + n_occ .push(0); + n_occ .push(0); + occurs .init(i); + touched .push(0); + elim_heap .insert(i); + } + } } diff --git a/src/prop/sat_solver.h b/src/prop/sat_solver.h index 898709c43..2865f2cb5 100644 --- a/src/prop/sat_solver.h +++ b/src/prop/sat_solver.h @@ -46,6 +46,12 @@ public: /** Create a new boolean variable in the solver. */ virtual SatVariable newVar(bool theoryAtom = false) = 0; + /** Create a new (or return an existing) boolean variable representing the constant true */ + virtual SatVariable trueVar() = 0; + + /** Create a new (or return an existing) boolean variable representing the constant false */ + virtual SatVariable falseVar() = 0; + /** Check the satisfiability of the added clauses */ virtual SatValue solve() = 0; diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index e636b9142..2759f5717 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -304,7 +304,6 @@ SmtEngine::SmtEngine(ExprManager* em) throw(AssertionException) : setTimeLimit(Options::current()->cumulativeMillisecondLimit, true); } - d_propEngine->assertFormula(NodeManager::currentNM()->mkConst(true)); d_propEngine->assertFormula(NodeManager::currentNM()->mkConst(false).notNode()); } diff --git a/src/theory/arith/congruence_manager.cpp b/src/theory/arith/congruence_manager.cpp index 201eb08e7..39468e928 100644 --- a/src/theory/arith/congruence_manager.cpp +++ b/src/theory/arith/congruence_manager.cpp @@ -1,5 +1,4 @@ #include "theory/arith/congruence_manager.h" -#include "theory/uf/equality_engine_impl.h" #include "theory/arith/constraint.h" #include "theory/arith/arith_utilities.h" @@ -17,8 +16,7 @@ ArithCongruenceManager::ArithCongruenceManager(context::Context* c, ConstraintDa d_constraintDatabase(cd), d_setupLiteral(setup), d_av2Node(av2Node), - d_ee(d_notify, c, "theory::arith::ArithCongruenceManager"), - d_false(mkBoolNode(false)) + d_ee(d_notify, c, "theory::arith::ArithCongruenceManager") {} ArithCongruenceManager::Statistics::Statistics(): @@ -113,7 +111,7 @@ bool ArithCongruenceManager::propagate(TNode x){ }else{ ++(d_statistics.d_conflicts); - Node conf = explainInternal(x); + Node conf = flattenAnd(explainInternal(x)); d_conflict.set(conf); Debug("arith::congruenceManager") << "rewritten to false "<& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::NOT: - lhs = literal[0]; - rhs = d_false; - break; - default: - Unreachable(); + if (literal.getKind() != kind::NOT) { + d_ee.explainEquality(literal[0], literal[1], true, assumptions); + } else { + d_ee.explainEquality(literal[0][0], literal[0][1], false, assumptions); } - d_ee.explainEquality(lhs, rhs, assumptions); } void ArithCongruenceManager::enqueueIntoNB(const std::set s, NodeBuilder<>& nb){ @@ -258,13 +247,10 @@ void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar TNode eq = d_watchedEqualities[s]; Assert(eq.getKind() == kind::EQUAL); - TNode x = eq[0]; - TNode y = eq[1]; - if(isEquality){ - d_ee.addEquality(x, y, reason); + d_ee.assertEquality(eq, true, reason); }else{ - d_ee.addDisequality(x, y, reason); + d_ee.assertEquality(eq, false, reason); } } @@ -286,7 +272,7 @@ void ArithCongruenceManager::equalsConstant(Constraint c){ Node reason = c->explainForConflict(); d_keepAlive.push_back(reason); - d_ee.addEquality(xAsNode, asRational, reason); + d_ee.assertEquality(eq, true, reason); } void ArithCongruenceManager::equalsConstant(Constraint lb, Constraint ub){ @@ -310,7 +296,7 @@ void ArithCongruenceManager::equalsConstant(Constraint lb, Constraint ub){ d_keepAlive.push_back(reason); - d_ee.addEquality(xAsNode, asRational, reason); + d_ee.assertEquality(eq, true, reason); } void ArithCongruenceManager::addSharedTerm(Node x){ diff --git a/src/theory/arith/congruence_manager.h b/src/theory/arith/congruence_manager.h index a72989498..18ecbeb9d 100644 --- a/src/theory/arith/congruence_manager.h +++ b/src/theory/arith/congruence_manager.h @@ -37,24 +37,43 @@ private: ArithVarToNodeMap d_watchedEqualities; - class ArithCongruenceNotify { + class ArithCongruenceNotify : public eq::EqualityEngineNotify { private: ArithCongruenceManager& d_acm; public: ArithCongruenceNotify(ArithCongruenceManager& acm): d_acm(acm) {} - bool notify(TNode propagation) { - Debug("arith::congruences") << "ArithCongruenceNotify::notify(" << propagation << ")" << std::endl; - // Just forward to dm - return d_acm.propagate(propagation); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; + if (value) { + return d_acm.propagate(equality); + } else { + return d_acm.propagate(equality.notNode()); + } } - void notify(TNode t1, TNode t2) { - Debug("arith::congruences") << "ArithCongruenceNotify::notify(" << t1 << ", " << t2 << ")" << std::endl; - Node equality = t1.eqNode(t2); - d_acm.propagate(equality); + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Unreachable(); } - }; + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ", " << (value ? "true" : "false") << ")" << std::endl; + if (value) { + return d_acm.propagate(t1.eqNode(t2)); + } else { + return d_acm.propagate(t1.eqNode(t2).notNode()); + } + } + + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl; + if (t1.getKind() == kind::CONST_BOOLEAN) { + return d_acm.propagate(t1.iffNode(t2)); + } else { + return d_acm.propagate(t1.eqNode(t2)); + } + } + }; ArithCongruenceNotify d_notify; context::CDList d_keepAlive; @@ -75,8 +94,7 @@ private: const ArithVarNodeMap& d_av2Node; - theory::uf::EqualityEngine d_ee; - Node d_false; + eq::EqualityEngine d_ee; public: diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index c7072de72..6bb3821da 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -1001,8 +1001,8 @@ Node TheoryArith::assertionCases(TNode assertion){ if(Debug.isOn("whytheoryenginewhy")){ debugPrintFacts(); } - Warning() << "arith: Theory engine is sending me both a literal and its negation?" - << "BOOOOOOOOOOOOOOOOOOOOOO!!!!"<< endl; +// Warning() << "arith: Theory engine is sending me both a literal and its negation?" +// << "BOOOOOOOOOOOOOOOOOOOOOO!!!!"<< endl; } Debug("arith::eq") << constraint << endl; Debug("arith::eq") << negation << endl; diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 80bcb47dd..1dd74f060 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -23,17 +23,13 @@ #include #include "theory/rewriter.h" #include "expr/command.h" -#include "theory/uf/equality_engine_impl.h" - using namespace std; - namespace CVC4 { namespace theory { namespace arrays { - // These are the options that produce the best empirical results on QF_AX benchmarks. // eagerLemmas = true // eagerIndexSplitting = false @@ -58,14 +54,12 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC d_numNonLinear("theory::arrays::number of calls to setNonLinear", 0), d_numSharedArrayVarSplits("theory::arrays::number of shared array var splits", 0), d_checkTimer("theory::arrays::checkTime"), - d_ppNotify(), - d_ppEqualityEngine(d_ppNotify, u, "theory::arrays::TheoryArraysPP"), + d_ppEqualityEngine(u, "theory::arrays::TheoryArraysPP"), d_ppFacts(u), // d_ppCache(u), d_literalsToPropagate(c), d_literalsToPropagateIndex(c, 0), - d_mayEqualNotify(), - d_mayEqualEqualityEngine(d_mayEqualNotify, c, "theory::arrays::TheoryArraysMayEqual"), + d_mayEqualEqualityEngine(c, "theory::arrays::TheoryArraysMayEqual"), d_notify(*this), d_equalityEngine(d_notify, c, "theory::arrays::TheoryArrays"), d_conflict(c, false), @@ -91,14 +85,6 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); - d_ppEqualityEngine.addTerm(d_true); - d_ppEqualityEngine.addTerm(d_false); - d_ppEqualityEngine.addTriggerEquality(d_true, d_false, d_false); - - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); - // The kinds we are treating as function application in congruence d_equalityEngine.addFunctionKind(kind::SELECT); if (d_ccStore) { @@ -281,7 +267,7 @@ Theory::PPAssertStatus TheoryArrays::ppAssert(TNode in, SubstitutionMap& outSubs case kind::EQUAL: { d_ppFacts.push_back(in); - d_ppEqualityEngine.addEquality(in[0], in[1], in); + d_ppEqualityEngine.assertEquality(in, true, in); if (in[0].getMetaKind() == kind::metakind::VARIABLE && !in[1].hasSubterm(in[0])) { outSubstitutions.addSubstitution(in[0], in[1]); return PP_ASSERT_STATUS_SOLVED; @@ -299,7 +285,7 @@ Theory::PPAssertStatus TheoryArrays::ppAssert(TNode in, SubstitutionMap& outSubs in[0].getKind() == kind::IFF ); Node a = in[0][0]; Node b = in[0][1]; - d_ppEqualityEngine.addDisequality(a, b, in); + d_ppEqualityEngine.assertEquality(in[0], false, in); break; } default: @@ -335,10 +321,8 @@ bool TheoryArrays::propagate(TNode literal) Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::propagate(" << literal << ", normalized = " << normalized << ") => conflict" << std::endl; std::vector assumptions; Node negatedLiteral; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -357,67 +341,40 @@ bool TheoryArrays::propagate(TNode literal) void TheoryArrays::explain(TNode literal, std::vector& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::SELECT: - lhs = literal; - rhs = d_true; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } else { - // Predicates - lhs = literal[0]; - rhs = d_false; - break; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + // Do the work + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); } - /** - * Stores in d_infoMap the following information for each term a of type array: - * - * - all i, such that there exists a term a[i] or a = store(b i v) - * (i.e. all indices it is being read atl; store(b i v) is implicitly read at - * position i due to the implicit axiom store(b i v)[i] = v ) - * - * - all the stores a is congruent to (this information is context dependent) - * - * - all store terms of the form store (a i v) (i.e. in which a appears - * directly; this is invariant because no new store terms are created) - * - * Note: completeness depends on having pre-register called on all the input - * terms before starting to instantiate lemmas. - */ +/** + * Stores in d_infoMap the following information for each term a of type array: + * + * - all i, such that there exists a term a[i] or a = store(b i v) + * (i.e. all indices it is being read atl; store(b i v) is implicitly read at + * position i due to the implicit axiom store(b i v)[i] = v ) + * + * - all the stores a is congruent to (this information is context dependent) + * + * - all store terms of the form store (a i v) (i.e. in which a appears + * directly; this is invariant because no new store terms are created) + * + * Note: completeness depends on having pre-register called on all the input + * terms before starting to instantiate lemmas. + */ void TheoryArrays::preRegisterTerm(TNode node) { Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::preRegisterTerm(" << node << ")" << std::endl; switch (node.getKind()) { case kind::EQUAL: - // Add the terms - // d_equalityEngine.addTerm(node[0]); - // d_equalityEngine.addTerm(node[1]); - d_equalityEngine.addTerm(node); // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node[0], node[1], node); - d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); + d_equalityEngine.addTriggerEquality(node); break; case kind::SELECT: { // Reads @@ -438,7 +395,7 @@ void TheoryArrays::preRegisterTerm(TNode node) Assert(!d_equalityEngine.hasTerm(ni)); preRegisterTerm(ni); } - d_equalityEngine.addEquality(ni, s[2], d_true); + d_equalityEngine.assertEquality(ni.eqNode(s[2]), true, d_true); Assert(++it == stores->end()); } } @@ -447,8 +404,7 @@ void TheoryArrays::preRegisterTerm(TNode node) // TODO: remove this or keep it if we allow Boolean elements in arrays. if (node.getType().isBoolean()) { // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerEquality(node, d_true, node); - d_equalityEngine.addTriggerEquality(node, d_false, node.notNode()); + d_equalityEngine.addTriggerPredicate(node); } d_infoMap.addIndex(node[0], node[1]); @@ -463,7 +419,7 @@ void TheoryArrays::preRegisterTerm(TNode node) // TNode i = node[1]; // TNode v = node[2]; - d_mayEqualEqualityEngine.addEquality(node, a, d_true); + d_mayEqualEqualityEngine.assertEquality(node.eqNode(a), true, d_true); // NodeManager* nm = NodeManager::currentNM(); // Node ni = nm->mkNode(kind::SELECT, node, i); @@ -508,10 +464,8 @@ void TheoryArrays::propagate(Effort e) Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::propagate(): in conflict, normalized = " << normalized << std::endl; Node negatedLiteral; std::vector assumptions; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -727,17 +681,17 @@ void TheoryArrays::check(Effort e) { // Do the work switch (fact.getKind()) { case kind::EQUAL: - d_equalityEngine.addEquality(fact[0], fact[1], fact); + d_equalityEngine.assertEquality(fact, true, fact); break; case kind::SELECT: - d_equalityEngine.addPredicate(fact, true, fact); + d_equalityEngine.assertPredicate(fact, true, fact); break; case kind::NOT: if (fact[0].getKind() == kind::SELECT) { - d_equalityEngine.addPredicate(fact[0], false, fact); + d_equalityEngine.assertPredicate(fact[0], false, fact); } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1])) { // Assert the dis-equality - d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact); + d_equalityEngine.assertEquality(fact[0], false, fact); // Apply ArrDiseq Rule if diseq is between arrays if(fact[0][0].getType().isArray()) { @@ -764,7 +718,7 @@ void TheoryArrays::check(Effort e) { if (!d_equalityEngine.hasTerm(bk)) { preRegisterTerm(bk); } - d_equalityEngine.addDisequality(ak, bk, fact); + d_equalityEngine.assertEquality(ak.eqNode(bk), false, fact); Trace("arrays-lem")<<"Arrays::addExtLemma "<< ak << " /= " << bk <<"\n"; ++d_numExt; } @@ -807,14 +761,11 @@ Node TheoryArrays::mkAnd(std::vector& conjunctions) for (; i < conjunctions.size(); ++i) { t = conjunctions[i]; - // Remove true node - represents axiomatically true assertion - if (t == d_true) continue; - // Expand explanation resulting from propagating a ROW lemma if (t.getKind() == kind::OR) { if ((explained.find(t) == explained.end())) { Assert(t[1].getKind() == kind::EQUAL); - d_equalityEngine.explainDisequality(t[1][0], t[1][1], conjunctions); + d_equalityEngine.explainEquality(t[1][0], t[1][1], false, conjunctions); explained.insert(t); } continue; @@ -949,7 +900,7 @@ void TheoryArrays::checkRIntro1(TNode a, TNode b) Node ni = nm->mkNode(kind::SELECT, s, s[1]); Assert(!d_equalityEngine.hasTerm(ni)); preRegisterTerm(ni); - d_equalityEngine.addEquality(ni, s[2], d_true); + d_equalityEngine.assertEquality(ni.eqNode(s[2]), true, d_true); } } @@ -1004,7 +955,7 @@ void TheoryArrays::mergeArrays(TNode a, TNode b) } } - d_mayEqualEqualityEngine.addEquality(a, b, d_true); + d_mayEqualEqualityEngine.assertEquality(a.eqNode(b), true, d_true); checkRowLemmas(a,b); checkRowLemmas(b,a); @@ -1186,7 +1137,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) if (!bjExists) { preRegisterTerm(bj); } - d_equalityEngine.addEquality(aj, bj, reason); + d_equalityEngine.assertEquality(aj.eqNode(bj), true, reason); ++d_numProp; return; } @@ -1194,7 +1145,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating i = j ("<mkNode(kind::OR, i.eqNode(j), aj.eqNode(bj)); d_permRef.push_back(reason); - d_equalityEngine.addEquality(i, j, reason); + d_equalityEngine.assertEquality(i.eqNode(j), true, reason); ++d_numProp; return; } diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index d18b3abde..88986ee7a 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -133,18 +133,8 @@ class TheoryArrays : public Theory { private: - // PPNotifyClass: dummy template class for d_ppEqualityEngine - notifications not used - class PPNotifyClass { - public: - bool notify(TNode propagation) { return true; } - void notify(TNode t1, TNode t2) { } - }; - - /** The notify class for d_ppEqualityEngine */ - PPNotifyClass d_ppNotify; - /** Equaltity engine */ - uf::EqualityEngine d_ppEqualityEngine; + eq::EqualityEngine d_ppEqualityEngine; // List of facts learned by preprocessor - needed for permanent ref for benefit of d_ppEqualityEngine context::CDList d_ppFacts; @@ -187,17 +177,8 @@ class TheoryArrays : public Theory { private: - class MayEqualNotifyClass { - public: - bool notify(TNode propagation) { return true; } - void notify(TNode t1, TNode t2) { } - }; - - /** The notify class for d_mayEqualEqualityEngine */ - MayEqualNotifyClass d_mayEqualNotify; - /** Equaltity engine for determining if two arrays might be equal */ - uf::EqualityEngine d_mayEqualEqualityEngine; + eq::EqualityEngine d_mayEqualEqualityEngine; public: @@ -238,37 +219,57 @@ class TheoryArrays : public Theory { private: // NotifyClass: template helper class for d_equalityEngine - handles call-back from congruence closure module - class NotifyClass { + class NotifyClass : public eq::EqualityEngineNotify { TheoryArrays& d_arrays; public: NotifyClass(TheoryArrays& arrays): d_arrays(arrays) {} - bool notify(TNode propagation) { - Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::notify(" << propagation << ")" << std::endl; + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; // Just forward to arrays - return d_arrays.propagate(propagation); + if (value) { + return d_arrays.propagate(equality); + } else { + return d_arrays.propagate(equality.notNode()); + } + } + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Unreachable(); } - void notify(TNode t1, TNode t2) { - Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; - if (t1.getType().isArray()) { - d_arrays.mergeArrays(t1, t2); - if (!d_arrays.isShared(t1) || !d_arrays.isShared(t2)) { - return; + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ")" << std::endl; + if (value) { + if (t1.getType().isArray()) { + d_arrays.mergeArrays(t1, t2); + if (!d_arrays.isShared(t1) || !d_arrays.isShared(t2)) { + return true; + } } + // Propagate equality between shared terms + Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); + d_arrays.propagate(equality); } - // Propagate equality between shared terms - Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); - d_arrays.propagate(equality); + // TODO: implement negation propagation + return true; } - }; + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl; + if (Theory::theoryOf(t1) == THEORY_BOOL) { + return d_arrays.propagate(t1.iffNode(t2)); + } else { + return d_arrays.propagate(t1.eqNode(t2)); + } + } + }; /** The notify class for d_equalityEngine */ NotifyClass d_notify; /** Equaltity engine */ - uf::EqualityEngine d_equalityEngine; + eq::EqualityEngine d_equalityEngine; // Are we in conflict? context::CDO d_conflict; diff --git a/src/theory/booleans/circuit_propagator.h b/src/theory/booleans/circuit_propagator.h index 78221a617..f5e4f4630 100644 --- a/src/theory/booleans/circuit_propagator.h +++ b/src/theory/booleans/circuit_propagator.h @@ -79,17 +79,17 @@ private: template class DataClearer : context::ContextNotifyObj { T& d_data; + protected: + void contextNotifyPop() { + Trace("circuit-prop") << "CircuitPropagator::DataClearer: clearing data " + << "(size was " << d_data.size() << ")" << std::endl; + d_data.clear(); + } public: DataClearer(context::Context* context, T& data) : context::ContextNotifyObj(context), d_data(data) { } - - void notify() { - Trace("circuit-prop") << "CircuitPropagator::DataClearer: clearing data " - << "(size was " << d_data.size() << ")" << std::endl; - d_data.clear(); - } };/* class DataClearer */ /** diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index c9d58574e..4076a7ee0 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -21,7 +21,6 @@ #include "theory/bv/theory_bv_utils.h" #include "theory/valuation.h" #include "theory/bv/bv_sat.h" -#include "theory/uf/equality_engine_impl.h" using namespace CVC4; using namespace CVC4::theory; @@ -52,18 +51,7 @@ TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel& d_toBitBlast(c), d_propagatedBy(c) { - d_true = utils::mkTrue(); - d_false = utils::mkFalse(); - if (d_useEqualityEngine) { - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); - - // add disequality between 0 and 1 bits - d_equalityEngine.addDisequality(utils::mkConst(BitVector((unsigned)1, (unsigned)0)), - utils::mkConst(BitVector((unsigned)1, (unsigned)1)), - d_true); // The kinds we are treating as function application in congruence d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT); @@ -137,11 +125,8 @@ void TheoryBV::preRegisterTerm(TNode node) { if (d_useEqualityEngine) { switch (node.getKind()) { case kind::EQUAL: - // Add the terms - d_equalityEngine.addTerm(node); // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node[0], node[1], node); - d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); + d_equalityEngine.addTriggerEquality(node); break; default: d_equalityEngine.addTerm(node); @@ -185,15 +170,15 @@ void TheoryBV::check(Effort e) if (predicate.getKind() == kind::EQUAL) { if (negated) { // dis-equality - d_equalityEngine.addDisequality(predicate[0], predicate[1], fact); + d_equalityEngine.assertEquality(predicate, false, fact); } else { // equality - d_equalityEngine.addEquality(predicate[0], predicate[1], fact); + d_equalityEngine.assertEquality(predicate, true, fact); } } else { // Adding predicate if the congruence over it is turned on if (d_equalityEngine.isFunctionKind(predicate.getKind())) { - d_equalityEngine.addPredicate(predicate, !negated, fact); + d_equalityEngine.assertPredicate(predicate, !negated, fact); } } } @@ -279,16 +264,16 @@ void TheoryBV::propagate(Effort e) { bool satValue; if (!d_valuation.hasSatValue(normalized, satValue) || satValue) { // check if we already propagated the negation - Node neg_literal = literal.getKind() == kind::NOT ? (Node)literal[0] : mkNot(literal); - if (d_alreadyPropagatedSet.find(neg_literal) != d_alreadyPropagatedSet.end()) { + Node negLiteral = literal.getKind() == kind::NOT ? (Node)literal[0] : mkNot(literal); + if (d_alreadyPropagatedSet.find(negLiteral) != d_alreadyPropagatedSet.end()) { Debug("bitvector") << spaces(getSatContext()->getLevel()) << "TheoryBV::propagate(): in conflict " << literal << " and its negation both propagated \n"; // we are in conflict std::vector assumptions; explain(literal, assumptions); - explain(neg_literal, assumptions); + explain(negLiteral, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; - return; + return; } BVDebug("bitvector") << spaces(getSatContext()->getLevel()) << "TheoryBV::propagate(): " << literal << std::endl; @@ -299,10 +284,8 @@ void TheoryBV::propagate(Effort e) { Node negatedLiteral; std::vector assumptions; - if (normalized != d_false) { negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); assumptions.push_back(negatedLiteral); - } explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -352,8 +335,6 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) // If propagated already, just skip PropagatedMap::const_iterator find = d_propagatedBy.find(literal); if (find != d_propagatedBy.end()) { - //unsigned theories = (*find).second | (unsigned) subtheory; - //d_propagatedBy[literal] = theories; return true; } else { d_propagatedBy[literal] = subtheory; @@ -362,56 +343,37 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) // See if the literal has been asserted already bool satValue = false; bool hasSatValue = d_valuation.hasSatValue(literal, satValue); - // If asserted, we might be in conflict + // If asserted, we might be in conflict if (hasSatValue && !satValue) { - Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => conflict" << std::endl; - std::vector assumptions; - Node negatedLiteral = literal.getKind() == kind::NOT ? (Node) literal[0] : literal.notNode(); - assumptions.push_back(negatedLiteral); - explain(literal, assumptions); - d_conflictNode = mkAnd(assumptions); - d_conflict = true; - return false; + Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => conflict" << std::endl; + std::vector assumptions; + Node negatedLiteral = literal.getKind() == kind::NOT ? (Node) literal[0] : literal.notNode(); + assumptions.push_back(negatedLiteral); + explain(literal, assumptions); + d_conflictNode = mkAnd(assumptions); + d_conflict = true; + return false; } // Nothing, just enqueue it for propagation and mark it as asserted already Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => enqueuing for propagation" << std::endl; d_literalsToPropagate.push_back(literal); + // No conflict return true; }/* TheoryBV::propagate(TNode) */ void TheoryBV::explain(TNode literal, std::vector& assumptions) { - if (propagatedBy(literal, SUB_EQUALITY)) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } else { - // Predicates - lhs = literal[0]; - rhs = d_false; - break; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); } else { Assert(propagatedBy(literal, SUB_BITBLASTER)); d_bitblaster->explain(literal, assumptions); @@ -430,7 +392,9 @@ Node TheoryBV::explain(TNode node) { return utils::mkTrue(); } // return the explanation - return mkAnd(assumptions); + Node explanation = mkAnd(assumptions); + Debug("bitvector::explain") << "TheoryBV::explain(" << node << ") => " << explanation << std::endl; + return explanation; } diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 0ced179ec..e46d052f8 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -61,8 +61,6 @@ private: /** Bitblaster */ Bitblaster* d_bitblaster; - Node d_true; - Node d_false; /** Context dependent set of atoms we already propagated */ context::CDHashSet d_alreadyPropagatedSet; @@ -99,22 +97,44 @@ private: // Added by Clark // NotifyClass: template helper class for d_equalityEngine - handles call-back from congruence closure module - class NotifyClass { + class NotifyClass : public eq::EqualityEngineNotify { + TheoryBV& d_bv; + public: + NotifyClass(TheoryBV& uf): d_bv(uf) {} - bool notify(TNode propagation) { - Debug("bitvector") << spaces(d_bv.getSatContext()->getLevel()) << "NotifyClass::notify(" << propagation << ")" << std::endl; - // Just forward to bv - return d_bv.storePropagation(propagation, SUB_EQUALITY); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("bitvector") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_bv.storePropagation(equality, SUB_EQUALITY); + } else { + return d_bv.storePropagation(equality.notNode(), SUB_EQUALITY); + } + } + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Debug("bitvector") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_bv.storePropagation(predicate, SUB_EQUALITY); + } else { + return d_bv.storePropagation(predicate, SUB_EQUALITY); + } + } + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("bitvector") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << std::endl; + if (value) { + return d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY); + } else { + return d_bv.storePropagation(t1.eqNode(t2).notNode(), SUB_EQUALITY); + } } - void notify(TNode t1, TNode t2) { - Debug("arrays") << spaces(d_bv.getSatContext()->getLevel()) << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; - // Propagate equality between shared terms - Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); - d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY); + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("bitvector") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl; + return d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY); } }; @@ -122,7 +142,7 @@ private: NotifyClass d_notify; /** Equaltity engine */ - uf::EqualityEngine d_equalityEngine; + eq::EqualityEngine d_equalityEngine; // Are we in conflict? context::CDO d_conflict; diff --git a/src/theory/datatypes/union_find.cpp b/src/theory/datatypes/union_find.cpp index eacc4e798..34706719e 100644 --- a/src/theory/datatypes/union_find.cpp +++ b/src/theory/datatypes/union_find.cpp @@ -31,7 +31,7 @@ namespace theory { namespace datatypes { template -void UnionFind::notify() { +void UnionFind::contextNotifyPop() { Trace("datatypesuf") << "datatypesUF cancelling : " << d_offset << " < " << d_trace.size() << " ?" << endl; while(d_offset < d_trace.size()) { pair p = d_trace.back(); @@ -50,9 +50,9 @@ void UnionFind::notify() { // The following declarations allow us to put functions in the .cpp file // instead of the header, since we know which instantiations are needed. -template void UnionFind::notify(); +template void UnionFind::contextNotifyPop(); -template void UnionFind::notify(); +template void UnionFind::contextNotifyPop(); }/* CVC4::theory::datatypes namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/datatypes/union_find.h b/src/theory/datatypes/union_find.h index 51d1d85bc..4893c3502 100644 --- a/src/theory/datatypes/union_find.h +++ b/src/theory/datatypes/union_find.h @@ -84,13 +84,13 @@ public: */ inline void setCanon(TNode n, TNode newParent); +protected: -public: /** * Called by the Context when a pop occurs. Cancels everything to the - * current context level. Overrides ContextNotifyObj::notify(). + * current context level. Overrides ContextNotifyObj::contextNotifyPop(). */ - void notify(); + void contextNotifyPop(); };/* class UnionFind<> */ diff --git a/src/theory/shared_terms_database.cpp b/src/theory/shared_terms_database.cpp index 577e1b957..4f5475e97 100644 --- a/src/theory/shared_terms_database.cpp +++ b/src/theory/shared_terms_database.cpp @@ -16,7 +16,6 @@ **/ #include "theory/shared_terms_database.h" -#include "theory/uf/equality_engine_impl.h" using namespace CVC4; using namespace theory; @@ -36,15 +35,8 @@ SharedTermsDatabase::SharedTermsDatabase(SharedTermsNotifyClass& notify, context d_equalityEngine(d_EENotify, context, "SharedTermsDatabase") { StatisticsRegistry::registerStat(&d_statSharedTerms); - NodeManager* nm = NodeManager::currentNM(); - d_true = nm->mkConst(true); - d_false = nm->mkConst(false); - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); } - SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException) { StatisticsRegistry::unregisterStat(&d_statSharedTerms); @@ -53,9 +45,9 @@ SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException) } } - void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) { Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl; + std::pair search_pair(atom, term); SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair); if (find == d_termsToTheories.end()) { @@ -243,23 +235,21 @@ bool SharedTermsDatabase::areDisequal(TNode a, TNode b) { return d_equalityEngine.areDisequal(a,b); } - void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason) { bool negated = literal.getKind() == kind::NOT; TNode atom = negated ? literal[0] : literal; if (negated) { Assert(!d_equalityEngine.areDisequal(atom[0], atom[1])); - d_equalityEngine.addDisequality(atom[0], atom[1], reason); + d_equalityEngine.assertEquality(atom, false, reason); // !!! need to send this out } else { Assert(!d_equalityEngine.areEqual(atom[0], atom[1])); - d_equalityEngine.addEquality(atom[0], atom[1], reason); + d_equalityEngine.assertEquality(atom, true, reason); } } - static Node mkAnd(const std::vector& conjunctions) { Assert(conjunctions.size() > 0); @@ -286,31 +276,12 @@ static Node mkAnd(const std::vector& conjunctions) { Node SharedTermsDatabase::explain(TNode literal) { std::vector assumptions; - explain(literal, assumptions); - return mkAnd(assumptions); -} - - -void SharedTermsDatabase::explain(TNode literal, std::vector& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + if (literal.getKind() == kind::NOT) { + Assert(literal[0].getKind() == kind::EQUAL); + d_equalityEngine.explainEquality(literal[0][0], literal[0][1], false, assumptions); + } else { + Assert(literal.getKind() == kind::EQUAL); + d_equalityEngine.explainEquality(literal[0], literal[1], true, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); + return mkAnd(assumptions); } diff --git a/src/theory/shared_terms_database.h b/src/theory/shared_terms_database.h index 6af7fd41f..403c90ced 100644 --- a/src/theory/shared_terms_database.h +++ b/src/theory/shared_terms_database.h @@ -28,25 +28,23 @@ class SharedTermsDatabase : public context::ContextNotifyObj { public: - /** A conainer for a list of shared terms */ + /** A container for a list of shared terms */ typedef std::vector shared_terms_list; - /** The iterator to go rhough the shared terms list */ + + /** The iterator to go through the shared terms list */ typedef shared_terms_list::const_iterator shared_terms_iterator; private: - Node d_true; - - Node d_false; - /** The context */ context::Context* d_context; /** Some statistics */ IntStat d_statSharedTerms; - // Needs to be a map from Nodes as after a backtrack they might not exist + // Needs to be a map from Nodes as after a backtrack they might not exist typedef std::hash_map SharedTermsMap; + /** A map from atoms to a list of shared terms */ SharedTermsMap d_atomsToTerms; @@ -57,14 +55,17 @@ private: context::CDO d_addedSharedTermsSize; typedef context::CDHashMap, theory::Theory::Set, TNodePairHashFunction> SharedTermsTheoriesMap; + /** A map from atoms and subterms to the theories that use it */ SharedTermsTheoriesMap d_termsToTheories; typedef context::CDHashMap AlreadyNotifiedMap; + /** Map from term to theories that have already been notified about the shared term */ AlreadyNotifiedMap d_alreadyNotifiedMap; public: + /** Class for notifications about new shared term equalities */ class SharedTermsNotifyClass { public: @@ -74,6 +75,7 @@ public: }; private: + // Instance of class to send shared term notifications to SharedTermsNotifyClass& d_sharedNotify; @@ -101,21 +103,37 @@ private: void backtrack(); // EENotifyClass: template helper class for d_equalityEngine - handles call-backs - class EENotifyClass { + class EENotifyClass : public theory::eq::EqualityEngineNotify { SharedTermsDatabase& d_sharedTerms; public: EENotifyClass(SharedTermsDatabase& shared): d_sharedTerms(shared) {} - bool notify(TNode propagation) { return true; } // Not used - void notify(TNode t1, TNode t2) { - d_sharedTerms.mergeSharedTerms(t1, t2); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Unreachable(); + return true; + } + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Unreachable(); + return true; + } + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + if (value) { + d_sharedTerms.mergeSharedTerms(t1, t2); + } + return true; + } + + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + return true; } }; /** The notify class for d_equalityEngine */ EENotifyClass d_EENotify; - /** Equaltity engine */ - theory::uf::EqualityEngine d_equalityEngine; + /** Equality engine */ + theory::eq::EqualityEngine d_equalityEngine; /** Attach a new notify list to an equivalence class representative */ NotifyList* getNewNotifyList(); @@ -123,9 +141,6 @@ private: /** Method called by equalityEngine when a becomes equal to b */ void mergeSharedTerms(TNode a, TNode b); - /** Internal explanation method */ - void explain(TNode literal, std::vector& assumptions); - public: SharedTermsDatabase(SharedTermsNotifyClass& notify, context::Context* context); @@ -179,10 +194,12 @@ public: Node explain(TNode literal); +protected: + /** * This method gets called on backtracks from the context manager. */ - void notify() { + void contextNotifyPop() { backtrack(); } }; diff --git a/src/theory/substitutions.h b/src/theory/substitutions.h index 27c1a2b69..958f50276 100644 --- a/src/theory/substitutions.h +++ b/src/theory/substitutions.h @@ -73,16 +73,16 @@ private: /** Helper class to invalidate cache on user pop */ class CacheInvalidator : public context::ContextNotifyObj { bool& d_cacheInvalidated; - + protected: + void contextNotifyPop() { + d_cacheInvalidated = true; + } public: CacheInvalidator(context::Context* context, bool& cacheInvalidated) : context::ContextNotifyObj(context), d_cacheInvalidated(cacheInvalidated) { } - void notify() { - d_cacheInvalidated = true; - } };/* class SubstitutionMap::CacheInvalidator */ /** diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index a3aee985d..c19bdda91 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -124,6 +124,50 @@ void TheoryEngine::preRegister(TNode preprocessed) { // } } +void TheoryEngine::printAssertions(const char* tag) { + if (Debug.isOn(tag)) { + for (TheoryId theoryId = THEORY_FIRST; theoryId < THEORY_LAST; ++theoryId) { + Theory* theory = d_theoryTable[theoryId]; + if (theory && d_logicInfo.isTheoryEnabled(theoryId)) { + Debug(tag) << "--------------------------------------------" << std::endl; + Debug(tag) << "Assertions of " << theory->getId() << ": " << std::endl; + context::CDList::const_iterator it = theory->facts_begin(), it_end = theory->facts_end(); + for (unsigned i = 0; it != it_end; ++ it, ++i) { + if ((*it).isPreregistered) { + Debug(tag) << "[" << i << "]: "; + } else { + Debug(tag) << "(" << i << "): "; + } + Debug(tag) << (*it).assertion << endl; + } + + if (d_logicInfo.isSharingEnabled()) { + Debug(tag) << "Shared terms of " << theory->getId() << ": " << std::endl; + context::CDList::const_iterator it = theory->shared_terms_begin(), it_end = theory->shared_terms_end(); + for (unsigned i = 0; it != it_end; ++ it, ++i) { + Debug(tag) << "[" << i << "]: " << (*it) << endl; + } + } + } + } + + } +} + +template +class scoped_vector_clear { + vector& d_v; +public: + scoped_vector_clear(vector& v) + : d_v(v) { + Assert(!doAssert || d_v.empty()); + } + ~scoped_vector_clear() { + d_v.clear(); + } + +}; + /** * Check all (currently-active) theories for conflicts. * @param effort the effort level to use @@ -143,12 +187,12 @@ void TheoryEngine::check(Theory::Effort effort) { } \ } + // make sure d_propagatedSharedLiterals is cleared on exit + scoped_vector_clear clear_shared_literals(d_propagatedSharedLiterals); + // Do the checking try { - // Clear any leftover propagated shared literals - d_propagatedSharedLiterals.clear(); - // Mark the output channel unused (if this is FULL_EFFORT, and nothing // is done by the theories, no additional check will be needed) d_outputChannelUsed = false; @@ -159,32 +203,10 @@ void TheoryEngine::check(Theory::Effort effort) { while (true) { Debug("theory") << "TheoryEngine::check(" << effort << "): running check" << std::endl; + Assert(d_propagatedSharedLiterals.empty()); if (Debug.isOn("theory::assertions")) { - for (TheoryId theoryId = THEORY_FIRST; theoryId < THEORY_LAST; ++theoryId) { - Theory* theory = d_theoryTable[theoryId]; - if (theory && d_logicInfo.isTheoryEnabled(theoryId)) { - Debug("theory::assertions") << "--------------------------------------------" << std::endl; - Debug("theory::assertions") << "Assertions of " << theory->getId() << ": " << std::endl; - context::CDList::const_iterator it = theory->facts_begin(), it_end = theory->facts_end(); - for (unsigned i = 0; it != it_end; ++ it, ++i) { - if ((*it).isPreregistered) { - Debug("theory::assertions") << "[" << i << "]: "; - } else { - Debug("theory::assertions") << "(" << i << "): "; - } - Debug("theory::assertions") << (*it).assertion << endl; - } - - if (d_logicInfo.isSharingEnabled()) { - Debug("theory::assertions") << "Shared terms of " << theory->getId() << ": " << std::endl; - context::CDList::const_iterator it = theory->shared_terms_begin(), it_end = theory->shared_terms_end(); - for (unsigned i = 0; it != it_end; ++ it, ++i) { - Debug("theory::assertions") << "[" << i << "]: " << (*it) << endl; - } - } - } - } + printAssertions("theory::assertions"); } // Do the checking @@ -232,9 +254,6 @@ void TheoryEngine::check(Theory::Effort effort) { } } - // Clear any leftover propagated shared literals - d_propagatedSharedLiterals.clear(); - Debug("theory") << "TheoryEngine::check(" << effort << "): done, we are " << (d_inConflict ? "unsat" : "sat") << (d_lemmasAdded ? " with new lemmas" : " with no new lemmas") << std::endl; } catch(const theory::Interrupted&) { @@ -243,6 +262,9 @@ void TheoryEngine::check(Theory::Effort effort) { } void TheoryEngine::outputSharedLiterals() { + + scoped_vector_clear clear_shared_literals(d_propagatedSharedLiterals); + // Assert all the shared literals for (unsigned i = 0; i < d_propagatedSharedLiterals.size(); ++ i) { const SharedLiteral& eq = d_propagatedSharedLiterals[i]; @@ -258,8 +280,6 @@ void TheoryEngine::outputSharedLiterals() { } } } - // Clear the equalities - d_propagatedSharedLiterals.clear(); } @@ -269,7 +289,9 @@ void TheoryEngine::combineTheories() { TimerStat::CodeTimer combineTheoriesTimer(d_combineTheoriesTime); + // Care graph we'll be building CareGraph careGraph; + #ifdef CVC4_FOR_EACH_THEORY_STATEMENT #undef CVC4_FOR_EACH_THEORY_STATEMENT #endif @@ -278,6 +300,7 @@ void TheoryEngine::combineTheories() { reinterpret_cast::theory_class*>(theoryOf(THEORY))->getCareGraph(careGraph); \ } + // Call on each parametric theory to give us its care graph CVC4_FOR_EACH_THEORY; // Now add splitters for the ones we are interested in @@ -833,6 +856,8 @@ Node TheoryEngine::getExplanation(TNode node) { } Assert(properExplanation(node, explanation)); + Debug("theory::explain") << "TheoryEngine::getExplanation(" << node << ") => " << explanation << std::endl; + return explanation; } diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index 5c73da1f6..2871d5559 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -309,8 +309,10 @@ class TheoryEngine { } struct SharedLiteral { - /** The node/theory pair for the assertion */ - /** THEORY_LAST indicates this is a SAT literal and should be sent to the SAT solver */ + /** + * The node/theory pair for the assertion. THEORY_LAST indicates this is a SAT + * literal and should be sent to the SAT solver + */ NodeTheoryPair toAssert; /** This is the node that we will use to explain it */ Node toExplain; @@ -319,7 +321,7 @@ class TheoryEngine { : toAssert(assertion, receivingTheory), toExplain(original) { } - };/* struct SharedLiteral */ + }; /** * Map from nodes to theories. @@ -728,6 +730,9 @@ private: /** Visitor for collecting shared terms */ SharedTermsVisitor d_sharedTermsVisitor; + /** Prints the assertions to the debug stream */ + void printAssertions(const char* tag); + };/* class TheoryEngine */ }/* CVC4 namespace */ diff --git a/src/theory/uf/Makefile.am b/src/theory/uf/Makefile.am index f25e50ec9..9d95eaa22 100644 --- a/src/theory/uf/Makefile.am +++ b/src/theory/uf/Makefile.am @@ -11,7 +11,7 @@ libuf_la_SOURCES = \ theory_uf_type_rules.h \ theory_uf_rewriter.h \ equality_engine.h \ - equality_engine_impl.h \ + equality_engine.cpp \ symmetry_breaker.h \ symmetry_breaker.cpp diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp new file mode 100644 index 000000000..b78015c00 --- /dev/null +++ b/src/theory/uf/equality_engine.cpp @@ -0,0 +1,995 @@ +/********************* */ +/*! \file equality_engine_impl.h + ** \verbatim + ** Original author: dejan + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "theory/uf/equality_engine.h" + +namespace CVC4 { +namespace theory { +namespace eq { + +/** + * Data used in the BFS search through the equality graph. + */ +struct BfsData { + // The current node + EqualityNodeId nodeId; + // The index of the edge we traversed + EqualityEdgeId edgeId; + // Index in the queue of the previous node. Shouldn't be too much of them, at most the size + // of the biggest equivalence class + size_t previousIndex; + + BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0) + : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {} +}; + +class ScopedBool { + bool& watch; + bool oldValue; +public: + ScopedBool(bool& watch, bool newValue) + : watch(watch), oldValue(watch) { + watch = newValue; + } + ~ScopedBool() { + watch = oldValue; + } +}; + +EqualityEngineNotifyNone EqualityEngine::s_notifyNone; + +void EqualityEngine::init() { + Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl; + Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl; + Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl; + d_true = NodeManager::currentNM()->mkConst(true); + d_false = NodeManager::currentNM()->mkConst(false); + addTerm(d_true); + addTerm(d_false); +} + + +EqualityEngine::EqualityEngine(context::Context* context, std::string name) +: ContextNotifyObj(context) +, d_context(context) +, d_performNotify(true) +, d_notify(s_notifyNone) +, d_applicationLookupsCount(context, 0) +, d_nodesCount(context, 0) +, d_assertedEqualitiesCount(context, 0) +, d_equalityTriggersCount(context, 0) +, d_individualTriggersSize(context, 0) +, d_constantRepresentativesSize(context, 0) +, d_stats(name) +{ + init(); +} + +EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name) +: ContextNotifyObj(context) +, d_context(context) +, d_performNotify(true) +, d_notify(notify) +, d_applicationLookupsCount(context, 0) +, d_nodesCount(context, 0) +, d_assertedEqualitiesCount(context, 0) +, d_equalityTriggersCount(context, 0) +, d_individualTriggersSize(context, 0) +, d_constantRepresentativesSize(context, 0) +, d_stats(name) +{ + init(); +} + +void EqualityEngine::enqueue(const MergeCandidate& candidate) { + Debug("equality") << "EqualityEngine::enqueue(" << candidate.toString(*this) << ")" << std::endl; + d_propagationQueue.push(candidate); +} + +EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) { + Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ")" << std::endl; + + ++ d_stats.functionTermsCount; + + // Get another id for this + EqualityNodeId funId = newNode(original); + FunctionApplication funOriginal(t1, t2); + // The function application we're creating + EqualityNodeId t1ClassId = getEqualityNode(t1).getFind(); + EqualityNodeId t2ClassId = getEqualityNode(t2).getFind(); + FunctionApplication funNormalized(t1ClassId, t2ClassId); + + // We add the original version + d_applications[funId] = FunctionApplicationPair(funOriginal, funNormalized); + + // Add the lookup data, if it's not already there + typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized); + if (find == d_applicationLookup.end()) { + // When we backtrack, if the lookup is not there anymore, we'll add it again + Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): no lookup, setting up" << std::endl; + // Mark the normalization to the lookup + storeApplicationLookup(funNormalized, funId); + } else { + // If it's there, we need to merge these two + Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): lookup exists, adding to queue" << std::endl; + enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null())); + propagate(); + } + + // Add to the use lists + d_equalityNodes[t1ClassId].usedIn(funId, d_useListNodes); + d_equalityNodes[t2ClassId].usedIn(funId, d_useListNodes); + + // Return the new id + Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl; + + return funId; +} + +EqualityNodeId EqualityEngine::newNode(TNode node) { + + Debug("equality") << "EqualityEngine::newNode(" << node << ")" << std::endl; + + ++ d_stats.termsCount; + + // Register the new id of the term + EqualityNodeId newId = d_nodes.size(); + d_nodeIds[node] = newId; + // Add the node to it's position + d_nodes.push_back(node); + // Note if this is an application or not + d_applications.push_back(FunctionApplicationPair()); + // Add the trigger list for this node + d_nodeTriggers.push_back(+null_trigger); + // Add it to the equality graph + d_equalityGraph.push_back(+null_edge); + // Mark the no-individual trigger + d_nodeIndividualTrigger.push_back(+null_id); + // Mark non-constant by default + d_constantRepresentative.push_back(node.isConst() ? newId : +null_id); + // Add the equality node to the nodes + d_equalityNodes.push_back(EqualityNode(newId)); + + // Increase the counters + d_nodesCount = d_nodesCount + 1; + + Debug("equality") << "EqualityEngine::newNode(" << node << ") => " << newId << std::endl; + + return newId; +} + +void EqualityEngine::addTerm(TNode t) { + + Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl; + + // If there already, we're done + if (hasTerm(t)) { + Debug("equality") << "EqualityEngine::addTerm(" << t << "): already there" << std::endl; + return; + } + + EqualityNodeId result; + + // If a function application we go in + if (t.getNumChildren() > 0 && d_congruenceKinds[t.getKind()]) + { + // Add the operator + TNode tOp = t.getOperator(); + addTerm(tOp); + // Add all the children and Curryfy + result = getNodeId(tOp); + for (unsigned i = 0; i < t.getNumChildren(); ++ i) { + // Add the child + addTerm(t[i]); + // Add the application + result = newApplicationNode(t, result, getNodeId(t[i])); + } + } else { + // Otherwise we just create the new id + result = newNode(t); + } + + Debug("equality") << "EqualityEngine::addTerm(" << t << ") => " << result << std::endl; +} + +bool EqualityEngine::hasTerm(TNode t) const { + return d_nodeIds.find(t) != d_nodeIds.end(); +} + +EqualityNodeId EqualityEngine::getNodeId(TNode node) const { + Assert(hasTerm(node), node.toString().c_str()); + return (*d_nodeIds.find(node)).second; +} + +EqualityNode& EqualityEngine::getEqualityNode(TNode t) { + return getEqualityNode(getNodeId(t)); +} + +EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) { + Assert(nodeId < d_equalityNodes.size()); + return d_equalityNodes[nodeId]; +} + +const EqualityNode& EqualityEngine::getEqualityNode(TNode t) const { + return getEqualityNode(getNodeId(t)); +} + +const EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) const { + Assert(nodeId < d_equalityNodes.size()); + return d_equalityNodes[nodeId]; +} + +void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) { + + Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl; + + // Add the terms if they are not already in the database + addTerm(t1); + addTerm(t2); + + // Add to the queue and propagate + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t2Id = getNodeId(t2); + enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason)); + + propagate(); +} + +void EqualityEngine::assertPredicate(TNode t, bool polarity, TNode reason) { + Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl; + Assert(t.getKind() != kind::EQUAL, "Use assertEquality instead"); + assertEqualityInternal(t, polarity ? d_true : d_false, reason); +} + +void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) { + Debug("equality") << "EqualityEngine::addEquality(" << eq << "," << (polarity ? "true" : "false") << std::endl; + if (polarity) { + // Add equality between terms + assertEqualityInternal(eq[0], eq[1], reason); + // Add eq = true for dis-equality propagation + assertEqualityInternal(eq, d_true, reason); + } else { + assertEqualityInternal(eq, d_false, reason); + Node eqSymm = eq[1].eqNode(eq[0]); + assertEqualityInternal(eqSymm, d_false, reason); + } +} + +TNode EqualityEngine::getRepresentative(TNode t) const { + Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl; + Assert(hasTerm(t)); + EqualityNodeId representativeId = getEqualityNode(t).getFind(); + Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ") => " << d_nodes[representativeId] << std::endl; + return d_nodes[representativeId]; +} + +bool EqualityEngine::areEqual(TNode t1, TNode t2) const { + Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; + + Assert(hasTerm(t1)); + Assert(hasTerm(t2)); + + // Both following commands are semantically const + EqualityNodeId rep1 = getEqualityNode(t1).getFind(); + EqualityNodeId rep2 = getEqualityNode(t2).getFind(); + + Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ") => " << (rep1 == rep2 ? "true" : "false") << std::endl; + + return rep1 == rep2; +} + +bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { + + Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; + + Assert(triggers.empty()); + + ++ d_stats.mergesCount; + + EqualityNodeId class1Id = class1.getFind(); + EqualityNodeId class2Id = class2.getFind(); + + // Update class2 representative information + Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl; + EqualityNodeId currentId = class2Id; + do { + // Get the current node + EqualityNode& currentNode = getEqualityNode(currentId); + + // Update it's find to class1 id + Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << "->" << class1Id << std::endl; + currentNode.setFind(class1Id); + + // Go through the triggers and inform if necessary + TriggerId currentTrigger = d_nodeTriggers[currentId]; + while (currentTrigger != null_trigger) { + Trigger& trigger = d_equalityTriggers[currentTrigger]; + Trigger& otherTrigger = d_equalityTriggers[currentTrigger ^ 1]; + + // If the two are not already in the same class + if (otherTrigger.classId != trigger.classId) { + trigger.classId = class1Id; + // If they became the same, call the trigger + if (otherTrigger.classId == class1Id) { + // Id of the real trigger is half the internal one + triggers.push_back(currentTrigger); + } + } + + // Go to the next trigger + currentTrigger = trigger.nextTrigger; + } + + // Move to the next node + currentId = currentNode.getNext(); + + } while (currentId != class2Id); + + + // Update class2 table lookup and information + Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of " << class2Id << std::endl; + do { + // Get the current node + EqualityNode& currentNode = getEqualityNode(currentId); + Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl; + + // Go through the uselist and check for congruences + UseListNodeId currentUseId = currentNode.getUseList(); + while (currentUseId != null_uselist_id) { + // Get the node of the use list + UseListNode& useNode = d_useListNodes[currentUseId]; + // Get the function application + EqualityNodeId funId = useNode.getApplicationId(); + Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << " in " << d_nodes[funId] << std::endl; + const FunctionApplication& fun = d_applications[useNode.getApplicationId()].normalized; + // Check if there is an application with find arguments + EqualityNodeId aNormalized = getEqualityNode(fun.a).getFind(); + EqualityNodeId bNormalized = getEqualityNode(fun.b).getFind(); + FunctionApplication funNormalized(aNormalized, bNormalized); + typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized); + if (find != d_applicationLookup.end()) { + // Applications fun and the funNormalized can be merged due to congruence + if (getEqualityNode(funId).getFind() != getEqualityNode(find->second).getFind()) { + enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null())); + } + } else { + // There is no representative, so we can add one, we remove this when backtracking + storeApplicationLookup(funNormalized, funId); + } + // Go to the next one in the use list + currentUseId = useNode.getNext(); + } + + // Move to the next node + currentId = currentNode.getNext(); + } while (currentId != class2Id); + + // Now merge the lists + class1.merge(class2); + + // Check for constants + EqualityNodeId class2constId = d_constantRepresentative[class2Id]; + if (class2constId != +null_id) { + EqualityNodeId class1constId = d_constantRepresentative[class1Id]; + if (class1constId != +null_id) { + if (d_performNotify) { + TNode const1 = d_nodes[class1constId]; + TNode const2 = d_nodes[class2constId]; + if (!d_notify.eqNotifyConstantTermMerge(const1, const2)) { + return false; + } + } + } else { + // If the class we're merging in is constant, mark the representative as constant + d_constantRepresentative[class1Id] = d_constantRepresentative[class2Id]; + d_constantRepresentatives.push_back(class1Id); + d_constantRepresentativesSize = d_constantRepresentativesSize + 1; + } + } + + // Notify the trigger term merges + EqualityNodeId class2triggerId = d_nodeIndividualTrigger[class2Id]; + if (class2triggerId != +null_id) { + EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id]; + if (class1triggerId == +null_id) { + // If class1 is not an individual trigger, but class2 is, mark it + d_nodeIndividualTrigger[class1Id] = class2triggerId; + // Add it to the list for backtracking + d_individualTriggers.push_back(class1Id); + d_individualTriggersSize = d_individualTriggersSize + 1; + } else { + // Notify when done + if (d_performNotify) { + if (!d_notify.eqNotifyTriggerTermEquality(d_nodes[class1triggerId], d_nodes[class2triggerId], true)) { + return false; + } + } + } + } + + // Everything fine + return true; +} + +void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) { + + Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl; + + // Now unmerge the lists (same as merge) + class1.merge(class2); + + // Update class2 representative information + EqualityNodeId currentId = class2Id; + Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << "): undoing representative info" << std::endl; + do { + // Get the current node + EqualityNode& currentNode = getEqualityNode(currentId); + + // Update it's find to class1 id + currentNode.setFind(class2Id); + + // Go through the trigger list (if any) and undo the class + TriggerId currentTrigger = d_nodeTriggers[currentId]; + while (currentTrigger != null_trigger) { + Trigger& trigger = d_equalityTriggers[currentTrigger]; + trigger.classId = class2Id; + currentTrigger = trigger.nextTrigger; + } + + // Move to the next node + currentId = currentNode.getNext(); + + } while (currentId != class2Id); + +} + +void EqualityEngine::backtrack() { + + Debug("equality::backtrack") << "backtracking" << std::endl; + + // If we need to backtrack then do it + if (d_assertedEqualitiesCount < d_assertedEqualities.size()) { + + // Clear the propagation queue + while (!d_propagationQueue.empty()) { + d_propagationQueue.pop(); + } + + Debug("equality") << "EqualityEngine::backtrack(): nodes" << std::endl; + + for (int i = (int)d_assertedEqualities.size() - 1, i_end = (int)d_assertedEqualitiesCount; i >= i_end; --i) { + // Get the ids of the merged classes + Equality& eq = d_assertedEqualities[i]; + // Undo the merge + undoMerge(d_equalityNodes[eq.lhs], d_equalityNodes[eq.rhs], eq.rhs); + } + + d_assertedEqualities.resize(d_assertedEqualitiesCount); + + Debug("equality") << "EqualityEngine::backtrack(): edges" << std::endl; + + for (int i = (int)d_equalityEdges.size() - 2, i_end = (int)(2*d_assertedEqualitiesCount); i >= i_end; i -= 2) { + EqualityEdge& edge1 = d_equalityEdges[i]; + EqualityEdge& edge2 = d_equalityEdges[i | 1]; + d_equalityGraph[edge2.getNodeId()] = edge1.getNext(); + d_equalityGraph[edge1.getNodeId()] = edge2.getNext(); + } + + d_equalityEdges.resize(2 * d_assertedEqualitiesCount); + } + + if (d_individualTriggers.size() > d_individualTriggersSize) { + // Unset the individual triggers + for (int i = d_individualTriggers.size() - 1, i_end = d_individualTriggersSize; i >= i_end; -- i) { + d_nodeIndividualTrigger[d_individualTriggers[i]] = +null_id; + } + d_individualTriggers.resize(d_individualTriggersSize); + } + + if (d_constantRepresentatives.size() > d_constantRepresentativesSize) { + // Unset the constant representatives + for (int i = d_constantRepresentatives.size() - 1, i_end = d_constantRepresentativesSize; i >= i_end; -- i) { + d_constantRepresentative[d_constantRepresentatives[i]] = +null_id; + } + d_constantRepresentatives.resize(d_constantRepresentativesSize); + } + + if (d_equalityTriggers.size() > d_equalityTriggersCount) { + // Unlink the triggers from the lists + for (int i = d_equalityTriggers.size() - 1, i_end = d_equalityTriggersCount; i >= i_end; -- i) { + const Trigger& trigger = d_equalityTriggers[i]; + d_nodeTriggers[trigger.classId] = trigger.nextTrigger; + } + // Get rid of the triggers + d_equalityTriggers.resize(d_equalityTriggersCount); + d_equalityTriggersOriginal.resize(d_equalityTriggersCount); + } + + if (d_applicationLookups.size() > d_applicationLookupsCount) { + for (int i = d_applicationLookups.size() - 1, i_end = (int) d_applicationLookupsCount; i >= i_end; -- i) { + d_applicationLookup.erase(d_applicationLookups[i]); + } + d_applicationLookups.resize(d_applicationLookupsCount); + } + + if (d_nodes.size() > d_nodesCount) { + // Go down the nodes, check the application nodes and remove them from use-lists + for(int i = d_nodes.size() - 1, i_end = (int)d_nodesCount; i >= i_end; -- i) { + // Remove from the node -> id map + Debug("equality") << "EqualityEngine::backtrack(): removing node " << d_nodes[i] << std::endl; + d_nodeIds.erase(d_nodes[i]); + + const FunctionApplication& app = d_applications[i].normalized; + if (app.isApplication()) { + // Remove b from use-list + getEqualityNode(app.b).removeTopFromUseList(d_useListNodes); + // Remove a from use-list + getEqualityNode(app.a).removeTopFromUseList(d_useListNodes); + } + } + + // Now get rid of the nodes and the rest + d_nodes.resize(d_nodesCount); + d_applications.resize(d_nodesCount); + d_nodeTriggers.resize(d_nodesCount); + d_nodeIndividualTrigger.resize(d_nodesCount); + d_constantRepresentative.resize(d_nodesCount); + d_equalityGraph.resize(d_nodesCount); + d_equalityNodes.resize(d_nodesCount); + } +} + +void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) { + Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << "," << reason << ")" << std::endl; + EqualityEdgeId edge = d_equalityEdges.size(); + d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1], type, reason)); + d_equalityEdges.push_back(EqualityEdge(t1, d_equalityGraph[t2], type, reason)); + d_equalityGraph[t1] = edge; + d_equalityGraph[t2] = edge | 1; + + if (Debug.isOn("equality::internal")) { + debugPrintGraph(); + } +} + +std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { + std::stringstream out; + bool first = true; + if (edgeId == null_edge) { + out << "null"; + } else { + while (edgeId != null_edge) { + const EqualityEdge& edge = d_equalityEdges[edgeId]; + if (!first) out << ","; + out << d_nodes[edge.getNodeId()]; + edgeId = edge.getNext(); + first = false; + } + } + return out.str(); +} + +void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector& equalities) { + Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl; + + // Don't notify during this check + ScopedBool turnOffNotify(d_performNotify, false); + + // Add the terms (they might not be there) + addTerm(t1); + addTerm(t2); + + if (polarity) { + // Get the explanation + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t2Id = getNodeId(t2); + getExplanation(t1Id, t2Id, equalities); + } else { + // Add the equality + Node equality = t1.eqNode(t2); + addTerm(equality); + + // Get the explanation + EqualityNodeId equalityId = getNodeId(equality); + EqualityNodeId falseId = getNodeId(d_false); + getExplanation(equalityId, falseId, equalities); + } +} + +void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector& assertions) { + Debug("equality") << "EqualityEngine::explainEquality(" << p << ")" << std::endl; + + // Don't notify during this check + ScopedBool turnOffNotify(d_performNotify, false); + + // Add the terms + addTerm(p); + + // Get the explanation + getExplanation(getNodeId(p), getNodeId(polarity ? d_true : d_false), assertions); +} + +void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector& equalities) const { + + Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl; + + Assert(getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind()); + + // If the nodes are the same, we're done + if (t1Id == t2Id) return; + + if (Debug.isOn("equality::internal")) { + debugPrintGraph(); + } + + // Queue for the BFS containing nodes + std::vector bfsQueue; + + // Find a path from t1 to t2 in the graph (BFS) + bfsQueue.push_back(BfsData(t1Id, null_id, 0)); + size_t currentIndex = 0; + while (true) { + // There should always be a path, and every node can be visited only once (tree) + Assert(currentIndex < bfsQueue.size()); + + // The next node to visit + BfsData current = bfsQueue[currentIndex]; + EqualityNodeId currentNode = current.nodeId; + + Debug("equality") << "EqualityEngine::getExplanation(): currentNode = " << d_nodes[currentNode] << std::endl; + + // Go through the equality edges of this node + EqualityEdgeId currentEdge = d_equalityGraph[currentNode]; + Debug("equality") << "EqualityEngine::getExplanation(): edgesId = " << currentEdge << std::endl; + Debug("equality") << "EqualityEngine::getExplanation(): edges = " << edgesToString(currentEdge) << std::endl; + + while (currentEdge != null_edge) { + // Get the edge + const EqualityEdge& edge = d_equalityEdges[currentEdge]; + + // If not just the backwards edge + if ((currentEdge | 1u) != (current.edgeId | 1u)) { + + Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = (" << d_nodes[currentNode] << "," << d_nodes[edge.getNodeId()] << ")" << std::endl; + + // Did we find the path + if (edge.getNodeId() == t2Id) { + + Debug("equality") << "EqualityEngine::getExplanation(): path found: " << std::endl; + + // Reconstruct the path + do { + // The current node + currentNode = bfsQueue[currentIndex].nodeId; + EqualityNodeId edgeNode = d_equalityEdges[currentEdge].getNodeId(); + MergeReasonType reasonType = d_equalityEdges[currentEdge].getReasonType(); + + Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = " << currentEdge << ", currentNode = " << currentNode << std::endl; + + // Add the actual equality to the vector + if (reasonType == MERGED_THROUGH_CONGRUENCE) { + // f(x1, x2) == f(y1, y2) because x1 = y1 and x2 = y2 + Debug("equality") << "EqualityEngine::getExplanation(): due to congruence, going deeper" << std::endl; + const FunctionApplication& f1 = d_applications[currentNode].original; + const FunctionApplication& f2 = d_applications[edgeNode].original; + Debug("equality") << push; + getExplanation(f1.a, f2.a, equalities); + getExplanation(f1.b, f2.b, equalities); + Debug("equality") << pop; + } else { + // Construct the equality + Debug("equality") << "EqualityEngine::getExplanation(): adding: " << d_equalityEdges[currentEdge].getReason() << std::endl; + equalities.push_back(d_equalityEdges[currentEdge].getReason()); + } + + // Go to the previous + currentEdge = bfsQueue[currentIndex].edgeId; + currentIndex = bfsQueue[currentIndex].previousIndex; + } while (currentEdge != null_id); + + // Done + return; + } + + // Push to the visitation queue if it's not the backward edge + bfsQueue.push_back(BfsData(edge.getNodeId(), currentEdge, currentIndex)); + } + + // Go to the next edge + currentEdge = edge.getNext(); + } + + // Go to the next node to visit + ++ currentIndex; + } +} + +void EqualityEngine::addTriggerEquality(TNode eq) { + Assert(eq.getKind() == kind::EQUAL); + // Add the terms + addTerm(eq); + // Positive trigger + addTriggerEqualityInternal(eq[0], eq[1], eq, true); + // Negative trigger + addTriggerEqualityInternal(eq, d_false, eq, false); +} + +void EqualityEngine::addTriggerPredicate(TNode predicate) { + Assert(predicate.getKind() != kind::NOT && predicate.getKind() != kind::EQUAL); + Assert(d_congruenceKinds.tst(predicate.getKind()), "No point in adding non-congruence predicates"); + // Add the term + addTerm(predicate); + // Positive trigger + addTriggerEqualityInternal(predicate, d_true, predicate, true); + // Negative trigger + addTriggerEqualityInternal(predicate, d_false, predicate, false); +} + +void EqualityEngine::addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity) { + + Debug("equality") << "EqualityEngine::addTrigger(" << t1 << ", " << t2 << ", " << trigger << ")" << std::endl; + + Assert(hasTerm(t1)); + Assert(hasTerm(t2)); + + // Get the information about t1 + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t1classId = getEqualityNode(t1Id).getFind(); + TriggerId t1TriggerId = d_nodeTriggers[t1classId]; + + // Get the information about t2 + EqualityNodeId t2Id = getNodeId(t2); + EqualityNodeId t2classId = getEqualityNode(t2Id).getFind(); + TriggerId t2TriggerId = d_nodeTriggers[t2classId]; + + Debug("equality") << "EqualityEngine::addTrigger(" << trigger << "): " << t1Id << " (" << t1classId << ") = " << t2Id << " (" << t2classId << ")" << std::endl; + + // Create the triggers + TriggerId t1NewTriggerId = d_equalityTriggers.size(); + TriggerId t2NewTriggerId = t1NewTriggerId | 1; + d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId)); + d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); + d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId)); + d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); + + // Update the counters + d_equalityTriggersCount = d_equalityTriggersCount + 2; + + // Add the trigger to the trigger graph + d_nodeTriggers[t1classId] = t1NewTriggerId; + d_nodeTriggers[t2classId] = t2NewTriggerId; + + // If the trigger is already on, we propagate + if (t1classId == t2classId) { + Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl; + if (d_performNotify) { + d_notify.eqNotifyTriggerEquality(trigger, polarity); // Don't care about the return value + } + } + + if (Debug.isOn("equality::internal")) { + debugPrintGraph(); + } + + Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ") => (" << t1NewTriggerId << ", " << t2NewTriggerId << ")" << std::endl; +} + +void EqualityEngine::propagate() { + + Debug("equality") << "EqualityEngine::propagate()" << std::endl; + + bool done = false; + while (!d_propagationQueue.empty()) { + + // The current merge candidate + const MergeCandidate current = d_propagationQueue.front(); + d_propagationQueue.pop(); + + if (done) { + // If we're done, just empty the queue + continue; + } + + // Get the representatives + EqualityNodeId t1classId = getEqualityNode(current.t1Id).getFind(); + EqualityNodeId t2classId = getEqualityNode(current.t2Id).getFind(); + + // If already the same, we're done + if (t1classId == t2classId) { + continue; + } + + // Get the nodes of the representatives + EqualityNode& node1 = getEqualityNode(t1classId); + EqualityNode& node2 = getEqualityNode(t2classId); + + Assert(node1.getFind() == t1classId); + Assert(node2.getFind() == t2classId); + + // Add the actual equality to the equality graph + addGraphEdge(current.t1Id, current.t2Id, current.type, current.reason); + + // One more equality added + d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1; + + // Depending on the merge preference (such as size), merge them + std::vector triggers; + if (node2.getSize() > node1.getSize()) { + Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl; + d_assertedEqualities.push_back(Equality(t2classId, t1classId)); + done = !merge(node2, node1, triggers); + } else { + Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t2Id] << " into " << d_nodes[current.t1Id] << std::endl; + d_assertedEqualities.push_back(Equality(t1classId, t2classId)); + done = !merge(node1, node2, triggers); + } + + // Notify the triggers + if (d_performNotify && !done) { + for (size_t trigger_i = 0, trigger_end = triggers.size(); trigger_i < trigger_end && !done; ++ trigger_i) { + const TriggerInfo& triggerInfo = d_equalityTriggersOriginal[triggers[trigger_i]]; + // Notify the trigger and exit if it fails + if (triggerInfo.trigger.getKind() == kind::EQUAL) { + done = !d_notify.eqNotifyTriggerEquality(triggerInfo.trigger, triggerInfo.polarity); + } else { + done = !d_notify.eqNotifyTriggerPredicate(triggerInfo.trigger, triggerInfo.polarity); + } + } + } + } +} + +void EqualityEngine::debugPrintGraph() const { + for (EqualityNodeId nodeId = 0; nodeId < d_nodes.size(); ++ nodeId) { + + Debug("equality::graph") << d_nodes[nodeId] << " " << nodeId << "(" << getEqualityNode(nodeId).getFind() << "):"; + + EqualityEdgeId edgeId = d_equalityGraph[nodeId]; + while (edgeId != null_edge) { + const EqualityEdge& edge = d_equalityEdges[edgeId]; + Debug("equality::graph") << " " << d_nodes[edge.getNodeId()] << ":" << edge.getReason(); + edgeId = edge.getNext(); + } + + Debug("equality::graph") << std::endl; + } +} + +bool EqualityEngine::areEqual(TNode t1, TNode t2) +{ + // Don't notify during this check + ScopedBool turnOffNotify(d_performNotify, false); + + // Add the terms + addTerm(t1); + addTerm(t2); + bool equal = getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind(); + + // Return whether the two terms are equal + return equal; +} + +bool EqualityEngine::areDisequal(TNode t1, TNode t2) +{ + // Don't notify during this check + ScopedBool turnOffNotify(d_performNotify, false); + + // Add the terms + addTerm(t1); + addTerm(t2); + + // Check (t1 = t2) = false + // No need to check the symmetric version: we can only deduce a disequality from an existing + // diseqality, and each of those is asserted in the symmetric version also + Node equality = t1.eqNode(t2); + addTerm(equality); + if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) { + return true; + } + + // Return whether the terms are disequal + return false; +} + +size_t EqualityEngine::getSize(TNode t) +{ + // Add the term + addTerm(t); + return getEqualityNode(getEqualityNode(t).getFind()).getSize(); +} + +void EqualityEngine::addTriggerTerm(TNode t) +{ + Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ")" << std::endl; + + // Add the term if it's not already there + addTerm(t); + + // Get the node id + EqualityNodeId eqNodeId = getNodeId(t); + EqualityNode& eqNode = getEqualityNode(eqNodeId); + EqualityNodeId classId = eqNode.getFind(); + + if (d_nodeIndividualTrigger[classId] != +null_id) { + // No need to keep it, just propagate the existing individual triggers + if (d_performNotify) { + d_notify.eqNotifyTriggerTermEquality(t, d_nodes[d_nodeIndividualTrigger[classId]], true); + } + } else { + // Add it to the list for backtracking + d_individualTriggers.push_back(classId); + d_individualTriggersSize = d_individualTriggersSize + 1; + // Mark the class id as a trigger + d_nodeIndividualTrigger[classId] = eqNodeId; + } +} + +bool EqualityEngine::isTriggerTerm(TNode t) const { + if (!hasTerm(t)) return false; + EqualityNodeId classId = getEqualityNode(t).getFind(); + return d_nodeIndividualTrigger[classId] != +null_id; +} + + +TNode EqualityEngine::getTriggerTermRepresentative(TNode t) const { + Assert(isTriggerTerm(t)); + EqualityNodeId classId = getEqualityNode(t).getFind(); + return d_nodes[d_nodeIndividualTrigger[classId]]; +} + +void EqualityEngine::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) { + Assert(d_applicationLookup.find(funNormalized) == d_applicationLookup.end()); + d_applicationLookup[funNormalized] = funId; + d_applicationLookups.push_back(funNormalized); + d_applicationLookupsCount = d_applicationLookupsCount + 1; + Debug("equality::backtrack") << "d_applicationLookupsCount = " << d_applicationLookupsCount << std::endl; + Debug("equality::backtrack") << "d_applicationLookups.size() = " << d_applicationLookups.size() << std::endl; + Assert(d_applicationLookupsCount == d_applicationLookups.size()); +} + +void EqualityEngine::getUseListTerms(TNode t, std::set& output) { + if (hasTerm(t)) { + // Get the equivalence class + EqualityNodeId classId = getEqualityNode(t).getFind(); + // Go through the equivalence class and get where t is used in + EqualityNodeId currentId = classId; + do { + // Get the current node + EqualityNode& currentNode = getEqualityNode(currentId); + // Go through the use-list + UseListNodeId currentUseId = currentNode.getUseList(); + while (currentUseId != null_uselist_id) { + // Get the node of the use list + UseListNode& useNode = d_useListNodes[currentUseId]; + // Get the function application + EqualityNodeId funId = useNode.getApplicationId(); + output.insert(d_nodes[funId]); + // Go to the next one in the use list + currentUseId = useNode.getNext(); + } + // Move to the next node + currentId = currentNode.getNext(); + } while (currentId != classId); + } +} + +} // Namespace uf +} // Namespace theory +} // Namespace CVC4 + diff --git a/src/theory/uf/equality_engine.h b/src/theory/uf/equality_engine.h index dccd5ba56..f9c10d1b6 100644 --- a/src/theory/uf/equality_engine.h +++ b/src/theory/uf/equality_engine.h @@ -35,7 +35,7 @@ namespace CVC4 { namespace theory { -namespace uf { +namespace eq { /** Id of the node */ typedef size_t EqualityNodeId; @@ -213,9 +213,74 @@ public: } }; -template +/** + * Interface for equality engine notifications. All the notifications + * are safe as TNodes, but not necessarily for negations. + */ +class EqualityEngineNotify { + + friend class EqualityEngine; + +public: + + virtual ~EqualityEngineNotify() {}; + + /** + * Notifies about a trigger equality that became true or false. + * + * @param eq the equality that became true or false + * @param value the value of the equality + */ + virtual bool eqNotifyTriggerEquality(TNode equality, bool value) = 0; + + /** + * Notifies about a trigger predicate that became true or false. + * + * @param predicate the trigger predicate that bacame true or false + * @param value the value of the predicate + */ + virtual bool eqNotifyTriggerPredicate(TNode predicate, bool value) = 0; + + /** + * Notifies about the merge of two trigger terms. + * + * @param t1 a term marked as trigger + * @param t2 a term marked as trigger + * @param value true if equal, false if dis-equal + */ + virtual bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) = 0; + + /** + * Notifies about the merge of two constant terms. + * + * @param t1 a constant term + * @param t2 a constnat term + */ + virtual bool eqNotifyConstantTermMerge(TNode t1, TNode t2) = 0; +}; + +/** + * Implementation of the notification interface that ignores all the + * notifications. + */ +class EqualityEngineNotifyNone : public EqualityEngineNotify { +public: + bool eqNotifyTriggerEquality(TNode equality, bool value) { return true; } + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { return true; } + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { return true; } + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { return true; } +}; + + +/** + * Class for keeping an incremental congurence closure over a set of terms. It provides + * notifications via an EqualityEngineNotify object. + */ class EqualityEngine : public context::ContextNotifyObj { + /** Default implementation of the notification object */ + static EqualityEngineNotifyNone s_notifyNone; + public: /** Statistics about the equality engine instance */ @@ -226,21 +291,26 @@ public: IntStat termsCount; /** Number of function terms managed by the system */ IntStat functionTermsCount; + /** Number of constant terms managed by the system */ + IntStat constantTermsCount; Statistics(std::string name) : mergesCount(name + "::mergesCount", 0), termsCount(name + "::termsCount", 0), - functionTermsCount(name + "::functionTermsCount", 0) + functionTermsCount(name + "::functionTermsCount", 0), + constantTermsCount(name + "::constantTermsCount", 0) { StatisticsRegistry::registerStat(&mergesCount); StatisticsRegistry::registerStat(&termsCount); StatisticsRegistry::registerStat(&functionTermsCount); + StatisticsRegistry::registerStat(&constantTermsCount); } ~Statistics() { StatisticsRegistry::unregisterStat(&mergesCount); StatisticsRegistry::unregisterStat(&termsCount); StatisticsRegistry::unregisterStat(&functionTermsCount); + StatisticsRegistry::unregisterStat(&constantTermsCount); } }; @@ -282,7 +352,7 @@ private: bool d_performNotify; /** The class to notify when a representative changes for a term */ - NotifyClass d_notify; + EqualityEngineNotify& d_notify; /** The map of kinds to be treated as function applications */ KindMap d_congruenceKinds; @@ -428,8 +498,11 @@ private: /** Returns the id of the node */ EqualityNodeId getNodeId(TNode node) const; - /** Merge the class2 into class1 */ - void merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers); + /** + * Merge the class2 into class1 + * @return true if ok, false if to break out + */ + bool merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers); /** Undo the mereg of class2 into class1 */ void undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id); @@ -437,29 +510,13 @@ private: /** Backtrack the information if necessary */ void backtrack(); - /** - * Data used in the BFS search through the equality graph. - */ - struct BfsData { - // The current node - EqualityNodeId nodeId; - // The index of the edge we traversed - EqualityEdgeId edgeId; - // Index in the queue of the previous node. Shouldn't be too much of them, at most the size - // of the biggest equivalence class - size_t previousIndex; - - BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0) - : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {} - }; - /** * Trigger that will be updated */ struct Trigger { /** The current class id of the LHS of the trigger */ EqualityNodeId classId; - /** Next trigger for class 1 */ + /** Next trigger for class */ TriggerId nextTrigger; Trigger(EqualityNodeId classId = null_id, TriggerId nextTrigger = null_trigger) @@ -473,10 +530,20 @@ private: */ std::vector d_equalityTriggers; + struct TriggerInfo { + /** The trigger itself */ + Node trigger; + /** Polarity of the trigger */ + bool polarity; + TriggerInfo() {} + TriggerInfo(Node trigger, bool polarity) + : trigger(trigger), polarity(polarity) {} + }; + /** * Vector of original equalities of the triggers. */ - std::vector d_equalityTriggersOriginal; + std::vector d_equalityTriggersOriginal; /** * Context dependent count of triggers @@ -504,6 +571,19 @@ private: */ std::vector d_nodeIndividualTrigger; + /** + * Map from ids to the id of the constant that is the representative. + */ + std::vector d_constantRepresentative; + + /** + * Size of the constant representatives list. + */ + context::CDO d_constantRepresentativesSize; + + /** The list of representatives that became constant. */ + std::vector d_constantRepresentatives; + /** * Adds the trigger with triggerId to the beginning of the trigger list of the node with id nodeId. */ @@ -516,7 +596,7 @@ private: EqualityNodeId newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2); /** Add a new node to the database */ - EqualityNodeId newNode(TNode t, bool isApplication); + EqualityNodeId newNode(TNode t); struct MergeCandidate { EqualityNodeId t1Id, t2Id; @@ -561,44 +641,41 @@ private: /** * Adds an equality of terms t1 and t2 to the database. */ - void addEqualityInternal(TNode t1, TNode t2, TNode reason); + void assertEqualityInternal(TNode t1, TNode t2, TNode reason); -public: + /** + * Adds a trigger equality to the database with the trigger node and polarity for notification. + */ + void addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity); /** - * Initialize the equality engine, given the owning class. This will initialize the notifier with - * the owner information. - */ - EqualityEngine(NotifyClass& notify, context::Context* context, std::string name) - : ContextNotifyObj(context), - d_context(context), - d_performNotify(true), - d_notify(notify), - d_applicationLookupsCount(context, 0), - d_nodesCount(context, 0), - d_assertedEqualitiesCount(context, 0), - d_equalityTriggersCount(context, 0), - d_individualTriggersSize(context, 0), - d_stats(name) - { - Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl; - Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl; - Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl; - d_true = NodeManager::currentNM()->mkConst(true); - d_false = NodeManager::currentNM()->mkConst(false); + * This method gets called on backtracks from the context manager. + */ + void contextNotifyPop() { + backtrack(); } /** - * Just a destructor. + * Constructor initialization stuff. */ - virtual ~EqualityEngine() throw(AssertionException) {} + void init(); + +public: /** - * This method gets called on backtracks from the context manager. + * Initialize the equality engine, given the notification class. */ - void notify() { - backtrack(); - } + EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name); + + /** + * Initialize the equality engine with no notification class. + */ + EqualityEngine(context::Context* context, std::string name); + + /** + * Just a destructor. + */ + virtual ~EqualityEngine() throw(AssertionException) {} /** * Adds a term to the term database. @@ -629,77 +706,91 @@ public: bool hasTerm(TNode t) const; /** - * Adds aa predicate t with given polarity + * Adds a predicate p with given polarity. The predicate asserted + * should be in the coungruence closure kinds (otherwise it's + * useless. + * + * @param p the (non-negated) predicate + * @param polarity true if asserting the predicate, false if + * asserting the negated predicate + * @param the reason to keep for building explanations */ - void addPredicate(TNode t, bool polarity, TNode reason); + void assertPredicate(TNode p, bool polarity, TNode reason); /** - * Adds an equality t1 = t2 to the database. + * Adds an equality eq with the given polarity to the database. + * + * @param eq the (non-negated) equality + * @param polarity true if asserting the equality, false if + * asserting the negated equality + * @param the reason to keep for building explanations */ - void addEquality(TNode t1, TNode t2, TNode reason); + void assertEquality(TNode eq, bool polarity, TNode reason); /** - * Adds an dis-equality t1 != t2 to the database. - */ - void addDisequality(TNode t1, TNode t2, TNode reason); - - /** - * Returns the representative of the term t. + * Returns the current representative of the term t. */ TNode getRepresentative(TNode t) const; /** - * Add all the terms where the given term appears in (directly or implicitly). + * Add all the terms where the given term appears as a first child + * (directly or implicitly). */ void getUseListTerms(TNode t, std::set& output); /** - * Returns true if the two nodes are in the same class. + * Returns true if the two nodes are in the same equivalence class. */ bool areEqual(TNode t1, TNode t2) const; /** - * Get an explanation of the equality t1 = t2. Returns the asserted equalities that - * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere - * else. + * Get an explanation of the equality t1 = t2 begin true of false. + * Returns the reasons (added when asserting) that imply it + * in the assertions vector. */ - void explainEquality(TNode t1, TNode t2, std::vector& equalities); + void explainEquality(TNode t1, TNode t2, bool polarity, std::vector& assertions); /** - * Get an explanation of the equality t1 = t2. Returns the asserted equalities that - * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere - * else. + * Get an explanation of the predicate being true or false. + * Returns the reasons (added when asserting) that imply imply it + * in the assertions vector. */ - void explainDisequality(TNode t1, TNode t2, std::vector& equalities); + void explainPredicate(TNode p, bool polarity, std::vector& assertions); /** - * Add term to the trigger terms. The notify class will get notified when two - * trigger terms become equal. Thihs will only happen on trigger term - * representatives. + * Add term to the trigger terms. The notify class will get notified + * when two trigger terms become equal or dis-equal. The notification + * will not happen on all the terms, but only on the ones that are + * represent the class. */ void addTriggerTerm(TNode t); /** - * Returns true if t is a trigger term or equal to some other trigger term. + * Returns true if t is a trigger term or in the same equivalence + * class as some other trigger term. */ bool isTriggerTerm(TNode t) const; /** - * Returns the representative trigger term (isTriggerTerm(t)) should be true. + * Returns the representative trigger term of the given term. + * + * @param t the term to check where isTriggerTerm(t) should be true */ TNode getTriggerTermRepresentative(TNode t) const; /** - * Adds a notify trigger for equality t1 = t2, i.e. when t1 = t2 the notify will be called with - * trigger. + * Adds a notify trigger for equality. When equality becomes true eqNotifyTriggerEquality + * will be called with value = true, and when equality becomes false eqNotifyTriggerEquality + * will be called with value = false. */ - void addTriggerEquality(TNode t1, TNode t2, TNode trigger); + void addTriggerEquality(TNode equality); /** - * Adds a notify trigger for dis-equality t1 != t2, i.e. when t1 != t2 the notify will be called with - * trigger. + * Adds a notify trigger for the predicate p. When the predicate becomes true + * eqNotifyTriggerPredicate will be called with value = true, and when equality becomes false + * eqNotifyTriggerPredicate will be called with value = false. */ - void addTriggerDisequality(TNode t1, TNode t2, TNode trigger); + void addTriggerPredicate(TNode predicate); /** * Check whether the two terms are equal. @@ -712,7 +803,7 @@ public: bool areDisequal(TNode t1, TNode t2); /** - * Return the number of nodes in the equivalence class contianing t + * Return the number of nodes in the equivalence class containing t * Adds t if not already there. */ size_t getSize(TNode t); diff --git a/src/theory/uf/equality_engine_impl.h b/src/theory/uf/equality_engine_impl.h deleted file mode 100644 index be12e5f19..000000000 --- a/src/theory/uf/equality_engine_impl.h +++ /dev/null @@ -1,947 +0,0 @@ -/********************* */ -/*! \file equality_engine_impl.h - ** \verbatim - ** Original author: dejan - ** Major contributors: none - ** Minor contributors (to current version): none - ** This file is part of the CVC4 prototype. - ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) - ** Courant Institute of Mathematical Sciences - ** New York University - ** See the file COPYING in the top-level source directory for licensing - ** information.\endverbatim - ** - ** \brief [[ Add one-line brief description here ]] - ** - ** [[ Add lengthier description here ]] - ** \todo document this file - **/ - -#include "cvc4_private.h" - -#pragma once - -#include "theory/uf/equality_engine.h" - -namespace CVC4 { -namespace theory { -namespace uf { - -class ScopedBool { - bool& watch; - bool oldValue; -public: - ScopedBool(bool& watch, bool newValue) - : watch(watch), oldValue(watch) { - watch = newValue; - } - ~ScopedBool() { - watch = oldValue; - } -}; - -template -void EqualityEngine::enqueue(const MergeCandidate& candidate) { - Debug("equality") << "EqualityEngine::enqueue(" << candidate.toString(*this) << ")" << std::endl; - d_propagationQueue.push(candidate); -} - -template -EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) { - Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ")" << std::endl; - - ++ d_stats.functionTermsCount; - - // Get another id for this - EqualityNodeId funId = newNode(original, true); - FunctionApplication funOriginal(t1, t2); - // The function application we're creating - EqualityNodeId t1ClassId = getEqualityNode(t1).getFind(); - EqualityNodeId t2ClassId = getEqualityNode(t2).getFind(); - FunctionApplication funNormalized(t1ClassId, t2ClassId); - - // We add the original version - d_applications[funId] = FunctionApplicationPair(funOriginal, funNormalized); - - // Add the lookup data, if it's not already there - typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized); - if (find == d_applicationLookup.end()) { - // When we backtrack, if the lookup is not there anymore, we'll add it again - Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): no lookup, setting up" << std::endl; - // Mark the normalization to the lookup - storeApplicationLookup(funNormalized, funId); - } else { - // If it's there, we need to merge these two - Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << "): lookup exists, adding to queue" << std::endl; - enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null())); - propagate(); - } - - // Add to the use lists - d_equalityNodes[t1ClassId].usedIn(funId, d_useListNodes); - d_equalityNodes[t2ClassId].usedIn(funId, d_useListNodes); - - // Return the new id - Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl; - - return funId; -} - -template -EqualityNodeId EqualityEngine::newNode(TNode node, bool isApplication) { - - Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ")" << std::endl; - - ++ d_stats.termsCount; - - // Register the new id of the term - EqualityNodeId newId = d_nodes.size(); - d_nodeIds[node] = newId; - // Add the node to it's position - d_nodes.push_back(node); - // Note if this is an application or not - d_applications.push_back(FunctionApplicationPair()); - // Add the trigger list for this node - d_nodeTriggers.push_back(+null_trigger); - // Add it to the equality graph - d_equalityGraph.push_back(+null_edge); - // Mark the no-individual trigger - d_nodeIndividualTrigger.push_back(+null_id); - // Add the equality node to the nodes - d_equalityNodes.push_back(EqualityNode(newId)); - - // Increase the counters - d_nodesCount = d_nodesCount + 1; - - Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ") => " << newId << std::endl; - - return newId; -} - - -template -void EqualityEngine::addTerm(TNode t) { - - Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl; - - // If there already, we're done - if (hasTerm(t)) { - Debug("equality") << "EqualityEngine::addTerm(" << t << "): already there" << std::endl; - return; - } - - EqualityNodeId result; - - // If a function application we go in - if (t.getNumChildren() > 0 && d_congruenceKinds[t.getKind()]) - { - // Add the operator - TNode tOp = t.getOperator(); - addTerm(tOp); - // Add all the children and Curryfy - result = getNodeId(tOp); - for (unsigned i = 0; i < t.getNumChildren(); ++ i) { - // Add the child - addTerm(t[i]); - // Add the application - result = newApplicationNode(t, result, getNodeId(t[i])); - } - } else { - // Otherwise we just create the new id - result = newNode(t, false); - } - - Debug("equality") << "EqualityEngine::addTerm(" << t << ") => " << result << std::endl; -} - -template -bool EqualityEngine::hasTerm(TNode t) const { - return d_nodeIds.find(t) != d_nodeIds.end(); -} - -template -EqualityNodeId EqualityEngine::getNodeId(TNode node) const { - Assert(hasTerm(node), node.toString().c_str()); - return (*d_nodeIds.find(node)).second; -} - -template -EqualityNode& EqualityEngine::getEqualityNode(TNode t) { - return getEqualityNode(getNodeId(t)); -} - -template -EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) { - Assert(nodeId < d_equalityNodes.size()); - return d_equalityNodes[nodeId]; -} - -template -const EqualityNode& EqualityEngine::getEqualityNode(TNode t) const { - return getEqualityNode(getNodeId(t)); -} - -template -const EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) const { - Assert(nodeId < d_equalityNodes.size()); - return d_equalityNodes[nodeId]; -} - -template -void EqualityEngine::addEqualityInternal(TNode t1, TNode t2, TNode reason) { - - Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl; - - // Add the terms if they are not already in the database - addTerm(t1); - addTerm(t2); - - // Add to the queue and propagate - EqualityNodeId t1Id = getNodeId(t1); - EqualityNodeId t2Id = getNodeId(t2); - enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason)); - - propagate(); -} - -template -void EqualityEngine::addPredicate(TNode t, bool polarity, TNode reason) { - - Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl; - - addEqualityInternal(t, polarity ? d_true : d_false, reason); -} - -template -void EqualityEngine::addEquality(TNode t1, TNode t2, TNode reason) { - - Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl; - - addEqualityInternal(t1, t2, reason); - - Node equality = t1.eqNode(t2); - addEqualityInternal(equality, d_true, reason); -} - -template -void EqualityEngine::addDisequality(TNode t1, TNode t2, TNode reason) { - - Debug("equality") << "EqualityEngine::addDisequality(" << t1 << "," << t2 << ")" << std::endl; - - Node equality1 = t1.eqNode(t2); - addEqualityInternal(equality1, d_false, reason); - - Node equality2 = t2.eqNode(t1); - addEqualityInternal(equality2, d_false, reason); -} - - -template -TNode EqualityEngine::getRepresentative(TNode t) const { - - Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl; - - Assert(hasTerm(t)); - - // Both following commands are semantically const - EqualityNodeId representativeId = getEqualityNode(t).getFind(); - - Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ") => " << d_nodes[representativeId] << std::endl; - - return d_nodes[representativeId]; -} - -template -bool EqualityEngine::areEqual(TNode t1, TNode t2) const { - Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; - - Assert(hasTerm(t1)); - Assert(hasTerm(t2)); - - // Both following commands are semantically const - EqualityNodeId rep1 = getEqualityNode(t1).getFind(); - EqualityNodeId rep2 = getEqualityNode(t2).getFind(); - - Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ") => " << (rep1 == rep2 ? "true" : "false") << std::endl; - - return rep1 == rep2; -} - -template -void EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { - - Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; - - Assert(triggers.empty()); - - ++ d_stats.mergesCount; - - EqualityNodeId class1Id = class1.getFind(); - EqualityNodeId class2Id = class2.getFind(); - - // Update class2 representative information - Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl; - EqualityNodeId currentId = class2Id; - do { - // Get the current node - EqualityNode& currentNode = getEqualityNode(currentId); - - // Update it's find to class1 id - Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << "->" << class1Id << std::endl; - currentNode.setFind(class1Id); - - // Go through the triggers and inform if necessary - TriggerId currentTrigger = d_nodeTriggers[currentId]; - while (currentTrigger != null_trigger) { - Trigger& trigger = d_equalityTriggers[currentTrigger]; - Trigger& otherTrigger = d_equalityTriggers[currentTrigger ^ 1]; - - // If the two are not already in the same class - if (otherTrigger.classId != trigger.classId) { - trigger.classId = class1Id; - // If they became the same, call the trigger - if (otherTrigger.classId == class1Id) { - // Id of the real trigger is half the internal one - triggers.push_back(currentTrigger); - } - } - - // Go to the next trigger - currentTrigger = trigger.nextTrigger; - } - - // Move to the next node - currentId = currentNode.getNext(); - - } while (currentId != class2Id); - - - // Update class2 table lookup and information - Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of " << class2Id << std::endl; - do { - // Get the current node - EqualityNode& currentNode = getEqualityNode(currentId); - Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl; - - // Go through the uselist and check for congruences - UseListNodeId currentUseId = currentNode.getUseList(); - while (currentUseId != null_uselist_id) { - // Get the node of the use list - UseListNode& useNode = d_useListNodes[currentUseId]; - // Get the function application - EqualityNodeId funId = useNode.getApplicationId(); - Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): " << currentId << " in " << d_nodes[funId] << std::endl; - const FunctionApplication& fun = d_applications[useNode.getApplicationId()].normalized; - // Check if there is an application with find arguments - EqualityNodeId aNormalized = getEqualityNode(fun.a).getFind(); - EqualityNodeId bNormalized = getEqualityNode(fun.b).getFind(); - FunctionApplication funNormalized(aNormalized, bNormalized); - typename ApplicationIdsMap::iterator find = d_applicationLookup.find(funNormalized); - if (find != d_applicationLookup.end()) { - // Applications fun and the funNormalized can be merged due to congruence - if (getEqualityNode(funId).getFind() != getEqualityNode(find->second).getFind()) { - enqueue(MergeCandidate(funId, find->second, MERGED_THROUGH_CONGRUENCE, TNode::null())); - } - } else { - // There is no representative, so we can add one, we remove this when backtracking - storeApplicationLookup(funNormalized, funId); - } - // Go to the next one in the use list - currentUseId = useNode.getNext(); - } - - // Move to the next node - currentId = currentNode.getNext(); - } while (currentId != class2Id); - - // Now merge the lists - class1.merge(class2); - - // Notfiy the triggers - EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id]; - EqualityNodeId class2triggerId = d_nodeIndividualTrigger[class2Id]; - if (class2triggerId != +null_id) { - if (class1triggerId == +null_id) { - // If class1 is not an individual trigger, but class2 is, mark it - d_nodeIndividualTrigger[class1Id] = class2triggerId; - // Add it to the list for backtracking - d_individualTriggers.push_back(class1Id); - d_individualTriggersSize = d_individualTriggersSize + 1; - } else { - // Notify when done - if (d_performNotify) { - d_notify.notify(d_nodes[class1triggerId], d_nodes[class2triggerId]); - } - } - } -} - -template -void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) { - - Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl; - - // Now unmerge the lists (same as merge) - class1.merge(class2); - - // Update class2 representative information - EqualityNodeId currentId = class2Id; - Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << "): undoing representative info" << std::endl; - do { - // Get the current node - EqualityNode& currentNode = getEqualityNode(currentId); - - // Update it's find to class1 id - currentNode.setFind(class2Id); - - // Go through the trigger list (if any) and undo the class - TriggerId currentTrigger = d_nodeTriggers[currentId]; - while (currentTrigger != null_trigger) { - Trigger& trigger = d_equalityTriggers[currentTrigger]; - trigger.classId = class2Id; - currentTrigger = trigger.nextTrigger; - } - - // Move to the next node - currentId = currentNode.getNext(); - - } while (currentId != class2Id); - -} - -template -void EqualityEngine::backtrack() { - - Debug("equality::backtrack") << "backtracking" << std::endl; - - // If we need to backtrack then do it - if (d_assertedEqualitiesCount < d_assertedEqualities.size()) { - - // Clear the propagation queue - while (!d_propagationQueue.empty()) { - d_propagationQueue.pop(); - } - - Debug("equality") << "EqualityEngine::backtrack(): nodes" << std::endl; - - for (int i = (int)d_assertedEqualities.size() - 1, i_end = (int)d_assertedEqualitiesCount; i >= i_end; --i) { - // Get the ids of the merged classes - Equality& eq = d_assertedEqualities[i]; - // Undo the merge - undoMerge(d_equalityNodes[eq.lhs], d_equalityNodes[eq.rhs], eq.rhs); - } - - d_assertedEqualities.resize(d_assertedEqualitiesCount); - - Debug("equality") << "EqualityEngine::backtrack(): edges" << std::endl; - - for (int i = (int)d_equalityEdges.size() - 2, i_end = (int)(2*d_assertedEqualitiesCount); i >= i_end; i -= 2) { - EqualityEdge& edge1 = d_equalityEdges[i]; - EqualityEdge& edge2 = d_equalityEdges[i | 1]; - d_equalityGraph[edge2.getNodeId()] = edge1.getNext(); - d_equalityGraph[edge1.getNodeId()] = edge2.getNext(); - } - - d_equalityEdges.resize(2 * d_assertedEqualitiesCount); - } - - if (d_individualTriggers.size() > d_individualTriggersSize) { - // Unset the individual triggers - for (int i = d_individualTriggers.size() - 1, i_end = d_individualTriggersSize; i >= i_end; -- i) { - d_nodeIndividualTrigger[d_individualTriggers[i]] = +null_id; - } - d_individualTriggers.resize(d_individualTriggersSize); - } - - if (d_equalityTriggers.size() > d_equalityTriggersCount) { - // Unlink the triggers from the lists - for (int i = d_equalityTriggers.size() - 1, i_end = d_equalityTriggersCount; i >= i_end; -- i) { - const Trigger& trigger = d_equalityTriggers[i]; - d_nodeTriggers[trigger.classId] = trigger.nextTrigger; - } - // Get rid of the triggers - d_equalityTriggers.resize(d_equalityTriggersCount); - d_equalityTriggersOriginal.resize(d_equalityTriggersCount); - } - - if (d_applicationLookups.size() > d_applicationLookupsCount) { - for (int i = d_applicationLookups.size() - 1, i_end = (int) d_applicationLookupsCount; i >= i_end; -- i) { - d_applicationLookup.erase(d_applicationLookups[i]); - } - d_applicationLookups.resize(d_applicationLookupsCount); - } - - if (d_nodes.size() > d_nodesCount) { - // Go down the nodes, check the application nodes and remove them from use-lists - for(int i = d_nodes.size() - 1, i_end = (int)d_nodesCount; i >= i_end; -- i) { - // Remove from the node -> id map - Debug("equality") << "EqualityEngine::backtrack(): removing node " << d_nodes[i] << std::endl; - d_nodeIds.erase(d_nodes[i]); - - const FunctionApplication& app = d_applications[i].normalized; - if (app.isApplication()) { - // Remove b from use-list - getEqualityNode(app.b).removeTopFromUseList(d_useListNodes); - // Remove a from use-list - getEqualityNode(app.a).removeTopFromUseList(d_useListNodes); - } - } - - // Now get rid of the nodes and the rest - d_nodes.resize(d_nodesCount); - d_applications.resize(d_nodesCount); - d_nodeTriggers.resize(d_nodesCount); - d_nodeIndividualTrigger.resize(d_nodesCount); - d_equalityGraph.resize(d_nodesCount); - d_equalityNodes.resize(d_nodesCount); - } -} - -template -void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) { - Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << "," << reason << ")" << std::endl; - EqualityEdgeId edge = d_equalityEdges.size(); - d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1], type, reason)); - d_equalityEdges.push_back(EqualityEdge(t1, d_equalityGraph[t2], type, reason)); - d_equalityGraph[t1] = edge; - d_equalityGraph[t2] = edge | 1; - - if (Debug.isOn("equality::internal")) { - debugPrintGraph(); - } -} - -template -std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { - std::stringstream out; - bool first = true; - if (edgeId == null_edge) { - out << "null"; - } else { - while (edgeId != null_edge) { - const EqualityEdge& edge = d_equalityEdges[edgeId]; - if (!first) out << ","; - out << d_nodes[edge.getNodeId()]; - edgeId = edge.getNext(); - first = false; - } - } - return out.str(); -} - -template -void EqualityEngine::explainEquality(TNode t1, TNode t2, std::vector& equalities) { - Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl; - - // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); - - // Add the terms (they might not be there) - addTerm(t1); - addTerm(t2); - - Assert(getRepresentative(t1) == getRepresentative(t2), - "Cannot explain an equality, because the two terms are not equal!\n" - "The representative of %s\n" - " is %s\n" - "The representative of %s\n" - " is %s", - t1.toString().c_str(), getRepresentative(t1).toString().c_str(), - t2.toString().c_str(), getRepresentative(t2).toString().c_str()); - - // Get the explanation - EqualityNodeId t1Id = getNodeId(t1); - EqualityNodeId t2Id = getNodeId(t2); - getExplanation(t1Id, t2Id, equalities); - -} - -template -void EqualityEngine::explainDisequality(TNode t1, TNode t2, std::vector& equalities) { - Debug("equality") << "EqualityEngine::explainDisequality(" << t1 << "," << t2 << ")" << std::endl; - - // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); - - // Add the terms - addTerm(t1); - addTerm(t2); - - // Add the equality - Node equality = t1.eqNode(t2); - addTerm(equality); - - Assert(getRepresentative(equality) == getRepresentative(d_false), - "Cannot explain the dis-equality, because the two terms are not dis-equal!\n" - "The representative of %s\n" - " is %s\n" - "The representative of %s\n" - " is %s", - equality.toString().c_str(), getRepresentative(equality).toString().c_str(), - d_false.toString().c_str(), getRepresentative(d_false).toString().c_str()); - - // Get the explanation - EqualityNodeId equalityId = getNodeId(equality); - EqualityNodeId falseId = getNodeId(d_false); - getExplanation(equalityId, falseId, equalities); - -} - - -template -void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector& equalities) const { - - Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl; - - // If the nodes are the same, we're done - if (t1Id == t2Id) return; - - if (Debug.isOn("equality::internal")) { - debugPrintGraph(); - } - - // Queue for the BFS containing nodes - std::vector bfsQueue; - - // Find a path from t1 to t2 in the graph (BFS) - bfsQueue.push_back(BfsData(t1Id, null_id, 0)); - size_t currentIndex = 0; - while (true) { - // There should always be a path, and every node can be visited only once (tree) - Assert(currentIndex < bfsQueue.size()); - - // The next node to visit - BfsData current = bfsQueue[currentIndex]; - EqualityNodeId currentNode = current.nodeId; - - Debug("equality") << "EqualityEngine::getExplanation(): currentNode = " << d_nodes[currentNode] << std::endl; - - // Go through the equality edges of this node - EqualityEdgeId currentEdge = d_equalityGraph[currentNode]; - Debug("equality") << "EqualityEngine::getExplanation(): edgesId = " << currentEdge << std::endl; - Debug("equality") << "EqualityEngine::getExplanation(): edges = " << edgesToString(currentEdge) << std::endl; - - while (currentEdge != null_edge) { - // Get the edge - const EqualityEdge& edge = d_equalityEdges[currentEdge]; - - // If not just the backwards edge - if ((currentEdge | 1u) != (current.edgeId | 1u)) { - - Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = (" << d_nodes[currentNode] << "," << d_nodes[edge.getNodeId()] << ")" << std::endl; - - // Did we find the path - if (edge.getNodeId() == t2Id) { - - Debug("equality") << "EqualityEngine::getExplanation(): path found: " << std::endl; - - // Reconstruct the path - do { - // The current node - currentNode = bfsQueue[currentIndex].nodeId; - EqualityNodeId edgeNode = d_equalityEdges[currentEdge].getNodeId(); - MergeReasonType reasonType = d_equalityEdges[currentEdge].getReasonType(); - - Debug("equality") << "EqualityEngine::getExplanation(): currentEdge = " << currentEdge << ", currentNode = " << currentNode << std::endl; - - // Add the actual equality to the vector - if (reasonType == MERGED_THROUGH_CONGRUENCE) { - // f(x1, x2) == f(y1, y2) because x1 = y1 and x2 = y2 - Debug("equality") << "EqualityEngine::getExplanation(): due to congruence, going deeper" << std::endl; - const FunctionApplication& f1 = d_applications[currentNode].original; - const FunctionApplication& f2 = d_applications[edgeNode].original; - Debug("equality") << push; - getExplanation(f1.a, f2.a, equalities); - getExplanation(f1.b, f2.b, equalities); - Debug("equality") << pop; - } else { - // Construct the equality - Debug("equality") << "EqualityEngine::getExplanation(): adding: " << d_equalityEdges[currentEdge].getReason() << std::endl; - equalities.push_back(d_equalityEdges[currentEdge].getReason()); - } - - // Go to the previous - currentEdge = bfsQueue[currentIndex].edgeId; - currentIndex = bfsQueue[currentIndex].previousIndex; - } while (currentEdge != null_id); - - // Done - return; - } - - // Push to the visitation queue if it's not the backward edge - bfsQueue.push_back(BfsData(edge.getNodeId(), currentEdge, currentIndex)); - } - - // Go to the next edge - currentEdge = edge.getNext(); - } - - // Go to the next node to visit - ++ currentIndex; - } -} - -template -void EqualityEngine::addTriggerDisequality(TNode t1, TNode t2, TNode trigger) { - Node equality = t1.eqNode(t2); - addTerm(equality); - addTriggerEquality(equality, d_false, trigger); -} - -template -void EqualityEngine::addTriggerEquality(TNode t1, TNode t2, TNode trigger) { - - Debug("equality") << "EqualityEngine::addTrigger(" << t1 << ", " << t2 << ", " << trigger << ")" << std::endl; - - Assert(hasTerm(t1)); - Assert(hasTerm(t2)); - - // Get the information about t1 - EqualityNodeId t1Id = getNodeId(t1); - EqualityNodeId t1classId = getEqualityNode(t1Id).getFind(); - TriggerId t1TriggerId = d_nodeTriggers[t1classId]; - - // Get the information about t2 - EqualityNodeId t2Id = getNodeId(t2); - EqualityNodeId t2classId = getEqualityNode(t2Id).getFind(); - TriggerId t2TriggerId = d_nodeTriggers[t2classId]; - - Debug("equality") << "EqualityEngine::addTrigger(" << trigger << "): " << t1Id << " (" << t1classId << ") = " << t2Id << " (" << t2classId << ")" << std::endl; - - // Create the triggers - TriggerId t1NewTriggerId = d_equalityTriggers.size(); - TriggerId t2NewTriggerId = t1NewTriggerId | 1; - d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId)); - d_equalityTriggersOriginal.push_back(trigger); - d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId)); - d_equalityTriggersOriginal.push_back(trigger); - - // Update the counters - d_equalityTriggersCount = d_equalityTriggersCount + 2; - - // Add the trigger to the trigger graph - d_nodeTriggers[t1classId] = t1NewTriggerId; - d_nodeTriggers[t2classId] = t2NewTriggerId; - - // If the trigger is already on, we propagate - if (t1classId == t2classId) { - Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl; - if (d_performNotify) { - d_notify.notify(trigger); // Don't care about the return value - } - } - - if (Debug.isOn("equality::internal")) { - debugPrintGraph(); - } - - Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ") => (" << t1NewTriggerId << ", " << t2NewTriggerId << ")" << std::endl; -} - -template -void EqualityEngine::propagate() { - - Debug("equality") << "EqualityEngine::propagate()" << std::endl; - - bool done = false; - while (!d_propagationQueue.empty()) { - - // The current merge candidate - const MergeCandidate current = d_propagationQueue.front(); - d_propagationQueue.pop(); - - if (done) { - // If we're done, just empty the queue - continue; - } - - // Get the representatives - EqualityNodeId t1classId = getEqualityNode(current.t1Id).getFind(); - EqualityNodeId t2classId = getEqualityNode(current.t2Id).getFind(); - - // If already the same, we're done - if (t1classId == t2classId) { - continue; - } - - // Get the nodes of the representatives - EqualityNode& node1 = getEqualityNode(t1classId); - EqualityNode& node2 = getEqualityNode(t2classId); - - Assert(node1.getFind() == t1classId); - Assert(node2.getFind() == t2classId); - - // Add the actual equality to the equality graph - addGraphEdge(current.t1Id, current.t2Id, current.type, current.reason); - - // One more equality added - d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1; - - // Depending on the merge preference (such as size), merge them - std::vector triggers; - if (node2.getSize() > node1.getSize()) { - Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl; - d_assertedEqualities.push_back(Equality(t2classId, t1classId)); - merge(node2, node1, triggers); - } else { - Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t2Id] << " into " << d_nodes[current.t1Id] << std::endl; - d_assertedEqualities.push_back(Equality(t1classId, t2classId)); - merge(node1, node2, triggers); - } - - // Notify the triggers - if (d_performNotify) { - for (size_t trigger = 0, trigger_end = triggers.size(); trigger < trigger_end && !done; ++ trigger) { - // Notify the trigger and exit if it fails - done = !d_notify.notify(d_equalityTriggersOriginal[triggers[trigger]]); - } - } - } -} - -template -void EqualityEngine::debugPrintGraph() const { - for (EqualityNodeId nodeId = 0; nodeId < d_nodes.size(); ++ nodeId) { - - Debug("equality::graph") << d_nodes[nodeId] << " " << nodeId << "(" << getEqualityNode(nodeId).getFind() << "):"; - - EqualityEdgeId edgeId = d_equalityGraph[nodeId]; - while (edgeId != null_edge) { - const EqualityEdge& edge = d_equalityEdges[edgeId]; - Debug("equality::graph") << " " << d_nodes[edge.getNodeId()] << ":" << edge.getReason(); - edgeId = edge.getNext(); - } - - Debug("equality::graph") << std::endl; - } -} - -template -bool EqualityEngine::areEqual(TNode t1, TNode t2) -{ - // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); - - // Add the terms - addTerm(t1); - addTerm(t2); - bool equal = getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind(); - - // Return whether the two terms are equal - return equal; -} - -template -bool EqualityEngine::areDisequal(TNode t1, TNode t2) -{ - // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); - - // Add the terms - addTerm(t1); - addTerm(t2); - - // Check (t1 = t2) = false - Node equality = t1.eqNode(t2); - addTerm(equality); - if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) { - return true; - } - - // Return whether the terms are disequal - return false; -} - -template -size_t EqualityEngine::getSize(TNode t) -{ - // Add the term - addTerm(t); - return getEqualityNode(getEqualityNode(t).getFind()).getSize(); -} - -template -void EqualityEngine::addTriggerTerm(TNode t) -{ - Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ")" << std::endl; - - // Add the term if it's not already there - addTerm(t); - - // Get the node id - EqualityNodeId eqNodeId = getNodeId(t); - EqualityNode& eqNode = getEqualityNode(eqNodeId); - EqualityNodeId classId = eqNode.getFind(); - - if (d_nodeIndividualTrigger[classId] != +null_id) { - // No need to keep it, just propagate the existing individual triggers - if (d_performNotify) { - d_notify.notify(t, d_nodes[d_nodeIndividualTrigger[classId]]); - } - } else { - // Add it to the list for backtracking - d_individualTriggers.push_back(classId); - d_individualTriggersSize = d_individualTriggersSize + 1; - // Mark the class id as a trigger - d_nodeIndividualTrigger[classId] = eqNodeId; - } -} - -template -bool EqualityEngine::isTriggerTerm(TNode t) const { - if (!hasTerm(t)) return false; - EqualityNodeId classId = getEqualityNode(t).getFind(); - return d_nodeIndividualTrigger[classId] != +null_id; -} - - -template -TNode EqualityEngine::getTriggerTermRepresentative(TNode t) const { - Assert(isTriggerTerm(t)); - EqualityNodeId classId = getEqualityNode(t).getFind(); - return d_nodes[d_nodeIndividualTrigger[classId]]; -} - -template -void EqualityEngine::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) { - Assert(d_applicationLookup.find(funNormalized) == d_applicationLookup.end()); - d_applicationLookup[funNormalized] = funId; - d_applicationLookups.push_back(funNormalized); - d_applicationLookupsCount = d_applicationLookupsCount + 1; - Debug("equality::backtrack") << "d_applicationLookupsCount = " << d_applicationLookupsCount << std::endl; - Debug("equality::backtrack") << "d_applicationLookups.size() = " << d_applicationLookups.size() << std::endl; - Assert(d_applicationLookupsCount == d_applicationLookups.size()); -} - -template -void EqualityEngine::getUseListTerms(TNode t, std::set& output) { - if (hasTerm(t)) { - // Get the equivalence class - EqualityNodeId classId = getEqualityNode(t).getFind(); - // Go through the equivalence class and get where t is used in - EqualityNodeId currentId = classId; - do { - // Get the current node - EqualityNode& currentNode = getEqualityNode(currentId); - // Go through the use-list - UseListNodeId currentUseId = currentNode.getUseList(); - while (currentUseId != null_uselist_id) { - // Get the node of the use list - UseListNode& useNode = d_useListNodes[currentUseId]; - // Get the function application - EqualityNodeId funId = useNode.getApplicationId(); - output.insert(d_nodes[funId]); - // Go to the next one in the use list - currentUseId = useNode.getNext(); - } - // Move to the next node - currentId = currentNode.getNext(); - } while (currentId != classId); - } -} - -} // Namespace uf -} // Namespace theory -} // Namespace CVC4 - diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index ec28dad75..cddd01b07 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -18,13 +18,11 @@ **/ #include "theory/uf/theory_uf.h" -#include "theory/uf/equality_engine_impl.h" using namespace std; - -namespace CVC4 { -namespace theory { -namespace uf { +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::uf; /** Constructs a new instance of TheoryUF w.r.t. the provided context.*/ TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo) : @@ -40,12 +38,6 @@ TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel& d_equalityEngine.addFunctionKind(kind::APPLY_UF); d_equalityEngine.addFunctionKind(kind::EQUAL); - // The boolean constants - d_true = NodeManager::currentNM()->mkConst(true); - d_false = NodeManager::currentNM()->mkConst(false); - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); }/* TheoryUF::TheoryUF() */ static Node mkAnd(const std::vector& conjunctions) { @@ -91,23 +83,12 @@ void TheoryUF::check(Effort level) { } // Do the work - switch (fact.getKind()) { - case kind::EQUAL: - d_equalityEngine.addEquality(fact[0], fact[1], fact); - break; - case kind::APPLY_UF: - d_equalityEngine.addPredicate(fact, true, fact); - break; - case kind::NOT: - if (fact[0].getKind() == kind::APPLY_UF) { - d_equalityEngine.addPredicate(fact[0], false, fact); - } else { - // Assert the dis-equality - d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact); - } - break; - default: - Unreachable(); + bool polarity = fact.getKind() != kind::NOT; + TNode atom = polarity ? fact : fact[0]; + if (atom.getKind() == kind::EQUAL) { + d_equalityEngine.assertEquality(atom, polarity, fact); + } else { + d_equalityEngine.assertPredicate(atom, polarity, fact); } } @@ -139,10 +120,8 @@ void TheoryUF::propagate(Effort level) { Debug("uf") << "TheoryUF::propagate(): in conflict, normalized = " << normalized << std::endl; Node negatedLiteral; std::vector assumptions; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -157,21 +136,17 @@ void TheoryUF::preRegisterTerm(TNode node) { switch (node.getKind()) { case kind::EQUAL: - // Add the terms - d_equalityEngine.addTerm(node[0]); - d_equalityEngine.addTerm(node[1]); // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node[0], node[1], node); - d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); + d_equalityEngine.addTriggerEquality(node); break; case kind::APPLY_UF: - // Function applications/predicates - d_equalityEngine.addTerm(node); // Maybe it's a predicate if (node.getType().isBoolean()) { // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerEquality(node, d_true, node); - d_equalityEngine.addTriggerEquality(node, d_false, node.notNode()); + d_equalityEngine.addTriggerPredicate(node); + } else { + // Function applications/predicates + d_equalityEngine.addTerm(node); } // Remember the function and predicate terms d_functionsTerms.push_back(node); @@ -194,26 +169,20 @@ bool TheoryUF::propagate(TNode literal) { // See if the literal has been asserted already Node normalized = Rewriter::rewrite(literal); - bool satValue = false; - bool isAsserted = normalized == d_false || d_valuation.hasSatValue(normalized, satValue); - // If asserted, we're done or in conflict - if (isAsserted) { - if (!satValue) { + // If asserted and is false, we're done or in conflict + // Note that even trivial equalities have a SAT value (i.e. 1 = 2 -> false) + bool satValue = false; + if (d_valuation.hasSatValue(normalized, satValue) && !satValue) { Debug("uf") << "TheoryUF::propagate(" << literal << ", normalized = " << normalized << ") => conflict" << std::endl; std::vector assumptions; Node negatedLiteral; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; return false; - } - // Propagate even if already known in SAT - could be a new equation between shared terms - // (terms that weren't shared when the literal was asserted!) } // Nothing, just enqueue it for propagation and mark it as asserted already @@ -224,36 +193,14 @@ bool TheoryUF::propagate(TNode literal) { }/* TheoryUF::propagate(TNode) */ void TheoryUF::explain(TNode literal, std::vector& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::APPLY_UF: - lhs = literal; - rhs = d_true; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } else { - // Predicates - lhs = literal[0]; - rhs = d_false; - break; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + // Do the work + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); } Node TheoryUF::explain(TNode literal) { @@ -508,7 +455,3 @@ void TheoryUF::computeCareGraph() { } } }/* TheoryUF::computeCareGraph() */ - -}/* CVC4::theory::uf namespace */ -}/* CVC4::theory namespace */ -}/* CVC4 namespace */ diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index 6956390f5..9017928b7 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -39,21 +39,46 @@ namespace uf { class TheoryUF : public Theory { public: - class NotifyClass { + class NotifyClass : public eq::EqualityEngineNotify { TheoryUF& d_uf; public: NotifyClass(TheoryUF& uf): d_uf(uf) {} - bool notify(TNode propagation) { - Debug("uf") << "NotifyClass::notify(" << propagation << ")" << std::endl; - // Just forward to uf - return d_uf.propagate(propagation); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_uf.propagate(equality); + } else { + // We use only literal triggers so taking not is safe + return d_uf.propagate(equality.notNode()); + } } - - void notify(TNode t1, TNode t2) { - Debug("uf") << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; - Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); - d_uf.propagate(equality); + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_uf.propagate(predicate); + } else { + return d_uf.propagate(predicate.notNode()); + } + } + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << std::endl; + if (value) { + return d_uf.propagate(t1.eqNode(t2)); + } else { + return d_uf.propagate(t1.eqNode(t2).notNode()); + } + } + + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("uf") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl; + if (Theory::theoryOf(t1) == THEORY_BOOL) { + return d_uf.propagate(t1.iffNode(t2)); + } else { + return d_uf.propagate(t1.eqNode(t2)); + } } }; @@ -63,7 +88,7 @@ private: NotifyClass d_notify; /** Equaltity engine */ - EqualityEngine d_equalityEngine; + eq::EqualityEngine d_equalityEngine; /** Are we in conflict */ context::CDO d_conflict; @@ -72,7 +97,8 @@ private: Node d_conflictNode; /** - * Should be called to propagate the literal. + * Should be called to propagate the literal. We use a node here + * since some of the propagated literals are not kept anywhere. */ bool propagate(TNode literal); @@ -90,12 +116,6 @@ private: /** All the function terms that the theory has seen */ context::CDList d_functionsTerms; - /** True node for predicates = true */ - Node d_true; - - /** True node for predicates = false */ - Node d_false; - /** Symmetry analyzer */ SymmetryBreaker d_symb; diff --git a/src/util/configuration.cpp b/src/util/configuration.cpp index 66b0a2f90..6f01d6cf4 100644 --- a/src/util/configuration.cpp +++ b/src/util/configuration.cpp @@ -162,7 +162,7 @@ bool Configuration::isDebugTag(char const *tag){ return true; } } -#endif * CVC4_DEBUG */ +#endif /* CVC4_DEBUG */ return false; } diff --git a/test/unit/context/context_black.h b/test/unit/context/context_black.h index 33863e848..1a50d0637 100644 --- a/test/unit/context/context_black.h +++ b/test/unit/context/context_black.h @@ -37,7 +37,7 @@ struct MyContextNotifyObj : public ContextNotifyObj { nCalls(0) { } - void notify() { + void contextNotifyPop() { ++nCalls; } }; diff --git a/test/unit/prop/cnf_stream_black.h b/test/unit/prop/cnf_stream_black.h index 63ba95b57..c24104acc 100644 --- a/test/unit/prop/cnf_stream_black.h +++ b/test/unit/prop/cnf_stream_black.h @@ -59,6 +59,14 @@ public: return d_nextVar++; } + SatVariable trueVar() { + return d_nextVar++; + } + + SatVariable falseVar() { + return d_nextVar++; + } + void addClause(SatClause& c, bool lemma) { d_addClauseCalled = true; }