From 7127be18692e2fd32bd2dfce53e50c105ed8a25d Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 30 Sep 2020 10:06:58 -0500 Subject: [PATCH] Dynamic allocation of equality engine for shared solver (#5152) This updates shared solver to have dynamic allocation of equality engine, analogous to theory solvers. --- src/theory/ee_manager_distributed.cpp | 12 +++++ src/theory/ee_manager_distributed.h | 2 + src/theory/shared_solver_distributed.cpp | 6 +-- src/theory/shared_terms_database.cpp | 64 ++++++++++++++++-------- src/theory/shared_terms_database.h | 15 +++--- 5 files changed, 68 insertions(+), 31 deletions(-) diff --git a/src/theory/ee_manager_distributed.cpp b/src/theory/ee_manager_distributed.cpp index 360b1257b..3fb5fc0ce 100644 --- a/src/theory/ee_manager_distributed.cpp +++ b/src/theory/ee_manager_distributed.cpp @@ -34,6 +34,18 @@ EqEngineManagerDistributed::~EqEngineManagerDistributed() void EqEngineManagerDistributed::initializeTheories() { context::Context* c = d_te.getSatContext(); + // initialize the shared solver + EeSetupInfo esis; + if (d_sharedSolver.needsEqualityEngine(esis)) + { + // allocate an equality engine for the shared terms database + d_stbEqualityEngine.reset(allocateEqualityEngine(esis, c)); + d_sharedSolver.setEqualityEngine(d_stbEqualityEngine.get()); + } + else + { + Unhandled() << "Expected shared solver to use equality engine"; + } // allocate equality engines per theory for (TheoryId theoryId = theory::THEORY_FIRST; diff --git a/src/theory/ee_manager_distributed.h b/src/theory/ee_manager_distributed.h index 90beb0d3b..c7c1e7f4c 100644 --- a/src/theory/ee_manager_distributed.h +++ b/src/theory/ee_manager_distributed.h @@ -89,6 +89,8 @@ class EqEngineManagerDistributed : public EqEngineManager std::unique_ptr d_masterEENotify; /** The master equality engine. */ std::unique_ptr d_masterEqualityEngine; + /** The equality engine of the shared solver / shared terms database. */ + std::unique_ptr d_stbEqualityEngine; }; } // namespace theory diff --git a/src/theory/shared_solver_distributed.cpp b/src/theory/shared_solver_distributed.cpp index 873c81db1..5975d3dd8 100644 --- a/src/theory/shared_solver_distributed.cpp +++ b/src/theory/shared_solver_distributed.cpp @@ -67,9 +67,8 @@ TrustNode SharedSolverDistributed::explain(TNode literal, TheoryId id) TrustNode texp; if (id == THEORY_BUILTIN) { - // explanation using the shared terms database - Node exp = d_sharedTerms.explain(literal); - texp = TrustNode::mkTrustPropExp(literal, exp, nullptr); + // explanation based on the specific solver + texp = d_sharedTerms.explain(literal); Trace("shared-solver") << "\tTerm was propagated by THEORY_BUILTIN. Explanation: " << texp.getNode() << std::endl; @@ -77,6 +76,7 @@ TrustNode SharedSolverDistributed::explain(TNode literal, TheoryId id) else { // By default, we ask the individual theory for the explanation. + // It is possible that a centralized approach could preempt this. texp = d_te.theoryOf(id)->explain(literal); Trace("shared-solver") << "\tTerm was propagated by owner theory: " << id << ". Explanation: " << texp.getNode() << std::endl; diff --git a/src/theory/shared_terms_database.cpp b/src/theory/shared_terms_database.cpp index 2f9ad74e0..92c66e83b 100644 --- a/src/theory/shared_terms_database.cpp +++ b/src/theory/shared_terms_database.cpp @@ -32,10 +32,11 @@ SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine, d_alreadyNotifiedMap(context), d_registeredEqualities(context), d_EENotify(*this), - d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true), d_theoryEngine(theoryEngine), d_inConflict(context, false), - d_conflictPolarity() { + d_conflictPolarity(), + d_equalityEngine(nullptr) +{ smtStatisticsRegistry()->registerStat(&d_statSharedTerms); } @@ -46,7 +47,7 @@ SharedTermsDatabase::~SharedTermsDatabase() void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee) { - // TODO (project #39): dynamic allocation of equality engine here + d_equalityEngine = ee; } bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi) @@ -57,8 +58,9 @@ bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi) } void SharedTermsDatabase::addEqualityToPropagate(TNode equality) { + Assert(d_equalityEngine != nullptr); d_registeredEqualities.insert(equality); - d_equalityEngine.addTriggerPredicate(equality); + d_equalityEngine->addTriggerPredicate(equality); checkForConflict(); } @@ -183,12 +185,18 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories) d_alreadyNotifiedMap[term] = TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified); + if (d_equalityEngine == nullptr) + { + // if we are not assigned an equality engine, there is nothing to do + return; + } + // Mark the shared terms in the equality engine theory::TheoryId currentTheory; while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified)) != THEORY_LAST) { - d_equalityEngine.addTriggerTerm(term, currentTheory); + d_equalityEngine->addTriggerTerm(term, currentTheory); } // Check for any conflits @@ -196,32 +204,42 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories) } bool SharedTermsDatabase::areEqual(TNode a, TNode b) const { - if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) { - return d_equalityEngine.areEqual(a,b); + Assert(d_equalityEngine != nullptr); + if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)) + { + return d_equalityEngine->areEqual(a, b); } else { - Assert(d_equalityEngine.hasTerm(a) || a.isConst()); - Assert(d_equalityEngine.hasTerm(b) || b.isConst()); + Assert(d_equalityEngine->hasTerm(a) || a.isConst()); + Assert(d_equalityEngine->hasTerm(b) || b.isConst()); // since one (or both) of them is a constant, and the other is in the equality engine, they are not same return false; } } bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const { - if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) { - return d_equalityEngine.areDisequal(a,b,false); + Assert(d_equalityEngine != nullptr); + if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)) + { + return d_equalityEngine->areDisequal(a, b, false); } else { - Assert(d_equalityEngine.hasTerm(a) || a.isConst()); - Assert(d_equalityEngine.hasTerm(b) || b.isConst()); + Assert(d_equalityEngine->hasTerm(a) || a.isConst()); + Assert(d_equalityEngine->hasTerm(b) || b.isConst()); // one (or both) are in the equality engine return false; } } +theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine() +{ + return d_equalityEngine; +} + void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason) { + Assert(d_equalityEngine != nullptr); Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl; // Add it to the equality engine - d_equalityEngine.assertEquality(equality, polarity, reason); + d_equalityEngine->assertEquality(equality, polarity, reason); // Check for conflict checkForConflict(); } @@ -258,10 +276,12 @@ static Node mkAnd(const std::vector& conjunctions) { } void SharedTermsDatabase::checkForConflict() { + Assert(d_equalityEngine != nullptr); if (d_inConflict) { d_inConflict = false; std::vector assumptions; - d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions); + d_equalityEngine->explainEquality( + d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions); Node conflict = mkAnd(assumptions); TrustNode tconf = TrustNode::mkTrustConflict(conflict); d_theoryEngine->conflict(tconf, THEORY_BUILTIN); @@ -270,22 +290,26 @@ void SharedTermsDatabase::checkForConflict() { } bool SharedTermsDatabase::isKnown(TNode literal) const { + Assert(d_equalityEngine != nullptr); bool polarity = literal.getKind() != kind::NOT; TNode equality = polarity ? literal : literal[0]; if (polarity) { - return d_equalityEngine.areEqual(equality[0], equality[1]); + return d_equalityEngine->areEqual(equality[0], equality[1]); } else { - return d_equalityEngine.areDisequal(equality[0], equality[1], false); + return d_equalityEngine->areDisequal(equality[0], equality[1], false); } } -Node SharedTermsDatabase::explain(TNode literal) const { +TrustNode SharedTermsDatabase::explain(TNode literal) const +{ + Assert(d_equalityEngine != nullptr); bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; Assert(atom.getKind() == kind::EQUAL); std::vector assumptions; - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); - return mkAnd(assumptions); + d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); + Node exp = mkAnd(assumptions); + return TrustNode::mkTrustPropExp(literal, exp, nullptr); } } /* namespace CVC4 */ diff --git a/src/theory/shared_terms_database.h b/src/theory/shared_terms_database.h index 369a35b34..558d6fc93 100644 --- a/src/theory/shared_terms_database.h +++ b/src/theory/shared_terms_database.h @@ -111,9 +111,6 @@ private: /** The notify class for d_equalityEngine */ EENotifyClass d_EENotify; - /** Equality engine */ - theory::eq::EqualityEngine d_equalityEngine; - /** * Method called by equalityEngine when a becomes (dis-)equal to b and a and b are shared with * the theory. Returns false if there is a direct conflict (via rewrite for example). @@ -182,7 +179,7 @@ public: /** * Returns an explanation of the propagation that came from the database. */ - Node explain(TNode literal) const; + theory::TrustNode explain(TNode literal) const; /** * Add an equality to propagate. @@ -254,14 +251,16 @@ public: /** * get equality engine */ - theory::eq::EqualityEngine* getEqualityEngine() { return &d_equalityEngine; } - -protected: + theory::eq::EqualityEngine* getEqualityEngine(); + protected: /** * This method gets called on backtracks from the context manager. */ - void contextNotifyPop() override { backtrack(); } + void contextNotifyPop() override { backtrack(); } + + /** Equality engine */ + theory::eq::EqualityEngine* d_equalityEngine; }; } -- 2.30.2