Dynamic allocation of equality engine for shared solver (#5152)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 30 Sep 2020 15:06:58 +0000 (10:06 -0500)
committerGitHub <noreply@github.com>
Wed, 30 Sep 2020 15:06:58 +0000 (10:06 -0500)
This updates shared solver to have dynamic allocation of equality engine, analogous to theory solvers.

src/theory/ee_manager_distributed.cpp
src/theory/ee_manager_distributed.h
src/theory/shared_solver_distributed.cpp
src/theory/shared_terms_database.cpp
src/theory/shared_terms_database.h

index 360b1257bda14c067371b2d2e466c48edf5ecf66..3fb5fc0cee53a4c39990bfdc8336d3652a6ede40 100644 (file)
@@ -34,6 +34,18 @@ EqEngineManagerDistributed::~EqEngineManagerDistributed()
 void EqEngineManagerDistributed::initializeTheories()
 {
   context::Context* c = d_te.getSatContext();
+  // initialize the shared solver
+  EeSetupInfo esis;
+  if (d_sharedSolver.needsEqualityEngine(esis))
+  {
+    // allocate an equality engine for the shared terms database
+    d_stbEqualityEngine.reset(allocateEqualityEngine(esis, c));
+    d_sharedSolver.setEqualityEngine(d_stbEqualityEngine.get());
+  }
+  else
+  {
+    Unhandled() << "Expected shared solver to use equality engine";
+  }
 
   // allocate equality engines per theory
   for (TheoryId theoryId = theory::THEORY_FIRST;
index 90beb0d3b67a6c799d2cc024694875f04fddbfae..c7c1e7f4c89d8770638c03ab2ba06ca4204e7691 100644 (file)
@@ -89,6 +89,8 @@ class EqEngineManagerDistributed : public EqEngineManager
   std::unique_ptr<MasterNotifyClass> d_masterEENotify;
   /** The master equality engine. */
   std::unique_ptr<eq::EqualityEngine> d_masterEqualityEngine;
+  /** The equality engine of the shared solver / shared terms database. */
+  std::unique_ptr<eq::EqualityEngine> d_stbEqualityEngine;
 };
 
 }  // namespace theory
index 873c81db134553ffb1f87954675c58ed1fd561c8..5975d3dd8c2c78404e2aa08f38599eb6f878db6c 100644 (file)
@@ -67,9 +67,8 @@ TrustNode SharedSolverDistributed::explain(TNode literal, TheoryId id)
   TrustNode texp;
   if (id == THEORY_BUILTIN)
   {
-    // explanation using the shared terms database
-    Node exp = d_sharedTerms.explain(literal);
-    texp = TrustNode::mkTrustPropExp(literal, exp, nullptr);
+    // explanation based on the specific solver
+    texp = d_sharedTerms.explain(literal);
     Trace("shared-solver")
         << "\tTerm was propagated by THEORY_BUILTIN. Explanation: "
         << texp.getNode() << std::endl;
@@ -77,6 +76,7 @@ TrustNode SharedSolverDistributed::explain(TNode literal, TheoryId id)
   else
   {
     // By default, we ask the individual theory for the explanation.
+    // It is possible that a centralized approach could preempt this.
     texp = d_te.theoryOf(id)->explain(literal);
     Trace("shared-solver") << "\tTerm was propagated by owner theory: " << id
                            << ". Explanation: " << texp.getNode() << std::endl;
index 2f9ad74e0e92ea1cb206ce800b57e53d71de1f46..92c66e83b6112a1e8db5148dc971df13959ef609 100644 (file)
@@ -32,10 +32,11 @@ SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
       d_alreadyNotifiedMap(context),
       d_registeredEqualities(context),
       d_EENotify(*this),
-      d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true),
       d_theoryEngine(theoryEngine),
       d_inConflict(context, false),
-      d_conflictPolarity() {
+      d_conflictPolarity(),
+      d_equalityEngine(nullptr)
+{
   smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
 }
 
@@ -46,7 +47,7 @@ SharedTermsDatabase::~SharedTermsDatabase()
 
 void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
 {
-  // TODO (project #39): dynamic allocation of equality engine here
+  d_equalityEngine = ee;
 }
 
 bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
@@ -57,8 +58,9 @@ bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
 }
 
 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
+  Assert(d_equalityEngine != nullptr);
   d_registeredEqualities.insert(equality);
-  d_equalityEngine.addTriggerPredicate(equality);
+  d_equalityEngine->addTriggerPredicate(equality);
   checkForConflict();
 }
 
@@ -183,12 +185,18 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
   d_alreadyNotifiedMap[term] =
       TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified);
 
