From 1ce0650dcf8ce30424b546deb540974cc510c215 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dejan=20Jovanovi=C4=87?= Date: Wed, 9 May 2012 21:25:17 +0000 Subject: [PATCH] * simplifying equality engine interface * notifications are now through the interface subclass instead of a template * notifications include constants being merged * changed contextNotifyObj::notify to contextNotifyObj::contextNotifyPop so it's more descriptive and doesn't clutter methods when subclassed * sat solver now has explicit methods to make true and false constants * 0-level literals are removed from explanations of propagations --- src/context/context.cpp | 4 +- src/context/context.h | 16 +- src/context/stacking_map.h | 16 +- src/context/stacking_vector.h | 4 +- src/prop/bvminisat/bvminisat.cpp | 2 +- src/prop/bvminisat/bvminisat.h | 8 +- src/prop/bvminisat/core/Solver.cc | 10 +- src/prop/bvminisat/core/Solver.h | 9 + src/prop/bvminisat/simp/SimpSolver.cc | 16 +- src/prop/cnf_stream.cpp | 10 +- src/prop/minisat/core/Solver.cc | 22 +- src/prop/minisat/core/Solver.h | 9 + src/prop/minisat/minisat.cpp | 1 - src/prop/minisat/minisat.h | 2 + src/prop/minisat/simp/SimpSolver.cc | 13 + src/prop/sat_solver.h | 6 + src/smt/smt_engine.cpp | 1 - src/theory/arith/congruence_manager.cpp | 34 +- src/theory/arith/congruence_manager.h | 42 +- src/theory/arith/theory_arith.cpp | 4 +- src/theory/arrays/theory_arrays.cpp | 137 ++----- src/theory/arrays/theory_arrays.h | 73 ++-- src/theory/booleans/circuit_propagator.h | 12 +- src/theory/bv/theory_bv.cpp | 90 ++--- src/theory/bv/theory_bv.h | 46 ++- src/theory/datatypes/union_find.cpp | 6 +- src/theory/datatypes/union_find.h | 6 +- src/theory/shared_terms_database.cpp | 49 +-- src/theory/shared_terms_database.h | 51 ++- src/theory/substitutions.h | 8 +- src/theory/theory_engine.cpp | 89 +++-- src/theory/theory_engine.h | 11 +- src/theory/uf/Makefile.am | 2 +- ...lity_engine_impl.h => equality_engine.cpp} | 376 ++++++++++-------- src/theory/uf/equality_engine.h | 261 ++++++++---- src/theory/uf/theory_uf.cpp | 115 ++---- src/theory/uf/theory_uf.h | 56 ++- src/util/configuration.cpp | 2 +- test/unit/context/context_black.h | 2 +- test/unit/prop/cnf_stream_black.h | 8 + 40 files changed, 897 insertions(+), 732 deletions(-) rename src/theory/uf/{equality_engine_impl.h => equality_engine.cpp} (73%) diff --git a/src/context/context.cpp b/src/context/context.cpp index abb1575d4..da60a5bc4 100644 --- a/src/context/context.cpp +++ b/src/context/context.cpp @@ -80,7 +80,7 @@ void Context::pop() { while(pCNO != NULL) { // pre-store the "next" pointer in case pCNO deletes itself on notify() ContextNotifyObj* next = pCNO->d_pCNOnext; - pCNO->notify(); + pCNO->contextNotifyPop(); pCNO = next; } @@ -101,7 +101,7 @@ void Context::pop() { while(pCNO != NULL) { // pre-store the "next" pointer in case pCNO deletes itself on notify() ContextNotifyObj* next = pCNO->d_pCNOnext; - pCNO->notify(); + pCNO->contextNotifyPop(); pCNO = next; } diff --git a/src/context/context.h b/src/context/context.h index f0dbff72b..165c35c58 100644 --- a/src/context/context.h +++ b/src/context/context.h @@ -658,6 +658,7 @@ public: * the ContextObj objects have been restored). */ class ContextNotifyObj { + /** * Context is our friend so that when the Context is deleted, any * remaining ContextNotifyObj can be removed from the Context list. @@ -686,6 +687,15 @@ class ContextNotifyObj { */ ContextNotifyObj**& prev() throw() { return d_ppCNOprev; } +protected: + + /** + * This is the method called to notify the object of a pop. It must be + * implemented by the subclass. It is protected since context is out + * friend. + */ + virtual void contextNotifyPop() = 0; + public: /** @@ -703,12 +713,6 @@ public: */ virtual ~ContextNotifyObj() throw(AssertionException); - /** - * This is the method called to notify the object of a pop. It must be - * implemented by the subclass. - */ - virtual void notify() = 0; - };/* class ContextNotifyObj */ inline void ContextObj::makeCurrent() throw(AssertionException) { diff --git a/src/context/stacking_map.h b/src/context/stacking_map.h index 2dec1845c..ba644596e 100644 --- a/src/context/stacking_map.h +++ b/src/context/stacking_map.h @@ -96,6 +96,14 @@ class StackingMap : context::ContextNotifyObj { /** Our current offset in the d_trace stack (context-dependent). */ context::CDO d_offset; +protected: + + /** + * Called by the Context when a pop occurs. Cancels everything to the + * current context level. Overrides ContextNotifyObj::contextNotifyPop(). + */ + void contextNotifyPop(); + public: typedef typename MapType::const_iterator const_iterator; @@ -128,12 +136,6 @@ public: */ void set(ArgType n, const ValueType& newValue); - /** - * Called by the Context when a pop occurs. Cancels everything to the - * current context level. Overrides ContextNotifyObj::notify(). - */ - void notify(); - };/* class StackingMap<> */ template @@ -146,7 +148,7 @@ void StackingMap::set(ArgType n, const ValueType& n } template -void StackingMap::notify() { +void StackingMap::contextNotifyPop() { Trace("sm") << "SM cancelling : " << d_offset << " < " << d_trace.size() << " ?" << std::endl; while(d_offset < d_trace.size()) { std::pair p = d_trace.back(); diff --git a/src/context/stacking_vector.h b/src/context/stacking_vector.h index 9987731d4..ed311b952 100644 --- a/src/context/stacking_vector.h +++ b/src/context/stacking_vector.h @@ -82,7 +82,7 @@ public: * Called by the Context when a pop occurs. Cancels everything to the * current context level. Overrides ContextNotifyObj::notify(). */ - void notify(); + void contextNotifyPop(); };/* class StackingVector<> */ @@ -99,7 +99,7 @@ void StackingVector::set(size_t n, const T& newValue) { } template -void StackingVector::notify() { +void StackingVector::contextNotifyPop() { Trace("sv") << "SV cancelling : " << d_offset << " < " << d_trace.size() << " ?" << std::endl; while(d_offset < d_trace.size()) { std::pair p = d_trace.back(); diff --git a/src/prop/bvminisat/bvminisat.cpp b/src/prop/bvminisat/bvminisat.cpp index 124fc35f1..4868db6f5 100644 --- a/src/prop/bvminisat/bvminisat.cpp +++ b/src/prop/bvminisat/bvminisat.cpp @@ -73,7 +73,7 @@ SatValue BVMinisatSatSolver::assertAssumption(SatLiteral lit, bool propagate) { return toSatLiteralValue(d_minisat->assertAssumption(toMinisatLit(lit), propagate)); } -void BVMinisatSatSolver::notify() { +void BVMinisatSatSolver::contextNotifyPop() { while (d_assertionsCount > d_assertionsRealCount) { popAssumption(); d_assertionsCount --; diff --git a/src/prop/bvminisat/bvminisat.h b/src/prop/bvminisat/bvminisat.h index cd2a2c6b9..60cdd1c28 100644 --- a/src/prop/bvminisat/bvminisat.h +++ b/src/prop/bvminisat/bvminisat.h @@ -54,6 +54,10 @@ private: context::CDO d_assertionsRealCount; context::CDO d_lastPropagation; +protected: + + void contextNotifyPop(); + public: BVMinisatSatSolver() : @@ -70,10 +74,12 @@ public: SatVariable newVar(bool theoryAtom = false); + SatVariable trueVar() { return d_minisat->trueVar(); } + SatVariable falseVar() { return d_minisat->falseVar(); } + void markUnremovable(SatLiteral lit); void interrupt(); - void notify(); SatValue solve(); SatValue solve(long unsigned int&); diff --git a/src/prop/bvminisat/core/Solver.cc b/src/prop/bvminisat/core/Solver.cc index e24fcac1a..c96b6e4b2 100644 --- a/src/prop/bvminisat/core/Solver.cc +++ b/src/prop/bvminisat/core/Solver.cc @@ -119,7 +119,15 @@ Solver::Solver(CVC4::context::Context* c) : , propagation_budget (-1) , asynch_interrupt (false) , clause_added(false) -{} +{ + // Create the constant variables + varTrue = newVar(true, false); + varFalse = newVar(false, false); + + // Assert the constants + uncheckedEnqueue(mkLit(varTrue, false)); + uncheckedEnqueue(mkLit(varFalse, true)); +} Solver::~Solver() diff --git a/src/prop/bvminisat/core/Solver.h b/src/prop/bvminisat/core/Solver.h index c323bfe2b..ae5efd81e 100644 --- a/src/prop/bvminisat/core/Solver.h +++ b/src/prop/bvminisat/core/Solver.h @@ -64,6 +64,12 @@ class Solver { /** Cvc4 context */ CVC4::context::Context* c; + /** True constant */ + Var varTrue; + + /** False constant */ + Var varFalse; + public: // Constructor/Destructor: @@ -76,6 +82,9 @@ public: // Problem specification: // Var newVar (bool polarity = true, bool dvar = true); // Add a new variable with parameters specifying variable mode. + Var trueVar() const { return varTrue; } + Var falseVar() const { return varFalse; } + bool addClause (const vec& ps); // Add a clause to the solver. bool addEmptyClause(); // Add the empty clause, making the solver contradictory. diff --git a/src/prop/bvminisat/simp/SimpSolver.cc b/src/prop/bvminisat/simp/SimpSolver.cc index c8ce13410..59820e9e3 100644 --- a/src/prop/bvminisat/simp/SimpSolver.cc +++ b/src/prop/bvminisat/simp/SimpSolver.cc @@ -63,11 +63,25 @@ SimpSolver::SimpSolver(CVC4::context::Context* c) : , bwdsub_assigns (0) , n_touched (0) { - CVC4::StatisticsRegistry::registerStat(&total_eliminate_time); + CVC4::StatisticsRegistry::registerStat(&total_eliminate_time); vec dummy(1,lit_Undef); ca.extra_clause_field = true; // NOTE: must happen before allocating the dummy clause below. bwdsub_tmpunit = ca.alloc(dummy); remove_satisfied = false; + + // add the initialization for all the internal variables + for (int i = frozen.size(); i < vardata.size(); ++ i) { + frozen .push(1); + eliminated.push(0); + if (use_simplification){ + n_occ .push(0); + n_occ .push(0); + occurs .init(i); + touched .push(0); + elim_heap .insert(i); + } + } + } diff --git a/src/prop/cnf_stream.cpp b/src/prop/cnf_stream.cpp index 3a4fa781a..d18ec6e69 100644 --- a/src/prop/cnf_stream.cpp +++ b/src/prop/cnf_stream.cpp @@ -175,7 +175,15 @@ SatLiteral CnfStream::newLiteral(TNode node, bool theoryLiteral) { SatLiteral lit; if (!hasLiteral(node)) { // If no literal, we'll make one - lit = SatLiteral(d_satSolver->newVar(theoryLiteral)); + if (node.getKind() == kind::CONST_BOOLEAN) { + if (node.getConst()) { + lit = SatLiteral(d_satSolver->trueVar()); + } else { + lit = SatLiteral(d_satSolver->falseVar()); + } + } else { + lit = SatLiteral(d_satSolver->newVar(theoryLiteral)); + } d_translationCache[node].literal = lit; d_translationCache[node.notNode()].literal = ~lit; } else { diff --git a/src/prop/minisat/core/Solver.cc b/src/prop/minisat/core/Solver.cc index 5e1b032a3..6ee508eba 100644 --- a/src/prop/minisat/core/Solver.cc +++ b/src/prop/minisat/core/Solver.cc @@ -126,6 +126,14 @@ Solver::Solver(CVC4::prop::TheoryProxy* proxy, CVC4::context::Context* context, , asynch_interrupt (false) { PROOF(ProofManager::initSatProof(this);) + + // Create the constant variables + varTrue = newVar(true, false, false); + varFalse = newVar(false, false, false); + + // Assert the constants + uncheckedEnqueue(mkLit(varTrue, false)); + uncheckedEnqueue(mkLit(varFalse, true)); } @@ -190,16 +198,26 @@ CRef Solver::reason(Var x) { // Compute the assertion level for this clause int explLevel = 0; - for (int i = 0; i < explanation.size(); ++ i) { + int i, j; + for (i = 0, j = 0; i < explanation.size(); ++ i) { int varLevel = intro_level(var(explanation[i])); if (varLevel > explLevel) { explLevel = varLevel; } Assert(value(explanation[i]) != l_Undef); Assert(i == 0 || trail_index(var(explanation[0])) > trail_index(var(explanation[i]))); + // ignore zero level literals + if (i == 0 || level(var(explanation[i])) > 0) { + explanation[j++] = explanation[i]; + } + } + explanation.shrink(i - j); + if (j == 1) { + // Add not TRUE to the clause + explanation.push(mkLit(varTrue, true)); } - // Construct the reason (level 0) + // Construct the reason CRef real_reason = ca.alloc(explLevel, explanation, true); vardata[x] = mkVarData(real_reason, level(x), intro_level(x), trail_index(x)); clauses_removable.push(real_reason); diff --git a/src/prop/minisat/core/Solver.h b/src/prop/minisat/core/Solver.h index cfeb06211..e677d7220 100644 --- a/src/prop/minisat/core/Solver.h +++ b/src/prop/minisat/core/Solver.h @@ -65,6 +65,13 @@ protected: /** The current assertion level (user) */ int assertionLevel; + + /** Variable representing true */ + Var varTrue; + + /** Variable representing false */ + Var varFalse; + public: /** Returns the current user assertion level */ int getAssertionLevel() const { return assertionLevel; } @@ -108,6 +115,8 @@ public: // Problem specification: // Var newVar (bool polarity = true, bool dvar = true, bool theoryAtom = false); // Add a new variable with parameters specifying variable mode. + Var trueVar() const { return varTrue; } + Var falseVar() const { return varFalse; } // Less than for literals in a lemma struct lemma_lt { diff --git a/src/prop/minisat/minisat.cpp b/src/prop/minisat/minisat.cpp index bed30d658..4f2a16670 100644 --- a/src/prop/minisat/minisat.cpp +++ b/src/prop/minisat/minisat.cpp @@ -121,7 +121,6 @@ SatVariable MinisatSatSolver::newVar(bool theoryAtom) { return d_minisat->newVar(true, true, theoryAtom); } - SatValue MinisatSatSolver::solve(unsigned long& resource) { Trace("limit") << "SatSolver::solve(): have limit of " << resource << " conflicts" << std::endl; if(resource == 0) { diff --git a/src/prop/minisat/minisat.h b/src/prop/minisat/minisat.h index 9cf75a12e..19ade8ffa 100644 --- a/src/prop/minisat/minisat.h +++ b/src/prop/minisat/minisat.h @@ -56,6 +56,8 @@ public: void addClause(SatClause& clause, bool removable); SatVariable newVar(bool theoryAtom = false); + SatVariable trueVar() { return d_minisat->trueVar(); } + SatVariable falseVar() { return d_minisat->falseVar(); } SatValue solve(); SatValue solve(long unsigned int&); diff --git a/src/prop/minisat/simp/SimpSolver.cc b/src/prop/minisat/simp/SimpSolver.cc index 2cacfbcc0..8da3856ff 100644 --- a/src/prop/minisat/simp/SimpSolver.cc +++ b/src/prop/minisat/simp/SimpSolver.cc @@ -67,6 +67,19 @@ SimpSolver::SimpSolver(CVC4::prop::TheoryProxy* proxy, CVC4::context::Context* c ca.extra_clause_field = true; // NOTE: must happen before allocating the dummy clause below. bwdsub_tmpunit = ca.alloc(0, dummy); remove_satisfied = false; + + // add the initialization for all the internal variables + for (int i = frozen.size(); i < vardata.size(); ++ i) { + frozen .push(1); + eliminated.push(0); + if (use_simplification){ + n_occ .push(0); + n_occ .push(0); + occurs .init(i); + touched .push(0); + elim_heap .insert(i); + } + } } diff --git a/src/prop/sat_solver.h b/src/prop/sat_solver.h index 898709c43..2865f2cb5 100644 --- a/src/prop/sat_solver.h +++ b/src/prop/sat_solver.h @@ -46,6 +46,12 @@ public: /** Create a new boolean variable in the solver. */ virtual SatVariable newVar(bool theoryAtom = false) = 0; + /** Create a new (or return an existing) boolean variable representing the constant true */ + virtual SatVariable trueVar() = 0; + + /** Create a new (or return an existing) boolean variable representing the constant false */ + virtual SatVariable falseVar() = 0; + /** Check the satisfiability of the added clauses */ virtual SatValue solve() = 0; diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index e636b9142..2759f5717 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -304,7 +304,6 @@ SmtEngine::SmtEngine(ExprManager* em) throw(AssertionException) : setTimeLimit(Options::current()->cumulativeMillisecondLimit, true); } - d_propEngine->assertFormula(NodeManager::currentNM()->mkConst(true)); d_propEngine->assertFormula(NodeManager::currentNM()->mkConst(false).notNode()); } diff --git a/src/theory/arith/congruence_manager.cpp b/src/theory/arith/congruence_manager.cpp index 201eb08e7..39468e928 100644 --- a/src/theory/arith/congruence_manager.cpp +++ b/src/theory/arith/congruence_manager.cpp @@ -1,5 +1,4 @@ #include "theory/arith/congruence_manager.h" -#include "theory/uf/equality_engine_impl.h" #include "theory/arith/constraint.h" #include "theory/arith/arith_utilities.h" @@ -17,8 +16,7 @@ ArithCongruenceManager::ArithCongruenceManager(context::Context* c, ConstraintDa d_constraintDatabase(cd), d_setupLiteral(setup), d_av2Node(av2Node), - d_ee(d_notify, c, "theory::arith::ArithCongruenceManager"), - d_false(mkBoolNode(false)) + d_ee(d_notify, c, "theory::arith::ArithCongruenceManager") {} ArithCongruenceManager::Statistics::Statistics(): @@ -113,7 +111,7 @@ bool ArithCongruenceManager::propagate(TNode x){ }else{ ++(d_statistics.d_conflicts); - Node conf = explainInternal(x); + Node conf = flattenAnd(explainInternal(x)); d_conflict.set(conf); Debug("arith::congruenceManager") << "rewritten to false "<& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::NOT: - lhs = literal[0]; - rhs = d_false; - break; - default: - Unreachable(); + if (literal.getKind() != kind::NOT) { + d_ee.explainEquality(literal[0], literal[1], true, assumptions); + } else { + d_ee.explainEquality(literal[0][0], literal[0][1], false, assumptions); } - d_ee.explainEquality(lhs, rhs, assumptions); } void ArithCongruenceManager::enqueueIntoNB(const std::set s, NodeBuilder<>& nb){ @@ -258,13 +247,10 @@ void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar TNode eq = d_watchedEqualities[s]; Assert(eq.getKind() == kind::EQUAL); - TNode x = eq[0]; - TNode y = eq[1]; - if(isEquality){ - d_ee.addEquality(x, y, reason); + d_ee.assertEquality(eq, true, reason); }else{ - d_ee.addDisequality(x, y, reason); + d_ee.assertEquality(eq, false, reason); } } @@ -286,7 +272,7 @@ void ArithCongruenceManager::equalsConstant(Constraint c){ Node reason = c->explainForConflict(); d_keepAlive.push_back(reason); - d_ee.addEquality(xAsNode, asRational, reason); + d_ee.assertEquality(eq, true, reason); } void ArithCongruenceManager::equalsConstant(Constraint lb, Constraint ub){ @@ -310,7 +296,7 @@ void ArithCongruenceManager::equalsConstant(Constraint lb, Constraint ub){ d_keepAlive.push_back(reason); - d_ee.addEquality(xAsNode, asRational, reason); + d_ee.assertEquality(eq, true, reason); } void ArithCongruenceManager::addSharedTerm(Node x){ diff --git a/src/theory/arith/congruence_manager.h b/src/theory/arith/congruence_manager.h index a72989498..18ecbeb9d 100644 --- a/src/theory/arith/congruence_manager.h +++ b/src/theory/arith/congruence_manager.h @@ -37,24 +37,43 @@ private: ArithVarToNodeMap d_watchedEqualities; - class ArithCongruenceNotify { + class ArithCongruenceNotify : public eq::EqualityEngineNotify { private: ArithCongruenceManager& d_acm; public: ArithCongruenceNotify(ArithCongruenceManager& acm): d_acm(acm) {} - bool notify(TNode propagation) { - Debug("arith::congruences") << "ArithCongruenceNotify::notify(" << propagation << ")" << std::endl; - // Just forward to dm - return d_acm.propagate(propagation); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; + if (value) { + return d_acm.propagate(equality); + } else { + return d_acm.propagate(equality.notNode()); + } } - void notify(TNode t1, TNode t2) { - Debug("arith::congruences") << "ArithCongruenceNotify::notify(" << t1 << ", " << t2 << ")" << std::endl; - Node equality = t1.eqNode(t2); - d_acm.propagate(equality); + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Unreachable(); } - }; + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ", " << (value ? "true" : "false") << ")" << std::endl; + if (value) { + return d_acm.propagate(t1.eqNode(t2)); + } else { + return d_acm.propagate(t1.eqNode(t2).notNode()); + } + } + + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("arith::congruences") << "ArithCongruenceNotify::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl; + if (t1.getKind() == kind::CONST_BOOLEAN) { + return d_acm.propagate(t1.iffNode(t2)); + } else { + return d_acm.propagate(t1.eqNode(t2)); + } + } + }; ArithCongruenceNotify d_notify; context::CDList d_keepAlive; @@ -75,8 +94,7 @@ private: const ArithVarNodeMap& d_av2Node; - theory::uf::EqualityEngine d_ee; - Node d_false; + eq::EqualityEngine d_ee; public: diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index c7072de72..6bb3821da 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -1001,8 +1001,8 @@ Node TheoryArith::assertionCases(TNode assertion){ if(Debug.isOn("whytheoryenginewhy")){ debugPrintFacts(); } - Warning() << "arith: Theory engine is sending me both a literal and its negation?" - << "BOOOOOOOOOOOOOOOOOOOOOO!!!!"<< endl; +// Warning() << "arith: Theory engine is sending me both a literal and its negation?" +// << "BOOOOOOOOOOOOOOOOOOOOOO!!!!"<< endl; } Debug("arith::eq") << constraint << endl; Debug("arith::eq") << negation << endl; diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 80bcb47dd..1dd74f060 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -23,17 +23,13 @@ #include #include "theory/rewriter.h" #include "expr/command.h" -#include "theory/uf/equality_engine_impl.h" - using namespace std; - namespace CVC4 { namespace theory { namespace arrays { - // These are the options that produce the best empirical results on QF_AX benchmarks. // eagerLemmas = true // eagerIndexSplitting = false @@ -58,14 +54,12 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC d_numNonLinear("theory::arrays::number of calls to setNonLinear", 0), d_numSharedArrayVarSplits("theory::arrays::number of shared array var splits", 0), d_checkTimer("theory::arrays::checkTime"), - d_ppNotify(), - d_ppEqualityEngine(d_ppNotify, u, "theory::arrays::TheoryArraysPP"), + d_ppEqualityEngine(u, "theory::arrays::TheoryArraysPP"), d_ppFacts(u), // d_ppCache(u), d_literalsToPropagate(c), d_literalsToPropagateIndex(c, 0), - d_mayEqualNotify(), - d_mayEqualEqualityEngine(d_mayEqualNotify, c, "theory::arrays::TheoryArraysMayEqual"), + d_mayEqualEqualityEngine(c, "theory::arrays::TheoryArraysMayEqual"), d_notify(*this), d_equalityEngine(d_notify, c, "theory::arrays::TheoryArrays"), d_conflict(c, false), @@ -91,14 +85,6 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); - d_ppEqualityEngine.addTerm(d_true); - d_ppEqualityEngine.addTerm(d_false); - d_ppEqualityEngine.addTriggerEquality(d_true, d_false, d_false); - - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); - // The kinds we are treating as function application in congruence d_equalityEngine.addFunctionKind(kind::SELECT); if (d_ccStore) { @@ -281,7 +267,7 @@ Theory::PPAssertStatus TheoryArrays::ppAssert(TNode in, SubstitutionMap& outSubs case kind::EQUAL: { d_ppFacts.push_back(in); - d_ppEqualityEngine.addEquality(in[0], in[1], in); + d_ppEqualityEngine.assertEquality(in, true, in); if (in[0].getMetaKind() == kind::metakind::VARIABLE && !in[1].hasSubterm(in[0])) { outSubstitutions.addSubstitution(in[0], in[1]); return PP_ASSERT_STATUS_SOLVED; @@ -299,7 +285,7 @@ Theory::PPAssertStatus TheoryArrays::ppAssert(TNode in, SubstitutionMap& outSubs in[0].getKind() == kind::IFF ); Node a = in[0][0]; Node b = in[0][1]; - d_ppEqualityEngine.addDisequality(a, b, in); + d_ppEqualityEngine.assertEquality(in[0], false, in); break; } default: @@ -335,10 +321,8 @@ bool TheoryArrays::propagate(TNode literal) Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::propagate(" << literal << ", normalized = " << normalized << ") => conflict" << std::endl; std::vector assumptions; Node negatedLiteral; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -357,67 +341,40 @@ bool TheoryArrays::propagate(TNode literal) void TheoryArrays::explain(TNode literal, std::vector& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::SELECT: - lhs = literal; - rhs = d_true; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } else { - // Predicates - lhs = literal[0]; - rhs = d_false; - break; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + // Do the work + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); } - /** - * Stores in d_infoMap the following information for each term a of type array: - * - * - all i, such that there exists a term a[i] or a = store(b i v) - * (i.e. all indices it is being read atl; store(b i v) is implicitly read at - * position i due to the implicit axiom store(b i v)[i] = v ) - * - * - all the stores a is congruent to (this information is context dependent) - * - * - all store terms of the form store (a i v) (i.e. in which a appears - * directly; this is invariant because no new store terms are created) - * - * Note: completeness depends on having pre-register called on all the input - * terms before starting to instantiate lemmas. - */ +/** + * Stores in d_infoMap the following information for each term a of type array: + * + * - all i, such that there exists a term a[i] or a = store(b i v) + * (i.e. all indices it is being read atl; store(b i v) is implicitly read at + * position i due to the implicit axiom store(b i v)[i] = v ) + * + * - all the stores a is congruent to (this information is context dependent) + * + * - all store terms of the form store (a i v) (i.e. in which a appears + * directly; this is invariant because no new store terms are created) + * + * Note: completeness depends on having pre-register called on all the input + * terms before starting to instantiate lemmas. + */ void TheoryArrays::preRegisterTerm(TNode node) { Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::preRegisterTerm(" << node << ")" << std::endl; switch (node.getKind()) { case kind::EQUAL: - // Add the terms - // d_equalityEngine.addTerm(node[0]); - // d_equalityEngine.addTerm(node[1]); - d_equalityEngine.addTerm(node); // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node[0], node[1], node); - d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); + d_equalityEngine.addTriggerEquality(node); break; case kind::SELECT: { // Reads @@ -438,7 +395,7 @@ void TheoryArrays::preRegisterTerm(TNode node) Assert(!d_equalityEngine.hasTerm(ni)); preRegisterTerm(ni); } - d_equalityEngine.addEquality(ni, s[2], d_true); + d_equalityEngine.assertEquality(ni.eqNode(s[2]), true, d_true); Assert(++it == stores->end()); } } @@ -447,8 +404,7 @@ void TheoryArrays::preRegisterTerm(TNode node) // TODO: remove this or keep it if we allow Boolean elements in arrays. if (node.getType().isBoolean()) { // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerEquality(node, d_true, node); - d_equalityEngine.addTriggerEquality(node, d_false, node.notNode()); + d_equalityEngine.addTriggerPredicate(node); } d_infoMap.addIndex(node[0], node[1]); @@ -463,7 +419,7 @@ void TheoryArrays::preRegisterTerm(TNode node) // TNode i = node[1]; // TNode v = node[2]; - d_mayEqualEqualityEngine.addEquality(node, a, d_true); + d_mayEqualEqualityEngine.assertEquality(node.eqNode(a), true, d_true); // NodeManager* nm = NodeManager::currentNM(); // Node ni = nm->mkNode(kind::SELECT, node, i); @@ -508,10 +464,8 @@ void TheoryArrays::propagate(Effort e) Debug("arrays") << spaces(getSatContext()->getLevel()) << "TheoryArrays::propagate(): in conflict, normalized = " << normalized << std::endl; Node negatedLiteral; std::vector assumptions; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -727,17 +681,17 @@ void TheoryArrays::check(Effort e) { // Do the work switch (fact.getKind()) { case kind::EQUAL: - d_equalityEngine.addEquality(fact[0], fact[1], fact); + d_equalityEngine.assertEquality(fact, true, fact); break; case kind::SELECT: - d_equalityEngine.addPredicate(fact, true, fact); + d_equalityEngine.assertPredicate(fact, true, fact); break; case kind::NOT: if (fact[0].getKind() == kind::SELECT) { - d_equalityEngine.addPredicate(fact[0], false, fact); + d_equalityEngine.assertPredicate(fact[0], false, fact); } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1])) { // Assert the dis-equality - d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact); + d_equalityEngine.assertEquality(fact[0], false, fact); // Apply ArrDiseq Rule if diseq is between arrays if(fact[0][0].getType().isArray()) { @@ -764,7 +718,7 @@ void TheoryArrays::check(Effort e) { if (!d_equalityEngine.hasTerm(bk)) { preRegisterTerm(bk); } - d_equalityEngine.addDisequality(ak, bk, fact); + d_equalityEngine.assertEquality(ak.eqNode(bk), false, fact); Trace("arrays-lem")<<"Arrays::addExtLemma "<< ak << " /= " << bk <<"\n"; ++d_numExt; } @@ -807,14 +761,11 @@ Node TheoryArrays::mkAnd(std::vector& conjunctions) for (; i < conjunctions.size(); ++i) { t = conjunctions[i]; - // Remove true node - represents axiomatically true assertion - if (t == d_true) continue; - // Expand explanation resulting from propagating a ROW lemma if (t.getKind() == kind::OR) { if ((explained.find(t) == explained.end())) { Assert(t[1].getKind() == kind::EQUAL); - d_equalityEngine.explainDisequality(t[1][0], t[1][1], conjunctions); + d_equalityEngine.explainEquality(t[1][0], t[1][1], false, conjunctions); explained.insert(t); } continue; @@ -949,7 +900,7 @@ void TheoryArrays::checkRIntro1(TNode a, TNode b) Node ni = nm->mkNode(kind::SELECT, s, s[1]); Assert(!d_equalityEngine.hasTerm(ni)); preRegisterTerm(ni); - d_equalityEngine.addEquality(ni, s[2], d_true); + d_equalityEngine.assertEquality(ni.eqNode(s[2]), true, d_true); } } @@ -1004,7 +955,7 @@ void TheoryArrays::mergeArrays(TNode a, TNode b) } } - d_mayEqualEqualityEngine.addEquality(a, b, d_true); + d_mayEqualEqualityEngine.assertEquality(a.eqNode(b), true, d_true); checkRowLemmas(a,b); checkRowLemmas(b,a); @@ -1186,7 +1137,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) if (!bjExists) { preRegisterTerm(bj); } - d_equalityEngine.addEquality(aj, bj, reason); + d_equalityEngine.assertEquality(aj.eqNode(bj), true, reason); ++d_numProp; return; } @@ -1194,7 +1145,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating i = j ("<mkNode(kind::OR, i.eqNode(j), aj.eqNode(bj)); d_permRef.push_back(reason); - d_equalityEngine.addEquality(i, j, reason); + d_equalityEngine.assertEquality(i.eqNode(j), true, reason); ++d_numProp; return; } diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index d18b3abde..88986ee7a 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -133,18 +133,8 @@ class TheoryArrays : public Theory { private: - // PPNotifyClass: dummy template class for d_ppEqualityEngine - notifications not used - class PPNotifyClass { - public: - bool notify(TNode propagation) { return true; } - void notify(TNode t1, TNode t2) { } - }; - - /** The notify class for d_ppEqualityEngine */ - PPNotifyClass d_ppNotify; - /** Equaltity engine */ - uf::EqualityEngine d_ppEqualityEngine; + eq::EqualityEngine d_ppEqualityEngine; // List of facts learned by preprocessor - needed for permanent ref for benefit of d_ppEqualityEngine context::CDList d_ppFacts; @@ -187,17 +177,8 @@ class TheoryArrays : public Theory { private: - class MayEqualNotifyClass { - public: - bool notify(TNode propagation) { return true; } - void notify(TNode t1, TNode t2) { } - }; - - /** The notify class for d_mayEqualEqualityEngine */ - MayEqualNotifyClass d_mayEqualNotify; - /** Equaltity engine for determining if two arrays might be equal */ - uf::EqualityEngine d_mayEqualEqualityEngine; + eq::EqualityEngine d_mayEqualEqualityEngine; public: @@ -238,37 +219,57 @@ class TheoryArrays : public Theory { private: // NotifyClass: template helper class for d_equalityEngine - handles call-back from congruence closure module - class NotifyClass { + class NotifyClass : public eq::EqualityEngineNotify { TheoryArrays& d_arrays; public: NotifyClass(TheoryArrays& arrays): d_arrays(arrays) {} - bool notify(TNode propagation) { - Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::notify(" << propagation << ")" << std::endl; + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; // Just forward to arrays - return d_arrays.propagate(propagation); + if (value) { + return d_arrays.propagate(equality); + } else { + return d_arrays.propagate(equality.notNode()); + } + } + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Unreachable(); } - void notify(TNode t1, TNode t2) { - Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; - if (t1.getType().isArray()) { - d_arrays.mergeArrays(t1, t2); - if (!d_arrays.isShared(t1) || !d_arrays.isShared(t2)) { - return; + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ")" << std::endl; + if (value) { + if (t1.getType().isArray()) { + d_arrays.mergeArrays(t1, t2); + if (!d_arrays.isShared(t1) || !d_arrays.isShared(t2)) { + return true; + } } + // Propagate equality between shared terms + Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); + d_arrays.propagate(equality); } - // Propagate equality between shared terms - Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); - d_arrays.propagate(equality); + // TODO: implement negation propagation + return true; } - }; + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl; + if (Theory::theoryOf(t1) == THEORY_BOOL) { + return d_arrays.propagate(t1.iffNode(t2)); + } else { + return d_arrays.propagate(t1.eqNode(t2)); + } + } + }; /** The notify class for d_equalityEngine */ NotifyClass d_notify; /** Equaltity engine */ - uf::EqualityEngine d_equalityEngine; + eq::EqualityEngine d_equalityEngine; // Are we in conflict? context::CDO d_conflict; diff --git a/src/theory/booleans/circuit_propagator.h b/src/theory/booleans/circuit_propagator.h index 78221a617..f5e4f4630 100644 --- a/src/theory/booleans/circuit_propagator.h +++ b/src/theory/booleans/circuit_propagator.h @@ -79,17 +79,17 @@ private: template class DataClearer : context::ContextNotifyObj { T& d_data; + protected: + void contextNotifyPop() { + Trace("circuit-prop") << "CircuitPropagator::DataClearer: clearing data " + << "(size was " << d_data.size() << ")" << std::endl; + d_data.clear(); + } public: DataClearer(context::Context* context, T& data) : context::ContextNotifyObj(context), d_data(data) { } - - void notify() { - Trace("circuit-prop") << "CircuitPropagator::DataClearer: clearing data " - << "(size was " << d_data.size() << ")" << std::endl; - d_data.clear(); - } };/* class DataClearer */ /** diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index c9d58574e..4076a7ee0 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -21,7 +21,6 @@ #include "theory/bv/theory_bv_utils.h" #include "theory/valuation.h" #include "theory/bv/bv_sat.h" -#include "theory/uf/equality_engine_impl.h" using namespace CVC4; using namespace CVC4::theory; @@ -52,18 +51,7 @@ TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel& d_toBitBlast(c), d_propagatedBy(c) { - d_true = utils::mkTrue(); - d_false = utils::mkFalse(); - if (d_useEqualityEngine) { - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); - - // add disequality between 0 and 1 bits - d_equalityEngine.addDisequality(utils::mkConst(BitVector((unsigned)1, (unsigned)0)), - utils::mkConst(BitVector((unsigned)1, (unsigned)1)), - d_true); // The kinds we are treating as function application in congruence d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT); @@ -137,11 +125,8 @@ void TheoryBV::preRegisterTerm(TNode node) { if (d_useEqualityEngine) { switch (node.getKind()) { case kind::EQUAL: - // Add the terms - d_equalityEngine.addTerm(node); // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node[0], node[1], node); - d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); + d_equalityEngine.addTriggerEquality(node); break; default: d_equalityEngine.addTerm(node); @@ -185,15 +170,15 @@ void TheoryBV::check(Effort e) if (predicate.getKind() == kind::EQUAL) { if (negated) { // dis-equality - d_equalityEngine.addDisequality(predicate[0], predicate[1], fact); + d_equalityEngine.assertEquality(predicate, false, fact); } else { // equality - d_equalityEngine.addEquality(predicate[0], predicate[1], fact); + d_equalityEngine.assertEquality(predicate, true, fact); } } else { // Adding predicate if the congruence over it is turned on if (d_equalityEngine.isFunctionKind(predicate.getKind())) { - d_equalityEngine.addPredicate(predicate, !negated, fact); + d_equalityEngine.assertPredicate(predicate, !negated, fact); } } } @@ -279,16 +264,16 @@ void TheoryBV::propagate(Effort e) { bool satValue; if (!d_valuation.hasSatValue(normalized, satValue) || satValue) { // check if we already propagated the negation - Node neg_literal = literal.getKind() == kind::NOT ? (Node)literal[0] : mkNot(literal); - if (d_alreadyPropagatedSet.find(neg_literal) != d_alreadyPropagatedSet.end()) { + Node negLiteral = literal.getKind() == kind::NOT ? (Node)literal[0] : mkNot(literal); + if (d_alreadyPropagatedSet.find(negLiteral) != d_alreadyPropagatedSet.end()) { Debug("bitvector") << spaces(getSatContext()->getLevel()) << "TheoryBV::propagate(): in conflict " << literal << " and its negation both propagated \n"; // we are in conflict std::vector assumptions; explain(literal, assumptions); - explain(neg_literal, assumptions); + explain(negLiteral, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; - return; + return; } BVDebug("bitvector") << spaces(getSatContext()->getLevel()) << "TheoryBV::propagate(): " << literal << std::endl; @@ -299,10 +284,8 @@ void TheoryBV::propagate(Effort e) { Node negatedLiteral; std::vector assumptions; - if (normalized != d_false) { negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); assumptions.push_back(negatedLiteral); - } explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -352,8 +335,6 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) // If propagated already, just skip PropagatedMap::const_iterator find = d_propagatedBy.find(literal); if (find != d_propagatedBy.end()) { - //unsigned theories = (*find).second | (unsigned) subtheory; - //d_propagatedBy[literal] = theories; return true; } else { d_propagatedBy[literal] = subtheory; @@ -362,56 +343,37 @@ bool TheoryBV::storePropagation(TNode literal, SubTheory subtheory) // See if the literal has been asserted already bool satValue = false; bool hasSatValue = d_valuation.hasSatValue(literal, satValue); - // If asserted, we might be in conflict + // If asserted, we might be in conflict if (hasSatValue && !satValue) { - Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => conflict" << std::endl; - std::vector assumptions; - Node negatedLiteral = literal.getKind() == kind::NOT ? (Node) literal[0] : literal.notNode(); - assumptions.push_back(negatedLiteral); - explain(literal, assumptions); - d_conflictNode = mkAnd(assumptions); - d_conflict = true; - return false; + Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => conflict" << std::endl; + std::vector assumptions; + Node negatedLiteral = literal.getKind() == kind::NOT ? (Node) literal[0] : literal.notNode(); + assumptions.push_back(negatedLiteral); + explain(literal, assumptions); + d_conflictNode = mkAnd(assumptions); + d_conflict = true; + return false; } // Nothing, just enqueue it for propagation and mark it as asserted already Debug("bitvector-prop") << spaces(getSatContext()->getLevel()) << "TheoryBV::storePropagation(" << literal << ") => enqueuing for propagation" << std::endl; d_literalsToPropagate.push_back(literal); + // No conflict return true; }/* TheoryBV::propagate(TNode) */ void TheoryBV::explain(TNode literal, std::vector& assumptions) { - if (propagatedBy(literal, SUB_EQUALITY)) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } else { - // Predicates - lhs = literal[0]; - rhs = d_false; - break; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); } else { Assert(propagatedBy(literal, SUB_BITBLASTER)); d_bitblaster->explain(literal, assumptions); @@ -430,7 +392,9 @@ Node TheoryBV::explain(TNode node) { return utils::mkTrue(); } // return the explanation - return mkAnd(assumptions); + Node explanation = mkAnd(assumptions); + Debug("bitvector::explain") << "TheoryBV::explain(" << node << ") => " << explanation << std::endl; + return explanation; } diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 0ced179ec..e46d052f8 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -61,8 +61,6 @@ private: /** Bitblaster */ Bitblaster* d_bitblaster; - Node d_true; - Node d_false; /** Context dependent set of atoms we already propagated */ context::CDHashSet d_alreadyPropagatedSet; @@ -99,22 +97,44 @@ private: // Added by Clark // NotifyClass: template helper class for d_equalityEngine - handles call-back from congruence closure module - class NotifyClass { + class NotifyClass : public eq::EqualityEngineNotify { + TheoryBV& d_bv; + public: + NotifyClass(TheoryBV& uf): d_bv(uf) {} - bool notify(TNode propagation) { - Debug("bitvector") << spaces(d_bv.getSatContext()->getLevel()) << "NotifyClass::notify(" << propagation << ")" << std::endl; - // Just forward to bv - return d_bv.storePropagation(propagation, SUB_EQUALITY); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("bitvector") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_bv.storePropagation(equality, SUB_EQUALITY); + } else { + return d_bv.storePropagation(equality.notNode(), SUB_EQUALITY); + } + } + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Debug("bitvector") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_bv.storePropagation(predicate, SUB_EQUALITY); + } else { + return d_bv.storePropagation(predicate, SUB_EQUALITY); + } + } + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("bitvector") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << std::endl; + if (value) { + return d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY); + } else { + return d_bv.storePropagation(t1.eqNode(t2).notNode(), SUB_EQUALITY); + } } - void notify(TNode t1, TNode t2) { - Debug("arrays") << spaces(d_bv.getSatContext()->getLevel()) << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; - // Propagate equality between shared terms - Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); - d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY); + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("bitvector") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl; + return d_bv.storePropagation(t1.eqNode(t2), SUB_EQUALITY); } }; @@ -122,7 +142,7 @@ private: NotifyClass d_notify; /** Equaltity engine */ - uf::EqualityEngine d_equalityEngine; + eq::EqualityEngine d_equalityEngine; // Are we in conflict? context::CDO d_conflict; diff --git a/src/theory/datatypes/union_find.cpp b/src/theory/datatypes/union_find.cpp index eacc4e798..34706719e 100644 --- a/src/theory/datatypes/union_find.cpp +++ b/src/theory/datatypes/union_find.cpp @@ -31,7 +31,7 @@ namespace theory { namespace datatypes { template -void UnionFind::notify() { +void UnionFind::contextNotifyPop() { Trace("datatypesuf") << "datatypesUF cancelling : " << d_offset << " < " << d_trace.size() << " ?" << endl; while(d_offset < d_trace.size()) { pair p = d_trace.back(); @@ -50,9 +50,9 @@ void UnionFind::notify() { // The following declarations allow us to put functions in the .cpp file // instead of the header, since we know which instantiations are needed. -template void UnionFind::notify(); +template void UnionFind::contextNotifyPop(); -template void UnionFind::notify(); +template void UnionFind::contextNotifyPop(); }/* CVC4::theory::datatypes namespace */ }/* CVC4::theory namespace */ diff --git a/src/theory/datatypes/union_find.h b/src/theory/datatypes/union_find.h index 51d1d85bc..4893c3502 100644 --- a/src/theory/datatypes/union_find.h +++ b/src/theory/datatypes/union_find.h @@ -84,13 +84,13 @@ public: */ inline void setCanon(TNode n, TNode newParent); +protected: -public: /** * Called by the Context when a pop occurs. Cancels everything to the - * current context level. Overrides ContextNotifyObj::notify(). + * current context level. Overrides ContextNotifyObj::contextNotifyPop(). */ - void notify(); + void contextNotifyPop(); };/* class UnionFind<> */ diff --git a/src/theory/shared_terms_database.cpp b/src/theory/shared_terms_database.cpp index 577e1b957..4f5475e97 100644 --- a/src/theory/shared_terms_database.cpp +++ b/src/theory/shared_terms_database.cpp @@ -16,7 +16,6 @@ **/ #include "theory/shared_terms_database.h" -#include "theory/uf/equality_engine_impl.h" using namespace CVC4; using namespace theory; @@ -36,15 +35,8 @@ SharedTermsDatabase::SharedTermsDatabase(SharedTermsNotifyClass& notify, context d_equalityEngine(d_EENotify, context, "SharedTermsDatabase") { StatisticsRegistry::registerStat(&d_statSharedTerms); - NodeManager* nm = NodeManager::currentNM(); - d_true = nm->mkConst(true); - d_false = nm->mkConst(false); - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); } - SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException) { StatisticsRegistry::unregisterStat(&d_statSharedTerms); @@ -53,9 +45,9 @@ SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException) } } - void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) { Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl; + std::pair search_pair(atom, term); SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair); if (find == d_termsToTheories.end()) { @@ -243,23 +235,21 @@ bool SharedTermsDatabase::areDisequal(TNode a, TNode b) { return d_equalityEngine.areDisequal(a,b); } - void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason) { bool negated = literal.getKind() == kind::NOT; TNode atom = negated ? literal[0] : literal; if (negated) { Assert(!d_equalityEngine.areDisequal(atom[0], atom[1])); - d_equalityEngine.addDisequality(atom[0], atom[1], reason); + d_equalityEngine.assertEquality(atom, false, reason); // !!! need to send this out } else { Assert(!d_equalityEngine.areEqual(atom[0], atom[1])); - d_equalityEngine.addEquality(atom[0], atom[1], reason); + d_equalityEngine.assertEquality(atom, true, reason); } } - static Node mkAnd(const std::vector& conjunctions) { Assert(conjunctions.size() > 0); @@ -286,31 +276,12 @@ static Node mkAnd(const std::vector& conjunctions) { Node SharedTermsDatabase::explain(TNode literal) { std::vector assumptions; - explain(literal, assumptions); - return mkAnd(assumptions); -} - - -void SharedTermsDatabase::explain(TNode literal, std::vector& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + if (literal.getKind() == kind::NOT) { + Assert(literal[0].getKind() == kind::EQUAL); + d_equalityEngine.explainEquality(literal[0][0], literal[0][1], false, assumptions); + } else { + Assert(literal.getKind() == kind::EQUAL); + d_equalityEngine.explainEquality(literal[0], literal[1], true, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); + return mkAnd(assumptions); } diff --git a/src/theory/shared_terms_database.h b/src/theory/shared_terms_database.h index 6af7fd41f..403c90ced 100644 --- a/src/theory/shared_terms_database.h +++ b/src/theory/shared_terms_database.h @@ -28,25 +28,23 @@ class SharedTermsDatabase : public context::ContextNotifyObj { public: - /** A conainer for a list of shared terms */ + /** A container for a list of shared terms */ typedef std::vector shared_terms_list; - /** The iterator to go rhough the shared terms list */ + + /** The iterator to go through the shared terms list */ typedef shared_terms_list::const_iterator shared_terms_iterator; private: - Node d_true; - - Node d_false; - /** The context */ context::Context* d_context; /** Some statistics */ IntStat d_statSharedTerms; - // Needs to be a map from Nodes as after a backtrack they might not exist + // Needs to be a map from Nodes as after a backtrack they might not exist typedef std::hash_map SharedTermsMap; + /** A map from atoms to a list of shared terms */ SharedTermsMap d_atomsToTerms; @@ -57,14 +55,17 @@ private: context::CDO d_addedSharedTermsSize; typedef context::CDHashMap, theory::Theory::Set, TNodePairHashFunction> SharedTermsTheoriesMap; + /** A map from atoms and subterms to the theories that use it */ SharedTermsTheoriesMap d_termsToTheories; typedef context::CDHashMap AlreadyNotifiedMap; + /** Map from term to theories that have already been notified about the shared term */ AlreadyNotifiedMap d_alreadyNotifiedMap; public: + /** Class for notifications about new shared term equalities */ class SharedTermsNotifyClass { public: @@ -74,6 +75,7 @@ public: }; private: + // Instance of class to send shared term notifications to SharedTermsNotifyClass& d_sharedNotify; @@ -101,21 +103,37 @@ private: void backtrack(); // EENotifyClass: template helper class for d_equalityEngine - handles call-backs - class EENotifyClass { + class EENotifyClass : public theory::eq::EqualityEngineNotify { SharedTermsDatabase& d_sharedTerms; public: EENotifyClass(SharedTermsDatabase& shared): d_sharedTerms(shared) {} - bool notify(TNode propagation) { return true; } // Not used - void notify(TNode t1, TNode t2) { - d_sharedTerms.mergeSharedTerms(t1, t2); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Unreachable(); + return true; + } + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Unreachable(); + return true; + } + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + if (value) { + d_sharedTerms.mergeSharedTerms(t1, t2); + } + return true; + } + + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + return true; } }; /** The notify class for d_equalityEngine */ EENotifyClass d_EENotify; - /** Equaltity engine */ - theory::uf::EqualityEngine d_equalityEngine; + /** Equality engine */ + theory::eq::EqualityEngine d_equalityEngine; /** Attach a new notify list to an equivalence class representative */ NotifyList* getNewNotifyList(); @@ -123,9 +141,6 @@ private: /** Method called by equalityEngine when a becomes equal to b */ void mergeSharedTerms(TNode a, TNode b); - /** Internal explanation method */ - void explain(TNode literal, std::vector& assumptions); - public: SharedTermsDatabase(SharedTermsNotifyClass& notify, context::Context* context); @@ -179,10 +194,12 @@ public: Node explain(TNode literal); +protected: + /** * This method gets called on backtracks from the context manager. */ - void notify() { + void contextNotifyPop() { backtrack(); } }; diff --git a/src/theory/substitutions.h b/src/theory/substitutions.h index 27c1a2b69..958f50276 100644 --- a/src/theory/substitutions.h +++ b/src/theory/substitutions.h @@ -73,16 +73,16 @@ private: /** Helper class to invalidate cache on user pop */ class CacheInvalidator : public context::ContextNotifyObj { bool& d_cacheInvalidated; - + protected: + void contextNotifyPop() { + d_cacheInvalidated = true; + } public: CacheInvalidator(context::Context* context, bool& cacheInvalidated) : context::ContextNotifyObj(context), d_cacheInvalidated(cacheInvalidated) { } - void notify() { - d_cacheInvalidated = true; - } };/* class SubstitutionMap::CacheInvalidator */ /** diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index a3aee985d..c19bdda91 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -124,6 +124,50 @@ void TheoryEngine::preRegister(TNode preprocessed) { // } } +void TheoryEngine::printAssertions(const char* tag) { + if (Debug.isOn(tag)) { + for (TheoryId theoryId = THEORY_FIRST; theoryId < THEORY_LAST; ++theoryId) { + Theory* theory = d_theoryTable[theoryId]; + if (theory && d_logicInfo.isTheoryEnabled(theoryId)) { + Debug(tag) << "--------------------------------------------" << std::endl; + Debug(tag) << "Assertions of " << theory->getId() << ": " << std::endl; + context::CDList::const_iterator it = theory->facts_begin(), it_end = theory->facts_end(); + for (unsigned i = 0; it != it_end; ++ it, ++i) { + if ((*it).isPreregistered) { + Debug(tag) << "[" << i << "]: "; + } else { + Debug(tag) << "(" << i << "): "; + } + Debug(tag) << (*it).assertion << endl; + } + + if (d_logicInfo.isSharingEnabled()) { + Debug(tag) << "Shared terms of " << theory->getId() << ": " << std::endl; + context::CDList::const_iterator it = theory->shared_terms_begin(), it_end = theory->shared_terms_end(); + for (unsigned i = 0; it != it_end; ++ it, ++i) { + Debug(tag) << "[" << i << "]: " << (*it) << endl; + } + } + } + } + + } +} + +template +class scoped_vector_clear { + vector& d_v; +public: + scoped_vector_clear(vector& v) + : d_v(v) { + Assert(!doAssert || d_v.empty()); + } + ~scoped_vector_clear() { + d_v.clear(); + } + +}; + /** * Check all (currently-active) theories for conflicts. * @param effort the effort level to use @@ -143,12 +187,12 @@ void TheoryEngine::check(Theory::Effort effort) { } \ } + // make sure d_propagatedSharedLiterals is cleared on exit + scoped_vector_clear clear_shared_literals(d_propagatedSharedLiterals); + // Do the checking try { - // Clear any leftover propagated shared literals - d_propagatedSharedLiterals.clear(); - // Mark the output channel unused (if this is FULL_EFFORT, and nothing // is done by the theories, no additional check will be needed) d_outputChannelUsed = false; @@ -159,32 +203,10 @@ void TheoryEngine::check(Theory::Effort effort) { while (true) { Debug("theory") << "TheoryEngine::check(" << effort << "): running check" << std::endl; + Assert(d_propagatedSharedLiterals.empty()); if (Debug.isOn("theory::assertions")) { - for (TheoryId theoryId = THEORY_FIRST; theoryId < THEORY_LAST; ++theoryId) { - Theory* theory = d_theoryTable[theoryId]; - if (theory && d_logicInfo.isTheoryEnabled(theoryId)) { - Debug("theory::assertions") << "--------------------------------------------" << std::endl; - Debug("theory::assertions") << "Assertions of " << theory->getId() << ": " << std::endl; - context::CDList::const_iterator it = theory->facts_begin(), it_end = theory->facts_end(); - for (unsigned i = 0; it != it_end; ++ it, ++i) { - if ((*it).isPreregistered) { - Debug("theory::assertions") << "[" << i << "]: "; - } else { - Debug("theory::assertions") << "(" << i << "): "; - } - Debug("theory::assertions") << (*it).assertion << endl; - } - - if (d_logicInfo.isSharingEnabled()) { - Debug("theory::assertions") << "Shared terms of " << theory->getId() << ": " << std::endl; - context::CDList::const_iterator it = theory->shared_terms_begin(), it_end = theory->shared_terms_end(); - for (unsigned i = 0; it != it_end; ++ it, ++i) { - Debug("theory::assertions") << "[" << i << "]: " << (*it) << endl; - } - } - } - } + printAssertions("theory::assertions"); } // Do the checking @@ -232,9 +254,6 @@ void TheoryEngine::check(Theory::Effort effort) { } } - // Clear any leftover propagated shared literals - d_propagatedSharedLiterals.clear(); - Debug("theory") << "TheoryEngine::check(" << effort << "): done, we are " << (d_inConflict ? "unsat" : "sat") << (d_lemmasAdded ? " with new lemmas" : " with no new lemmas") << std::endl; } catch(const theory::Interrupted&) { @@ -243,6 +262,9 @@ void TheoryEngine::check(Theory::Effort effort) { } void TheoryEngine::outputSharedLiterals() { + + scoped_vector_clear clear_shared_literals(d_propagatedSharedLiterals); + // Assert all the shared literals for (unsigned i = 0; i < d_propagatedSharedLiterals.size(); ++ i) { const SharedLiteral& eq = d_propagatedSharedLiterals[i]; @@ -258,8 +280,6 @@ void TheoryEngine::outputSharedLiterals() { } } } - // Clear the equalities - d_propagatedSharedLiterals.clear(); } @@ -269,7 +289,9 @@ void TheoryEngine::combineTheories() { TimerStat::CodeTimer combineTheoriesTimer(d_combineTheoriesTime); + // Care graph we'll be building CareGraph careGraph; + #ifdef CVC4_FOR_EACH_THEORY_STATEMENT #undef CVC4_FOR_EACH_THEORY_STATEMENT #endif @@ -278,6 +300,7 @@ void TheoryEngine::combineTheories() { reinterpret_cast::theory_class*>(theoryOf(THEORY))->getCareGraph(careGraph); \ } + // Call on each parametric theory to give us its care graph CVC4_FOR_EACH_THEORY; // Now add splitters for the ones we are interested in @@ -833,6 +856,8 @@ Node TheoryEngine::getExplanation(TNode node) { } Assert(properExplanation(node, explanation)); + Debug("theory::explain") << "TheoryEngine::getExplanation(" << node << ") => " << explanation << std::endl; + return explanation; } diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index 5c73da1f6..2871d5559 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -309,8 +309,10 @@ class TheoryEngine { } struct SharedLiteral { - /** The node/theory pair for the assertion */ - /** THEORY_LAST indicates this is a SAT literal and should be sent to the SAT solver */ + /** + * The node/theory pair for the assertion. THEORY_LAST indicates this is a SAT + * literal and should be sent to the SAT solver + */ NodeTheoryPair toAssert; /** This is the node that we will use to explain it */ Node toExplain; @@ -319,7 +321,7 @@ class TheoryEngine { : toAssert(assertion, receivingTheory), toExplain(original) { } - };/* struct SharedLiteral */ + }; /** * Map from nodes to theories. @@ -728,6 +730,9 @@ private: /** Visitor for collecting shared terms */ SharedTermsVisitor d_sharedTermsVisitor; + /** Prints the assertions to the debug stream */ + void printAssertions(const char* tag); + };/* class TheoryEngine */ }/* CVC4 namespace */ diff --git a/src/theory/uf/Makefile.am b/src/theory/uf/Makefile.am index f25e50ec9..9d95eaa22 100644 --- a/src/theory/uf/Makefile.am +++ b/src/theory/uf/Makefile.am @@ -11,7 +11,7 @@ libuf_la_SOURCES = \ theory_uf_type_rules.h \ theory_uf_rewriter.h \ equality_engine.h \ - equality_engine_impl.h \ + equality_engine.cpp \ symmetry_breaker.h \ symmetry_breaker.cpp diff --git a/src/theory/uf/equality_engine_impl.h b/src/theory/uf/equality_engine.cpp similarity index 73% rename from src/theory/uf/equality_engine_impl.h rename to src/theory/uf/equality_engine.cpp index be12e5f19..b78015c00 100644 --- a/src/theory/uf/equality_engine_impl.h +++ b/src/theory/uf/equality_engine.cpp @@ -17,15 +17,27 @@ ** \todo document this file **/ -#include "cvc4_private.h" - -#pragma once - #include "theory/uf/equality_engine.h" namespace CVC4 { namespace theory { -namespace uf { +namespace eq { + +/** + * Data used in the BFS search through the equality graph. + */ +struct BfsData { + // The current node + EqualityNodeId nodeId; + // The index of the edge we traversed + EqualityEdgeId edgeId; + // Index in the queue of the previous node. Shouldn't be too much of them, at most the size + // of the biggest equivalence class + size_t previousIndex; + + BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0) + : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {} +}; class ScopedBool { bool& watch; @@ -40,20 +52,63 @@ public: } }; -template -void EqualityEngine::enqueue(const MergeCandidate& candidate) { +EqualityEngineNotifyNone EqualityEngine::s_notifyNone; + +void EqualityEngine::init() { + Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl; + Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl; + Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl; + d_true = NodeManager::currentNM()->mkConst(true); + d_false = NodeManager::currentNM()->mkConst(false); + addTerm(d_true); + addTerm(d_false); +} + + +EqualityEngine::EqualityEngine(context::Context* context, std::string name) +: ContextNotifyObj(context) +, d_context(context) +, d_performNotify(true) +, d_notify(s_notifyNone) +, d_applicationLookupsCount(context, 0) +, d_nodesCount(context, 0) +, d_assertedEqualitiesCount(context, 0) +, d_equalityTriggersCount(context, 0) +, d_individualTriggersSize(context, 0) +, d_constantRepresentativesSize(context, 0) +, d_stats(name) +{ + init(); +} + +EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name) +: ContextNotifyObj(context) +, d_context(context) +, d_performNotify(true) +, d_notify(notify) +, d_applicationLookupsCount(context, 0) +, d_nodesCount(context, 0) +, d_assertedEqualitiesCount(context, 0) +, d_equalityTriggersCount(context, 0) +, d_individualTriggersSize(context, 0) +, d_constantRepresentativesSize(context, 0) +, d_stats(name) +{ + init(); +} + +void EqualityEngine::enqueue(const MergeCandidate& candidate) { Debug("equality") << "EqualityEngine::enqueue(" << candidate.toString(*this) << ")" << std::endl; d_propagationQueue.push(candidate); } -template -EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) { +EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2) { Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ")" << std::endl; ++ d_stats.functionTermsCount; // Get another id for this - EqualityNodeId funId = newNode(original, true); + EqualityNodeId funId = newNode(original); FunctionApplication funOriginal(t1, t2); // The function application we're creating EqualityNodeId t1ClassId = getEqualityNode(t1).getFind(); @@ -87,10 +142,9 @@ EqualityNodeId EqualityEngine::newApplicationNode(TNode original, E return funId; } -template -EqualityNodeId EqualityEngine::newNode(TNode node, bool isApplication) { +EqualityNodeId EqualityEngine::newNode(TNode node) { - Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ")" << std::endl; + Debug("equality") << "EqualityEngine::newNode(" << node << ")" << std::endl; ++ d_stats.termsCount; @@ -107,20 +161,20 @@ EqualityNodeId EqualityEngine::newNode(TNode node, bool isApplicati d_equalityGraph.push_back(+null_edge); // Mark the no-individual trigger d_nodeIndividualTrigger.push_back(+null_id); + // Mark non-constant by default + d_constantRepresentative.push_back(node.isConst() ? newId : +null_id); // Add the equality node to the nodes d_equalityNodes.push_back(EqualityNode(newId)); // Increase the counters d_nodesCount = d_nodesCount + 1; - Debug("equality") << "EqualityEngine::newNode(" << node << ", " << (isApplication ? "function" : "regular") << ") => " << newId << std::endl; + Debug("equality") << "EqualityEngine::newNode(" << node << ") => " << newId << std::endl; return newId; } - -template -void EqualityEngine::addTerm(TNode t) { +void EqualityEngine::addTerm(TNode t) { Debug("equality") << "EqualityEngine::addTerm(" << t << ")" << std::endl; @@ -148,47 +202,40 @@ void EqualityEngine::addTerm(TNode t) { } } else { // Otherwise we just create the new id - result = newNode(t, false); + result = newNode(t); } Debug("equality") << "EqualityEngine::addTerm(" << t << ") => " << result << std::endl; } -template -bool EqualityEngine::hasTerm(TNode t) const { +bool EqualityEngine::hasTerm(TNode t) const { return d_nodeIds.find(t) != d_nodeIds.end(); } -template -EqualityNodeId EqualityEngine::getNodeId(TNode node) const { +EqualityNodeId EqualityEngine::getNodeId(TNode node) const { Assert(hasTerm(node), node.toString().c_str()); return (*d_nodeIds.find(node)).second; } -template -EqualityNode& EqualityEngine::getEqualityNode(TNode t) { +EqualityNode& EqualityEngine::getEqualityNode(TNode t) { return getEqualityNode(getNodeId(t)); } -template -EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) { +EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) { Assert(nodeId < d_equalityNodes.size()); return d_equalityNodes[nodeId]; } -template -const EqualityNode& EqualityEngine::getEqualityNode(TNode t) const { +const EqualityNode& EqualityEngine::getEqualityNode(TNode t) const { return getEqualityNode(getNodeId(t)); } -template -const EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) const { +const EqualityNode& EqualityEngine::getEqualityNode(EqualityNodeId nodeId) const { Assert(nodeId < d_equalityNodes.size()); return d_equalityNodes[nodeId]; } -template -void EqualityEngine::addEqualityInternal(TNode t1, TNode t2, TNode reason) { +void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) { Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl; @@ -204,55 +251,35 @@ void EqualityEngine::addEqualityInternal(TNode t1, TNode t2, TNode propagate(); } -template -void EqualityEngine::addPredicate(TNode t, bool polarity, TNode reason) { - +void EqualityEngine::assertPredicate(TNode t, bool polarity, TNode reason) { Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl; - - addEqualityInternal(t, polarity ? d_true : d_false, reason); -} - -template -void EqualityEngine::addEquality(TNode t1, TNode t2, TNode reason) { - - Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl; - - addEqualityInternal(t1, t2, reason); - - Node equality = t1.eqNode(t2); - addEqualityInternal(equality, d_true, reason); + Assert(t.getKind() != kind::EQUAL, "Use assertEquality instead"); + assertEqualityInternal(t, polarity ? d_true : d_false, reason); } -template -void EqualityEngine::addDisequality(TNode t1, TNode t2, TNode reason) { - - Debug("equality") << "EqualityEngine::addDisequality(" << t1 << "," << t2 << ")" << std::endl; - - Node equality1 = t1.eqNode(t2); - addEqualityInternal(equality1, d_false, reason); - - Node equality2 = t2.eqNode(t1); - addEqualityInternal(equality2, d_false, reason); +void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) { + Debug("equality") << "EqualityEngine::addEquality(" << eq << "," << (polarity ? "true" : "false") << std::endl; + if (polarity) { + // Add equality between terms + assertEqualityInternal(eq[0], eq[1], reason); + // Add eq = true for dis-equality propagation + assertEqualityInternal(eq, d_true, reason); + } else { + assertEqualityInternal(eq, d_false, reason); + Node eqSymm = eq[1].eqNode(eq[0]); + assertEqualityInternal(eqSymm, d_false, reason); + } } - -template -TNode EqualityEngine::getRepresentative(TNode t) const { - +TNode EqualityEngine::getRepresentative(TNode t) const { Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ")" << std::endl; - Assert(hasTerm(t)); - - // Both following commands are semantically const EqualityNodeId representativeId = getEqualityNode(t).getFind(); - Debug("equality::internal") << "EqualityEngine::getRepresentative(" << t << ") => " << d_nodes[representativeId] << std::endl; - return d_nodes[representativeId]; } -template -bool EqualityEngine::areEqual(TNode t1, TNode t2) const { +bool EqualityEngine::areEqual(TNode t1, TNode t2) const { Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; Assert(hasTerm(t1)); @@ -267,8 +294,7 @@ bool EqualityEngine::areEqual(TNode t1, TNode t2) const { return rep1 == rep2; } -template -void EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { +bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; @@ -357,10 +383,30 @@ void EqualityEngine::merge(EqualityNode& class1, EqualityNode& clas // Now merge the lists class1.merge(class2); - // Notfiy the triggers - EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id]; + // Check for constants + EqualityNodeId class2constId = d_constantRepresentative[class2Id]; + if (class2constId != +null_id) { + EqualityNodeId class1constId = d_constantRepresentative[class1Id]; + if (class1constId != +null_id) { + if (d_performNotify) { + TNode const1 = d_nodes[class1constId]; + TNode const2 = d_nodes[class2constId]; + if (!d_notify.eqNotifyConstantTermMerge(const1, const2)) { + return false; + } + } + } else { + // If the class we're merging in is constant, mark the representative as constant + d_constantRepresentative[class1Id] = d_constantRepresentative[class2Id]; + d_constantRepresentatives.push_back(class1Id); + d_constantRepresentativesSize = d_constantRepresentativesSize + 1; + } + } + + // Notify the trigger term merges EqualityNodeId class2triggerId = d_nodeIndividualTrigger[class2Id]; if (class2triggerId != +null_id) { + EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id]; if (class1triggerId == +null_id) { // If class1 is not an individual trigger, but class2 is, mark it d_nodeIndividualTrigger[class1Id] = class2triggerId; @@ -370,14 +416,18 @@ void EqualityEngine::merge(EqualityNode& class1, EqualityNode& clas } else { // Notify when done if (d_performNotify) { - d_notify.notify(d_nodes[class1triggerId], d_nodes[class2triggerId]); + if (!d_notify.eqNotifyTriggerTermEquality(d_nodes[class1triggerId], d_nodes[class2triggerId], true)) { + return false; + } } } } + + // Everything fine + return true; } -template -void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) { +void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id) { Debug("equality") << "EqualityEngine::undoMerge(" << class1.getFind() << "," << class2Id << ")" << std::endl; @@ -409,8 +459,7 @@ void EqualityEngine::undoMerge(EqualityNode& class1, EqualityNode& } -template -void EqualityEngine::backtrack() { +void EqualityEngine::backtrack() { Debug("equality::backtrack") << "backtracking" << std::endl; @@ -453,6 +502,14 @@ void EqualityEngine::backtrack() { d_individualTriggers.resize(d_individualTriggersSize); } + if (d_constantRepresentatives.size() > d_constantRepresentativesSize) { + // Unset the constant representatives + for (int i = d_constantRepresentatives.size() - 1, i_end = d_constantRepresentativesSize; i >= i_end; -- i) { + d_constantRepresentative[d_constantRepresentatives[i]] = +null_id; + } + d_constantRepresentatives.resize(d_constantRepresentativesSize); + } + if (d_equalityTriggers.size() > d_equalityTriggersCount) { // Unlink the triggers from the lists for (int i = d_equalityTriggers.size() - 1, i_end = d_equalityTriggersCount; i >= i_end; -- i) { @@ -492,13 +549,13 @@ void EqualityEngine::backtrack() { d_applications.resize(d_nodesCount); d_nodeTriggers.resize(d_nodesCount); d_nodeIndividualTrigger.resize(d_nodesCount); + d_constantRepresentative.resize(d_nodesCount); d_equalityGraph.resize(d_nodesCount); d_equalityNodes.resize(d_nodesCount); } } -template -void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) { +void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) { Debug("equality") << "EqualityEngine::addGraphEdge(" << d_nodes[t1] << "," << d_nodes[t2] << "," << reason << ")" << std::endl; EqualityEdgeId edge = d_equalityEdges.size(); d_equalityEdges.push_back(EqualityEdge(t2, d_equalityGraph[t1], type, reason)); @@ -511,8 +568,7 @@ void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId } } -template -std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { +std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { std::stringstream out; bool first = true; if (edgeId == null_edge) { @@ -529,70 +585,52 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) co return out.str(); } -template -void EqualityEngine::explainEquality(TNode t1, TNode t2, std::vector& equalities) { +void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector& equalities) { Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl; // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms (they might not be there) addTerm(t1); addTerm(t2); - Assert(getRepresentative(t1) == getRepresentative(t2), - "Cannot explain an equality, because the two terms are not equal!\n" - "The representative of %s\n" - " is %s\n" - "The representative of %s\n" - " is %s", - t1.toString().c_str(), getRepresentative(t1).toString().c_str(), - t2.toString().c_str(), getRepresentative(t2).toString().c_str()); - - // Get the explanation - EqualityNodeId t1Id = getNodeId(t1); - EqualityNodeId t2Id = getNodeId(t2); - getExplanation(t1Id, t2Id, equalities); - + if (polarity) { + // Get the explanation + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t2Id = getNodeId(t2); + getExplanation(t1Id, t2Id, equalities); + } else { + // Add the equality + Node equality = t1.eqNode(t2); + addTerm(equality); + + // Get the explanation + EqualityNodeId equalityId = getNodeId(equality); + EqualityNodeId falseId = getNodeId(d_false); + getExplanation(equalityId, falseId, equalities); + } } -template -void EqualityEngine::explainDisequality(TNode t1, TNode t2, std::vector& equalities) { - Debug("equality") << "EqualityEngine::explainDisequality(" << t1 << "," << t2 << ")" << std::endl; +void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector& assertions) { + Debug("equality") << "EqualityEngine::explainEquality(" << p << ")" << std::endl; // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms - addTerm(t1); - addTerm(t2); - - // Add the equality - Node equality = t1.eqNode(t2); - addTerm(equality); - - Assert(getRepresentative(equality) == getRepresentative(d_false), - "Cannot explain the dis-equality, because the two terms are not dis-equal!\n" - "The representative of %s\n" - " is %s\n" - "The representative of %s\n" - " is %s", - equality.toString().c_str(), getRepresentative(equality).toString().c_str(), - d_false.toString().c_str(), getRepresentative(d_false).toString().c_str()); - - // Get the explanation - EqualityNodeId equalityId = getNodeId(equality); - EqualityNodeId falseId = getNodeId(d_false); - getExplanation(equalityId, falseId, equalities); + addTerm(p); + // Get the explanation + getExplanation(getNodeId(p), getNodeId(polarity ? d_true : d_false), assertions); } - -template -void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector& equalities) const { +void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, std::vector& equalities) const { Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl; + Assert(getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind()); + // If the nodes are the same, we're done if (t1Id == t2Id) return; @@ -682,15 +720,28 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNo } } -template -void EqualityEngine::addTriggerDisequality(TNode t1, TNode t2, TNode trigger) { - Node equality = t1.eqNode(t2); - addTerm(equality); - addTriggerEquality(equality, d_false, trigger); +void EqualityEngine::addTriggerEquality(TNode eq) { + Assert(eq.getKind() == kind::EQUAL); + // Add the terms + addTerm(eq); + // Positive trigger + addTriggerEqualityInternal(eq[0], eq[1], eq, true); + // Negative trigger + addTriggerEqualityInternal(eq, d_false, eq, false); +} + +void EqualityEngine::addTriggerPredicate(TNode predicate) { + Assert(predicate.getKind() != kind::NOT && predicate.getKind() != kind::EQUAL); + Assert(d_congruenceKinds.tst(predicate.getKind()), "No point in adding non-congruence predicates"); + // Add the term + addTerm(predicate); + // Positive trigger + addTriggerEqualityInternal(predicate, d_true, predicate, true); + // Negative trigger + addTriggerEqualityInternal(predicate, d_false, predicate, false); } -template -void EqualityEngine::addTriggerEquality(TNode t1, TNode t2, TNode trigger) { +void EqualityEngine::addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity) { Debug("equality") << "EqualityEngine::addTrigger(" << t1 << ", " << t2 << ", " << trigger << ")" << std::endl; @@ -713,9 +764,9 @@ void EqualityEngine::addTriggerEquality(TNode t1, TNode t2, TNode t TriggerId t1NewTriggerId = d_equalityTriggers.size(); TriggerId t2NewTriggerId = t1NewTriggerId | 1; d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId)); - d_equalityTriggersOriginal.push_back(trigger); + d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId)); - d_equalityTriggersOriginal.push_back(trigger); + d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); // Update the counters d_equalityTriggersCount = d_equalityTriggersCount + 2; @@ -728,7 +779,7 @@ void EqualityEngine::addTriggerEquality(TNode t1, TNode t2, TNode t if (t1classId == t2classId) { Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl; if (d_performNotify) { - d_notify.notify(trigger); // Don't care about the return value + d_notify.eqNotifyTriggerEquality(trigger, polarity); // Don't care about the return value } } @@ -739,8 +790,7 @@ void EqualityEngine::addTriggerEquality(TNode t1, TNode t2, TNode t Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ") => (" << t1NewTriggerId << ", " << t2NewTriggerId << ")" << std::endl; } -template -void EqualityEngine::propagate() { +void EqualityEngine::propagate() { Debug("equality") << "EqualityEngine::propagate()" << std::endl; @@ -783,25 +833,29 @@ void EqualityEngine::propagate() { if (node2.getSize() > node1.getSize()) { Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl; d_assertedEqualities.push_back(Equality(t2classId, t1classId)); - merge(node2, node1, triggers); + done = !merge(node2, node1, triggers); } else { Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t2Id] << " into " << d_nodes[current.t1Id] << std::endl; d_assertedEqualities.push_back(Equality(t1classId, t2classId)); - merge(node1, node2, triggers); + done = !merge(node1, node2, triggers); } // Notify the triggers - if (d_performNotify) { - for (size_t trigger = 0, trigger_end = triggers.size(); trigger < trigger_end && !done; ++ trigger) { + if (d_performNotify && !done) { + for (size_t trigger_i = 0, trigger_end = triggers.size(); trigger_i < trigger_end && !done; ++ trigger_i) { + const TriggerInfo& triggerInfo = d_equalityTriggersOriginal[triggers[trigger_i]]; // Notify the trigger and exit if it fails - done = !d_notify.notify(d_equalityTriggersOriginal[triggers[trigger]]); + if (triggerInfo.trigger.getKind() == kind::EQUAL) { + done = !d_notify.eqNotifyTriggerEquality(triggerInfo.trigger, triggerInfo.polarity); + } else { + done = !d_notify.eqNotifyTriggerPredicate(triggerInfo.trigger, triggerInfo.polarity); + } } } } } -template -void EqualityEngine::debugPrintGraph() const { +void EqualityEngine::debugPrintGraph() const { for (EqualityNodeId nodeId = 0; nodeId < d_nodes.size(); ++ nodeId) { Debug("equality::graph") << d_nodes[nodeId] << " " << nodeId << "(" << getEqualityNode(nodeId).getFind() << "):"; @@ -817,11 +871,10 @@ void EqualityEngine::debugPrintGraph() const { } } -template -bool EqualityEngine::areEqual(TNode t1, TNode t2) +bool EqualityEngine::areEqual(TNode t1, TNode t2) { // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms addTerm(t1); @@ -832,17 +885,18 @@ bool EqualityEngine::areEqual(TNode t1, TNode t2) return equal; } -template -bool EqualityEngine::areDisequal(TNode t1, TNode t2) +bool EqualityEngine::areDisequal(TNode t1, TNode t2) { // Don't notify during this check - ScopedBool turnOfNotify(d_performNotify, false); + ScopedBool turnOffNotify(d_performNotify, false); // Add the terms addTerm(t1); addTerm(t2); // Check (t1 = t2) = false + // No need to check the symmetric version: we can only deduce a disequality from an existing + // diseqality, and each of those is asserted in the symmetric version also Node equality = t1.eqNode(t2); addTerm(equality); if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) { @@ -853,16 +907,14 @@ bool EqualityEngine::areDisequal(TNode t1, TNode t2) return false; } -template -size_t EqualityEngine::getSize(TNode t) +size_t EqualityEngine::getSize(TNode t) { // Add the term addTerm(t); return getEqualityNode(getEqualityNode(t).getFind()).getSize(); } -template -void EqualityEngine::addTriggerTerm(TNode t) +void EqualityEngine::addTriggerTerm(TNode t) { Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ")" << std::endl; @@ -877,7 +929,7 @@ void EqualityEngine::addTriggerTerm(TNode t) if (d_nodeIndividualTrigger[classId] != +null_id) { // No need to keep it, just propagate the existing individual triggers if (d_performNotify) { - d_notify.notify(t, d_nodes[d_nodeIndividualTrigger[classId]]); + d_notify.eqNotifyTriggerTermEquality(t, d_nodes[d_nodeIndividualTrigger[classId]], true); } } else { // Add it to the list for backtracking @@ -888,23 +940,20 @@ void EqualityEngine::addTriggerTerm(TNode t) } } -template -bool EqualityEngine::isTriggerTerm(TNode t) const { +bool EqualityEngine::isTriggerTerm(TNode t) const { if (!hasTerm(t)) return false; EqualityNodeId classId = getEqualityNode(t).getFind(); return d_nodeIndividualTrigger[classId] != +null_id; } -template -TNode EqualityEngine::getTriggerTermRepresentative(TNode t) const { +TNode EqualityEngine::getTriggerTermRepresentative(TNode t) const { Assert(isTriggerTerm(t)); EqualityNodeId classId = getEqualityNode(t).getFind(); return d_nodes[d_nodeIndividualTrigger[classId]]; } -template -void EqualityEngine::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) { +void EqualityEngine::storeApplicationLookup(FunctionApplication& funNormalized, EqualityNodeId funId) { Assert(d_applicationLookup.find(funNormalized) == d_applicationLookup.end()); d_applicationLookup[funNormalized] = funId; d_applicationLookups.push_back(funNormalized); @@ -914,8 +963,7 @@ void EqualityEngine::storeApplicationLookup(FunctionApplication& fu Assert(d_applicationLookupsCount == d_applicationLookups.size()); } -template -void EqualityEngine::getUseListTerms(TNode t, std::set& output) { +void EqualityEngine::getUseListTerms(TNode t, std::set& output) { if (hasTerm(t)) { // Get the equivalence class EqualityNodeId classId = getEqualityNode(t).getFind(); diff --git a/src/theory/uf/equality_engine.h b/src/theory/uf/equality_engine.h index dccd5ba56..f9c10d1b6 100644 --- a/src/theory/uf/equality_engine.h +++ b/src/theory/uf/equality_engine.h @@ -35,7 +35,7 @@ namespace CVC4 { namespace theory { -namespace uf { +namespace eq { /** Id of the node */ typedef size_t EqualityNodeId; @@ -213,9 +213,74 @@ public: } }; -template +/** + * Interface for equality engine notifications. All the notifications + * are safe as TNodes, but not necessarily for negations. + */ +class EqualityEngineNotify { + + friend class EqualityEngine; + +public: + + virtual ~EqualityEngineNotify() {}; + + /** + * Notifies about a trigger equality that became true or false. + * + * @param eq the equality that became true or false + * @param value the value of the equality + */ + virtual bool eqNotifyTriggerEquality(TNode equality, bool value) = 0; + + /** + * Notifies about a trigger predicate that became true or false. + * + * @param predicate the trigger predicate that bacame true or false + * @param value the value of the predicate + */ + virtual bool eqNotifyTriggerPredicate(TNode predicate, bool value) = 0; + + /** + * Notifies about the merge of two trigger terms. + * + * @param t1 a term marked as trigger + * @param t2 a term marked as trigger + * @param value true if equal, false if dis-equal + */ + virtual bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) = 0; + + /** + * Notifies about the merge of two constant terms. + * + * @param t1 a constant term + * @param t2 a constnat term + */ + virtual bool eqNotifyConstantTermMerge(TNode t1, TNode t2) = 0; +}; + +/** + * Implementation of the notification interface that ignores all the + * notifications. + */ +class EqualityEngineNotifyNone : public EqualityEngineNotify { +public: + bool eqNotifyTriggerEquality(TNode equality, bool value) { return true; } + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { return true; } + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { return true; } + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { return true; } +}; + + +/** + * Class for keeping an incremental congurence closure over a set of terms. It provides + * notifications via an EqualityEngineNotify object. + */ class EqualityEngine : public context::ContextNotifyObj { + /** Default implementation of the notification object */ + static EqualityEngineNotifyNone s_notifyNone; + public: /** Statistics about the equality engine instance */ @@ -226,21 +291,26 @@ public: IntStat termsCount; /** Number of function terms managed by the system */ IntStat functionTermsCount; + /** Number of constant terms managed by the system */ + IntStat constantTermsCount; Statistics(std::string name) : mergesCount(name + "::mergesCount", 0), termsCount(name + "::termsCount", 0), - functionTermsCount(name + "::functionTermsCount", 0) + functionTermsCount(name + "::functionTermsCount", 0), + constantTermsCount(name + "::constantTermsCount", 0) { StatisticsRegistry::registerStat(&mergesCount); StatisticsRegistry::registerStat(&termsCount); StatisticsRegistry::registerStat(&functionTermsCount); + StatisticsRegistry::registerStat(&constantTermsCount); } ~Statistics() { StatisticsRegistry::unregisterStat(&mergesCount); StatisticsRegistry::unregisterStat(&termsCount); StatisticsRegistry::unregisterStat(&functionTermsCount); + StatisticsRegistry::unregisterStat(&constantTermsCount); } }; @@ -282,7 +352,7 @@ private: bool d_performNotify; /** The class to notify when a representative changes for a term */ - NotifyClass d_notify; + EqualityEngineNotify& d_notify; /** The map of kinds to be treated as function applications */ KindMap d_congruenceKinds; @@ -428,8 +498,11 @@ private: /** Returns the id of the node */ EqualityNodeId getNodeId(TNode node) const; - /** Merge the class2 into class1 */ - void merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers); + /** + * Merge the class2 into class1 + * @return true if ok, false if to break out + */ + bool merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers); /** Undo the mereg of class2 into class1 */ void undoMerge(EqualityNode& class1, EqualityNode& class2, EqualityNodeId class2Id); @@ -437,29 +510,13 @@ private: /** Backtrack the information if necessary */ void backtrack(); - /** - * Data used in the BFS search through the equality graph. - */ - struct BfsData { - // The current node - EqualityNodeId nodeId; - // The index of the edge we traversed - EqualityEdgeId edgeId; - // Index in the queue of the previous node. Shouldn't be too much of them, at most the size - // of the biggest equivalence class - size_t previousIndex; - - BfsData(EqualityNodeId nodeId = null_id, EqualityEdgeId edgeId = null_edge, size_t prev = 0) - : nodeId(nodeId), edgeId(edgeId), previousIndex(prev) {} - }; - /** * Trigger that will be updated */ struct Trigger { /** The current class id of the LHS of the trigger */ EqualityNodeId classId; - /** Next trigger for class 1 */ + /** Next trigger for class */ TriggerId nextTrigger; Trigger(EqualityNodeId classId = null_id, TriggerId nextTrigger = null_trigger) @@ -473,10 +530,20 @@ private: */ std::vector d_equalityTriggers; + struct TriggerInfo { + /** The trigger itself */ + Node trigger; + /** Polarity of the trigger */ + bool polarity; + TriggerInfo() {} + TriggerInfo(Node trigger, bool polarity) + : trigger(trigger), polarity(polarity) {} + }; + /** * Vector of original equalities of the triggers. */ - std::vector d_equalityTriggersOriginal; + std::vector d_equalityTriggersOriginal; /** * Context dependent count of triggers @@ -504,6 +571,19 @@ private: */ std::vector d_nodeIndividualTrigger; + /** + * Map from ids to the id of the constant that is the representative. + */ + std::vector d_constantRepresentative; + + /** + * Size of the constant representatives list. + */ + context::CDO d_constantRepresentativesSize; + + /** The list of representatives that became constant. */ + std::vector d_constantRepresentatives; + /** * Adds the trigger with triggerId to the beginning of the trigger list of the node with id nodeId. */ @@ -516,7 +596,7 @@ private: EqualityNodeId newApplicationNode(TNode original, EqualityNodeId t1, EqualityNodeId t2); /** Add a new node to the database */ - EqualityNodeId newNode(TNode t, bool isApplication); + EqualityNodeId newNode(TNode t); struct MergeCandidate { EqualityNodeId t1Id, t2Id; @@ -561,44 +641,41 @@ private: /** * Adds an equality of terms t1 and t2 to the database. */ - void addEqualityInternal(TNode t1, TNode t2, TNode reason); + void assertEqualityInternal(TNode t1, TNode t2, TNode reason); -public: + /** + * Adds a trigger equality to the database with the trigger node and polarity for notification. + */ + void addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigger, bool polarity); /** - * Initialize the equality engine, given the owning class. This will initialize the notifier with - * the owner information. - */ - EqualityEngine(NotifyClass& notify, context::Context* context, std::string name) - : ContextNotifyObj(context), - d_context(context), - d_performNotify(true), - d_notify(notify), - d_applicationLookupsCount(context, 0), - d_nodesCount(context, 0), - d_assertedEqualitiesCount(context, 0), - d_equalityTriggersCount(context, 0), - d_individualTriggersSize(context, 0), - d_stats(name) - { - Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl; - Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl; - Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl; - d_true = NodeManager::currentNM()->mkConst(true); - d_false = NodeManager::currentNM()->mkConst(false); + * This method gets called on backtracks from the context manager. + */ + void contextNotifyPop() { + backtrack(); } /** - * Just a destructor. + * Constructor initialization stuff. */ - virtual ~EqualityEngine() throw(AssertionException) {} + void init(); + +public: /** - * This method gets called on backtracks from the context manager. + * Initialize the equality engine, given the notification class. */ - void notify() { - backtrack(); - } + EqualityEngine(EqualityEngineNotify& notify, context::Context* context, std::string name); + + /** + * Initialize the equality engine with no notification class. + */ + EqualityEngine(context::Context* context, std::string name); + + /** + * Just a destructor. + */ + virtual ~EqualityEngine() throw(AssertionException) {} /** * Adds a term to the term database. @@ -629,77 +706,91 @@ public: bool hasTerm(TNode t) const; /** - * Adds aa predicate t with given polarity + * Adds a predicate p with given polarity. The predicate asserted + * should be in the coungruence closure kinds (otherwise it's + * useless. + * + * @param p the (non-negated) predicate + * @param polarity true if asserting the predicate, false if + * asserting the negated predicate + * @param the reason to keep for building explanations */ - void addPredicate(TNode t, bool polarity, TNode reason); + void assertPredicate(TNode p, bool polarity, TNode reason); /** - * Adds an equality t1 = t2 to the database. + * Adds an equality eq with the given polarity to the database. + * + * @param eq the (non-negated) equality + * @param polarity true if asserting the equality, false if + * asserting the negated equality + * @param the reason to keep for building explanations */ - void addEquality(TNode t1, TNode t2, TNode reason); + void assertEquality(TNode eq, bool polarity, TNode reason); /** - * Adds an dis-equality t1 != t2 to the database. - */ - void addDisequality(TNode t1, TNode t2, TNode reason); - - /** - * Returns the representative of the term t. + * Returns the current representative of the term t. */ TNode getRepresentative(TNode t) const; /** - * Add all the terms where the given term appears in (directly or implicitly). + * Add all the terms where the given term appears as a first child + * (directly or implicitly). */ void getUseListTerms(TNode t, std::set& output); /** - * Returns true if the two nodes are in the same class. + * Returns true if the two nodes are in the same equivalence class. */ bool areEqual(TNode t1, TNode t2) const; /** - * Get an explanation of the equality t1 = t2. Returns the asserted equalities that - * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere - * else. + * Get an explanation of the equality t1 = t2 begin true of false. + * Returns the reasons (added when asserting) that imply it + * in the assertions vector. */ - void explainEquality(TNode t1, TNode t2, std::vector& equalities); + void explainEquality(TNode t1, TNode t2, bool polarity, std::vector& assertions); /** - * Get an explanation of the equality t1 = t2. Returns the asserted equalities that - * imply t1 = t2. Returns TNodes as the assertion equalities should be hashed somewhere - * else. + * Get an explanation of the predicate being true or false. + * Returns the reasons (added when asserting) that imply imply it + * in the assertions vector. */ - void explainDisequality(TNode t1, TNode t2, std::vector& equalities); + void explainPredicate(TNode p, bool polarity, std::vector& assertions); /** - * Add term to the trigger terms. The notify class will get notified when two - * trigger terms become equal. Thihs will only happen on trigger term - * representatives. + * Add term to the trigger terms. The notify class will get notified + * when two trigger terms become equal or dis-equal. The notification + * will not happen on all the terms, but only on the ones that are + * represent the class. */ void addTriggerTerm(TNode t); /** - * Returns true if t is a trigger term or equal to some other trigger term. + * Returns true if t is a trigger term or in the same equivalence + * class as some other trigger term. */ bool isTriggerTerm(TNode t) const; /** - * Returns the representative trigger term (isTriggerTerm(t)) should be true. + * Returns the representative trigger term of the given term. + * + * @param t the term to check where isTriggerTerm(t) should be true */ TNode getTriggerTermRepresentative(TNode t) const; /** - * Adds a notify trigger for equality t1 = t2, i.e. when t1 = t2 the notify will be called with - * trigger. + * Adds a notify trigger for equality. When equality becomes true eqNotifyTriggerEquality + * will be called with value = true, and when equality becomes false eqNotifyTriggerEquality + * will be called with value = false. */ - void addTriggerEquality(TNode t1, TNode t2, TNode trigger); + void addTriggerEquality(TNode equality); /** - * Adds a notify trigger for dis-equality t1 != t2, i.e. when t1 != t2 the notify will be called with - * trigger. + * Adds a notify trigger for the predicate p. When the predicate becomes true + * eqNotifyTriggerPredicate will be called with value = true, and when equality becomes false + * eqNotifyTriggerPredicate will be called with value = false. */ - void addTriggerDisequality(TNode t1, TNode t2, TNode trigger); + void addTriggerPredicate(TNode predicate); /** * Check whether the two terms are equal. @@ -712,7 +803,7 @@ public: bool areDisequal(TNode t1, TNode t2); /** - * Return the number of nodes in the equivalence class contianing t + * Return the number of nodes in the equivalence class containing t * Adds t if not already there. */ size_t getSize(TNode t); diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index ec28dad75..cddd01b07 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -18,13 +18,11 @@ **/ #include "theory/uf/theory_uf.h" -#include "theory/uf/equality_engine_impl.h" using namespace std; - -namespace CVC4 { -namespace theory { -namespace uf { +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::uf; /** Constructs a new instance of TheoryUF w.r.t. the provided context.*/ TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo) : @@ -40,12 +38,6 @@ TheoryUF::TheoryUF(context::Context* c, context::UserContext* u, OutputChannel& d_equalityEngine.addFunctionKind(kind::APPLY_UF); d_equalityEngine.addFunctionKind(kind::EQUAL); - // The boolean constants - d_true = NodeManager::currentNM()->mkConst(true); - d_false = NodeManager::currentNM()->mkConst(false); - d_equalityEngine.addTerm(d_true); - d_equalityEngine.addTerm(d_false); - d_equalityEngine.addTriggerEquality(d_true, d_false, d_false); }/* TheoryUF::TheoryUF() */ static Node mkAnd(const std::vector& conjunctions) { @@ -91,23 +83,12 @@ void TheoryUF::check(Effort level) { } // Do the work - switch (fact.getKind()) { - case kind::EQUAL: - d_equalityEngine.addEquality(fact[0], fact[1], fact); - break; - case kind::APPLY_UF: - d_equalityEngine.addPredicate(fact, true, fact); - break; - case kind::NOT: - if (fact[0].getKind() == kind::APPLY_UF) { - d_equalityEngine.addPredicate(fact[0], false, fact); - } else { - // Assert the dis-equality - d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact); - } - break; - default: - Unreachable(); + bool polarity = fact.getKind() != kind::NOT; + TNode atom = polarity ? fact : fact[0]; + if (atom.getKind() == kind::EQUAL) { + d_equalityEngine.assertEquality(atom, polarity, fact); + } else { + d_equalityEngine.assertPredicate(atom, polarity, fact); } } @@ -139,10 +120,8 @@ void TheoryUF::propagate(Effort level) { Debug("uf") << "TheoryUF::propagate(): in conflict, normalized = " << normalized << std::endl; Node negatedLiteral; std::vector assumptions; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; @@ -157,21 +136,17 @@ void TheoryUF::preRegisterTerm(TNode node) { switch (node.getKind()) { case kind::EQUAL: - // Add the terms - d_equalityEngine.addTerm(node[0]); - d_equalityEngine.addTerm(node[1]); // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node[0], node[1], node); - d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); + d_equalityEngine.addTriggerEquality(node); break; case kind::APPLY_UF: - // Function applications/predicates - d_equalityEngine.addTerm(node); // Maybe it's a predicate if (node.getType().isBoolean()) { // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerEquality(node, d_true, node); - d_equalityEngine.addTriggerEquality(node, d_false, node.notNode()); + d_equalityEngine.addTriggerPredicate(node); + } else { + // Function applications/predicates + d_equalityEngine.addTerm(node); } // Remember the function and predicate terms d_functionsTerms.push_back(node); @@ -194,26 +169,20 @@ bool TheoryUF::propagate(TNode literal) { // See if the literal has been asserted already Node normalized = Rewriter::rewrite(literal); - bool satValue = false; - bool isAsserted = normalized == d_false || d_valuation.hasSatValue(normalized, satValue); - // If asserted, we're done or in conflict - if (isAsserted) { - if (!satValue) { + // If asserted and is false, we're done or in conflict + // Note that even trivial equalities have a SAT value (i.e. 1 = 2 -> false) + bool satValue = false; + if (d_valuation.hasSatValue(normalized, satValue) && !satValue) { Debug("uf") << "TheoryUF::propagate(" << literal << ", normalized = " << normalized << ") => conflict" << std::endl; std::vector assumptions; Node negatedLiteral; - if (normalized != d_false) { - negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); - assumptions.push_back(negatedLiteral); - } + negatedLiteral = normalized.getKind() == kind::NOT ? (Node) normalized[0] : normalized.notNode(); + assumptions.push_back(negatedLiteral); explain(literal, assumptions); d_conflictNode = mkAnd(assumptions); d_conflict = true; return false; - } - // Propagate even if already known in SAT - could be a new equation between shared terms - // (terms that weren't shared when the literal was asserted!) } // Nothing, just enqueue it for propagation and mark it as asserted already @@ -224,36 +193,14 @@ bool TheoryUF::propagate(TNode literal) { }/* TheoryUF::propagate(TNode) */ void TheoryUF::explain(TNode literal, std::vector& assumptions) { - TNode lhs, rhs; - switch (literal.getKind()) { - case kind::EQUAL: - lhs = literal[0]; - rhs = literal[1]; - break; - case kind::APPLY_UF: - lhs = literal; - rhs = d_true; - break; - case kind::NOT: - if (literal[0].getKind() == kind::EQUAL) { - // Disequalities - d_equalityEngine.explainDisequality(literal[0][0], literal[0][1], assumptions); - return; - } else { - // Predicates - lhs = literal[0]; - rhs = d_false; - break; - } - case kind::CONST_BOOLEAN: - // we get to explain true = false, since we set false to be the trigger of this - lhs = d_true; - rhs = d_false; - break; - default: - Unreachable(); + // Do the work + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); } - d_equalityEngine.explainEquality(lhs, rhs, assumptions); } Node TheoryUF::explain(TNode literal) { @@ -508,7 +455,3 @@ void TheoryUF::computeCareGraph() { } } }/* TheoryUF::computeCareGraph() */ - -}/* CVC4::theory::uf namespace */ -}/* CVC4::theory namespace */ -}/* CVC4 namespace */ diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index 6956390f5..9017928b7 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -39,21 +39,46 @@ namespace uf { class TheoryUF : public Theory { public: - class NotifyClass { + class NotifyClass : public eq::EqualityEngineNotify { TheoryUF& d_uf; public: NotifyClass(TheoryUF& uf): d_uf(uf) {} - bool notify(TNode propagation) { - Debug("uf") << "NotifyClass::notify(" << propagation << ")" << std::endl; - // Just forward to uf - return d_uf.propagate(propagation); + bool eqNotifyTriggerEquality(TNode equality, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_uf.propagate(equality); + } else { + // We use only literal triggers so taking not is safe + return d_uf.propagate(equality.notNode()); + } } - - void notify(TNode t1, TNode t2) { - Debug("uf") << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; - Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); - d_uf.propagate(equality); + + bool eqNotifyTriggerPredicate(TNode predicate, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_uf.propagate(predicate); + } else { + return d_uf.propagate(predicate.notNode()); + } + } + + bool eqNotifyTriggerTermEquality(TNode t1, TNode t2, bool value) { + Debug("uf") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << std::endl; + if (value) { + return d_uf.propagate(t1.eqNode(t2)); + } else { + return d_uf.propagate(t1.eqNode(t2).notNode()); + } + } + + bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { + Debug("uf") << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl; + if (Theory::theoryOf(t1) == THEORY_BOOL) { + return d_uf.propagate(t1.iffNode(t2)); + } else { + return d_uf.propagate(t1.eqNode(t2)); + } } }; @@ -63,7 +88,7 @@ private: NotifyClass d_notify; /** Equaltity engine */ - EqualityEngine d_equalityEngine; + eq::EqualityEngine d_equalityEngine; /** Are we in conflict */ context::CDO d_conflict; @@ -72,7 +97,8 @@ private: Node d_conflictNode; /** - * Should be called to propagate the literal. + * Should be called to propagate the literal. We use a node here + * since some of the propagated literals are not kept anywhere. */ bool propagate(TNode literal); @@ -90,12 +116,6 @@ private: /** All the function terms that the theory has seen */ context::CDList d_functionsTerms; - /** True node for predicates = true */ - Node d_true; - - /** True node for predicates = false */ - Node d_false; - /** Symmetry analyzer */ SymmetryBreaker d_symb; diff --git a/src/util/configuration.cpp b/src/util/configuration.cpp index 66b0a2f90..6f01d6cf4 100644 --- a/src/util/configuration.cpp +++ b/src/util/configuration.cpp @@ -162,7 +162,7 @@ bool Configuration::isDebugTag(char const *tag){ return true; } } -#endif * CVC4_DEBUG */ +#endif /* CVC4_DEBUG */ return false; } diff --git a/test/unit/context/context_black.h b/test/unit/context/context_black.h index 33863e848..1a50d0637 100644 --- a/test/unit/context/context_black.h +++ b/test/unit/context/context_black.h @@ -37,7 +37,7 @@ struct MyContextNotifyObj : public ContextNotifyObj { nCalls(0) { } - void notify() { + void contextNotifyPop() { ++nCalls; } }; diff --git a/test/unit/prop/cnf_stream_black.h b/test/unit/prop/cnf_stream_black.h index 63ba95b57..c24104acc 100644 --- a/test/unit/prop/cnf_stream_black.h +++ b/test/unit/prop/cnf_stream_black.h @@ -59,6 +59,14 @@ public: return d_nextVar++; } + SatVariable trueVar() { + return d_nextVar++; + } + + SatVariable falseVar() { + return d_nextVar++; + } + void addClause(SatClause& c, bool lemma) { d_addClauseCalled = true; } -- 2.30.2