From: Andrew Reynolds Date: Wed, 12 Aug 2020 03:53:58 +0000 (-0500) Subject: Prepare theory of sets for dynamic allocation of equality engine (#4868) X-Git-Tag: cvc5-1.0.0~3023 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=b5b2858;p=cvc5.git Prepare theory of sets for dynamic allocation of equality engine (#4868) In forthcoming PRs, Theory objects will be assigned equality engine objects dynamically. This PR prepares the theory of sets for this update, which involves refactoring of its internal members. --- diff --git a/src/theory/sets/cardinality_extension.cpp b/src/theory/sets/cardinality_extension.cpp index dc3c77353..1c12c71e4 100644 --- a/src/theory/sets/cardinality_extension.cpp +++ b/src/theory/sets/cardinality_extension.cpp @@ -31,19 +31,15 @@ namespace sets { CardinalityExtension::CardinalityExtension(SolverState& s, InferenceManager& im, - eq::EqualityEngine& e, context::Context* c, context::UserContext* u) : d_state(s), d_im(im), - d_ee(e), d_card_processed(u), d_finite_type_constants_processed(false) { d_true = NodeManager::currentNM()->mkConst(true); d_zero = NodeManager::currentNM()->mkConst(Rational(0)); - // we do congruence over cardinality - d_ee.addFunctionKind(CARD); } void CardinalityExtension::reset() @@ -60,7 +56,7 @@ void CardinalityExtension::registerTerm(Node n) Assert(n.getKind() == CARD); TypeNode tnc = n[0].getType().getSetElementType(); d_t_card_enabled[tnc] = true; - Node r = d_ee.getRepresentative(n[0]); + Node r = d_state.getRepresentative(n[0]); if (d_eqc_to_card_term.find(r) == d_eqc_to_card_term.end()) { d_eqc_to_card_term[r] = n; @@ -144,7 +140,7 @@ void CardinalityExtension::checkCardinalityExtended(TypeNode& t) for (Node& representative : representatives) { // the universe set is a subset of itself - if (representative != d_ee.getRepresentative(univ)) + if (representative != d_state.getRepresentative(univ)) { // here we only add representatives with variables to avoid adding // infinite equivalent generated terms to the cardinality graph @@ -399,7 +395,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc, true_sib = 1; } Node u = Rewriter::rewrite(nm->mkNode(UNION, n[0], n[1])); - if (!d_ee.hasTerm(u)) + if (!d_state.hasTerm(u)) { u = Node::null(); } @@ -413,7 +409,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc, // parent equal siblings for (unsigned e = 0; e < true_sib; e++) { - if (d_ee.hasTerm(sib[e]) && !d_state.areEqual(n[e], sib[e])) + if (d_state.hasTerm(sib[e]) && !d_state.areEqual(n[e], sib[e])) { conc.push_back(n[e].eqNode(sib[e])); } @@ -518,7 +514,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc, for (unsigned k = 0, numcp = card_parents.size(); k < numcp; k++) { Node cpk = card_parents[k]; - Node eqcc = d_ee.getRepresentative(cpk); + Node eqcc = d_state.getRepresentative(cpk); Trace("sets-debug") << "Check card parent " << k << "/" << card_parents.size() << std::endl; @@ -833,7 +829,7 @@ void CardinalityExtension::checkNormalForm(Node eqc, Trace("sets-nf") << " Intro split : " << o0 << " against " << o1 << ", term is " << intro << std::endl; intro_sets.push_back(intro); - Assert(!d_ee.hasTerm(intro)); + Assert(!d_state.hasTerm(intro)); return; } } diff --git a/src/theory/sets/cardinality_extension.h b/src/theory/sets/cardinality_extension.h index 570530108..b71af8a43 100644 --- a/src/theory/sets/cardinality_extension.h +++ b/src/theory/sets/cardinality_extension.h @@ -69,7 +69,6 @@ class CardinalityExtension */ CardinalityExtension(SolverState& s, InferenceManager& im, - eq::EqualityEngine& e, context::Context* c, context::UserContext* u); @@ -164,8 +163,6 @@ class CardinalityExtension SolverState& d_state; /** Reference to the inference manager for the theory of sets */ InferenceManager& d_im; - /** Reference to the equality engine of theory of sets */ - eq::EqualityEngine& d_ee; /** register cardinality term * * This method add lemmas corresponding to the definition of diff --git a/src/theory/sets/inference_manager.cpp b/src/theory/sets/inference_manager.cpp index ac0a8205b..f99dad91e 100644 --- a/src/theory/sets/inference_manager.cpp +++ b/src/theory/sets/inference_manager.cpp @@ -27,12 +27,10 @@ namespace sets { InferenceManager::InferenceManager(TheorySetsPrivate& p, SolverState& s, - eq::EqualityEngine& e, context::Context* c, context::UserContext* u) : d_parent(p), d_state(s), - d_ee(e), d_sentLemma(false), d_addedFact(false), d_lemmas_produced(u), diff --git a/src/theory/sets/inference_manager.h b/src/theory/sets/inference_manager.h index 2a5a7703c..ba6be9905 100644 --- a/src/theory/sets/inference_manager.h +++ b/src/theory/sets/inference_manager.h @@ -40,7 +40,6 @@ class InferenceManager public: InferenceManager(TheorySetsPrivate& p, SolverState& s, - eq::EqualityEngine& e, context::Context* c, context::UserContext* u); /** reset @@ -118,8 +117,6 @@ class InferenceManager TheorySetsPrivate& d_parent; /** Reference to the state object for the theory of sets */ SolverState& d_state; - /** Reference to the equality engine of theory of sets */ - eq::EqualityEngine& d_ee; /** pending lemmas */ std::vector d_pendingLemmas; /** sent lemma diff --git a/src/theory/sets/solver_state.cpp b/src/theory/sets/solver_state.cpp index d3c23454e..f3371cf61 100644 --- a/src/theory/sets/solver_state.cpp +++ b/src/theory/sets/solver_state.cpp @@ -26,15 +26,20 @@ namespace theory { namespace sets { SolverState::SolverState(TheorySetsPrivate& p, - eq::EqualityEngine& e, context::Context* c, context::UserContext* u) - : d_conflict(c), d_parent(p), d_ee(e), d_proxy(u), d_proxy_to_term(u) + : d_conflict(c), d_parent(p), d_ee(nullptr), d_proxy(u), d_proxy_to_term(u) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); } +void SolverState::finishInit(eq::EqualityEngine* ee) +{ + Assert(ee != nullptr); + d_ee = ee; +} + void SolverState::reset() { d_set_eqc.clear(); @@ -69,8 +74,8 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n) { if (r.isConst()) { - Node s = d_ee.getRepresentative(n[1]); - Node x = d_ee.getRepresentative(n[0]); + Node s = d_ee->getRepresentative(n[1]); + Node x = d_ee->getRepresentative(n[0]); int pindex = r == d_true ? 0 : (r == d_false ? 1 : -1); if (pindex != -1) { @@ -99,7 +104,7 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n) { // singleton lemma getProxy(n); - Node re = d_ee.getRepresentative(n[0]); + Node re = d_ee->getRepresentative(n[0]); if (d_singleton_index.find(re) == d_singleton_index.end()) { d_singleton_index[re] = n; @@ -122,8 +127,8 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n) } else { - Node r1 = d_ee.getRepresentative(n[0]); - Node r2 = d_ee.getRepresentative(n[1]); + Node r1 = d_ee->getRepresentative(n[0]); + Node r2 = d_ee->getRepresentative(n[1]); std::map& binr1 = d_bop_index[nk][r1]; std::map::iterator itb = binr1.find(r2); if (itb == binr1.end()) @@ -164,15 +169,26 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n) } } +Node SolverState::getRepresentative(Node a) const +{ + if (d_ee->hasTerm(a)) + { + return d_ee->getRepresentative(a); + } + return a; +} + +bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); } + bool SolverState::areEqual(Node a, Node b) const { if (a == b) { return true; } - if (d_ee.hasTerm(a) && d_ee.hasTerm(b)) + if (d_ee->hasTerm(a) && d_ee->hasTerm(b)) { - return d_ee.areEqual(a, b); + return d_ee->areEqual(a, b); } return false; } @@ -183,13 +199,15 @@ bool SolverState::areDisequal(Node a, Node b) const { return false; } - else if (d_ee.hasTerm(a) && d_ee.hasTerm(b)) + else if (d_ee->hasTerm(a) && d_ee->hasTerm(b)) { - return d_ee.areDisequal(a, b, false); + return d_ee->areDisequal(a, b, false); } return a.isConst() && b.isConst(); } +eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; } + void SolverState::setConflict() { d_conflict = true; } void SolverState::setConflict(Node conf) { @@ -279,9 +297,9 @@ bool SolverState::isEntailed(Node n, bool polarity) const return true; } // check members cache - if (polarity && d_ee.hasTerm(n[1])) + if (polarity && d_ee->hasTerm(n[1])) { - Node r = d_ee.getRepresentative(n[1]); + Node r = d_ee->getRepresentative(n[1]); if (d_parent.isMember(n[0], r)) { return true; @@ -310,8 +328,8 @@ bool SolverState::isEntailed(Node n, bool polarity) const bool SolverState::isSetDisequalityEntailed(Node r1, Node r2) const { - Assert(d_ee.hasTerm(r1) && d_ee.getRepresentative(r1) == r1); - Assert(d_ee.hasTerm(r2) && d_ee.getRepresentative(r2) == r2); + Assert(d_ee->hasTerm(r1) && d_ee->getRepresentative(r1) == r1); + Assert(d_ee->hasTerm(r2) && d_ee->getRepresentative(r2) == r2); TypeNode tn = r1.getType(); Node re = getEmptySetEqClass(tn); for (unsigned e = 0; e < 2; e++) @@ -433,7 +451,7 @@ Node SolverState::getProxy(Node n) Node SolverState::getCongruent(Node n) const { - Assert(d_ee.hasTerm(n)); + Assert(d_ee->hasTerm(n)); std::map::const_iterator it = d_congruent.find(n); if (it == d_congruent.end()) { diff --git a/src/theory/sets/solver_state.h b/src/theory/sets/solver_state.h index 0b301dbb7..dce90c2d3 100644 --- a/src/theory/sets/solver_state.h +++ b/src/theory/sets/solver_state.h @@ -48,17 +48,21 @@ class SolverState public: SolverState(TheorySetsPrivate& p, - eq::EqualityEngine& e, context::Context* c, context::UserContext* u); - //-------------------------------- initialize + /** + * Finish initialize, there ee is a pointer to the official equality engine + * of theory of strings. + */ + void finishInit(eq::EqualityEngine* ee); + //-------------------------------- initialize per check /** reset, clears the data structures maintained by this class. */ void reset(); /** register equivalence class whose type is tn */ void registerEqc(TypeNode tn, Node r); /** register term n of type tnn in the equivalence class of r */ void registerTerm(Node r, TypeNode tnn, Node n); - //-------------------------------- end initialize + //-------------------------------- end initialize per check /** Are we currently in conflict? */ bool isInConflict() const { return d_conflict; } /** @@ -68,10 +72,19 @@ class SolverState void setConflict(); /** Set conf is a conflict node to be sent on the output channel. */ void setConflict(Node conf); + /** + * Get the representative of a in the equality engine of this class, or a + * itself if it is not registered as a term. + */ + Node getRepresentative(Node a) const; + /** Is a registered as a term in the equality engine of this class? */ + bool hasTerm(Node a) const; /** Is a=b according to equality reasoning in the current context? */ bool areEqual(Node a, Node b) const; /** Is a!=b according to equality reasoning in the current context? */ bool areDisequal(Node a, Node b) const; + /** get equality engine */ + eq::EqualityEngine* getEqualityEngine() const; /** add equality to explanation * * This adds a = b to exp if a and b are syntactically disequal. The equality @@ -220,8 +233,8 @@ class SolverState context::CDO d_conflict; /** Reference to the parent theory of sets */ TheorySetsPrivate& d_parent; - /** Reference to the equality engine of theory of sets */ - eq::EqualityEngine& d_ee; + /** Pointer to the official equality engine of theory of sets */ + eq::EqualityEngine* d_ee; /** The list of all equivalence classes of type set in the current context */ std::vector d_set_eqc; /** Maps types to the equivalence class containing empty set of that type */ diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index 9c680cc64..17caac4f7 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -34,7 +34,9 @@ TheorySets::TheorySets(context::Context* c, const LogicInfo& logicInfo, ProofNodeManager* pnm) : Theory(THEORY_SETS, c, u, out, valuation, logicInfo, pnm), - d_internal(new TheorySetsPrivate(*this, c, u)) + d_internal(new TheorySetsPrivate(*this, c, u)), + d_notify(*d_internal.get()), + d_equalityEngine(d_notify, c, "theory::sets::ee", true) { // Do not move me to the header. // The constructor + destructor are not in the header as d_internal is a @@ -57,6 +59,27 @@ void TheorySets::finishInit() d_valuation.setUnevaluatedKind(COMPREHENSION); // choice is used to eliminate witness d_valuation.setUnevaluatedKind(WITNESS); + + // functions we are doing congruence over + d_equalityEngine.addFunctionKind(kind::SINGLETON); + d_equalityEngine.addFunctionKind(kind::UNION); + d_equalityEngine.addFunctionKind(kind::INTERSECTION); + d_equalityEngine.addFunctionKind(kind::SETMINUS); + d_equalityEngine.addFunctionKind(kind::MEMBER); + d_equalityEngine.addFunctionKind(kind::SUBSET); + // relation operators + d_equalityEngine.addFunctionKind(PRODUCT); + d_equalityEngine.addFunctionKind(JOIN); + d_equalityEngine.addFunctionKind(TRANSPOSE); + d_equalityEngine.addFunctionKind(TCLOSURE); + d_equalityEngine.addFunctionKind(JOIN_IMAGE); + d_equalityEngine.addFunctionKind(IDEN); + d_equalityEngine.addFunctionKind(APPLY_CONSTRUCTOR); + // we do congruence over cardinality + d_equalityEngine.addFunctionKind(CARD); + + // finish initialization internally + d_internal->finishInit(); } void TheorySets::addSharedTerm(TNode n) { @@ -171,14 +194,96 @@ void TheorySets::propagate(Effort e) { d_internal->propagate(e); } -void TheorySets::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_internal->setMasterEqualityEngine(eq); -} - bool TheorySets::isEntailed( Node n, bool pol ) { return d_internal->isEntailed( n, pol ); } +eq::EqualityEngine* TheorySets::getEqualityEngine() +{ + return &d_equalityEngine; +} + +/**************************** eq::NotifyClass *****************************/ + +bool TheorySets::NotifyClass::eqNotifyTriggerEquality(TNode equality, + bool value) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyTriggerEquality: equality = " + << equality << " value = " << value << std::endl; + if (value) + { + return d_theory.propagate(equality); + } + else + { + // We use only literal triggers so taking not is safe + return d_theory.propagate(equality.notNode()); + } +} + +bool TheorySets::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, + bool value) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyTriggerPredicate: predicate = " + << predicate << " value = " << value << std::endl; + if (value) + { + return d_theory.propagate(predicate); + } + else + { + return d_theory.propagate(predicate.notNode()); + } +} + +bool TheorySets::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, + TNode t1, + TNode t2, + bool value) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyTriggerTermEquality: tag = " << tag + << " t1 = " << t1 << " t2 = " << t2 << " value = " << value + << std::endl; + d_theory.propagate(value ? t1.eqNode(t2) : t1.eqNode(t2).negate()); + return true; +} + +void TheorySets::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyConstantTermMerge " + << " t1 = " << t1 << " t2 = " << t2 << std::endl; + d_theory.conflict(t1, t2); +} + +void TheorySets::NotifyClass::eqNotifyNewClass(TNode t) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyNewClass:" + << " t = " << t << std::endl; + d_theory.eqNotifyNewClass(t); +} + +void TheorySets::NotifyClass::eqNotifyPreMerge(TNode t1, TNode t2) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyPreMerge:" + << " t1 = " << t1 << " t2 = " << t2 << std::endl; + d_theory.eqNotifyPreMerge(t1, t2); +} + +void TheorySets::NotifyClass::eqNotifyPostMerge(TNode t1, TNode t2) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyPostMerge:" + << " t1 = " << t1 << " t2 = " << t2 << std::endl; + d_theory.eqNotifyPostMerge(t1, t2); +} + +void TheorySets::NotifyClass::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) +{ + Debug("sets-eq") << "[sets-eq] eqNotifyDisequal:" + << " t1 = " << t1 << " t2 = " << t2 << " reason = " << reason + << std::endl; + d_theory.eqNotifyDisequal(t1, t2, reason); +} + }/* CVC4::theory::sets namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/sets/theory_sets.h b/src/theory/sets/theory_sets.h index e81412ba9..f1b59e419 100644 --- a/src/theory/sets/theory_sets.h +++ b/src/theory/sets/theory_sets.h @@ -33,6 +33,8 @@ class TheorySetsScrutinize; class TheorySets : public Theory { + friend class TheorySetsPrivate; + friend class TheorySetsRels; public: /** Constructs a new instance of TheorySets w.r.t. the provided contexts. */ TheorySets(context::Context* c, @@ -43,10 +45,13 @@ class TheorySets : public Theory ProofNodeManager* pnm); ~TheorySets() override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ TheoryRewriter* getTheoryRewriter() override; - /** finish initialization */ void finishInit() override; + //--------------------------------- end initialization + void addSharedTerm(TNode) override; void check(Effort) override; bool collectModelInfo(TheoryModel* m) override; @@ -60,15 +65,36 @@ class TheorySets : public Theory PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; void presolve() override; void propagate(Effort) override; - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; bool isEntailed(Node n, bool pol); - + /* equality engine */ + virtual eq::EqualityEngine* getEqualityEngine() override; private: - friend class TheorySetsPrivate; - friend class TheorySetsScrutinize; - friend class TheorySetsRels; - + /** Functions to handle callbacks from equality engine */ + class NotifyClass : public eq::EqualityEngineNotify + { + public: + NotifyClass(TheorySetsPrivate& theory) : d_theory(theory) {} + bool eqNotifyTriggerEquality(TNode equality, bool value) override; + bool eqNotifyTriggerPredicate(TNode predicate, bool value) override; + bool eqNotifyTriggerTermEquality(TheoryId tag, + TNode t1, + TNode t2, + bool value) override; + void eqNotifyConstantTermMerge(TNode t1, TNode t2) override; + void eqNotifyNewClass(TNode t) override; + void eqNotifyPreMerge(TNode t1, TNode t2) override; + void eqNotifyPostMerge(TNode t1, TNode t2) override; + void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override; + + private: + TheorySetsPrivate& d_theory; + }; + /** The internal theory */ std::unique_ptr d_internal; + /** Instance of the above class */ + NotifyClass d_notify; + /** Equality engine */ + eq::EqualityEngine d_equalityEngine; }; /* class TheorySets */ }/* CVC4::theory::sets namespace */ diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index 4c3affe99..25ee3167e 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -43,27 +43,16 @@ TheorySetsPrivate::TheorySetsPrivate(TheorySets& external, d_keep(c), d_full_check_incomplete(false), d_external(external), - d_notify(*this), - d_equalityEngine(d_notify, c, "theory::sets::ee", true), - d_state(*this, d_equalityEngine, c, u), - d_im(*this, d_state, d_equalityEngine, c, u), - d_rels(new TheorySetsRels(d_state, d_im, d_equalityEngine, u)), - d_cardSolver( - new CardinalityExtension(d_state, d_im, d_equalityEngine, c, u)), + d_state(*this, c, u), + d_im(*this, d_state, c, u), + d_rels(new TheorySetsRels(d_state, d_im, u)), + d_cardSolver(new CardinalityExtension(d_state, d_im, c, u)), d_rels_enabled(false), d_card_enabled(false) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); d_zero = NodeManager::currentNM()->mkConst(Rational(0)); - - d_equalityEngine.addFunctionKind(kind::SINGLETON); - d_equalityEngine.addFunctionKind(kind::UNION); - d_equalityEngine.addFunctionKind(kind::INTERSECTION); - d_equalityEngine.addFunctionKind(kind::SETMINUS); - - d_equalityEngine.addFunctionKind(kind::MEMBER); - d_equalityEngine.addFunctionKind(kind::SUBSET); } TheorySetsPrivate::~TheorySetsPrivate() @@ -74,6 +63,13 @@ TheorySetsPrivate::~TheorySetsPrivate() } } +void TheorySetsPrivate::finishInit() +{ + d_equalityEngine = d_external.getEqualityEngine(); + Assert(d_equalityEngine != nullptr); + d_state.finishInit(d_equalityEngine); +} + void TheorySetsPrivate::eqNotifyNewClass(TNode t) { if (t.getKind() == kind::SINGLETON || t.getKind() == kind::EMPTYSET) @@ -240,13 +236,13 @@ TheorySetsPrivate::EqcInfo* TheorySetsPrivate::getOrMakeEqcInfo(TNode n, bool TheorySetsPrivate::areCareDisequal(Node a, Node b) { - if (d_equalityEngine.isTriggerTerm(a, THEORY_SETS) - && d_equalityEngine.isTriggerTerm(b, THEORY_SETS)) + if (d_equalityEngine->isTriggerTerm(a, THEORY_SETS) + && d_equalityEngine->isTriggerTerm(b, THEORY_SETS)) { TNode a_shared = - d_equalityEngine.getTriggerTermRepresentative(a, THEORY_SETS); + d_equalityEngine->getTriggerTermRepresentative(a, THEORY_SETS); TNode b_shared = - d_equalityEngine.getTriggerTermRepresentative(b, THEORY_SETS); + d_equalityEngine->getTriggerTermRepresentative(b, THEORY_SETS); EqualityStatus eqStatus = d_external.d_valuation.getEqualityStatus(a_shared, b_shared); if (eqStatus == EQUALITY_FALSE_AND_PROPAGATED || eqStatus == EQUALITY_FALSE @@ -260,8 +256,8 @@ bool TheorySetsPrivate::areCareDisequal(Node a, Node b) bool TheorySetsPrivate::isMember(Node x, Node s) { - Assert(d_equalityEngine.hasTerm(s) - && d_equalityEngine.getRepresentative(s) == s); + Assert(d_equalityEngine->hasTerm(s) + && d_equalityEngine->getRepresentative(s) == s); NodeIntMap::iterator mem_i = d_members.find(s); if (mem_i != d_members.end()) { @@ -286,18 +282,18 @@ bool TheorySetsPrivate::assertFact(Node fact, Node exp) { if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.assertEquality(atom, polarity, exp); + d_equalityEngine->assertEquality(atom, polarity, exp); } else { - d_equalityEngine.assertPredicate(atom, polarity, exp); + d_equalityEngine->assertPredicate(atom, polarity, exp); } if (!d_state.isInConflict()) { if (atom.getKind() == kind::MEMBER && polarity) { // check if set has a value, if so, we can propagate - Node r = d_equalityEngine.getRepresentative(atom[1]); + Node r = d_equalityEngine->getRepresentative(atom[1]); EqcInfo* e = getOrMakeEqcInfo(r, true); if (e) { @@ -354,7 +350,7 @@ bool TheorySetsPrivate::assertFact(Node fact, Node exp) void TheorySetsPrivate::fullEffortReset() { - Assert(d_equalityEngine.consistent()); + Assert(d_equalityEngine->consistent()); d_full_check_incomplete = false; d_most_common_type.clear(); d_most_common_type_term.clear(); @@ -380,7 +376,7 @@ void TheorySetsPrivate::fullEffortCheck() Trace("sets-eqc") << "Equality Engine:" << std::endl; std::map eqcTypeCount; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while (!eqcs_i.isFinished()) { Node eqc = (*eqcs_i); @@ -398,7 +394,7 @@ void TheorySetsPrivate::fullEffortCheck() tnct = eqc; } Trace("sets-eqc") << "[" << eqc << "] : "; - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); while (!eqc_i.isFinished()) { Node n = (*eqc_i); @@ -624,7 +620,7 @@ void TheorySetsPrivate::checkDownwardsClosure() { Node mem = it2.second; Node eq_set = nv; - Assert(d_equalityEngine.areEqual(mem[1], eq_set)); + Assert(d_equalityEngine->areEqual(mem[1], eq_set)); if (mem[1] != eq_set) { Trace("sets-debug") << "Downwards closure based on " << mem @@ -761,7 +757,7 @@ void TheorySetsPrivate::checkUpwardsClosure() } if (valid) { - Node rr = d_equalityEngine.getRepresentative(term); + Node rr = d_equalityEngine->getRepresentative(term); if (!isMember(x, rr)) { Node kk = d_state.getProxy(term); @@ -785,7 +781,7 @@ void TheorySetsPrivate::checkUpwardsClosure() for (const std::pair& itm2m : r2mem) { Node x = itm2m.second[0]; - Node rr = d_equalityEngine.getRepresentative(term); + Node rr = d_equalityEngine->getRepresentative(term); if (!isMember(x, rr)) { std::vector exp; @@ -882,10 +878,10 @@ void TheorySetsPrivate::checkDisequalities() } Node deq = (*it).first; // check if it is already satisfied - Assert(d_equalityEngine.hasTerm(deq[0]) - && d_equalityEngine.hasTerm(deq[1])); - Node r1 = d_equalityEngine.getRepresentative(deq[0]); - Node r2 = d_equalityEngine.getRepresentative(deq[1]); + Assert(d_equalityEngine->hasTerm(deq[0]) + && d_equalityEngine->hasTerm(deq[1])); + Node r1 = d_equalityEngine->getRepresentative(deq[0]); + Node r2 = d_equalityEngine->getRepresentative(deq[1]); bool is_sat = d_state.isSetDisequalityEntailed(r1, r2); Trace("sets-debug") << "Check disequality " << deq << ", is_sat = " << is_sat << std::endl; @@ -1005,7 +1001,7 @@ void TheorySetsPrivate::addSharedTerm(TNode n) { Debug("sets") << "[sets] TheorySetsPrivate::addSharedTerm( " << n << ")" << std::endl; - d_equalityEngine.addTriggerTerm(n, THEORY_SETS); + d_equalityEngine->addTriggerTerm(n, THEORY_SETS); } void TheorySetsPrivate::addCarePairs(TNodeTrie* t1, @@ -1028,21 +1024,21 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1, { TNode x = f1[k]; TNode y = f2[k]; - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); Assert(!d_state.areDisequal(x, y)); Assert(!areCareDisequal(x, y)); - if (!d_equalityEngine.areEqual(x, y)) + if (!d_equalityEngine->areEqual(x, y)) { Trace("sets-cg") << "Arg #" << k << " is " << x << " " << y << std::endl; - if (d_equalityEngine.isTriggerTerm(x, THEORY_SETS) - && d_equalityEngine.isTriggerTerm(y, THEORY_SETS)) + if (d_equalityEngine->isTriggerTerm(x, THEORY_SETS) + && d_equalityEngine->isTriggerTerm(y, THEORY_SETS)) { - TNode x_shared = - d_equalityEngine.getTriggerTermRepresentative(x, THEORY_SETS); - TNode y_shared = - d_equalityEngine.getTriggerTermRepresentative(y, THEORY_SETS); + TNode x_shared = d_equalityEngine->getTriggerTermRepresentative( + x, THEORY_SETS); + TNode y_shared = d_equalityEngine->getTriggerTermRepresentative( + y, THEORY_SETS); currentPairs.push_back(make_pair(x_shared, y_shared)); } else if (isCareArg(f1, k) && isCareArg(f2, k)) @@ -1092,7 +1088,7 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1, ++it2; for (; it2 != t1->d_data.end(); ++it2) { - if (!d_equalityEngine.areDisequal(it->first, it2->first, false)) + if (!d_equalityEngine->areDisequal(it->first, it2->first, false)) { if (!areCareDisequal(it->first, it2->first)) { @@ -1110,7 +1106,7 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1, { for (std::pair& tt2 : t2->d_data) { - if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false)) + if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false)) { if (!areCareDisequal(tt1.first, tt2.first)) { @@ -1140,9 +1136,9 @@ void TheorySetsPrivate::computeCareGraph() // populate indices for (TNode f1 : it.second) { - Assert(d_equalityEngine.hasTerm(f1)); + Assert(d_equalityEngine->hasTerm(f1)); Trace("sets-cg-debug") << "...build for " << f1 << std::endl; - Assert(d_equalityEngine.hasTerm(f1)); + Assert(d_equalityEngine->hasTerm(f1)); // break into index based on operator, and type of first argument (since // some operators are parametric) TypeNode tn = f1[0].getType(); @@ -1150,7 +1146,7 @@ void TheorySetsPrivate::computeCareGraph() bool hasCareArg = false; for (unsigned j = 0; j < f1.getNumChildren(); j++) { - reps.push_back(d_equalityEngine.getRepresentative(f1[j])); + reps.push_back(d_equalityEngine->getRepresentative(f1[j])); if (isCareArg(f1, j)) { hasCareArg = true; @@ -1184,7 +1180,7 @@ void TheorySetsPrivate::computeCareGraph() bool TheorySetsPrivate::isCareArg(Node n, unsigned a) { - if (d_equalityEngine.isTriggerTerm(n[a], THEORY_SETS)) + if (d_equalityEngine->isTriggerTerm(n[a], THEORY_SETS)) { return true; } @@ -1201,13 +1197,13 @@ bool TheorySetsPrivate::isCareArg(Node n, unsigned a) EqualityStatus TheorySetsPrivate::getEqualityStatus(TNode a, TNode b) { - Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)); - if (d_equalityEngine.areEqual(a, b)) + Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)); + if (d_equalityEngine->areEqual(a, b)) { // The terms are implied to be equal return EQUALITY_TRUE; } - if (d_equalityEngine.areDisequal(a, b, false)) + if (d_equalityEngine->areDisequal(a, b, false)) { // The terms are implied to be dis-equal return EQUALITY_FALSE; @@ -1272,7 +1268,7 @@ bool TheorySetsPrivate::collectModelInfo(TheoryModel* m) d_external.computeRelevantTerms(termSet); // Assert equalities and disequalities to the model - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { return false; } @@ -1429,11 +1425,6 @@ OutputChannel* TheorySetsPrivate::getOutputChannel() Valuation& TheorySetsPrivate::getValuation() { return d_external.d_valuation; } -void TheorySetsPrivate::setMasterEqualityEngine(eq::EqualityEngine* eq) -{ - d_equalityEngine.setMasterEqualityEngine(eq); -} - void TheorySetsPrivate::conflict(TNode a, TNode b) { Node conf = explain(a.eqNode(b)); @@ -1453,11 +1444,11 @@ Node TheorySetsPrivate::explain(TNode literal) if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); } else if (atom.getKind() == kind::MEMBER) { - d_equalityEngine.explainPredicate(atom, polarity, assumptions); + d_equalityEngine->explainPredicate(atom, polarity, assumptions); } else { @@ -1475,10 +1466,10 @@ void TheorySetsPrivate::preRegisterTerm(TNode node) << std::endl; switch (node.getKind()) { - case kind::EQUAL: d_equalityEngine.addTriggerEquality(node); break; - case kind::MEMBER: d_equalityEngine.addTriggerPredicate(node); break; - case kind::CARD: d_equalityEngine.addTriggerTerm(node, THEORY_SETS); break; - default: d_equalityEngine.addTerm(node); break; + case kind::EQUAL: d_equalityEngine->addTriggerEquality(node); break; + case kind::MEMBER: d_equalityEngine->addTriggerPredicate(node); break; + case kind::CARD: d_equalityEngine->addTriggerTerm(node, THEORY_SETS); break; + default: d_equalityEngine->addTerm(node); break; } } @@ -1538,92 +1529,6 @@ Node TheorySetsPrivate::getChooseFunction(const TypeNode& setType) void TheorySetsPrivate::presolve() { d_state.reset(); } -/**************************** eq::NotifyClass *****************************/ -/**************************** eq::NotifyClass *****************************/ -/**************************** eq::NotifyClass *****************************/ - -bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerEquality(TNode equality, - bool value) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyTriggerEquality: equality = " - << equality << " value = " << value << std::endl; - if (value) - { - return d_theory.propagate(equality); - } - else - { - // We use only literal triggers so taking not is safe - return d_theory.propagate(equality.notNode()); - } -} - -bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, - bool value) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyTriggerPredicate: predicate = " - << predicate << " value = " << value << std::endl; - if (value) - { - return d_theory.propagate(predicate); - } - else - { - return d_theory.propagate(predicate.notNode()); - } -} - -bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, - TNode t1, - TNode t2, - bool value) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyTriggerTermEquality: tag = " << tag - << " t1 = " << t1 << " t2 = " << t2 << " value = " << value - << std::endl; - d_theory.propagate(value ? t1.eqNode(t2) : t1.eqNode(t2).negate()); - return true; -} - -void TheorySetsPrivate::NotifyClass::eqNotifyConstantTermMerge(TNode t1, - TNode t2) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyConstantTermMerge " - << " t1 = " << t1 << " t2 = " << t2 << std::endl; - d_theory.conflict(t1, t2); -} - -void TheorySetsPrivate::NotifyClass::eqNotifyNewClass(TNode t) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyNewClass:" - << " t = " << t << std::endl; - d_theory.eqNotifyNewClass(t); -} - -void TheorySetsPrivate::NotifyClass::eqNotifyPreMerge(TNode t1, TNode t2) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyPreMerge:" - << " t1 = " << t1 << " t2 = " << t2 << std::endl; - d_theory.eqNotifyPreMerge(t1, t2); -} - -void TheorySetsPrivate::NotifyClass::eqNotifyPostMerge(TNode t1, TNode t2) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyPostMerge:" - << " t1 = " << t1 << " t2 = " << t2 << std::endl; - d_theory.eqNotifyPostMerge(t1, t2); -} - -void TheorySetsPrivate::NotifyClass::eqNotifyDisequal(TNode t1, - TNode t2, - TNode reason) -{ - Debug("sets-eq") << "[sets-eq] eqNotifyDisequal:" - << " t1 = " << t1 << " t2 = " << t2 << " reason = " << reason - << std::endl; - d_theory.eqNotifyDisequal(t1, t2, reason); -} - } // namespace sets } // namespace theory } // namespace CVC4 diff --git a/src/theory/sets/theory_sets_private.h b/src/theory/sets/theory_sets_private.h index c65c86795..2779a42b7 100644 --- a/src/theory/sets/theory_sets_private.h +++ b/src/theory/sets/theory_sets_private.h @@ -163,7 +163,11 @@ class TheorySetsPrivate { TheoryRewriter* getTheoryRewriter() { return &d_rewriter; } - void setMasterEqualityEngine(eq::EqualityEngine* eq); + /** + * Finish initialize, called after the equality engine of theory sets has + * been determined. + */ + void finishInit(); void addSharedTerm(TNode); @@ -219,37 +223,18 @@ class TheorySetsPrivate { /** get the valuation */ Valuation& getValuation(); - private: - TheorySets& d_external; - - /** Functions to handle callbacks from equality engine */ - class NotifyClass : public eq::EqualityEngineNotify { - TheorySetsPrivate& d_theory; - - public: - NotifyClass(TheorySetsPrivate& theory): d_theory(theory) {} - bool eqNotifyTriggerEquality(TNode equality, bool value) override; - bool eqNotifyTriggerPredicate(TNode predicate, bool value) override; - bool eqNotifyTriggerTermEquality(TheoryId tag, - TNode t1, - TNode t2, - bool value) override; - void eqNotifyConstantTermMerge(TNode t1, TNode t2) override; - void eqNotifyNewClass(TNode t) override; - void eqNotifyPreMerge(TNode t1, TNode t2) override; - void eqNotifyPostMerge(TNode t1, TNode t2) override; - void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override; - } d_notify; - - /** Equality engine */ - eq::EqualityEngine d_equalityEngine; - /** Proagate out to output channel */ bool propagate(TNode); /** generate and send out conflict node */ void conflict(TNode, TNode); - + + private: + TheorySets& d_external; + + /** Pointer to the equality engine of theory of sets */ + eq::EqualityEngine* d_equalityEngine; + bool isCareArg( Node n, unsigned a ); public: diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index f4b1d3be9..82f32337e 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -34,19 +34,11 @@ typedef std::map< Node, std::map< Node, std::unordered_set< Node, NodeHashFuncti TheorySetsRels::TheorySetsRels(SolverState& s, InferenceManager& im, - eq::EqualityEngine& e, context::UserContext* u) - : d_state(s), d_im(im), d_ee(e), d_shared_terms(u) + : d_state(s), d_im(im), d_shared_terms(u) { d_trueNode = NodeManager::currentNM()->mkConst(true); d_falseNode = NodeManager::currentNM()->mkConst(false); - d_ee.addFunctionKind(PRODUCT); - d_ee.addFunctionKind(JOIN); - d_ee.addFunctionKind(TRANSPOSE); - d_ee.addFunctionKind(TCLOSURE); - d_ee.addFunctionKind(JOIN_IMAGE); - d_ee.addFunctionKind(IDEN); - d_ee.addFunctionKind(APPLY_CONSTRUCTOR); } TheorySetsRels::~TheorySetsRels() {} @@ -185,10 +177,11 @@ void TheorySetsRels::check(Theory::Effort level) void TheorySetsRels::collectRelsInfo() { Trace("rels") << "[sets-rels] Start collecting relational terms..." << std::endl; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_ee); + eq::EqualityEngine* ee = d_state.getEqualityEngine(); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee); while( !eqcs_i.isFinished() ){ Node eqc_rep = (*eqcs_i); - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc_rep, &d_ee); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc_rep, ee); TypeNode erType = eqc_rep.getType(); Trace("rels-ee") << "[sets-rels-ee] Eqc term representative: " << eqc_rep << " with type " << eqc_rep.getType() << std::endl; @@ -1139,24 +1132,17 @@ void TheorySetsRels::check(Theory::Effort level) } Node TheorySetsRels::getRepresentative( Node t ) { - if (d_ee.hasTerm(t)) - { - return d_ee.getRepresentative(t); - } - else - { - return t; - } + return d_state.getRepresentative(t); } - bool TheorySetsRels::hasTerm(Node a) { return d_ee.hasTerm(a); } + bool TheorySetsRels::hasTerm(Node a) { return d_state.hasTerm(a); } bool TheorySetsRels::areEqual( Node a, Node b ){ Assert(a.getType() == b.getType()); Trace("rels-eq") << "[sets-rels]**** checking equality between " << a << " and " << b << std::endl; if(a == b) { return true; } else if( hasTerm( a ) && hasTerm( b ) ){ - return d_ee.areEqual(a, b); + return d_state.areEqual(a, b); } else if(a.getType().isTuple()) { bool equal = true; for(unsigned int i = 0; i < a.getType().getTupleLength(); i++) { diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index 0d8ace50f..60715ff57 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -67,7 +67,6 @@ class TheorySetsRels { public: TheorySetsRels(SolverState& s, InferenceManager& im, - eq::EqualityEngine& e, context::UserContext* u); ~TheorySetsRels(); @@ -90,8 +89,6 @@ private: SolverState& d_state; /** Reference to the inference manager for the theory of sets */ InferenceManager& d_im; - /** Reference to the equality engine of theory of sets */ - eq::EqualityEngine& d_ee; /** A list of pending inferences to process */ std::vector d_pending; NodeSet d_shared_terms;