Simplify interface to instantiate (#5926)
[cvc5.git] / src / theory / shared_terms_database.cpp
index ced845a2749fa9afd03518be1bc24a45bf0a0fe9..29bdc03f84d865cc3769855bcbec7dcc27169fb7 100644 (file)
@@ -1,56 +1,89 @@
 /*********************                                                        */
 /*! \file shared_terms_database.cpp
  ** \verbatim
- ** Original author: dejan
- ** Major contributors: mdeters
- ** Minor contributors (to current version): ajreynol, barrett
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009-2012  New York University and The University of Iowa
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds, Dejan Jovanovic, Morgan Deters
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
  **
  ** [[ Add lengthier description here ]]
  ** \todo document this file
  **/
 
-
 #include "theory/shared_terms_database.h"
+
+#include "smt/smt_statistics_registry.h"
 #include "theory/theory_engine.h"
 
 using namespace std;
-using namespace CVC4;
-using namespace theory;
-
-SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine, context::Context* context)
-: ContextNotifyObj(context)
-, d_context(context)
-, d_statSharedTerms("theory::shared_terms", 0)
-, d_addedSharedTermsSize(context, 0)
-, d_termsToTheories(context)
-, d_alreadyNotifiedMap(context)
-, d_registeredEqualities(context)
-, d_EENotify(*this)
-, d_equalityEngine(d_EENotify, context, "SharedTermsDatabase")
-, d_theoryEngine(theoryEngine)
-, d_inConflict(context, false)
+using namespace CVC4::theory;
+
+namespace CVC4 {
+
+SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
+                                         context::Context* context,
+                                         context::UserContext* userContext,
+                                         ProofNodeManager* pnm)
+    : ContextNotifyObj(context),
+      d_statSharedTerms("theory::shared_terms", 0),
+      d_addedSharedTermsSize(context, 0),
+      d_termsToTheories(context),
+      d_alreadyNotifiedMap(context),
+      d_registeredEqualities(context),
+      d_EENotify(*this),
+      d_theoryEngine(theoryEngine),
+      d_inConflict(context, false),
+      d_conflictPolarity(),
+      d_satContext(context),
+      d_userContext(userContext),
+      d_equalityEngine(nullptr),
+      d_pfee(nullptr),
+      d_pnm(pnm)
+{
+  smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
+}
+
+SharedTermsDatabase::~SharedTermsDatabase()
 {
-  StatisticsRegistry::registerStat(&d_statSharedTerms);
+  smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
+}
+
+void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
+{
+  Assert(ee != nullptr);
+  d_equalityEngine = ee;
+  // if proofs are enabled, make the proof equality engine
+  if (d_pnm != nullptr)
+  {
+    d_pfee.reset(
+        new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm));
+  }
 }
 
-SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
+bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
 {
-  StatisticsRegistry::unregisterStat(&d_statSharedTerms);
+  esi.d_notify = &d_EENotify;
+  esi.d_name = "SharedTermsDatabase";
+  return true;
 }
 
 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
+  Assert(d_equalityEngine != nullptr);
   d_registeredEqualities.insert(equality);
-  d_equalityEngine.addTriggerEquality(equality);
+  d_equalityEngine->addTriggerPredicate(equality);
   checkForConflict();
 }
 
-
-void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
-  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl;
+void SharedTermsDatabase::addSharedTerm(TNode atom,
+                                        TNode term,
+                                        TheoryIdSet theories)
+{
+  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", "
+                    << term << ", " << TheoryIdSetUtil::setToString(theories)
+                    << ")" << std::endl;
 
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
@@ -62,7 +95,8 @@ void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theo
     d_termsToTheories[search_pair] = theories;
   } else {
     Assert(theories != (*find).second);
-    d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second);
+    d_termsToTheories[search_pair] =
+        TheoryIdSetUtil::setUnion(theories, (*find).second);
   }
 }
 
@@ -92,25 +126,27 @@ void SharedTermsDatabase::backtrack() {
   d_addedSharedTerms.resize(d_addedSharedTermsSize);
 }
 
-Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
+TheoryIdSet SharedTermsDatabase::getTheoriesToNotify(TNode atom,
+                                                     TNode term) const
+{
   // Get the theories that share this term from this atom
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
   Assert(find != d_termsToTheories.end());
 
   // Get the theories that were already notified
-  Theory::Set alreadyNotified = 0;
+  TheoryIdSet alreadyNotified = 0;
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
     alreadyNotified = (*theoriesFind).second;
   }
 
   // Return the ones that haven't been notified yet
-  return Theory::setDifference((*find).second, alreadyNotified);
+  return TheoryIdSetUtil::setDifference((*find).second, alreadyNotified);
 }
 
