Simplify interface to instantiate (#5926)
[cvc5.git] / src / theory / shared_terms_database.cpp
index 034401b9a0125ba2f1ea5fb4ea45787292a92510..29bdc03f84d865cc3769855bcbec7dcc27169fb7 100644 (file)
@@ -2,10 +2,10 @@
 /*! \file shared_terms_database.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Dejan Jovanovic, Morgan Deters, Tim King
+ **   Andrew Reynolds, Dejan Jovanovic, Morgan Deters
  ** This file is part of the CVC4 project.
  ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
+ ** in the top-level source directory and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
  **
@@ -24,7 +24,9 @@ using namespace CVC4::theory;
 namespace CVC4 {
 
 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
-                                         context::Context* context)
+                                         context::Context* context,
+                                         context::UserContext* userContext,
+                                         ProofNodeManager* pnm)
     : ContextNotifyObj(context),
       d_statSharedTerms("theory::shared_terms", 0),
       d_addedSharedTermsSize(context, 0),
@@ -32,10 +34,15 @@ SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
       d_alreadyNotifiedMap(context),
       d_registeredEqualities(context),
       d_EENotify(*this),
-      d_equalityEngine(d_EENotify, context, "SharedTermsDatabase", true),
       d_theoryEngine(theoryEngine),
       d_inConflict(context, false),
-      d_conflictPolarity() {
+      d_conflictPolarity(),
+      d_satContext(context),
+      d_userContext(userContext),
+      d_equalityEngine(nullptr),
+      d_pfee(nullptr),
+      d_pnm(pnm)
+{
   smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
 }
 
@@ -46,7 +53,14 @@ SharedTermsDatabase::~SharedTermsDatabase()
 
 void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
 {
-  // TODO (project #39): dynamic allocation of equality engine here
+  Assert(ee != nullptr);
+  d_equalityEngine = ee;
+  // if proofs are enabled, make the proof equality engine
+  if (d_pnm != nullptr)
+  {
+    d_pfee.reset(
+        new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm));
+  }
 }
 
 bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
@@ -57,8 +71,9 @@ bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
 }
 
 void SharedTermsDatabase::addEqualityToPropagate(TNode equality) {
+  Assert(d_equalityEngine != nullptr);
   d_registeredEqualities.insert(equality);
-  d_equalityEngine.addTriggerPredicate(equality);
+  d_equalityEngine->addTriggerPredicate(equality);
   checkForConflict();
 }
 
@@ -183,12 +198,18 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
   d_alreadyNotifiedMap[term] =
       TheoryIdSetUtil::setUnion(newlyNotified, alreadyNotified);
 
+  if (d_equalityEngine == nullptr)
+  {
+    // if we are not assigned an equality engine, there is nothing to do
+    return;
+  }
+
   // Mark the shared terms in the equality engine
   theory::TheoryId currentTheory;
   while ((currentTheory = TheoryIdSetUtil::setPop(newlyNotified))
          != THEORY_LAST)
   {
-    d_equalityEngine.addTriggerTerm(term, currentTheory);
+    d_equalityEngine->addTriggerTerm(term, currentTheory);
   }
 
   // Check for any conflits
@@ -196,32 +217,42 @@ void SharedTermsDatabase::markNotified(TNode term, TheoryIdSet theories)
 }
 
 bool SharedTermsDatabase::areEqual(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areEqual(a,b);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areEqual(a, b);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // since one (or both) of them is a constant, and the other is in the equality engine, they are not same
     return false;
   }
 }
 
 bool SharedTermsDatabase::areDisequal(TNode a, TNode b) const {
-  if (d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)) {
-    return d_equalityEngine.areDisequal(a,b,false);
+  Assert(d_equalityEngine != nullptr);
+  if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b))
+  {
+    return d_equalityEngine->areDisequal(a, b, false);
   } else {
-    Assert(d_equalityEngine.hasTerm(a) || a.isConst());
-    Assert(d_equalityEngine.hasTerm(b) || b.isConst());
+    Assert(d_equalityEngine->hasTerm(a) || a.isConst());
+    Assert(d_equalityEngine->hasTerm(b) || b.isConst());
     // one (or both) are in the equality engine
     return false;
   }
 }
 
+theory::eq::EqualityEngine* SharedTermsDatabase::getEqualityEngine()
+{
+  return d_equalityEngine;
+}
+
 void SharedTermsDatabase::assertEquality(TNode equality, bool polarity, TNode reason)
 {
+  Assert(d_equalityEngine != nullptr);
   Debug("shared-terms-database::assert") << "SharedTermsDatabase::assertEquality(" << equality << ", " << (polarity ? "true" : "false") << ", " << reason << ")" << endl;
   // Add it to the equality engine
-  d_equalityEngine.assertEquality(equality, polarity, reason);
+  d_equalityEngine->assertEquality(equality, polarity, reason);
   // Check for conflict
   checkForConflict();
 }
@@ -235,57 +266,55 @@ bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
   return true;
 }
 
-static Node mkAnd(const std::vector<TNode>& conjunctions) {
-  Assert(conjunctions.size() > 0);
-
-  std::set<TNode> all;
-  all.insert(conjunctions.begin(), conjunctions.end());
-
-  if (all.size() == 1) {
-    // All the same, or just one
-    return conjunctions[0];
+void SharedTermsDatabase::checkForConflict()
+{
+  if (!d_inConflict)
+  {
+    return;
   }
-
-  NodeBuilder<> conjunction(kind::AND);
-  std::set<TNode>::const_iterator it = all.begin();
-  std::set<TNode>::const_iterator it_end = all.end();
-  while (it != it_end) {
-    conjunction << *it;
-    ++ it;
+  d_inConflict = false;
+  TrustNode trnc;
+  if (d_pfee != nullptr)
+  {
+    Node conflict = d_conflictLHS.eqNode(d_conflictRHS);
+    conflict = d_conflictPolarity ? conflict : conflict.notNode();
+    trnc = d_pfee->assertConflict(conflict);
   }
-
-  return conjunction;
-}
-
-void SharedTermsDatabase::checkForConflict() {
-  if (d_inConflict) {
-    d_inConflict = false;
+  else
+  {
+    // standard explain
     std::vector<TNode> assumptions;
-    d_equalityEngine.explainEquality(d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
-    Node conflict = mkAnd(assumptions);
-    TrustNode tconf = TrustNode::mkTrustConflict(conflict);
-    d_theoryEngine->conflict(tconf, THEORY_BUILTIN);
-    d_conflictLHS = d_conflictRHS = Node::null();
+    d_equalityEngine->explainEquality(
+        d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
+    Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions);
+    trnc = TrustNode::mkTrustConflict(conflictNode, nullptr);
   }
+  d_theoryEngine->conflict(trnc, THEORY_BUILTIN);
+  d_conflictLHS = d_conflictRHS = Node::null();
 }
 
 bool SharedTermsDatabase::isKnown(TNode literal) const {
+  Assert(d_equalityEngine != nullptr);
   bool polarity = literal.getKind() != kind::NOT;
   TNode equality = polarity ? literal : literal[0];
   if (polarity) {
-    return d_equalityEngine.areEqual(equality[0], equality[1]);
+    return d_equalityEngine->areEqual(equality[0], equality[1]);
   } else {
-    return d_equalityEngine.areDisequal(equality[0], equality[1], false);
+    return d_equalityEngine->areDisequal(equality[0], equality[1], false);
   }
 }
 
-Node SharedTermsDatabase::explain(TNode literal) const {
-  bool polarity = literal.getKind() != kind::NOT;
-  TNode atom = polarity ? literal : literal[0];
-  Assert(atom.getKind() == kind::EQUAL);
-  std::vector<TNode> assumptions;
-  d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
-  return mkAnd(assumptions);
+theory::TrustNode SharedTermsDatabase::explain(TNode literal) const
+{
+  if (d_pfee != nullptr)
+  {
+    // use the proof equality engine if it exists
+    return d_pfee->explain(literal);
+  }
+  // otherwise, explain without proofs
+  Node exp = d_equalityEngine->mkExplainLit(literal);
+  // no proof generator
+  return TrustNode::mkTrustPropExp(literal, exp, nullptr);
 }
 
 } /* namespace CVC4 */