+  if (d_equalityEngine == nullptr)
+  {
+    // if we are not assigned an equality engine, there is nothing to do
+    return;
+  }
+
   // Mark the shared terms in the equality engine
   theory::TheoryId currentTheory;
   while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified))
          != THEORY_LAST)
   {
-    d_equalityEngine.addTriggerTerm(term, currentTheory);
+    d_equalityEngine->addTriggerTerm(term, currentTheory);
   }
 
   // Check for any conflits
@@ -196,32 +204,42 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
 }
 
 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areEqual(a,b);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areEqual(a, b);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
     return false;
   }
 }
 
 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areDisequal(a,b,false);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areDisequal(a, b, false);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // one (or both) are in the equality engine
     return false;
   }
 }
 
+theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine()
+{
+  return d_equalityEngine;
+}
+
 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
 {
+  Assert(d_equalityEngine != nullptr);
   Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
   // Add it to the equality engine
-  d_equalityEngine.assertEquality(equality, polarity, reason);
+  d_equalityEngine->assertEquality(equality, polarity, reason);
   // Check for conflict
   checkForConflict();
 }
@@ -258,10 +276,12 @@ static Node mkAnd(const std::vector<TNode>& conjunctions) {
 }
 
 void SharedTermsDatabase::checkForConflict() {
+  Assert(d_equalityEngine != nullptr);
   if (d_inConflict) {
     d_inConflict = false;
     std::vector<TNode> assumptions;
-    d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
+    d_equalityEngine->explainEquality(
+        d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
     Node conflict = mkAnd(assumptions);
     TrustNode tconf = TrustNode::mkTrustConflict(conflict);
     d_theoryEngine->conflict(tconf, THEORY_BUILTIN);
@@ -270,22 +290,26 @@ void SharedTermsDatabase::checkForConflict() {
 }
 
 bool SharedTermsDatabase::isKnown(TNode literal) const {
+  Assert(d_equalityEngine != nullptr);
   bool polarity = literal.getKind() != kind::NOT;
   TNode equality = polarity ? literal : literal[0];
   if (polarity) {
-    return d_equalityEngine.areEqual(equality[0], equality[1]);
+    return d_equalityEngine->areEqual(equality[0], equality[1]);
   } else {
-    return d_equalityEngine.areDisequal(equality[0], equality[1], false);
+    return d_equalityEngine->areDisequal(equality[0], equality[1], false);
   }
 }
 
-Node SharedTermsDatabase::explain(TNode literal) const {
+TrustNode SharedTermsDatabase::explain(TNode literal) const
+{
+  Assert(d_equalityEngine != nullptr);
   bool polarity = literal.getKind() != kind::NOT;
   TNode atom = polarity ? literal : literal[0];
   Assert(atom.getKind() == kind::EQUAL);
   std::vector<TNode> assumptions;
-  d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
-  return mkAnd(assumptions);
+  d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions);
+  Node exp = mkAnd(assumptions);
+  return TrustNode::mkTrustPropExp(literal, exp, nullptr);
 }
 
 } /* namespace CVC4 */
index 369a35b34fd67749aad5bd7decd47c5e0945c709..558d6fc939ae5ef3bd3061e7e919d69b61ce6db4 100644 (file)
@@ -111,9 +111,6 @@ private:
   /** The notify class for d_equalityEngine */
   EENotifyClass d_EENotify;
 
-  /** Equality engine */
-  theory::eq::EqualityEngine d_equalityEngine;
-
   /**
    * Method called by equalityEngine when a becomes (dis-)equal to b and a and b are shared with
    * the theory. Returns false if there is a direct conflict (via rewrite for example).
@@ -182,7 +179,7 @@ public:
   /**
    * Returns an explanation of the propagation that came from the database.
    */
-  Node explain(TNode literal) const;
+  theory::TrustNode explain(TNode literal) const;
 
   /**
    * Add an equality to propagate.
@@ -254,14 +251,16 @@ public:
   /**
    * get equality engine
    */
-  theory::eq::EqualityEngine* getEqualityEngine() { return &d_equalityEngine; }
-
-protected:
+  theory::eq::EqualityEngine* getEqualityEngine();
 
+ protected:
   /**
    * This method gets called on backtracks from the context manager.
    */
- void contextNotifyPop() override { backtrack(); }
+  void contextNotifyPop() override { backtrack(); }
+
+  /** Equality engine */
+  theory::eq::EqualityEngine* d_equalityEngine;
 };
 
 }