From 26601663d6cc8fb8613df5a1d253eba8743e57cb Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 2 Oct 2020 14:55:31 -0500 Subject: [PATCH] (proof-new) Make shared solver proof producing (#5169) This makes the shared terms database use a proof equality engine as a layer on top of its equality engine, analogous to how this done in theories. --- src/theory/combination_engine.cpp | 3 +- src/theory/combination_engine.h | 2 + src/theory/shared_solver.cpp | 4 +- src/theory/shared_solver.h | 3 +- src/theory/shared_solver_distributed.cpp | 7 +- src/theory/shared_solver_distributed.h | 2 +- src/theory/shared_terms_database.cpp | 81 +++++++++++++----------- src/theory/shared_terms_database.h | 37 +++++++---- 8 files changed, 82 insertions(+), 57 deletions(-) diff --git a/src/theory/combination_engine.cpp b/src/theory/combination_engine.cpp index 32af15054..5e242659f 100644 --- a/src/theory/combination_engine.cpp +++ b/src/theory/combination_engine.cpp @@ -28,6 +28,7 @@ CombinationEngine::CombinationEngine(TheoryEngine& te, const std::vector& paraTheories, ProofNodeManager* pnm) : d_te(te), + d_pnm(pnm), d_logicInfo(te.getLogicInfo()), d_paraTheories(paraTheories), d_eemanager(nullptr), @@ -46,7 +47,7 @@ void CombinationEngine::finishInit() if (options::eeMode() == options::EqEngineMode::DISTRIBUTED) { // use the distributed shared solver - d_sharedSolver.reset(new SharedSolverDistributed(d_te)); + d_sharedSolver.reset(new SharedSolverDistributed(d_te, d_pnm)); // make the distributed equality engine manager d_eemanager.reset( new EqEngineManagerDistributed(d_te, *d_sharedSolver.get())); diff --git a/src/theory/combination_engine.h b/src/theory/combination_engine.h index daafc1f67..4413da603 100644 --- a/src/theory/combination_engine.h +++ b/src/theory/combination_engine.h @@ -111,6 +111,8 @@ class CombinationEngine void sendLemma(TrustNode trn, TheoryId atomsTo); /** Reference to the theory engine */ TheoryEngine& d_te; + /** The proof node manager */ + ProofNodeManager* d_pnm; /** Logic info of theory engine (cached) */ const LogicInfo& d_logicInfo; /** List of parametric theories of theory engine */ diff --git a/src/theory/shared_solver.cpp b/src/theory/shared_solver.cpp index 794d3ca7c..24d7d29cf 100644 --- a/src/theory/shared_solver.cpp +++ b/src/theory/shared_solver.cpp @@ -26,10 +26,10 @@ namespace theory { // In distributed equality engine management, shared terms database also // maintains an equality engine. In central equality engine management, // it does not. -SharedSolver::SharedSolver(TheoryEngine& te) +SharedSolver::SharedSolver(TheoryEngine& te, ProofNodeManager* pnm) : d_te(te), d_logicInfo(te.getLogicInfo()), - d_sharedTerms(&d_te, d_te.getSatContext()), + d_sharedTerms(&d_te, d_te.getSatContext(), d_te.getUserContext(), pnm), d_sharedTermsVisitor(d_sharedTerms) { } diff --git a/src/theory/shared_solver.h b/src/theory/shared_solver.h index d3604faca..c3d95f3c4 100644 --- a/src/theory/shared_solver.h +++ b/src/theory/shared_solver.h @@ -18,6 +18,7 @@ #define CVC4__THEORY__SHARED_SOLVER__H #include "expr/node.h" +#include "expr/proof_node_manager.h" #include "theory/ee_setup_info.h" #include "theory/logic_info.h" #include "theory/shared_terms_database.h" @@ -42,7 +43,7 @@ namespace theory { class SharedSolver { public: - SharedSolver(TheoryEngine& te); + SharedSolver(TheoryEngine& te, ProofNodeManager* pnm); virtual ~SharedSolver() {} //------------------------------------- initialization /** diff --git a/src/theory/shared_solver_distributed.cpp b/src/theory/shared_solver_distributed.cpp index 5975d3dd8..c868ed206 100644 --- a/src/theory/shared_solver_distributed.cpp +++ b/src/theory/shared_solver_distributed.cpp @@ -19,8 +19,9 @@ namespace CVC4 { namespace theory { -SharedSolverDistributed::SharedSolverDistributed(TheoryEngine& te) - : SharedSolver(te) +SharedSolverDistributed::SharedSolverDistributed(TheoryEngine& te, + ProofNodeManager* pnm) + : SharedSolver(te, pnm) { } @@ -67,7 +68,7 @@ TrustNode SharedSolverDistributed::explain(TNode literal, TheoryId id) TrustNode texp; if (id == THEORY_BUILTIN) { - // explanation based on the specific solver + // explanation using the shared terms database texp = d_sharedTerms.explain(literal); Trace("shared-solver") << "\tTerm was propagated by THEORY_BUILTIN. Explanation: " diff --git a/src/theory/shared_solver_distributed.h b/src/theory/shared_solver_distributed.h index 45c7eafb3..de6e29743 100644 --- a/src/theory/shared_solver_distributed.h +++ b/src/theory/shared_solver_distributed.h @@ -30,7 +30,7 @@ namespace theory { class SharedSolverDistributed : public SharedSolver { public: - SharedSolverDistributed(TheoryEngine& te); + SharedSolverDistributed(TheoryEngine& te, ProofNodeManager* pnm); virtual ~SharedSolverDistributed() {} //------------------------------------- initialization /** diff --git a/src/theory/shared_terms_database.cpp b/src/theory/shared_terms_database.cpp index 92c66e83b..edf512e4b 100644 --- a/src/theory/shared_terms_database.cpp +++ b/src/theory/shared_terms_database.cpp @@ -24,7 +24,9 @@ using namespace CVC4::theory; namespace CVC4 { SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine, - context::Context* context) + context::Context* context, + context::UserContext* userContext, + ProofNodeManager* pnm) : ContextNotifyObj(context), d_statSharedTerms("theory::shared_terms", 0), d_addedSharedTermsSize(context, 0), @@ -35,7 +37,11 @@ SharedTermsDatabase::SharedTermsDatabase(TheoryEngine* theoryEngine, d_theoryEngine(theoryEngine), d_inConflict(context, false), d_conflictPolarity(), - d_equalityEngine(nullptr) + d_satContext(context), + d_userContext(userContext), + d_equalityEngine(nullptr), + d_pfee(nullptr), + d_pnm(pnm) { smtStatisticsRegistry()->registerStat(&d_statSharedTerms); } @@ -47,7 +53,14 @@ SharedTermsDatabase::~SharedTermsDatabase() void SharedTermsDatabase::setEqualityEngine(eq::EqualityEngine* ee) { + Assert(ee != nullptr); d_equalityEngine = ee; + // if proofs are enabled, make the proof equality engine + if (d_pnm != nullptr) + { + d_pfee.reset( + new eq::ProofEqEngine(d_satContext, d_userContext, *ee, d_pnm)); + } } bool SharedTermsDatabase::needsEqualityEngine(EeSetupInfo& esi) @@ -253,40 +266,31 @@ bool SharedTermsDatabase::propagateEquality(TNode equality, bool polarity) { return true; } -static Node mkAnd(const std::vector& conjunctions) { - Assert(conjunctions.size() > 0); - - std::set all; - all.insert(conjunctions.begin(), conjunctions.end()); - - if (all.size() == 1) { - // All the same, or just one - return conjunctions[0]; +void SharedTermsDatabase::checkForConflict() +{ + if (!d_inConflict) + { + return; } - - NodeBuilder<> conjunction(kind::AND); - std::set::const_iterator it = all.begin(); - std::set::const_iterator it_end = all.end(); - while (it != it_end) { - conjunction << *it; - ++ it; + d_inConflict = false; + TrustNode trnc; + if (d_pfee != nullptr) + { + Node conflict = d_conflictLHS.eqNode(d_conflictRHS); + conflict = d_conflictPolarity ? conflict : conflict.notNode(); + trnc = d_pfee->assertConflict(conflict); } - - return conjunction; -} - -void SharedTermsDatabase::checkForConflict() { - Assert(d_equalityEngine != nullptr); - if (d_inConflict) { - d_inConflict = false; + else + { + // standard explain std::vector assumptions; d_equalityEngine->explainEquality( d_conflictLHS, d_conflictRHS, d_conflictPolarity, assumptions); - Node conflict = mkAnd(assumptions); - TrustNode tconf = TrustNode::mkTrustConflict(conflict); - d_theoryEngine->conflict(tconf, THEORY_BUILTIN); - d_conflictLHS = d_conflictRHS = Node::null(); + Node conflictNode = NodeManager::currentNM()->mkAnd(assumptions); + trnc = TrustNode::mkTrustConflict(conflictNode, nullptr); } + d_theoryEngine->conflict(trnc, THEORY_BUILTIN); + d_conflictLHS = d_conflictRHS = Node::null(); } bool SharedTermsDatabase::isKnown(TNode literal) const { @@ -300,15 +304,16 @@ bool SharedTermsDatabase::isKnown(TNode literal) const { } } -TrustNode SharedTermsDatabase::explain(TNode literal) const +theory::TrustNode SharedTermsDatabase::explain(TNode literal) const { - Assert(d_equalityEngine != nullptr); - bool polarity = literal.getKind() != kind::NOT; - TNode atom = polarity ? literal : literal[0]; - Assert(atom.getKind() == kind::EQUAL); - std::vector assumptions; - d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); - Node exp = mkAnd(assumptions); + if (d_pfee != nullptr) + { + // use the proof equality engine if it exists + return d_pfee->explain(literal); + } + // otherwise, explain without proofs + Node exp = d_equalityEngine->mkExplainLit(literal); + // no proof generator return TrustNode::mkTrustPropExp(literal, exp, nullptr); } diff --git a/src/theory/shared_terms_database.h b/src/theory/shared_terms_database.h index 558d6fc93..693e93228 100644 --- a/src/theory/shared_terms_database.h +++ b/src/theory/shared_terms_database.h @@ -21,9 +21,12 @@ #include "context/cdhashset.h" #include "expr/node.h" +#include "expr/proof_node_manager.h" #include "theory/ee_setup_info.h" #include "theory/theory_id.h" +#include "theory/trust_node.h" #include "theory/uf/equality_engine.h" +#include "theory/uf/proof_equality_engine.h" #include "util/statistics_registry.h" namespace CVC4 { @@ -31,17 +34,14 @@ namespace CVC4 { class TheoryEngine; class SharedTermsDatabase : public context::ContextNotifyObj { - -public: - + public: /** A container for a list of shared terms */ typedef std::vector shared_terms_list; /** The iterator to go through the shared terms list */ typedef shared_terms_list::const_iterator shared_terms_iterator; -private: - + private: /** Some statistics */ IntStat d_statSharedTerms; @@ -73,8 +73,7 @@ private: typedef context::CDHashSet RegisteredEqualitiesSet; RegisteredEqualitiesSet d_registeredEqualities; -private: - + private: /** This method removes all the un-necessary stuff from the maps */ void backtrack(); @@ -151,9 +150,18 @@ private: */ void checkForConflict(); -public: - - SharedTermsDatabase(TheoryEngine* theoryEngine, context::Context* context); + public: + /** + * @param theoryEngine The parent theory engine + * @param context The SAT context + * @param userContext The user context + * @param pnm The proof node manager to use, which is non-null if proofs + * are enabled. + */ + SharedTermsDatabase(TheoryEngine* theoryEngine, + context::Context* context, + context::UserContext* userContext, + ProofNodeManager* pnm); ~SharedTermsDatabase(); //-------------------------------------------- initialization @@ -258,9 +266,16 @@ public: * This method gets called on backtracks from the context manager. */ void contextNotifyPop() override { backtrack(); } - + /** The SAT search context. */ + context::Context* d_satContext; + /** The user level assertion context. */ + context::UserContext* d_userContext; /** Equality engine */ theory::eq::EqualityEngine* d_equalityEngine; + /** Proof equality engine */ + std::unique_ptr d_pfee; + /** The proof node manager */ + ProofNodeManager* d_pnm; }; } -- 2.30.2