(proof-new) Make shared solver proof producing (#5169)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 2 Oct 2020 19:55:31 +0000 (14:55 -0500)
committerGitHub <noreply@github.com>
Fri, 2 Oct 2020 19:55:31 +0000 (14:55 -0500)
This makes the shared terms database use a proof equality engine as a layer on top of its equality engine, analogous to how this done in theories.

src/theory/combination_engine.cpp
src/theory/combination_engine.h
src/theory/shared_solver.cpp
src/theory/shared_solver.h
src/theory/shared_solver_distributed.cpp
src/theory/shared_solver_distributed.h
src/theory/shared_terms_database.cpp
src/theory/shared_terms_database.h

index 32af150547356f324669554f16d0ef8df3628143..5e242659f238c350073e99dff8fc721b66f0888d 100644 (file)
@@ -28,6 +28,7 @@ CombinationEngine::CombinationEngine(TheoryEngine& te,
                                      const std::vector<Theory*>& paraTheories,
                                      ProofNodeManager* pnm)
     : d_te(te),
+      d_pnm(pnm),
       d_logicInfo(te.getLogicInfo()),
       d_paraTheories(paraTheories),
       d_eemanager(nullptr),
@@ -46,7 +47,7 @@ void CombinationEngine::finishInit()
   if (options::eeMode() == options::EqEngineMode::DISTRIBUTED)
   {
     // use the distributed shared solver
-    d_sharedSolver.reset(new SharedSolverDistributed(d_te));
+    d_sharedSolver.reset(new SharedSolverDistributed(d_te, d_pnm));
     // make the distributed equality engine manager
     d_eemanager.reset(
         new EqEngineManagerDistributed(d_te, *d_sharedSolver.get()));
index daafc1f67d55a98e91c495bcab30888192a34f42..4413da603846488975848ff3206923ec77415fc7 100644 (file)
@@ -111,6 +111,8 @@ class CombinationEngine
   void sendLemma(TrustNode trn, TheoryId atomsTo);
   /** Reference to the theory engine */
   TheoryEngine& d_te;
+  /** The proof node manager */
+  ProofNodeManager* d_pnm;
   /** Logic info of theory engine (cached) */
   const LogicInfo& d_logicInfo;
   /** List of parametric theories of theory engine */
index 794d3ca7c49789eb9d07012816efd417aebc428e..24d7d29cfe66e81d1f1138ecdcf9cd322c19b2b1 100644 (file)
@@ -26,10 +26,10 @@ namespace theory {
 // In distributed equality engine management, shared terms database also
 // maintains an equality engine. In central equality engine management,
 // it does not.
-SharedSolver::SharedSolver(TheoryEngine& te)
+SharedSolver::SharedSolver(TheoryEngine& te, ProofNodeManager* pnm)
     : d_te(te),
       d_logicInfo(te.getLogicInfo()),
-      d_sharedTerms(&d_te, d_te.getSatContext()),
+      d_sharedTerms(&d_te, d_te.getSatContext(), d_te.getUserContext(), pnm),
       d_sharedTermsVisitor(d_sharedTerms)
 {
 }
index d3604facac8ffe6ef3b9343d8b54f3e7585a4fa4..c3d95f3c44daa0481aa3913acc3428069082dedb 100644 (file)
@@ -18,6 +18,7 @@
 #define CVC4__THEORY__SHARED_SOLVER__H
 
 #include "expr/node.h"
+#include "expr/proof_node_manager.h"
 #include "theory/ee_setup_info.h"
 #include "theory/logic_info.h"
 #include "theory/shared_terms_database.h"
@@ -42,7 +43,7 @@ namespace theory {
 class SharedSolver
 {
  public:
-  SharedSolver(TheoryEngine& te);
+  SharedSolver(TheoryEngine& te, ProofNodeManager* pnm);
   virtual ~SharedSolver() {}
   //------------------------------------- initialization
   /**
index 5975d3dd8c2c78404e2aa08f38599eb6f878db6c..c868ed2061cf766e71a8312ec33a3f45482263cd 100644 (file)
@@ -19,8 +19,9 @@
 namespace CVC4 {
 namespace theory {
 
-SharedSolverDistributed::SharedSolverDistributed(TheoryEngine& te)
-    : SharedSolver(te)
+SharedSolverDistributed::SharedSolverDistributed(TheoryEngine& te,
+                                                 ProofNodeManager* pnm)
+    : SharedSolver(te, pnm)
 {
 }
 
@@ -67,7 +68,7 @@ TrustNode SharedSolverDistributed::explain(TNode literal, TheoryId id)
   TrustNode texp;
   if (id == THEORY_BUILTIN)
   {
-    // explanation based on the specific solver
+    // explanation using the shared terms database
     texp = d_sharedTerms.explain(literal);
     Trace("shared-solver")
         << "\tTerm was propagated by THEORY_BUILTIN. Explanation: "
index 45c7eafb3026cb3d03c87f2d428ff354e9d56923..de6e29743b67624f489d7ec58be23023b4e8f269 100644 (file)
@@ -30,7 +30,7 @@ namespace theory {
 class SharedSolverDistributed : public SharedSolver
 {
  public:
-  SharedSolverDistributed(TheoryEngine& te);
+  SharedSolverDistributed(TheoryEngine& te, ProofNodeManager* pnm);
   virtual ~SharedSolverDistributed() {}
   //------------------------------------- initialization
   /**
index 92c66e83b6112a1e8db5148dc971df13959ef609..edf512e4b71f73eacde47bb15e0c019846865850 100644 (file)
@@ -24,7 +24,9 @@ using namespace CVC4::theory;
 namespace CVC4 {
 
 SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
-                                         context::Context* context)
+                                         context::Context* context,
+                                         context::UserContext* userContext,
+                                         ProofNodeManager* pnm)
     : ContextNotifyObj(context),
       d_statSharedTerms("theory::shared_terms", 0),
       d_addedSharedTermsSize(context, 0),
@@ -35,7 +37,11 @@ SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine,
       d_theoryEngine(theoryEngine),
       d_inConflict(context, false),
       d_conflictPolarity(),
-      d_equalityEngine(nullptr)
+      d_satContext(context),
+      d_userContext(userContext),
+      d_equalityEngine(nullptr),
+      d_pfee(nullptr),
+      d_pnm(pnm)
 {
   smtStatisticsRegistry()->registerStat(&d_statSharedTerms);
 }
@@ -47,7 +53,14 @@ SharedTermsDatabase::~SharedTermsDatabase()
 
 void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee)
 {
+  Assert(ee != nullptr);
   d_equalityEngine = ee;
+  // if proofs are enabled, make the proof equality engine
+  if (d_pnm != nullptr)
+  {
+    d_pfee.reset(
+        new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm));
+  }
 }
 
 bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi)
@@ -253,40 +266,31 @@ bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) {
   return true;
 }
 
-static Node mkAnd(const std::vector<TNode>& conjunctions) {
-  Assert(conjunctions.size() > 0);
-
-  std::set<TNode> all;
-  all.insert(conjunctions.begin(), conjunctions.end());
-
-  if (all.size() == 1) {
-    // All the same, or just one
-    return conjunctions[0];
+void SharedTermsDatabase::checkForConflict()
+{
+  if (!d_inConflict)
+  {
+    return;
   }
-
-  NodeBuilder<> conjunction(kind::AND);
-  std::set<TNode>::const_iterator it = all.begin();
-  std::set<TNode>::const_iterator it_end = all.end();
-  while (it != it_end) {
-    conjunction << *it;
-    ++ it;
+  d_inConflict = false;
+  TrustNode trnc;
+  if (d_pfee != nullptr)
+  {
+    Node conflict = d_conflictLHS.eqNode(d_conflictRHS);
+    conflict = d_conflictPolarity ? conflict : conflict.notNode();
+    trnc = d_pfee->assertConflict(conflict);
   }
-
-  return conjunction;
-}
-
-void SharedTermsDatabase::checkForConflict() {
-  Assert(d_equalityEngine != nullptr);
-  if (d_inConflict) {
-    d_inConflict = false;
+  else
+  {
+    // standard explain
     std::vector<TNode> assumptions;
     d_equalityEngine->explainEquality(
         d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions);
-    Node conflict = mkAnd(assumptions);
-    TrustNode tconf = TrustNode::mkTrustConflict(conflict);
-    d_theoryEngine->conflict(tconf, THEORY_BUILTIN);
-    d_conflictLHS = d_conflictRHS = Node::null();
+    Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions);
+    trnc = TrustNode::mkTrustConflict(conflictNode, nullptr);
   }
+  d_theoryEngine->conflict(trnc, THEORY_BUILTIN);
+  d_conflictLHS = d_conflictRHS = Node::null();
 }
 
 bool SharedTermsDatabase::isKnown(TNode literal) const {
@@ -300,15 +304,16 @@ bool SharedTermsDatabase::isKnown(TNode literal) const {
   }
 }
 
-TrustNode SharedTermsDatabase::explain(TNode literal) const
+theory::TrustNode SharedTermsDatabase::explain(TNode literal) const
 {
-  Assert(d_equalityEngine != nullptr);
-  bool polarity = literal.getKind() != kind::NOT;
-  TNode atom = polarity ? literal : literal[0];
-  Assert(atom.getKind() == kind::EQUAL);
-  std::vector<TNode> assumptions;
-  d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions);
-  Node exp = mkAnd(assumptions);
+  if (d_pfee != nullptr)
+  {
+    // use the proof equality engine if it exists
+    return d_pfee->explain(literal);
+  }
+  // otherwise, explain without proofs
+  Node exp = d_equalityEngine->mkExplainLit(literal);
+  // no proof generator
   return TrustNode::mkTrustPropExp(literal, exp, nullptr);
 }
 
index 558d6fc939ae5ef3bd3061e7e919d69b61ce6db4..693e93228f62d9bc55509b61026c5953009cd561 100644 (file)
 
 #include "context/cdhashset.h"
 #include "expr/node.h"
+#include "expr/proof_node_manager.h"
 #include "theory/ee_setup_info.h"
 #include "theory/theory_id.h"
+#include "theory/trust_node.h"
 #include "theory/uf/equality_engine.h"
+#include "theory/uf/proof_equality_engine.h"
 #include "util/statistics_registry.h"
 
 namespace CVC4 {
@@ -31,17 +34,14 @@ namespace CVC4 {
 class TheoryEngine;
 
 class SharedTermsDatabase : public context::ContextNotifyObj {
-
-public:
-
+ public:
   /** A container for a list of shared terms */
   typedef std::vector<TNode> shared_terms_list;
 
   /** The iterator to go through the shared terms list */
   typedef shared_terms_list::const_iterator shared_terms_iterator;
 
-private:
-
+ private:
   /** Some statistics */
   IntStat d_statSharedTerms;
 
@@ -73,8 +73,7 @@ private:
   typedef context::CDHashSet<Node, NodeHashFunction> RegisteredEqualitiesSet;
   RegisteredEqualitiesSet d_registeredEqualities;
 
-private:
-
+ private:
   /** This method removes all the un-necessary stuff from the maps */
   void backtrack();
 
@@ -151,9 +150,18 @@ private:
    */
   void checkForConflict();
 
-public:
-
-  SharedTermsDatabase(TheoryEngine* theoryEngine, context::Context* context);
+ public:
+  /**
+   * @param theoryEngine The parent theory engine
+   * @param context The SAT context
+   * @param userContext The user context
+   * @param pnm The proof node manager to use, which is non-null if proofs
+   * are enabled.
+   */
+  SharedTermsDatabase(TheoryEngine* theoryEngine,
+                      context::Context* context,
+                      context::UserContext* userContext,
+                      ProofNodeManager* pnm);
   ~SharedTermsDatabase();
 
   //-------------------------------------------- initialization
@@ -258,9 +266,16 @@ public:
    * This method gets called on backtracks from the context manager.
    */
   void contextNotifyPop() override { backtrack(); }
-
+  /** The SAT search context. */
+  context::Context* d_satContext;
+  /** The user level assertion context. */
+  context::UserContext* d_userContext;
   /** Equality engine */
   theory::eq::EqualityEngine* d_equalityEngine;
+  /** Proof equality engine */
+  std::unique_ptr<theory::eq::ProofEqEngine> d_pfee;
+  /** The proof node manager */
+  ProofNodeManager* d_pnm;
 };
 
 }