From: Andrew Reynolds Date: Tue, 3 Aug 2021 16:50:51 +0000 (-0500) Subject: Refactor shared solver to use theory builtin inference manager (#6960) X-Git-Tag: cvc5-1.0.0~1418 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=2c2981d419bdf5a7bbf424f62266883724e85168;p=cvc5.git Refactor shared solver to use theory builtin inference manager (#6960) This ensures that e.g. COMBINATION_SPLIT shows up under theory::builtin::inferencesLemmas, and -t im. It also removes outdated interfaces from OutputChannel, and makes the feature TheoryEngine::ensureLemmaAtoms more modular, which was required for making these interfaces more consistent. It also ensures that TheoryBuiltin has an inference manager, which will simplify special casing in #6956. --- diff --git a/src/theory/builtin/theory_builtin.cpp b/src/theory/builtin/theory_builtin.cpp index 80c5dba3d..1db03d22b 100644 --- a/src/theory/builtin/theory_builtin.cpp +++ b/src/theory/builtin/theory_builtin.cpp @@ -31,8 +31,13 @@ TheoryBuiltin::TheoryBuiltin(context::Context* c, Valuation valuation, const LogicInfo& logicInfo, ProofNodeManager* pnm) - : Theory(THEORY_BUILTIN, c, u, out, valuation, logicInfo, pnm) + : Theory(THEORY_BUILTIN, c, u, out, valuation, logicInfo, pnm), + d_state(c, u, valuation), + d_im(*this, d_state, pnm, "theory::builtin::") { + // indicate we are using the default theory state and inference managers + d_theoryState = &d_state; + d_inferManager = &d_im; } TheoryRewriter* TheoryBuiltin::getTheoryRewriter() { return &d_rewriter; } diff --git a/src/theory/builtin/theory_builtin.h b/src/theory/builtin/theory_builtin.h index 27492cb53..72485f0ea 100644 --- a/src/theory/builtin/theory_builtin.h +++ b/src/theory/builtin/theory_builtin.h @@ -21,6 +21,8 @@ #include "theory/builtin/proof_checker.h" #include "theory/builtin/theory_builtin_rewriter.h" #include "theory/theory.h" +#include "theory/theory_inference_manager.h" +#include "theory/theory_state.h" namespace cvc5 { namespace theory { @@ -51,6 +53,10 @@ class TheoryBuiltin : public Theory TheoryBuiltinRewriter d_rewriter; /** Proof rule checker */ BuiltinProofRuleChecker d_checker; + /** A (default) theory state object */ + TheoryState d_state; + /** A (default) inference manager */ + TheoryInferenceManager d_im; }; /* class TheoryBuiltin */ } // namespace builtin diff --git a/src/theory/ee_manager_central.cpp b/src/theory/ee_manager_central.cpp index f98bce970..23946512e 100644 --- a/src/theory/ee_manager_central.cpp +++ b/src/theory/ee_manager_central.cpp @@ -298,7 +298,8 @@ void EqEngineManagerCentral::eqNotifyConstantTermMerge(TNode t1, TNode t2) Node conflict = d_centralEqualityEngine.mkExplainLit(lit); Trace("eem-central") << "...explained conflict of " << lit << " ... " << conflict << std::endl; - d_sharedSolver.sendConflict(TrustNode::mkTrustConflict(conflict)); + d_sharedSolver.sendConflict(TrustNode::mkTrustConflict(conflict), + InferenceId::EQ_CONSTANT_MERGE); return; } diff --git a/src/theory/engine_output_channel.cpp b/src/theory/engine_output_channel.cpp index b1f35821f..ee07dcd57 100644 --- a/src/theory/engine_output_channel.cpp +++ b/src/theory/engine_output_channel.cpp @@ -60,31 +60,7 @@ void EngineOutputChannel::safePoint(Resource r) void EngineOutputChannel::lemma(TNode lemma, LemmaProperty p) { - Trace("theory::lemma") << "EngineOutputChannel<" << d_theory << ">::lemma(" - << lemma << ")" - << ", properties = " << p << std::endl; - ++d_statistics.lemmas; - d_engine->d_outputChannelUsed = true; - - TrustNode tlem = TrustNode::mkTrustLemma(lemma); - d_engine->lemma(tlem, - p, - isLemmaPropertySendAtoms(p) ? d_theory : theory::THEORY_LAST, - d_theory); -} - -void EngineOutputChannel::splitLemma(TNode lemma, bool removable) -{ - Trace("theory::lemma") << "EngineOutputChannel<" << d_theory << ">::lemma(" - << lemma << ")" << std::endl; - ++d_statistics.lemmas; - d_engine->d_outputChannelUsed = true; - - Trace("pf::explain") << "EngineOutputChannel::splitLemma( " << lemma << " )" - << std::endl; - TrustNode tlem = TrustNode::mkTrustLemma(lemma); - LemmaProperty p = removable ? LemmaProperty::REMOVABLE : LemmaProperty::NONE; - d_engine->lemma(tlem, p, d_theory); + trustedLemma(TrustNode::mkTrustLemma(lemma), p); } bool EngineOutputChannel::propagate(TNode literal) @@ -172,10 +148,13 @@ void EngineOutputChannel::trustedLemma(TrustNode plem, LemmaProperty p) } ++d_statistics.lemmas; d_engine->d_outputChannelUsed = true; + if (isLemmaPropertySendAtoms(p)) + { + d_engine->ensureLemmaAtoms(plem.getNode(), d_theory); + } // now, call the normal interface for lemma d_engine->lemma(plem, p, - isLemmaPropertySendAtoms(p) ? d_theory : theory::THEORY_LAST, d_theory); } diff --git a/src/theory/engine_output_channel.h b/src/theory/engine_output_channel.h index dcf8fba55..cc1d8ece7 100644 --- a/src/theory/engine_output_channel.h +++ b/src/theory/engine_output_channel.h @@ -53,8 +53,6 @@ class EngineOutputChannel : public theory::OutputChannel void lemma(TNode lemma, LemmaProperty p = LemmaProperty::NONE) override; - void splitLemma(TNode lemma, bool removable = false) override; - void demandRestart() override; void requirePhase(TNode n, bool phase) override; diff --git a/src/theory/output_channel.cpp b/src/theory/output_channel.cpp index 2d85f4de6..5ab91b6e0 100644 --- a/src/theory/output_channel.cpp +++ b/src/theory/output_channel.cpp @@ -77,8 +77,6 @@ std::ostream& operator<<(std::ostream& out, LemmaProperty p) return out; } -void OutputChannel::split(TNode n) { splitLemma(n.orNode(n.notNode())); } - void OutputChannel::trustedConflict(TrustNode pconf) { Unreachable() << "OutputChannel::trustedConflict: no implementation" diff --git a/src/theory/output_channel.h b/src/theory/output_channel.h index b681dad17..80115d438 100644 --- a/src/theory/output_channel.h +++ b/src/theory/output_channel.h @@ -121,16 +121,6 @@ class OutputChannel { */ virtual void lemma(TNode n, LemmaProperty p = LemmaProperty::NONE) = 0; - /** - * Request a split on a new theory atom. This is equivalent to - * calling lemma({OR n (NOT n)}). - * - * @param n - a theory atom; must be of Boolean type - */ - void split(TNode n); - - virtual void splitLemma(TNode n, bool removable = false) = 0; - /** * If a decision is made on n, it must be in the phase specified. * Note that this is enforced *globally*, i.e., it is completely diff --git a/src/theory/shared_solver.cpp b/src/theory/shared_solver.cpp index b020a3938..95558ead1 100644 --- a/src/theory/shared_solver.cpp +++ b/src/theory/shared_solver.cpp @@ -19,6 +19,7 @@ #include "theory/ee_setup_info.h" #include "theory/logic_info.h" #include "theory/theory_engine.h" +#include "theory/theory_inference_manager.h" namespace cvc5 { namespace theory { @@ -35,7 +36,7 @@ SharedSolver::SharedSolver(TheoryEngine& te, ProofNodeManager* pnm) d_sharedTerms(&d_te, d_te.getSatContext(), d_te.getUserContext(), pnm), d_preRegistrationVisitor(&te, d_te.getSatContext()), d_sharedTermsVisitor(&te, d_sharedTerms, d_te.getSatContext()), - d_out(te.theoryOf(THEORY_BUILTIN)->getOutputChannel()) + d_im(te.theoryOf(THEORY_BUILTIN)->getInferenceManager()) { } @@ -113,9 +114,9 @@ bool SharedSolver::propagateLit(TNode predicate, bool value) { if (value) { - return d_out.propagate(predicate); + return d_im->propagateLit(predicate); } - return d_out.propagate(predicate.notNode()); + return d_im->propagateLit(predicate.notNode()); } bool SharedSolver::propagateSharedEquality(theory::TheoryId theory, @@ -141,11 +142,18 @@ bool SharedSolver::isShared(TNode t) const { return d_sharedTerms.isShared(t); } void SharedSolver::sendLemma(TrustNode trn, TheoryId atomsTo, InferenceId id) { - Trace("im") << "(lemma " << id << " " << trn.getProven() << ")" << std::endl; - d_te.lemma(trn, LemmaProperty::NONE, atomsTo); + // Do we need to check atoms + if (atomsTo != theory::THEORY_LAST) + { + d_te.ensureLemmaAtoms(trn.getNode(), atomsTo); + } + d_im->trustedLemma(trn, id); } -void SharedSolver::sendConflict(TrustNode trn) { d_out.trustedConflict(trn); } +void SharedSolver::sendConflict(TrustNode trn, InferenceId id) +{ + d_im->trustedConflict(trn, id); +} } // namespace theory } // namespace cvc5 diff --git a/src/theory/shared_solver.h b/src/theory/shared_solver.h index e2cda0fbc..a7f9ceff5 100644 --- a/src/theory/shared_solver.h +++ b/src/theory/shared_solver.h @@ -33,7 +33,7 @@ class TheoryEngine; namespace theory { struct EeSetupInfo; -class OutputChannel; +class TheoryInferenceManager; /** * A base class for shared solver. The shared solver is the component of theory @@ -124,7 +124,7 @@ class SharedSolver /** Send lemma to the theory engine, atomsTo is the theory to send atoms to */ void sendLemma(TrustNode trn, TheoryId atomsTo, InferenceId id); /** Send conflict to the theory engine */ - void sendConflict(TrustNode trn); + void sendConflict(TrustNode trn, InferenceId id); protected: /** Solver-specific pre-register shared */ @@ -139,8 +139,8 @@ class SharedSolver PreRegisterVisitor d_preRegistrationVisitor; /** Visitor for collecting shared terms */ SharedTermsVisitor d_sharedTermsVisitor; - /** Output channel of theory builtin */ - OutputChannel& d_out; + /** Theory inference manager of theory builtin */ + TheoryInferenceManager* d_im; }; } // namespace theory diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index 63fd6d9b7..fb93403b9 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -1246,6 +1246,16 @@ struct AtomsCollect { } }; +void TheoryEngine::ensureLemmaAtoms(TNode n, theory::TheoryId atomsTo) +{ + Assert(atomsTo != THEORY_LAST); + Debug("theory::atoms") << "TheoryEngine::ensureLemmaAtoms(" << n << ", " + << atomsTo << ")" << endl; + AtomsCollect collectAtoms; + NodeVisitor::run(collectAtoms, n); + ensureLemmaAtoms(collectAtoms.getAtoms(), atomsTo); +} + void TheoryEngine::ensureLemmaAtoms(const std::vector& atoms, theory::TheoryId atomsTo) { for (unsigned i = 0; i < atoms.size(); ++ i) { @@ -1314,7 +1324,6 @@ void TheoryEngine::ensureLemmaAtoms(const std::vector& atoms, theory::The void TheoryEngine::lemma(TrustNode tlemma, theory::LemmaProperty p, - theory::TheoryId atomsTo, theory::TheoryId from) { // For resource-limiting (also does a time check). @@ -1346,15 +1355,6 @@ void TheoryEngine::lemma(TrustNode tlemma, tlemma.debugCheckClosed("te-proof-debug", "TheoryEngine::lemma_initial"); } - // Do we need to check atoms - if (atomsTo != theory::THEORY_LAST) { - Debug("theory::atoms") << "TheoryEngine::lemma(" << node << ", " << atomsTo - << ")" << endl; - AtomsCollect collectAtoms; - NodeVisitor::run(collectAtoms, node); - ensureLemmaAtoms(collectAtoms.getAtoms(), atomsTo); - } - if(Dump.isOn("t-lemmas")) { // we dump the negation of the lemma, to show validity of the lemma Node n = lemma.negate(); @@ -1504,7 +1504,7 @@ void TheoryEngine::conflict(TrustNode tconflict, TheoryId theoryId) // When only one theory, the conflict should need no processing Assert(properConflict(conflict)); // pass the trust node that was sent from the theory - lemma(tconflict, LemmaProperty::REMOVABLE, THEORY_LAST, theoryId); + lemma(tconflict, LemmaProperty::REMOVABLE, theoryId); } } diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index cfcdcce13..5d16f04ba 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -277,11 +277,13 @@ class TheoryEngine { */ void lemma(TrustNode node, theory::LemmaProperty p, - theory::TheoryId atomsTo = theory::THEORY_LAST, theory::TheoryId from = theory::THEORY_LAST); - /** Enusre that the given atoms are send to the given theory */ - void ensureLemmaAtoms(const std::vector& atoms, theory::TheoryId theory); + /** Ensure atoms from the given node are sent to the given theory */ + void ensureLemmaAtoms(TNode n, theory::TheoryId atomsTo); + /** Ensure that the given atoms are sent to the given theory */ + void ensureLemmaAtoms(const std::vector& atoms, + theory::TheoryId atomsTo); /** sort inference module */ std::unique_ptr d_sortInfer; diff --git a/test/unit/test_smt.h b/test/unit/test_smt.h index 672693366..4226f8095 100644 --- a/test/unit/test_smt.h +++ b/test/unit/test_smt.h @@ -138,8 +138,6 @@ class DummyOutputChannel : public cvc5::theory::OutputChannel void setIncomplete(theory::IncompleteId id) override {} void handleUserAttribute(const char* attr, theory::Theory* t) override {} - void splitLemma(TNode n, bool removable = false) override { push(LEMMA, n); } - void clear() { d_callHistory.clear(); } Node getIthNode(int i) const diff --git a/test/unit/theory/theory_white.cpp b/test/unit/theory/theory_white.cpp index 5a469ed97..94021a9e3 100644 --- a/test/unit/theory/theory_white.cpp +++ b/test/unit/theory/theory_white.cpp @@ -100,7 +100,7 @@ TEST_F(TestTheoryWhite, outputChannel) { Node n = d_atom0.orNode(d_atom1); d_outputChannel.lemma(n); - d_outputChannel.split(d_atom0); + d_outputChannel.lemma(d_atom0.orNode(d_atom0.notNode())); Node s = d_atom0.orNode(d_atom0.notNode()); ASSERT_EQ(d_outputChannel.d_callHistory.size(), 2u); ASSERT_EQ(d_outputChannel.d_callHistory[0], std::make_pair(LEMMA, n));