Simplify interface to instantiate (#5926)
[cvc5.git] / src / theory / shared_terms_database.cpp
index 0c893482a33feaaffbe2f012dcf34901fac602cd..29bdc03f84d865cc3769855bcbec7dcc27169fb7 100644 (file)
@@ -1,15 +1,13 @@
 /*********************                                                        */
 /*! \file shared_terms_database.cpp
  ** \verbatim
- ** Original author: dejan
- ** Major contributors: none
- ** Minor contributors (to current version): none
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010, 2011  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds, Dejan Jovanovic, Morgan Deters
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
  **
  ** [[ Add lengthier description here ]]
  ** \todo document this file
 
 #include "theory/shared_terms_database.h"
 
+#include "smt/smt_statistics_registry.h"
+#include "theory/theory_engine.h"
+
 using namespace std;
-using namespace CVC4;
-using namespace theory;
-
-SharedTermsDatabase::SharedTermsDatabase(SharedTermsNotifyClass& notify, context::Context* context)
-  : ContextNotifyObj(context),
-    d_context(context), 
-    d_statSharedTerms("theory::shared_terms", 0),
-    d_addedSharedTermsSize(context, 0),
-    d_termsToTheories(context),
-    d_alreadyNotifiedMap(context),
-    d_sharedNotify(notify),
-    d_termToNotifyList(context),
-    d_allocatedNLSize(0),
-    d_allocatedNLNext(context, 0),
-    d_EENotify(*this),
-    d_equalityEngine(d_EENotify, context, "SharedTermsDatabase")
+using namespace CVC4::theory;
+
+namespace CVC4 {
+
+SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
+                                         context::Context* context,
+                                         context::UserContext* userContext,
+                                         ProofNodeManager* pnm)
+    : ContextNotifyObj(context),
+      d_statSharedTerms("theory::shared_terms", 0),
+      d_addedSharedTermsSize(context, 0),
+      d_termsToTheories(context),
+      d_alreadyNotifiedMap(context),
+      d_registeredEqualities(context),
+      d_EENotify(*this),
+      d_theoryEngine(theoryEngine),
+      d_inConflict(context, false),
+      d_conflictPolarity(),
+      d_satContext(context),
+      d_userContext(userContext),
+      d_equalityEngine(nullptr),
+      d_pfee(nullptr),
+      d_pnm(pnm)
 {
-  StatisticsRegistry::registerStat(&d_statSharedTerms);
+  smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
 }
 
-SharedTermsDatabase::~SharedTermsDatabase() throw(AssertionException)
+SharedTermsDatabase::~SharedTermsDatabase()
 {
-  StatisticsRegistry::unregisterStat(&d_statSharedTerms);
-  for (unsigned i = 0; i < d_allocatedNLSize; ++i) {
-    d_allocatedNL[i]->deleteSelf();
+  smtStatisticsRegistry()->unregisterStat(&d_statSharedTerms);
+}
+
+void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
+{
+  Assert(ee != nullptr);
+  d_equalityEngine = ee;
+  // if proofs are enabled, make the proof equality engine
+  if (d_pnm != nullptr)
+  {
+    d_pfee.reset(
+        new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm));
   }
 }
 
-void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theories) {
-  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", " << term << ", " << Theory::setToString(theories) << ")" << std::endl; 
+bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
+{
+  esi.d_notify = &d_EENotify;
+  esi.d_name = "SharedTermsDatabase";
+  return true;
+}
+
+void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
+  Assert(d_equalityEngine != nullptr);
+  d_registeredEqualities.insert(equality);
+  d_equalityEngine->addTriggerPredicate(equality);
+  checkForConflict();
+}
+
+void SharedTermsDatabase::addSharedTerm(TNode atom,
+                                        TNode term,
+                                        TheoryIdSet theories)
+{
+  Debug("register") << "SharedTermsDatabase::addSharedTerm(" << atom << ", "
+                    << term << ", " << TheoryIdSetUtil::setToString(theories)
+                    << ")" << std::endl;
 
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
@@ -57,23 +93,21 @@ void SharedTermsDatabase::addSharedTerm(TNode atom, TNode term, Theory::Set theo
     d_addedSharedTerms.push_back(atom);
     d_addedSharedTermsSize = d_addedSharedTermsSize + 1;
     d_termsToTheories[search_pair] = theories;
-    if (!d_equalityEngine.hasTerm(term)) {
-      d_equalityEngine.addTriggerTerm(term, THEORY_UF);
-    }
   } else {
     Assert(theories != (*find).second);
-    d_termsToTheories[search_pair] = Theory::setUnion(theories, (*find).second); 
+    d_termsToTheories[search_pair] =
+        TheoryIdSetUtil::setUnion(theories, (*find).second);
   }
 }
 
 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::begin(TNode atom) const {
   Assert(hasSharedTerms(atom));
-  return d_atomsToTerms.find(atom)->second.begin();  
+  return d_atomsToTerms.find(atom)->second.begin();
 }
 
 SharedTermsDatabase::shared_terms_iterator SharedTermsDatabase::end(TNode atom) const {
   Assert(hasSharedTerms(atom));
-  return d_atomsToTerms.find(atom)->second.end();  
+  return d_atomsToTerms.find(atom)->second.end();
 }
 
 bool SharedTermsDatabase::hasSharedTerms(TNode atom) const {
@@ -87,30 +121,32 @@ void SharedTermsDatabase::backtrack() {
     list.pop_back();
     if (list.empty()) {
       d_atomsToTerms.erase(atom);
-    } 
+    }
   }
   d_addedSharedTerms.resize(d_addedSharedTermsSize);
 }
 
-Theory::Set SharedTermsDatabase::getTheoriesToNotify(TNode atom, TNode term) const {
-  // Get the theories that share this term from this atom 
+TheoryIdSet SharedTermsDatabase::getTheoriesToNotify(TNode atom,
+                                                     TNode term) const
+{
+  // Get the theories that share this term from this atom
   std::pair<TNode, TNode> search_pair(atom, term);
   SharedTermsTheoriesMap::iterator find = d_termsToTheories.find(search_pair);
-  Assert(find != d_termsToTheories.end());  
-  
+  Assert(find != d_termsToTheories.end());
+
   // Get the theories that were already notified
-  Theory::Set alreadyNotified = 0;
+  TheoryIdSet alreadyNotified = 0;
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
     alreadyNotified = (*theoriesFind).second;
   }
-  
+
   // Return the ones that haven't been notified yet
-  return Theory::setDifference((*find).second, alreadyNotified);
+  return TheoryIdSetUtil::setDifference((*find).second, alreadyNotified);
 }
 
-
-Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
+TheoryIdSet SharedTermsDatabase::getNotifiedTheories(TNode term) const
+{
   // Get the theories that were already notified
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
@@ -120,113 +156,36 @@ Theory::Set SharedTermsDatabase::getNotifiedTheories(TNode term) const {
   }
 }
 
-
-SharedTermsDatabase::NotifyList* SharedTermsDatabase::getNewNotifyList()
+bool SharedTermsDatabase::propagateSharedEquality(TheoryId theory, TNode a, TNode b, bool value)
 {
-  NotifyList* retval;
-  if (d_allocatedNLSize == d_allocatedNLNext) {
-    retval = new (true) NotifyList(d_context);
-    d_allocatedNL.push_back(retval);
-    d_allocatedNLNext = ++d_allocatedNLSize;
-  }
-  else {
-    retval = d_allocatedNL[d_allocatedNLNext];
-    d_allocatedNLNext = d_allocatedNLNext + 1;
-  }
-  Assert(retval->empty());
-  return retval;
-}
-
-
-void SharedTermsDatabase::mergeSharedTerms(TNode a, TNode b)
-{
-  // Note: a is the new representative
-  Debug("shared-terms-database") << "SharedTermsDatabase::mergeSharedTerms(" << a << "," << b << ")" << endl;
+  Debug("shared-terms-database") << "SharedTermsDatabase::newEquality(" << theory << "," << a << "," << b << ", " << (value ? "true" : "false") << ")" << endl;
 
-  NotifyList* pnlLeft = NULL;
-  NotifyList* pnlRight = NULL;
-
-  TermToNotifyList::iterator it = d_termToNotifyList.find(a);
-  if (it == d_termToNotifyList.end()) {
-    pnlLeft = getNewNotifyList();
-    d_termToNotifyList[a] = pnlLeft;
-  }
-  else {
-    pnlLeft = (*it).second;
+  if (d_inConflict) {
+    return false;
   }
-  it = d_termToNotifyList.find(b);
-  if (it != d_termToNotifyList.end()) {
-    pnlRight = (*it).second;
-  }
-
-  // Get theories interested in EC for lhs
-  Theory::Set lhsSet = getNotifiedTheories(a);
-  Theory::Set rhsSet = getNotifiedTheories(b);
-  NotifyList::iterator nit;
-  TNode left, right;
-
-  for (TheoryId currentTheory = THEORY_FIRST; currentTheory != THEORY_LAST; ++ currentTheory) {
-
-    if (Theory::setContains(currentTheory, rhsSet)) {
-      right = b;
-    }
-    else if (pnlRight != NULL &&
-             ((nit = pnlRight->find(currentTheory)) != pnlRight->end())) {
-      right = (*nit).second;
-    }
-    else {
-      // no match for right: continue
-      continue;
-    }
-
-    if (Theory::setContains(currentTheory, lhsSet)) {
-      left = a;
-    }
-    else if ((nit = pnlLeft->find(currentTheory)) != pnlLeft->end()) {
-      left = (*nit).second;
-    }
-    else {
-      // no match for left: insert right into left
-      (*pnlLeft)[currentTheory] = right;
-      continue;
-    }
 
-    // New shared equality: notify the client
-
-    // TODO: add propagation of disequalities?
-
-    assertEq(left.eqNode(right), currentTheory);
+  // Propagate away
+  Node equality = a.eqNode(b);
+  if (value) {
+    d_theoryEngine->assertToTheory(equality, equality, theory, THEORY_BUILTIN);
+  } else {
+    d_theoryEngine->assertToTheory(equality.notNode(), equality.notNode(), theory, THEORY_BUILTIN);
   }
 
+  // As you were
+  return true;
 }
-  
 
-void SharedTermsDatabase::assertEq(TNode equality, TheoryId theory)
+void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
 {
-  Debug("shared-terms-database") << "SharedTermsDatabase::assertEq(" << equality << ") to theory " << theory << endl;
-  Node normalized = Rewriter::rewriteEquality(theory, equality);
-  if (normalized.getKind() != kind::CONST_BOOLEAN || !normalized.getConst<bool>()) {
-    // Notify client
-    d_sharedNotify.notify(normalized, equality, theory);
-  }
-}
-
-
-// term was just part of an assertion that makes it shared for theories
-// Let's mark that the set theories has now been notified
-// In addition, we make sure the equivalence class containing term knows a
-// representative for each theory in theories.
-// Finally, if the EC already knows a rep for a theory that was just notified, we
-// have to tell the theory that these two terms are equal.
-void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
-
   // Find out if there are any new theories that were notified about this term
-  Theory::Set alreadyNotified = 0;
+  TheoryIdSet alreadyNotified = 0;
   AlreadyNotifiedMap::iterator theoriesFind = d_alreadyNotifiedMap.find(term);
   if (theoriesFind != d_alreadyNotifiedMap.end()) {
     alreadyNotified = (*theoriesFind).second;
   }
-  Theory::Set newlyNotified = Theory::setDifference(theories, alreadyNotified);
+  TheoryIdSet newlyNotified =
+      TheoryIdSetUtil::setDifference(theories, alreadyNotified);
 
   // If no new theories were notified, we are done
   if (newlyNotified == 0) {
@@ -236,117 +195,126 @@ void SharedTermsDatabase::markNotified(TNode term, Theory::Set theories) {
   Debug("shared-terms-database") << "SharedTermsDatabase::markNotified(" << term << ")" << endl;
 
   // First update the set of notified theories for this term
-  d_alreadyNotifiedMap[term] = Theory::setUnion(newlyNotified, alreadyNotified);
-
-  // Now get the representative of the equivalence class and find out which theories it represents
-  TNode rep = d_equalityEngine.getRepresentative(term);
-  if (rep != term) {
-    alreadyNotified = 0;
-    theoriesFind = d_alreadyNotifiedMap.find(rep);
-    if (theoriesFind != d_alreadyNotifiedMap.end()) {
-      alreadyNotified = (*theoriesFind).second;
-    }
-  }
+  d_alreadyNotifiedMap[term] =
+      TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified);
 
-  // For each theory that is newly notified
-  for (TheoryId theory = THEORY_FIRST; theory != THEORY_LAST; ++ theory) {
-    if (Theory::setContains(theory, newlyNotified)) {
-
-      Debug("shared-terms-database") << "SharedTermsDatabase::markNotified: processing theory " << theory << endl;
-
-      if (Theory::setContains(theory, alreadyNotified)) {
-        // rep represents this theory already, need to assert that term = rep
-        Assert(rep != term);
-        assertEq(rep.eqNode(term), theory);
-      }
-      else {
-        // Get the list of terms representing theories for this EC
-        TermToNotifyList::iterator it = d_termToNotifyList.find(rep);
-        if (it == d_termToNotifyList.end()) {
-          // No need to do anything - no list associated with this EC
-          Assert(term == rep);
-        }
-        else {
-          NotifyList* pnl = (*it).second;
-          Assert(pnl != NULL);
-
-          // Check if this theory is already represented
-          NotifyList::iterator nit = pnl->find(theory);
-          if (nit != pnl->end()) {
-            // Already have a representative for this theory, assert term equal to it
-            assertEq((*nit).second.eqNode(term), theory);
-          }
-          else {
-            // if term == rep, no need to do anything, as term will represent the theory via alreadyNotifiedMap
-            if (term != rep) {
-              // No term in this EC represents this theory, so add term as a new representative
-              Debug("shared-terms-database") << "SharedTermsDatabase::markNotified: adding " << term << " to representative " << rep << " for theory " << theory << endl;
-              (*pnl)[theory] = term;
-            }
-          }
-        }
-      }
-    }
+  if (d_equalityEngine == nullptr)
+  {
+    // if we are not assigned an equality engine, there is nothing to do
+    return;
   }
-}
 
+  // Mark the shared terms in the equality engine
+  theory::TheoryId currentTheory;
+  while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified))
+         != THEORY_LAST)
+  {
+    d_equalityEngine->addTriggerTerm(term, currentTheory);
+  }
 
-bool SharedTermsDatabase::areEqual(TNode a, TNode b) {
-  return d_equalityEngine.areEqual(a,b);
+  // Check for any conflits
+  checkForConflict();
 }
 
-
-bool SharedTermsDatabase::areDisequal(TNode a, TNode b) {
-  return d_equalityEngine.areDisequal(a,b,false);
+bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areEqual(a, b);
+  } else {
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
+    // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
+    return false;
+  }
 }
 
-void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason)
-{
-  bool negated = literal.getKind() == kind::NOT;
-  TNode atom = negated ? literal[0] : literal;
-  if (negated) {
-    Assert(!d_equalityEngine.areDisequal(atom[0], atom[1],false));
-    d_equalityEngine.assertEquality(atom, false, reason);
-    //    !!! need to send this out
-  }
-  else {
-    Assert(!d_equalityEngine.areEqual(atom[0], atom[1]));
-    d_equalityEngine.assertEquality(atom, true, reason);
+bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areDisequal(a, b, false);
+  } else {
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
+    // one (or both) are in the equality engine
+    return false;
   }
 }
 
-static Node mkAnd(const std::vector<TNode>& conjunctions) {
-  Assert(conjunctions.size() > 0);
+theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine()
+{
+  return d_equalityEngine;
+}
 
-  std::set<TNode> all;
-  all.insert(conjunctions.begin(), conjunctions.end());
+void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
+{
+  Assert(d_equalityEngine != nullptr);
+  Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
+  // Add it to the equality engine
+  d_equalityEngine->assertEquality(equality, polarity, reason);
+  // Check for conflict
+  checkForConflict();
+}
 
-  if (all.size() == 1) {
-    // All the same, or just one
-    return conjunctions[0];
+bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
+  if (polarity) {
+    d_theoryEngine->propagate(equality, THEORY_BUILTIN);
+  } else {
+    d_theoryEngine->propagate(equality.notNode(), THEORY_BUILTIN);
   }
+  return true;
+}
 
-  NodeBuilder<> conjunction(kind::AND);
-  std::set<TNode>::const_iterator it = all.begin();
-  std::set<TNode>::const_iterator it_end = all.end();
-  while (it != it_end) {
-    conjunction << *it;
-    ++ it;
+void SharedTermsDatabase::checkForConflict()
+{
+  if (!d_inConflict)
+  {
+    return;
   }
+  d_inConflict = false;
+  TrustNode trnc;
+  if (d_pfee != nullptr)
+  {
+    Node conflict = d_conflictLHS.eqNode(d_conflictRHS);
+    conflict = d_conflictPolarity ? conflict : conflict.notNode();
+    trnc = d_pfee->assertConflict(conflict);
+  }
+  else
+  {
+    // standard explain
+    std::vector<TNode> assumptions;
+    d_equalityEngine->explainEquality(
+        d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
+    Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions);
+    trnc = TrustNode::mkTrustConflict(conflictNode, nullptr);
+  }
+  d_theoryEngine->conflict(trnc, THEORY_BUILTIN);
+  d_conflictLHS = d_conflictRHS = Node::null();
+}
 
-  return conjunction;
-}/* mkAnd() */
-
+bool SharedTermsDatabase::isKnown(TNode literal) const {
+  Assert(d_equalityEngine != nullptr);
+  bool polarity = literal.getKind() != kind::NOT;
+  TNode equality = polarity ? literal : literal[0];
+  if (polarity) {
+    return d_equalityEngine->areEqual(equality[0], equality[1]);
+  } else {
+    return d_equalityEngine->areDisequal(equality[0], equality[1], false);
+  }
+}
 
-Node SharedTermsDatabase::explain(TNode literal)
+theory::TrustNode SharedTermsDatabase::explain(TNode literal) const
 {
-  std::vector<TNode> assumptions;
-  if (literal.getKind() == kind::NOT) {
-    Assert(literal[0].getKind() == kind::EQUAL);
-    d_equalityEngine.explainEquality(literal[0][0], literal[0][1], false, assumptions);
-  } else {
-    Assert(literal.getKind() == kind::EQUAL);
-    d_equalityEngine.explainEquality(literal[0], literal[1], true, assumptions);
+  if (d_pfee != nullptr)
+  {
+    // use the proof equality engine if it exists
+    return d_pfee->explain(literal);
   }
-  return mkAnd(assumptions);
+  // otherwise, explain without proofs
+  Node exp = d_equalityEngine->mkExplainLit(literal);
+  // no proof generator
+  return TrustNode::mkTrustPropExp(literal, exp, nullptr);
 }
+
+} /* namespace CVC4 */