From: Andrew Reynolds Date: Wed, 19 Aug 2020 18:36:59 +0000 (-0500) Subject: Make sets and strings solver states inherit from TheoryState (#4918) X-Git-Tag: cvc5-1.0.0~2977 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=31717bf7c014bf1971cabcc9b871de5818278126;p=cvc5.git Make sets and strings solver states inherit from TheoryState (#4918) This is towards the new standard for theory solvers. This PR makes the custom states of sets and strings inherit from the standard base class TheoryState. It also makes a minor change to InferenceManager/SolverState to make sets more in line with the plan for a standard base class InferenceManager. Followup PRs will establish the official TheoryState classes for all other theories (which in most cases will be an instance of the base class). --- diff --git a/src/theory/sets/inference_manager.cpp b/src/theory/sets/inference_manager.cpp index f99dad91e..8f25f6511 100644 --- a/src/theory/sets/inference_manager.cpp +++ b/src/theory/sets/inference_manager.cpp @@ -72,7 +72,7 @@ bool InferenceManager::assertFactRec(Node fact, Node exp, int inferType) if (fact == d_false) { Trace("sets-lemma") << "Conflict : " << exp << std::endl; - d_state.setConflict(exp); + conflict(exp); return true; } return false; @@ -233,6 +233,12 @@ bool InferenceManager::hasProcessed() const bool InferenceManager::hasSentLemma() const { return d_sentLemma; } bool InferenceManager::hasAddedFact() const { return d_addedFact; } +void InferenceManager::conflict(Node conf) +{ + d_parent.getOutputChannel()->conflict(conf); + d_state.notifyInConflict(); +} + } // namespace sets } // namespace theory } // namespace CVC4 diff --git a/src/theory/sets/inference_manager.h b/src/theory/sets/inference_manager.h index ba6be9905..3278b848e 100644 --- a/src/theory/sets/inference_manager.h +++ b/src/theory/sets/inference_manager.h @@ -109,6 +109,12 @@ class InferenceManager /** Have we sent lem as a lemma in the current user context? */ bool hasLemmaCached(Node lem) const; + /** + * Send conflict. + * @param conf The conflict node to be sent on the output channel + */ + void conflict(Node conf); + private: /** constants */ Node d_true; diff --git a/src/theory/sets/solver_state.cpp b/src/theory/sets/solver_state.cpp index f3371cf61..5e5e9d22a 100644 --- a/src/theory/sets/solver_state.cpp +++ b/src/theory/sets/solver_state.cpp @@ -27,19 +27,14 @@ namespace sets { SolverState::SolverState(TheorySetsPrivate& p, context::Context* c, - context::UserContext* u) - : d_conflict(c), d_parent(p), d_ee(nullptr), d_proxy(u), d_proxy_to_term(u) + context::UserContext* u, + Valuation val) + : TheoryState(c, u, val), d_parent(p), d_proxy(u), d_proxy_to_term(u) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); } -void SolverState::finishInit(eq::EqualityEngine* ee) -{ - Assert(ee != nullptr); - d_ee = ee; -} - void SolverState::reset() { d_set_eqc.clear(); @@ -169,52 +164,6 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n) } } -Node SolverState::getRepresentative(Node a) const -{ - if (d_ee->hasTerm(a)) - { - return d_ee->getRepresentative(a); - } - return a; -} - -bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); } - -bool SolverState::areEqual(Node a, Node b) const -{ - if (a == b) - { - return true; - } - if (d_ee->hasTerm(a) && d_ee->hasTerm(b)) - { - return d_ee->areEqual(a, b); - } - return false; -} - -bool SolverState::areDisequal(Node a, Node b) const -{ - if (a == b) - { - return false; - } - else if (d_ee->hasTerm(a) && d_ee->hasTerm(b)) - { - return d_ee->areDisequal(a, b, false); - } - return a.isConst() && b.isConst(); -} - -eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; } - -void SolverState::setConflict() { d_conflict = true; } -void SolverState::setConflict(Node conf) -{ - d_parent.getOutputChannel()->conflict(conf); - d_conflict = true; -} - void SolverState::addEqualityToExp(Node a, Node b, std::vector& exp) const { if (a != b) diff --git a/src/theory/sets/solver_state.h b/src/theory/sets/solver_state.h index dce90c2d3..3a40befbd 100644 --- a/src/theory/sets/solver_state.h +++ b/src/theory/sets/solver_state.h @@ -22,6 +22,7 @@ #include "context/cdhashset.h" #include "theory/sets/skolem_cache.h" +#include "theory/theory_state.h" #include "theory/uf/equality_engine.h" namespace CVC4 { @@ -42,19 +43,15 @@ class TheorySetsPrivate; * to initialize the information in this class regarding full effort checks. * Other query calls are then valid for the remainder of the full effort check. */ -class SolverState +class SolverState : public TheoryState { typedef context::CDHashMap NodeMap; public: SolverState(TheorySetsPrivate& p, context::Context* c, - context::UserContext* u); - /** - * Finish initialize, there ee is a pointer to the official equality engine - * of theory of strings. - */ - void finishInit(eq::EqualityEngine* ee); + context::UserContext* u, + Valuation val); //-------------------------------- initialize per check /** reset, clears the data structures maintained by this class. */ void reset(); @@ -63,28 +60,6 @@ class SolverState /** register term n of type tnn in the equivalence class of r */ void registerTerm(Node r, TypeNode tnn, Node n); //-------------------------------- end initialize per check - /** Are we currently in conflict? */ - bool isInConflict() const { return d_conflict; } - /** - * Indicate that we are in conflict, without a conflict clause. This is - * called, for instance, when we have propagated a conflicting literal. - */ - void setConflict(); - /** Set conf is a conflict node to be sent on the output channel. */ - void setConflict(Node conf); - /** - * Get the representative of a in the equality engine of this class, or a - * itself if it is not registered as a term. - */ - Node getRepresentative(Node a) const; - /** Is a registered as a term in the equality engine of this class? */ - bool hasTerm(Node a) const; - /** Is a=b according to equality reasoning in the current context? */ - bool areEqual(Node a, Node b) const; - /** Is a!=b according to equality reasoning in the current context? */ - bool areDisequal(Node a, Node b) const; - /** get equality engine */ - eq::EqualityEngine* getEqualityEngine() const; /** add equality to explanation * * This adds a = b to exp if a and b are syntactically disequal. The equality @@ -229,12 +204,8 @@ class SolverState /** the empty vector and map */ std::vector d_emptyVec; std::map d_emptyMap; - /** Whether or not we are in conflict. This flag is SAT context dependent. */ - context::CDO d_conflict; /** Reference to the parent theory of sets */ TheorySetsPrivate& d_parent; - /** Pointer to the official equality engine of theory of sets */ - eq::EqualityEngine* d_ee; /** The list of all equivalence classes of type set in the current context */ std::vector d_set_eqc; /** Maps types to the equivalence class containing empty set of that type */ diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index fd9af488f..fc544f46f 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -34,13 +34,11 @@ TheorySets::TheorySets(context::Context* c, const LogicInfo& logicInfo, ProofNodeManager* pnm) : Theory(THEORY_SETS, c, u, out, valuation, logicInfo, pnm), - d_internal(new TheorySetsPrivate(*this, c, u)), + d_internal(new TheorySetsPrivate(*this, c, u, valuation)), d_notify(*d_internal.get()) { - // Do not move me to the header. - // The constructor + destructor are not in the header as d_internal is a - // unique_ptr and TheorySetsPrivate is an opaque type in - // the header (Pimpl). See https://herbsutter.com/gotw/_100/ . + // use the state object as the official theory state + d_theoryState = d_internal->getSolverState(); } TheorySets::~TheorySets() diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index bb9423570..879862d15 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -36,14 +36,15 @@ namespace sets { TheorySetsPrivate::TheorySetsPrivate(TheorySets& external, context::Context* c, - context::UserContext* u) + context::UserContext* u, + Valuation valuation) : d_members(c), d_deq(c), d_termProcessed(u), d_keep(c), d_full_check_incomplete(false), d_external(external), - d_state(*this, c, u), + d_state(*this, c, u, valuation), d_im(*this, d_state, c, u), d_rels(new TheorySetsRels(d_state, d_im, u)), d_cardSolver(new CardinalityExtension(d_state, d_im, c, u)), @@ -67,7 +68,6 @@ void TheorySetsPrivate::finishInit() { d_equalityEngine = d_external.getEqualityEngine(); Assert(d_equalityEngine != nullptr); - d_state.finishInit(d_equalityEngine); } void TheorySetsPrivate::eqNotifyNewClass(TNode t) @@ -178,7 +178,7 @@ void TheorySetsPrivate::eqNotifyMerge(TNode t1, TNode t2) // conflict Trace("sets-prop") << "Propagate eq-mem conflict : " << exp << std::endl; - d_state.setConflict(exp); + d_im.conflict(exp); return; } } @@ -316,7 +316,7 @@ bool TheorySetsPrivate::assertFact(Node fact, Node exp) { Trace("sets-prop") << "Propagate mem-eq conflict : " << pexp << std::endl; - d_state.setConflict(pexp); + d_im.conflict(pexp); } } } @@ -1410,7 +1410,7 @@ bool TheorySetsPrivate::propagate(TNode literal) bool ok = d_external.d_out->propagate(literal); if (!ok) { - d_state.setConflict(); + d_state.notifyInConflict(); } return ok; @@ -1426,7 +1426,7 @@ Valuation& TheorySetsPrivate::getValuation() { return d_external.d_valuation; } void TheorySetsPrivate::conflict(TNode a, TNode b) { Node conf = explain(a.eqNode(b)); - d_state.setConflict(conf); + d_im.conflict(conf); Debug("sets") << "[sets] conflict: " << a << " iff " << b << ", explanation " << conf << std::endl; Trace("sets-lemma") << "Equality Conflict : " << conf << std::endl; diff --git a/src/theory/sets/theory_sets_private.h b/src/theory/sets/theory_sets_private.h index 27ea6a9b8..9a786598c 100644 --- a/src/theory/sets/theory_sets_private.h +++ b/src/theory/sets/theory_sets_private.h @@ -156,12 +156,16 @@ class TheorySetsPrivate { */ TheorySetsPrivate(TheorySets& external, context::Context* c, - context::UserContext* u); + context::UserContext* u, + Valuation valuation); ~TheorySetsPrivate(); TheoryRewriter* getTheoryRewriter() { return &d_rewriter; } + /** Get the solver state */ + SolverState* getSolverState() { return &d_state; } + /** * Finish initialize, called after the equality engine of theory sets has * been determined. diff --git a/src/theory/strings/inference_manager.cpp b/src/theory/strings/inference_manager.cpp index 88cf6d958..a8ebd921a 100644 --- a/src/theory/strings/inference_manager.cpp +++ b/src/theory/strings/inference_manager.cpp @@ -171,7 +171,7 @@ void InferenceManager::sendInference(const InferInfo& ii, bool asLemma) // only keep stats if we process it here d_statistics.d_inferences << ii.d_id; d_out.conflict(conf); - d_state.setConflict(); + d_state.notifyInConflict(); return; } Trace("strings-infer-debug") << "...as lemma" << std::endl; @@ -435,7 +435,7 @@ void InferenceManager::assertPendingFact(Node atom, bool polarity, Node exp) Trace("strings-pending") << "Process pending conflict " << pc << std::endl; Node conflictNode = mkExplain(a); - d_state.setConflict(); + d_state.notifyInConflict(); Trace("strings-conflict") << "CONFLICT: Eager prefix : " << conflictNode << std::endl; ++(d_statistics.d_conflictsEagerPrefix); diff --git a/src/theory/strings/solver_state.cpp b/src/theory/strings/solver_state.cpp index 8634478fd..fd0f0174f 100644 --- a/src/theory/strings/solver_state.cpp +++ b/src/theory/strings/solver_state.cpp @@ -28,13 +28,7 @@ namespace strings { SolverState::SolverState(context::Context* c, context::UserContext* u, Valuation& v) - : d_context(c), - d_ucontext(u), - d_ee(nullptr), - d_eeDisequalities(c), - d_valuation(v), - d_conflict(c, false), - d_pendingConflict(c) + : TheoryState(c, u, v), d_eeDisequalities(c), d_pendingConflict(c) { d_zero = NodeManager::currentNM()->mkConst(Rational(0)); } @@ -47,59 +41,6 @@ SolverState::~SolverState() } } -void SolverState::finishInit(eq::EqualityEngine* ee) -{ - Assert(ee != nullptr); - d_ee = ee; -} - -context::Context* SolverState::getSatContext() const { return d_context; } -context::UserContext* SolverState::getUserContext() const { return d_ucontext; } - -Node SolverState::getRepresentative(Node t) const -{ - if (d_ee->hasTerm(t)) - { - return d_ee->getRepresentative(t); - } - return t; -} - -bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); } - -bool SolverState::areEqual(Node a, Node b) const -{ - if (a == b) - { - return true; - } - else if (hasTerm(a) && hasTerm(b)) - { - return d_ee->areEqual(a, b); - } - return false; -} - -bool SolverState::areDisequal(Node a, Node b) const -{ - if (a == b) - { - return false; - } - else if (hasTerm(a) && hasTerm(b)) - { - Node ar = d_ee->getRepresentative(a); - Node br = d_ee->getRepresentative(b); - return (ar != br && ar.isConst() && br.isConst()) - || d_ee->areDisequal(ar, br, false); - } - Node ar = getRepresentative(a); - Node br = getRepresentative(b); - return ar != br && ar.isConst() && br.isConst(); -} - -eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; } - const context::CDList& SolverState::getDisequalityList() const { return d_eeDisequalities; @@ -199,7 +140,7 @@ EqcInfo* SolverState::getOrMakeEqcInfo(Node eqc, bool doMake) return nullptr; } -TheoryModel* SolverState::getModel() const { return d_valuation.getModel(); } +TheoryModel* SolverState::getModel() { return d_valuation.getModel(); } void SolverState::addEndpointsToEqcInfo(Node t, Node concat, Node eqc) { @@ -286,9 +227,6 @@ bool SolverState::isEqualEmptyWord(Node s, Node& emps) return false; } -void SolverState::setConflict() { d_conflict = true; } -bool SolverState::isInConflict() const { return d_conflict; } - void SolverState::setPendingConflictWhen(Node conf) { if (!conf.isNull() && d_pendingConflict.get().isNull()) diff --git a/src/theory/strings/solver_state.h b/src/theory/strings/solver_state.h index 0322abdb7..fc27b847b 100644 --- a/src/theory/strings/solver_state.h +++ b/src/theory/strings/solver_state.h @@ -39,7 +39,7 @@ namespace strings { * (2) Whether the set of assertions is in conflict. * (3) Equivalence class information as in the class above. */ -class SolverState +class SolverState : public TheoryState { typedef context::CDList NodeList; @@ -48,35 +48,7 @@ class SolverState context::UserContext* u, Valuation& v); ~SolverState(); - /** - * Finish initialize, ee is a pointer to the official equality engine - * of theory of strings. - */ - void finishInit(eq::EqualityEngine* ee); - /** Get the SAT context */ - context::Context* getSatContext() const; - /** Get the user context */ - context::UserContext* getUserContext() const; //-------------------------------------- equality information - /** - * Get the representative of t in the equality engine of this class, or t - * itself if it is not registered as a term. - */ - Node getRepresentative(Node t) const; - /** Is t registered as a term in the equality engine of this class? */ - bool hasTerm(Node a) const; - /** - * Are a and b equal according to the equality engine of this class? Also - * returns true if a and b are identical. - */ - bool areEqual(Node a, Node b) const; - /** - * Are a and b disequal according to the equality engine of this class? Also - * returns true if the representative of a and b are distinct constants. - */ - bool areDisequal(Node a, Node b) const; - /** get equality engine */ - eq::EqualityEngine* getEqualityEngine() const; /** * Get the list of disequalities that are currently asserted to the equality * engine. @@ -92,14 +64,6 @@ class SolverState void eqNotifyDisequal(TNode t1, TNode t2, TNode reason); //-------------------------------------- end notifications for equalities //------------------------------------------ conflicts - /** - * Set that the current state of the solver is in conflict. This should be - * called immediately after a call to conflict(...) on the output channel of - * the theory of strings. - */ - void setConflict(); - /** Are we currently in conflict? */ - bool isInConflict() const; /** set pending conflict * * If conf is non-null, this is called when conf is a conjunction of literals @@ -153,7 +117,7 @@ class SolverState */ EqcInfo* getOrMakeEqcInfo(Node eqc, bool doMake = true); /** Get pointer to the model object of the Valuation object */ - TheoryModel* getModel() const; + TheoryModel* getModel(); /** add endpoints to eqc info * @@ -186,21 +150,11 @@ class SolverState private: /** Common constants */ Node d_zero; - /** Pointer to the SAT context object used by the theory of strings. */ - context::Context* d_context; - /** Pointer to the user context object used by the theory of strings. */ - context::UserContext* d_ucontext; - /** Pointer to equality engine of the theory of strings. */ - eq::EqualityEngine* d_ee; /** * The (SAT-context-dependent) list of disequalities that have been asserted * to the equality engine above. */ NodeList d_eeDisequalities; - /** Reference to the valuation of the theory of strings */ - Valuation& d_valuation; - /** Are we in conflict? */ - context::CDO d_conflict; /** The pending conflict if one exists */ context::CDO d_pendingConflict; /** Map from representatives to their equivalence class information */ diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index c78e8dc2a..6d81c742a 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -81,6 +81,8 @@ TheoryStrings::TheoryStrings(context::Context* c, // add checkers d_sProofChecker.registerTo(pc); } + // use the state object as the official theory state + d_theoryState = &d_state; } TheoryStrings::~TheoryStrings() { @@ -126,8 +128,6 @@ void TheoryStrings::finishInit() d_equalityEngine->addFunctionKind(kind::STRING_TOLOWER, eagerEval); d_equalityEngine->addFunctionKind(kind::STRING_TOUPPER, eagerEval); d_equalityEngine->addFunctionKind(kind::STRING_REV, eagerEval); - - d_state.finishInit(d_equalityEngine); } std::string TheoryStrings::identify() const @@ -196,7 +196,7 @@ bool TheoryStrings::propagate(TNode literal) { // Propagate out bool ok = d_out->propagate(literal); if (!ok) { - d_state.setConflict(); + d_state.notifyInConflict(); } return ok; } @@ -762,7 +762,7 @@ void TheoryStrings::conflict(TNode a, TNode b){ if (!d_state.isInConflict()) { Debug("strings-conflict") << "Making conflict..." << std::endl; - d_state.setConflict(); + d_state.notifyInConflict(); TrustNode conflictNode = explain(a.eqNode(b)); Trace("strings-conflict") << "CONFLICT: Eq engine conflict : " << conflictNode.getNode() diff --git a/src/theory/theory.cpp b/src/theory/theory.cpp index 9669d97e0..7220e2e1c 100644 --- a/src/theory/theory.cpp +++ b/src/theory/theory.cpp @@ -104,6 +104,10 @@ void Theory::setEqualityEngine(eq::EqualityEngine* ee) { // set the equality engine pointer d_equalityEngine = ee; + if (d_theoryState != nullptr) + { + d_theoryState->setEqualityEngine(ee); + } } void Theory::setQuantifiersEngine(QuantifiersEngine* qe) { @@ -127,7 +131,7 @@ void Theory::finishInitStandalone() d_allocEqualityEngine.reset(new eq::EqualityEngine( *esi.d_notify, d_satContext, esi.d_name, esi.d_constantsAreTriggers)); // use it as the official equality engine - d_equalityEngine = d_allocEqualityEngine.get(); + setEqualityEngine(d_allocEqualityEngine.get()); } finishInit(); } diff --git a/src/theory/theory_state.cpp b/src/theory/theory_state.cpp index bc8e53245..8afe3be96 100644 --- a/src/theory/theory_state.cpp +++ b/src/theory/theory_state.cpp @@ -30,7 +30,7 @@ TheoryState::TheoryState(context::Context* c, { } -void TheoryState::finishInit(eq::EqualityEngine* ee) { d_ee = ee; } +void TheoryState::setEqualityEngine(eq::EqualityEngine* ee) { d_ee = ee; } context::Context* TheoryState::getSatContext() const { return d_context; } diff --git a/src/theory/theory_state.h b/src/theory/theory_state.h index 71197dddc..de6e6d477 100644 --- a/src/theory/theory_state.h +++ b/src/theory/theory_state.h @@ -34,10 +34,10 @@ class TheoryState TheoryState(context::Context* c, context::UserContext* u, Valuation val); virtual ~TheoryState() {} /** - * Finish initialize, ee is a pointer to the official equality engine + * Set equality engine, where ee is a pointer to the official equality engine * of theory. */ - virtual void finishInit(eq::EqualityEngine* ee); + void setEqualityEngine(eq::EqualityEngine* ee); /** Get the SAT context */ context::Context* getSatContext() const; /** Get the user context */