-
-Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
+TheoryIdSet SharedTermsDatabase::getNotifiedTheories(TNode term) const
+{
   // Get the theories that were already notified
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
@@ -122,7 +158,7 @@ Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
 
 bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
 {
-  Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
+  Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << "," << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
 
   if (d_inConflict) {
     return false;
@@ -131,24 +167,25 @@ bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNod
   // Propagate away
   Node equality = a.eqNode(b);
   if (value) {
-    d_theoryEngine->assertToTheory(equality, theory, THEORY_BUILTIN);
+    d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
   } else {
-    d_theoryEngine->assertToTheory(equality.notNode(), theory, THEORY_BUILTIN);
+    d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
   }
 
   // As you were
   return true;
 }
 
-void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
-
+void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
+{
   // Find out if there are any new theories that were notified about this term
-  Theory::Set alreadyNotified = 0;
+  TheoryIdSet alreadyNotified = 0;
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
     alreadyNotified = (*theoriesFind).second;
   }
-  Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
+  TheoryIdSet newlyNotified =
+      TheoryIdSetUtil::setDifference(theories, alreadyNotified);
 
   // If no new theories were notified, we are done
   if (newlyNotified == 0) {
@@ -158,12 +195,21 @@ void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
   Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
 
   // First update the set of notified theories for this term
-  d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
+  d_alreadyNotifiedMap[term] =
+      TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified);
+
+  if (d_equalityEngine == nullptr)
+  {
+    // if we are not assigned an equality engine, there is nothing to do
+    return;
+  }
 
   // Mark the shared terms in the equality engine
   theory::TheoryId currentTheory;
-  while ((currentTheory = Theory::setPop(newlyNotified)) != THEORY_LAST) {
-    d_equalityEngine.addTriggerTerm(term, currentTheory);
+  while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified))
+         != THEORY_LAST)
+  {
+    d_equalityEngine->addTriggerTerm(term, currentTheory);
   }
 
   // Check for any conflits
@@ -171,32 +217,42 @@ void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
 }
 
 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areEqual(a,b);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areEqual(a, b);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
     return false;
   }
 }
 
 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areDisequal(a,b,false);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areDisequal(a, b, false);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // one (or both) are in the equality engine
     return false;
   }
 }
 
+theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine()
+{
+  return d_equalityEngine;
+}
+
 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
 {
+  Assert(d_equalityEngine != nullptr);
   Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
   // Add it to the equality engine
-  d_equalityEngine.assertEquality(equality, polarity, reason);
+  d_equalityEngine->assertEquality(equality, polarity, reason);
   // Check for conflict
   checkForConflict();
 }
@@ -210,54 +266,55 @@ bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
   return true;
 }
 
-static Node mkAnd(const std::vector<TNode>& conjunctions) {
-  Assert(conjunctions.size() > 0);
-
-  std::set<TNode> all;
-  all.insert(conjunctions.begin(), conjunctions.end());
-
-  if (all.size() == 1) {
-    // All the same, or just one
-    return conjunctions[0];
+void SharedTermsDatabase::checkForConflict()
+{
+  if (!d_inConflict)
+  {
+    return;
   }
-
-  NodeBuilder<> conjunction(kind::AND);
-  std::set<TNode>::const_iterator it = all.begin();
-  std::set<TNode>::const_iterator it_end = all.end();
-  while (it != it_end) {
-    conjunction << *it;
-    ++ it;
+  d_inConflict = false;
+  TrustNode trnc;
+  if (d_pfee != nullptr)
+  {
+    Node conflict = d_conflictLHS.eqNode(d_conflictRHS);
+    conflict = d_conflictPolarity ? conflict : conflict.notNode();
+    trnc = d_pfee->assertConflict(conflict);
   }
-
-  return conjunction;
-}
-
-void SharedTermsDatabase::checkForConflict() {
-  if (d_inConflict) {
-    d_inConflict = false;
+  else
+  {
+    // standard explain
     std::vector<TNode> assumptions;
-    d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
-    Node conflict = mkAnd(assumptions);
-    d_theoryEngine->conflict(conflict, THEORY_BUILTIN);
-    d_conflictLHS = d_conflictRHS = Node::null();
+    d_equalityEngine->explainEquality(
+        d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
+    Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions);
+    trnc = TrustNode::mkTrustConflict(conflictNode, nullptr);
   }
+  d_theoryEngine->conflict(trnc, THEORY_BUILTIN);
+  d_conflictLHS = d_conflictRHS = Node::null();
 }
 
 bool SharedTermsDatabase::isKnown(TNode literal) const {
+  Assert(d_equalityEngine != nullptr);
   bool polarity = literal.getKind() != kind::NOT;
   TNode equality = polarity ? literal : literal[0];
   if (polarity) {
-    return d_equalityEngine.areEqual(equality[0], equality[1]);
+    return d_equalityEngine->areEqual(equality[0], equality[1]);
   } else {
-    return d_equalityEngine.areDisequal(equality[0], equality[1], false);
+    return d_equalityEngine->areDisequal(equality[0], equality[1], false);
   }
 }
 
-Node SharedTermsDatabase::explain(TNode literal) const {
-  bool polarity = literal.getKind() != kind::NOT;
-  TNode atom = polarity ? literal : literal[0];
-  Assert(atom.getKind() == kind::EQUAL);
-  std::vector<TNode> assumptions;
-  d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
-  return mkAnd(assumptions);
+theory::TrustNode SharedTermsDatabase::explain(TNode literal) const
+{
+  if (d_pfee != nullptr)
+  {
+    // use the proof equality engine if it exists
+    return d_pfee->explain(literal);
+  }
+  // otherwise, explain without proofs
+  Node exp = d_equalityEngine->mkExplainLit(literal);
+  // no proof generator
+  return TrustNode::mkTrustPropExp(literal, exp, nullptr);
 }
+
+} /* namespace CVC4 */