Simplify interface to instantiate (#5926)
[cvc5.git] / src / theory / shared_terms_database.cpp
index 99584b167e86392b59db6e58459c0d9e8328aa97..29bdc03f84d865cc3769855bcbec7dcc27169fb7 100644 (file)
@@ -2,10 +2,10 @@
 /*! \file shared_terms_database.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Dejan Jovanovic, Morgan Deters, Tim King
+ **   Andrew Reynolds, Dejan Jovanovic, Morgan Deters
  ** This file is part of the CVC4 project.
  ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
+ ** in the top-level source directory and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
  **
@@ -24,7 +24,9 @@ using namespace CVC4::theory;
 namespace CVC4 {
 
 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
-                                         context::Context* context)
+                                         context::Context* context,
+                                         context::UserContext* userContext,
+                                         ProofNodeManager* pnm)
     : ContextNotifyObj(context),
       d_statSharedTerms("theory::shared_terms", 0),
       d_addedSharedTermsSize(context, 0),
@@ -32,10 +34,15 @@ SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
       d_alreadyNotifiedMap(context),
       d_registeredEqualities(context),
       d_EENotify(*this),
-      d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true),
       d_theoryEngine(theoryEngine),
       d_inConflict(context, false),
-      d_conflictPolarity() {
+      d_conflictPolarity(),
+      d_satContext(context),
+      d_userContext(userContext),
+      d_equalityEngine(nullptr),
+      d_pfee(nullptr),
+      d_pnm(pnm)
+{
   smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
 }
 
@@ -44,15 +51,39 @@ SharedTermsDatabase::~SharedTermsDatabase()
   smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
 }
 
+void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
+{
+  Assert(ee != nullptr);
+  d_equalityEngine = ee;
+  // if proofs are enabled, make the proof equality engine
+  if (d_pnm != nullptr)
+  {
+    d_pfee.reset(
+        new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm));
+  }
+}
+
+bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
+{
+  esi.d_notify = &d_EENotify;
+  esi.d_name = "SharedTermsDatabase";
+  return true;
+}
+
 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
+  Assert(d_equalityEngine != nullptr);
   d_registeredEqualities.insert(equality);
-  d_equalityEngine.addTriggerPredicate(equality);
+  d_equalityEngine->addTriggerPredicate(equality);
   checkForConflict();
 }
 
-
-void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
-  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
+void SharedTermsDatabase::addSharedTerm(TNode atom,
+                                        TNode term,
+                                        TheoryIdSet theories)
+{
+  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", "
+                    << term << ", " << TheoryIdSetUtil::setToString(theories)
+                    << ")" << std::endl;
 
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
@@ -64,7 +95,8 @@ void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theo
     d_termsToTheories[search_pair] = theories;
   } else {
     Assert(theories != (*find).second);
-    d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
+    d_termsToTheories[search_pair] =
+        TheoryIdSetUtil::setUnion(theories, (*find).second);
   }
 }
 
@@ -94,25 +126,27 @@ void SharedTermsDatabase::backtrack() {
   d_addedSharedTerms.resize(d_addedSharedTermsSize);
 }
 
-Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
+TheoryIdSet SharedTermsDatabase::getTheoriesToNotify(TNode atom,
+                                                     TNode term) const
+{
   // Get the theories that share this term from this atom
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
   Assert(find != d_termsToTheories.end());
 
   // Get the theories that were already notified
-  Theory::Set alreadyNotified = 0;
+  TheoryIdSet alreadyNotified = 0;
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
     alreadyNotified = (*theoriesFind).second;
   }
 
   // Return the ones that haven't been notified yet
-  return Theory::setDifference((*find).second, alreadyNotified);
+  return TheoryIdSetUtil::setDifference((*find).second, alreadyNotified);
 }
 
-
-Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
+TheoryIdSet SharedTermsDatabase::getNotifiedTheories(TNode term) const
+{
   // Get the theories that were already notified
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
@@ -142,15 +176,16 @@ bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNod
   return true;
 }
 
-void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
-
+void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
+{
   // Find out if there are any new theories that were notified about this term
-  Theory::Set alreadyNotified = 0;
+  TheoryIdSet alreadyNotified = 0;
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
     alreadyNotified = (*theoriesFind).second;
   }
-  Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
+  TheoryIdSet newlyNotified =
+      TheoryIdSetUtil::setDifference(theories, alreadyNotified);
 
   // If no new theories were notified, we are done
   if (newlyNotified == 0) {
@@ -160,12 +195,21 @@ void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
   Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
 
   // First update the set of notified theories for this term
-  d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
+  d_alreadyNotifiedMap[term] =
+      TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified);
+
+  if (d_equalityEngine == nullptr)
+  {
+    // if we are not assigned an equality engine, there is nothing to do
+    return;
+  }
 
   // Mark the shared terms in the equality engine
   theory::TheoryId currentTheory;
-  while ((currentTheory = Theory::setPop(newlyNotified)) != THEORY_LAST) {
-    d_equalityEngine.addTriggerTerm(term, currentTheory);
+  while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified))
+         != THEORY_LAST)
+  {
+    d_equalityEngine->addTriggerTerm(term, currentTheory);
   }
 
   // Check for any conflits
@@ -173,32 +217,42 @@ void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
 }
 
 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areEqual(a,b);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areEqual(a, b);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
     return false;
   }
 }
 
 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areDisequal(a,b,false);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areDisequal(a, b, false);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // one (or both) are in the equality engine
     return false;
   }
 }
 
+theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine()
+{
+  return d_equalityEngine;
+}
+
 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
 {
+  Assert(d_equalityEngine != nullptr);
   Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
   // Add it to the equality engine
-  d_equalityEngine.assertEquality(equality, polarity, reason);
+  d_equalityEngine->assertEquality(equality, polarity, reason);
   // Check for conflict
   checkForConflict();
 }
@@ -212,56 +266,55 @@ bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
   return true;
 }
 
-static Node mkAnd(const std::vector<TNode>& conjunctions) {
-  Assert(conjunctions.size() > 0);
-
-  std::set<TNode> all;
-  all.insert(conjunctions.begin(), conjunctions.end());
-
-  if (all.size() == 1) {
-    // All the same, or just one
-    return conjunctions[0];
+void SharedTermsDatabase::checkForConflict()
+{
+  if (!d_inConflict)
+  {
+    return;
   }
-
-  NodeBuilder<> conjunction(kind::AND);
-  std::set<TNode>::const_iterator it = all.begin();
-  std::set<TNode>::const_iterator it_end = all.end();
-  while (it != it_end) {
-    conjunction << *it;
-    ++ it;
+  d_inConflict = false;
+  TrustNode trnc;
+  if (d_pfee != nullptr)
+  {
+    Node conflict = d_conflictLHS.eqNode(d_conflictRHS);
+    conflict = d_conflictPolarity ? conflict : conflict.notNode();
+    trnc = d_pfee->assertConflict(conflict);
   }
-
-  return conjunction;
-}
-
-void SharedTermsDatabase::checkForConflict() {
-  if (d_inConflict) {
-    d_inConflict = false;
+  else
+  {
+    // standard explain
     std::vector<TNode> assumptions;
-    d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
-    Node conflict = mkAnd(assumptions);
-    d_theoryEngine->conflict(conflict, THEORY_BUILTIN);
-    d_conflictLHS = d_conflictRHS = Node::null();
+    d_equalityEngine->explainEquality(
+        d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
+    Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions);
+    trnc = TrustNode::mkTrustConflict(conflictNode, nullptr);
   }
+  d_theoryEngine->conflict(trnc, THEORY_BUILTIN);
+  d_conflictLHS = d_conflictRHS = Node::null();
 }
 
 bool SharedTermsDatabase::isKnown(TNode literal) const {
+  Assert(d_equalityEngine != nullptr);
   bool polarity = literal.getKind() != kind::NOT;
   TNode equality = polarity ? literal : literal[0];
   if (polarity) {
-    return d_equalityEngine.areEqual(equality[0], equality[1]);
+    return d_equalityEngine->areEqual(equality[0], equality[1]);
   } else {
-    return d_equalityEngine.areDisequal(equality[0], equality[1], false);
+    return d_equalityEngine->areDisequal(equality[0], equality[1], false);
   }
 }
 
-Node SharedTermsDatabase::explain(TNode literal) const {
-  bool polarity = literal.getKind() != kind::NOT;
-  TNode atom = polarity ? literal : literal[0];
-  Assert(atom.getKind() == kind::EQUAL);
-  std::vector<TNode> assumptions;
-  d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
-  return mkAnd(assumptions);
+theory::TrustNode SharedTermsDatabase::explain(TNode literal) const
+{
+  if (d_pfee != nullptr)
+  {
+    // use the proof equality engine if it exists
+    return d_pfee->explain(literal);
+  }
+  // otherwise, explain without proofs
+  Node exp = d_equalityEngine->mkExplainLit(literal);
+  // no proof generator
+  return TrustNode::mkTrustPropExp(literal, exp, nullptr);
 }
 
 } /* namespace CVC4 */