Simplify interface to instantiate (#5926)
[cvc5.git] / src / theory / shared_terms_database.cpp
index 1687f34806793e8bf2f2b4264964edb2ab61b6cf..29bdc03f84d865cc3769855bcbec7dcc27169fb7 100644 (file)
@@ -1,15 +1,13 @@
 /*********************                                                        */
 /*! \file shared_terms_database.cpp
  ** \verbatim
- ** Original author: dejan
- ** Major contributors: none
- ** Minor contributors (to current version): none
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010, 2011  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds, Dejan Jovanovic, Morgan Deters
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
  **
  ** [[ Add lengthier description here ]]
  ** \todo document this file
 
 #include "theory/shared_terms_database.h"
 
-using namespace CVC4;
+#include "smt/smt_statistics_registry.h"
+#include "theory/theory_engine.h"
+
+using namespace std;
 using namespace CVC4::theory;
 
-void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
-  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl; 
+namespace CVC4 {
+
+SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
+                                         context::Context* context,
+                                         context::UserContext* userContext,
+                                         ProofNodeManager* pnm)
+    : ContextNotifyObj(context),
+      d_statSharedTerms("theory::shared_terms", 0),
+      d_addedSharedTermsSize(context, 0),
+      d_termsToTheories(context),
+      d_alreadyNotifiedMap(context),
+      d_registeredEqualities(context),
+      d_EENotify(*this),
+      d_theoryEngine(theoryEngine),
+      d_inConflict(context, false),
+      d_conflictPolarity(),
+      d_satContext(context),
+      d_userContext(userContext),
+      d_equalityEngine(nullptr),
+      d_pfee(nullptr),
+      d_pnm(pnm)
+{
+  smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
+}
+
+SharedTermsDatabase::~SharedTermsDatabase()
+{
+  smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
+}
+
+void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
+{
+  Assert(ee != nullptr);
+  d_equalityEngine = ee;
+  // if proofs are enabled, make the proof equality engine
+  if (d_pnm != nullptr)
+  {
+    d_pfee.reset(
+        new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm));
+  }
+}
+
+bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
+{
+  esi.d_notify = &d_EENotify;
+  esi.d_name = "SharedTermsDatabase";
+  return true;
+}
+
+void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
+  Assert(d_equalityEngine != nullptr);
+  d_registeredEqualities.insert(equality);
+  d_equalityEngine->addTriggerPredicate(equality);
+  checkForConflict();
+}
+
+void SharedTermsDatabase::addSharedTerm(TNode atom,
+                                        TNode term,
+                                        TheoryIdSet theories)
+{
+  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", "
+                    << term << ", " << TheoryIdSetUtil::setToString(theories)
+                    << ")" << std::endl;
+
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
   if (find == d_termsToTheories.end()) {
@@ -32,18 +95,19 @@ void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theo
     d_termsToTheories[search_pair] = theories;
   } else {
     Assert(theories != (*find).second);
-    d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second); 
+    d_termsToTheories[search_pair] =
+        TheoryIdSetUtil::setUnion(theories, (*find).second);
   }
 }
 
 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
   Assert(hasSharedTerms(atom));
-  return d_atomsToTerms.find(atom)->second.begin();  
+  return d_atomsToTerms.find(atom)->second.begin();
 }
 
 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
   Assert(hasSharedTerms(atom));
-  return d_atomsToTerms.find(atom)->second.end();  
+  return d_atomsToTerms.find(atom)->second.end();
 }
 
 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
@@ -57,29 +121,32 @@ void SharedTermsDatabase::backtrack() {
     list.pop_back();
     if (list.empty()) {
       d_atomsToTerms.erase(atom);
-    } 
+    }
   }
   d_addedSharedTerms.resize(d_addedSharedTermsSize);
 }
 
-Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
-  // Get the theories that share this term from this atom 
+TheoryIdSet SharedTermsDatabase::getTheoriesToNotify(TNode atom,
+                                                     TNode term) const
+{
+  // Get the theories that share this term from this atom
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
-  Assert(find != d_termsToTheories.end());  
-  
+  Assert(find != d_termsToTheories.end());
+
   // Get the theories that were already notified
-  Theory::Set alreadyNotified = 0;
+  TheoryIdSet alreadyNotified = 0;
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
     alreadyNotified = (*theoriesFind).second;
   }
-  
+
   // Return the ones that haven't been notified yet
-  return Theory::setDifference((*find).second, alreadyNotified);
+  return TheoryIdSetUtil::setDifference((*find).second, alreadyNotified);
 }
 
-Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
+TheoryIdSet SharedTermsDatabase::getNotifiedTheories(TNode term) const
+{
   // Get the theories that were already notified
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
@@ -89,7 +156,165 @@ Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
   }
 }
 
-void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
-  d_alreadyNotifiedMap[term] = Theory::setUnion(theories, d_alreadyNotifiedMap[term]);
+bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
+{
+  Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << "," << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
+
+  if (d_inConflict) {
+    return false;
+  }
+
+  // Propagate away
+  Node equality = a.eqNode(b);
+  if (value) {
+    d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
+  } else {
+    d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
+  }
+
+  // As you were
+  return true;
+}
+
+void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
+{
+  // Find out if there are any new theories that were notified about this term
+  TheoryIdSet alreadyNotified = 0;
+  AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
+  if (theoriesFind != d_alreadyNotifiedMap.end()) {
+    alreadyNotified = (*theoriesFind).second;
+  }
+  TheoryIdSet newlyNotified =
+      TheoryIdSetUtil::setDifference(theories, alreadyNotified);
+
+  // If no new theories were notified, we are done
+  if (newlyNotified == 0) {
+    return;
+  }
+
+  Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
+
+  // First update the set of notified theories for this term
+  d_alreadyNotifiedMap[term] =
+      TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified);
+
+  if (d_equalityEngine == nullptr)
+  {
+    // if we are not assigned an equality engine, there is nothing to do
+    return;
+  }
+
+  // Mark the shared terms in the equality engine
+  theory::TheoryId currentTheory;
+  while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified))
+         != THEORY_LAST)
+  {
+    d_equalityEngine->addTriggerTerm(term, currentTheory);
+  }
+
+  // Check for any conflits
+  checkForConflict();
+}
+
+bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areEqual(a, b);
+  } else {
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
+    // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
+    return false;
+  }
+}
+
+bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areDisequal(a, b, false);
+  } else {
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
+    // one (or both) are in the equality engine
+    return false;
+  }
+}
+
+theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine()
+{
+  return d_equalityEngine;
+}
+
+void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
+{
+  Assert(d_equalityEngine != nullptr);
+  Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
+  // Add it to the equality engine
+  d_equalityEngine->assertEquality(equality, polarity, reason);
+  // Check for conflict
+  checkForConflict();
+}
+
+bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
+  if (polarity) {
+    d_theoryEngine->propagate(equality, THEORY_BUILTIN);
+  } else {
+    d_theoryEngine->propagate(equality.notNode(), THEORY_BUILTIN);
+  }
+  return true;
+}
+
+void SharedTermsDatabase::checkForConflict()
+{
+  if (!d_inConflict)
+  {
+    return;
+  }
+  d_inConflict = false;
+  TrustNode trnc;
+  if (d_pfee != nullptr)
+  {
+    Node conflict = d_conflictLHS.eqNode(d_conflictRHS);
+    conflict = d_conflictPolarity ? conflict : conflict.notNode();
+    trnc = d_pfee->assertConflict(conflict);
+  }
+  else
+  {
+    // standard explain
+    std::vector<TNode> assumptions;
+    d_equalityEngine->explainEquality(
+        d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
+    Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions);
+    trnc = TrustNode::mkTrustConflict(conflictNode, nullptr);
+  }
+  d_theoryEngine->conflict(trnc, THEORY_BUILTIN);
+  d_conflictLHS = d_conflictRHS = Node::null();
+}
+
+bool SharedTermsDatabase::isKnown(TNode literal) const {
+  Assert(d_equalityEngine != nullptr);
+  bool polarity = literal.getKind() != kind::NOT;
+  TNode equality = polarity ? literal : literal[0];
+  if (polarity) {
+    return d_equalityEngine->areEqual(equality[0], equality[1]);
+  } else {
+    return d_equalityEngine->areDisequal(equality[0], equality[1], false);
+  }
+}
+
+theory::TrustNode SharedTermsDatabase::explain(TNode literal) const
+{
+  if (d_pfee != nullptr)
+  {
+    // use the proof equality engine if it exists
+    return d_pfee->explain(literal);
+  }
+  // otherwise, explain without proofs
+  Node exp = d_equalityEngine->mkExplainLit(literal);
+  // no proof generator
+  return TrustNode::mkTrustPropExp(literal, exp, nullptr);
 }
 
+} /* namespace CVC4 */