From: Andrew Reynolds Date: Mon, 17 Aug 2020 19:38:16 +0000 (-0500) Subject: Dynamic allocation of equality engine in Theory (#4890) X-Git-Tag: cvc5-1.0.0~2994 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=4f82b6eb7cc921ba2c6470a5ca0027be8dfc04e9;p=cvc5.git Dynamic allocation of equality engine in Theory (#4890) This commit updates Theory so that equality engines are allocated dynamically. The plan is to make this configurable based on the theory combination method. The fundamental changes include: - Add `d_equalityEngine` (raw) pointer to Theory, which is the "official" equality engine of the theory. - Standardize methods for initializing Theory. This is now made more explicit in the documentation in theory.h, and includes a method `finishInitStandalone` for users of Theory that don't have an associated TheoryEngine. - Refactor TheoryEngine::finishInit, including how Theory is initialized to incorporate the new policy. - Things related to master equality engine are now specific to EqEngineManagerDistributed and hence can be removed from TheoryEngine. This will be further refactored in forthcoming PRs. Note that the majority of changes are due to changing `d_equalityEngine.` to `d_equalityEngine->` throughout. --- diff --git a/src/proof/theory_proof.cpp b/src/proof/theory_proof.cpp index 3103557c8..b47fd6a1e 100644 --- a/src/proof/theory_proof.cpp +++ b/src/proof/theory_proof.cpp @@ -1093,6 +1093,12 @@ void TheoryProof::printTheoryLemmaProof(std::vector& lemma, InternalError() << "can't generate theory-proof for " << ProofManager::currentPM()->getLogic(); } + // must perform initialization on the theory + if (th != nullptr) + { + // finish init, standalone version + th->finishInitStandalone(); + } Debug("pf::tp") << "TheoryProof::printTheoryLemmaProof - calling th->ProduceProofs()" << std::endl; th->produceProofs(); diff --git a/src/theory/arith/congruence_manager.cpp b/src/theory/arith/congruence_manager.cpp index ab3b485a8..a70339c01 100644 --- a/src/theory/arith/congruence_manager.cpp +++ b/src/theory/arith/congruence_manager.cpp @@ -42,16 +42,29 @@ ArithCongruenceManager::ArithCongruenceManager( d_constraintDatabase(cd), d_setupLiteral(setup), d_avariables(avars), - d_ee(d_notify, c, "theory::arith::ArithCongruenceManager", true) + d_ee(nullptr) { - d_ee.addFunctionKind(kind::NONLINEAR_MULT); - d_ee.addFunctionKind(kind::EXPONENTIAL); - d_ee.addFunctionKind(kind::SINE); - d_ee.addFunctionKind(kind::IAND); } ArithCongruenceManager::~ArithCongruenceManager() {} +bool ArithCongruenceManager::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::arith::ArithCongruenceManager"; + return true; +} + +void ArithCongruenceManager::finishInit(eq::EqualityEngine* ee) +{ + Assert(ee != nullptr); + d_ee = ee; + d_ee->addFunctionKind(kind::NONLINEAR_MULT); + d_ee->addFunctionKind(kind::EXPONENTIAL); + d_ee->addFunctionKind(kind::SINE); + d_ee->addFunctionKind(kind::IAND); +} + ArithCongruenceManager::Statistics::Statistics(): d_watchedVariables("theory::arith::congruence::watchedVariables", 0), d_watchedVariableIsZero("theory::arith::congruence::watchedVariableIsZero", 0), @@ -141,10 +154,6 @@ bool ArithCongruenceManager::canExplain(TNode n) const { return d_explanationMap.find(n) != d_explanationMap.end(); } -void ArithCongruenceManager::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_ee.setMasterEqualityEngine(eq); -} - Node ArithCongruenceManager::externalToInternal(TNode n) const{ Assert(canExplain(n)); ExplainMap::const_iterator iter = d_explanationMap.find(n); @@ -320,9 +329,9 @@ bool ArithCongruenceManager::propagate(TNode x){ void ArithCongruenceManager::explain(TNode literal, std::vector& assumptions) { if (literal.getKind() != kind::NOT) { - d_ee.explainEquality(literal[0], literal[1], true, assumptions); + d_ee->explainEquality(literal[0], literal[1], true, assumptions); } else { - d_ee.explainEquality(literal[0][0], literal[0][1], false, assumptions); + d_ee->explainEquality(literal[0][0], literal[0][1], false, assumptions); } } @@ -392,9 +401,9 @@ void ArithCongruenceManager::assertionToEqualityEngine(bool isEquality, ArithVar Trace("arith-ee") << "Assert " << eq << ", pol " << isEquality << ", reason " << reason << std::endl; if(isEquality){ - d_ee.assertEquality(eq, true, reason); + d_ee->assertEquality(eq, true, reason); }else{ - d_ee.assertEquality(eq, false, reason); + d_ee->assertEquality(eq, false, reason); } } @@ -417,7 +426,7 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP c){ d_keepAlive.push_back(reason); Trace("arith-ee") << "Assert equalsConstant " << eq << ", reason " << reason << std::endl; - d_ee.assertEquality(eq, true, reason); + d_ee->assertEquality(eq, true, reason); } void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){ @@ -441,11 +450,11 @@ void ArithCongruenceManager::equalsConstant(ConstraintCP lb, ConstraintCP ub){ d_keepAlive.push_back(reason); Trace("arith-ee") << "Assert equalsConstant2 " << eq << ", reason " << reason << std::endl; - d_ee.assertEquality(eq, true, reason); + d_ee->assertEquality(eq, true, reason); } void ArithCongruenceManager::addSharedTerm(Node x){ - d_ee.addTriggerTerm(x, THEORY_ARITH); + d_ee->addTriggerTerm(x, THEORY_ARITH); } }/* CVC4::theory::arith namespace */ diff --git a/src/theory/arith/congruence_manager.h b/src/theory/arith/congruence_manager.h index aeb72ec94..f3b5641b4 100644 --- a/src/theory/arith/congruence_manager.h +++ b/src/theory/arith/congruence_manager.h @@ -95,7 +95,8 @@ private: const ArithVariables& d_avariables; - eq::EqualityEngine d_ee; + /** The equality engine being used by this class */ + eq::EqualityEngine* d_ee; void raiseConflict(Node conflict); public: @@ -108,8 +109,6 @@ public: bool canExplain(TNode n) const; - void setMasterEqualityEngine(eq::EqualityEngine* eq); - private: Node externalToInternal(TNode n) const; @@ -138,6 +137,19 @@ public: ArithCongruenceManager(context::Context* satContext, ConstraintDatabase&, SetupLiteralCallBack, const ArithVariables&, RaiseEqualityEngineConflict raiseConflict); ~ArithCongruenceManager(); + //--------------------------------- initialization + /** + * Returns true if we need an equality engine, see + * Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi); + /** + * Finish initialize. This class is instructed by TheoryArithPrivate to use + * the equality engine ee. + */ + void finishInit(eq::EqualityEngine* ee); + //--------------------------------- end initialization + Node explain(TNode literal); void explain(TNode lit, NodeBuilder<>& out); @@ -166,10 +178,8 @@ public: void addSharedTerm(Node x); - - eq::EqualityEngine * getEqualityEngine() { return &d_ee; } -private: + private: class Statistics { public: IntStat d_watchedVariables; diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index bc6e18a83..b95b5e243 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -56,10 +56,10 @@ TheoryRewriter* TheoryArith::getTheoryRewriter() return d_internal->getTheoryRewriter(); } -void TheoryArith::preRegisterTerm(TNode n){ - d_internal->preRegisterTerm(n); +bool TheoryArith::needsEqualityEngine(EeSetupInfo& esi) +{ + return d_internal->needsEqualityEngine(esi); } - void TheoryArith::finishInit() { if (getLogicInfo().isTheoryEnabled(THEORY_ARITH) @@ -72,17 +72,17 @@ void TheoryArith::finishInit() d_valuation.setUnevaluatedKind(kind::SINE); d_valuation.setUnevaluatedKind(kind::PI); } + // finish initialize internally + d_internal->finishInit(); } +void TheoryArith::preRegisterTerm(TNode n) { d_internal->preRegisterTerm(n); } + TrustNode TheoryArith::expandDefinition(Node node) { return d_internal->expandDefinition(node); } -void TheoryArith::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_internal->setMasterEqualityEngine(eq); -} - void TheoryArith::addSharedTerm(TNode n){ d_internal->addSharedTerm(n); } diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 30de7bbad..ad3b91b07 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -55,19 +55,28 @@ class TheoryArith : public Theory { ProofNodeManager* pnm = nullptr); virtual ~TheoryArith(); + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if this theory needs an equality engine, which is assigned + * to it (d_equalityEngine) by the equality engine manager during + * TheoryEngine::finishInit, prior to calling finishInit for this theory. + * If this method returns true, it stores instructions for the notifications + * this Theory wishes to receive from its equality engine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization /** * Does non-context dependent setup for a node connected to a theory. */ void preRegisterTerm(TNode n) override; - void finishInit() override; - TrustNode expandDefinition(Node node) override; - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - void check(Effort e) override; bool needsCheckLastEffort() override; void propagate(Effort e) override; diff --git a/src/theory/arith/theory_arith_private.cpp b/src/theory/arith/theory_arith_private.cpp index 6f47ffb0e..8ca99d369 100644 --- a/src/theory/arith/theory_arith_private.cpp +++ b/src/theory/arith/theory_arith_private.cpp @@ -134,7 +134,7 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing, d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), d_attemptSolSimplex( d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)), - d_nonlinearExtension(NULL), + d_nonlinearExtension(nullptr), d_pass1SDP(NULL), d_otherSDP(NULL), d_lastContextIntegerAttempted(c, -1), @@ -159,12 +159,6 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing, d_statistics(), d_opElim(pnm, logicInfo) { - // only need to create if non-linear logic - if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear()) - { - d_nonlinearExtension = new nl::NonlinearExtension( - containing, d_congruenceManager.getEqualityEngine()); - } } TheoryArithPrivate::~TheoryArithPrivate(){ @@ -173,6 +167,24 @@ TheoryArithPrivate::~TheoryArithPrivate(){ if(d_nonlinearExtension != NULL) { delete d_nonlinearExtension; } } +TheoryRewriter* TheoryArithPrivate::getTheoryRewriter() { return &d_rewriter; } +bool TheoryArithPrivate::needsEqualityEngine(EeSetupInfo& esi) +{ + return d_congruenceManager.needsEqualityEngine(esi); +} +void TheoryArithPrivate::finishInit() +{ + eq::EqualityEngine* ee = d_containing.getEqualityEngine(); + Assert(ee != nullptr); + d_congruenceManager.finishInit(ee); + const LogicInfo& logicInfo = getLogicInfo(); + // only need to create nonlinear extension if non-linear logic + if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear()) + { + d_nonlinearExtension = new nl::NonlinearExtension(d_containing, ee); + } +} + static bool contains(const ConstraintCPVec& v, ConstraintP con){ for(unsigned i = 0, N = v.size(); i < N; ++i){ if(v[i] == con){ @@ -227,10 +239,6 @@ static void resolve(ConstraintCPVec& buf, ConstraintP c, const ConstraintCPVec& // return safeConstructNary(nb); } -void TheoryArithPrivate::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_congruenceManager.setMasterEqualityEngine(eq); -} - TheoryArithPrivate::ModelException::ModelException(TNode n, const char* msg) { stringstream ss; diff --git a/src/theory/arith/theory_arith_private.h b/src/theory/arith/theory_arith_private.h index 42ec7f47b..4c4aedf00 100644 --- a/src/theory/arith/theory_arith_private.h +++ b/src/theory/arith/theory_arith_private.h @@ -427,7 +427,17 @@ private: ProofNodeManager* pnm); ~TheoryArithPrivate(); - TheoryRewriter* getTheoryRewriter() { return &d_rewriter; } + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter(); + /** + * Returns true if we need an equality engine, see + * Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi); + /** finish initialize */ + void finishInit(); + //--------------------------------- end initialization /** * Does non-context dependent setup for a node connected to a theory. @@ -435,8 +445,6 @@ private: void preRegisterTerm(TNode n); TrustNode expandDefinition(Node node); - void setMasterEqualityEngine(eq::EqualityEngine* eq); - void check(Theory::Effort e); bool needsCheckLastEffort(); void propagate(Theory::Effort e); diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 245da617b..85759b75f 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -88,7 +88,6 @@ TheoryArrays::TheoryArrays(context::Context* c, d_isPreRegistered(c), d_mayEqualEqualityEngine(c, name + "theory::arrays::mayEqual", true), d_notify(*this), - d_equalityEngine(d_notify, c, name + "theory::arrays", true), d_conflict(c, false), d_backtracker(c), d_infoMap(c, &d_backtracker, name), @@ -112,7 +111,7 @@ TheoryArrays::TheoryArrays(context::Context* c, d_readTableContext(new context::Context()), d_arrayMerges(c), d_inCheckModel(false), - d_proofReconstruction(&d_equalityEngine), + d_proofReconstruction(nullptr), d_dstrat(new TheoryArraysDecisionStrategy(this)), d_dstratInit(false) { @@ -133,27 +132,6 @@ TheoryArrays::TheoryArrays(context::Context* c, // The preprocessing congruence kinds d_ppEqualityEngine.addFunctionKind(kind::SELECT); d_ppEqualityEngine.addFunctionKind(kind::STORE); - - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::SELECT); - if (d_ccStore) { - d_equalityEngine.addFunctionKind(kind::STORE); - } - if (d_useArrTable) { - d_equalityEngine.addFunctionKind(kind::ARR_TABLE_FUN); - } - - d_reasonRow = d_equalityEngine.getFreshMergeReasonType(); - d_reasonRow1 = d_equalityEngine.getFreshMergeReasonType(); - d_reasonExt = d_equalityEngine.getFreshMergeReasonType(); - - d_proofReconstruction.setRowMergeTag(d_reasonRow); - d_proofReconstruction.setRow1MergeTag(d_reasonRow1); - d_proofReconstruction.setExtMergeTag(d_reasonExt); - - d_equalityEngine.addPathReconstructionTrigger(d_reasonRow, &d_proofReconstruction); - d_equalityEngine.addPathReconstructionTrigger(d_reasonRow1, &d_proofReconstruction); - d_equalityEngine.addPathReconstructionTrigger(d_reasonExt, &d_proofReconstruction); } TheoryArrays::~TheoryArrays() { @@ -179,8 +157,45 @@ TheoryArrays::~TheoryArrays() { smtStatisticsRegistry()->unregisterStat(&d_numSetModelValConflicts); } -void TheoryArrays::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheoryArrays::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryArrays::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = d_instanceName + "theory::arrays::ee"; + return true; +} + +void TheoryArrays::finishInit() +{ + Assert(d_equalityEngine != nullptr); + + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::SELECT); + if (d_ccStore) + { + d_equalityEngine->addFunctionKind(kind::STORE); + } + if (d_useArrTable) + { + d_equalityEngine->addFunctionKind(kind::ARR_TABLE_FUN); + } + + d_proofReconstruction.reset(new ArrayProofReconstruction(d_equalityEngine)); + d_reasonRow = d_equalityEngine->getFreshMergeReasonType(); + d_reasonRow1 = d_equalityEngine->getFreshMergeReasonType(); + d_reasonExt = d_equalityEngine->getFreshMergeReasonType(); + + d_proofReconstruction->setRowMergeTag(d_reasonRow); + d_proofReconstruction->setRow1MergeTag(d_reasonRow1); + d_proofReconstruction->setExtMergeTag(d_reasonExt); + + d_equalityEngine->addPathReconstructionTrigger(d_reasonRow, + d_proofReconstruction.get()); + d_equalityEngine->addPathReconstructionTrigger(d_reasonRow1, + d_proofReconstruction.get()); + d_equalityEngine->addPathReconstructionTrigger(d_reasonExt, + d_proofReconstruction.get()); } ///////////////////////////////////////////////////////////////////////////// @@ -427,9 +442,10 @@ void TheoryArrays::explain(TNode literal, std::vector& assumptions, //eq::EqProof * eqp = new eq::EqProof; // eq::EqProof * eqp = NULL; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions, proof); + d_equalityEngine->explainEquality( + atom[0], atom[1], polarity, assumptions, proof); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions, proof); + d_equalityEngine->explainPredicate(atom, polarity, assumptions, proof); } if (Debug.isOn("pf::array")) { @@ -469,7 +485,8 @@ TNode TheoryArrays::weakEquivGetRepIndex(TNode node, TNode index) { return node; } index2 = d_infoMap.getWeakEquivIndex(node); - if (index2.isNull() || !d_equalityEngine.areEqual(index, index2)) { + if (index2.isNull() || !d_equalityEngine->areEqual(index, index2)) + { node = pointer; } else { @@ -493,7 +510,8 @@ void TheoryArrays::visitAllLeaves(TNode reason, vector& conjunctions) { conjunctions.push_back(reason); break; case kind::EQUAL: - d_equalityEngine.explainEquality(reason[0], reason[1], true, conjunctions); + d_equalityEngine->explainEquality( + reason[0], reason[1], true, conjunctions); break; default: Unreachable(); @@ -511,10 +529,11 @@ void TheoryArrays::weakEquivBuildCond(TNode node, TNode index, vector& co index2 = d_infoMap.getWeakEquivIndex(node); if (index2.isNull()) { // Null index means these two nodes became equal: explain the equality. - d_equalityEngine.explainEquality(node, pointer, true, conjunctions); + d_equalityEngine->explainEquality(node, pointer, true, conjunctions); node = pointer; } - else if (!d_equalityEngine.areEqual(index, index2)) { + else if (!d_equalityEngine->areEqual(index, index2)) + { // If indices are not equal in current context, need to add that to the lemma. Node reason = index.eqNode(index2).notNode(); d_permRef.push_back(reason); @@ -556,7 +575,8 @@ void TheoryArrays::weakEquivMakeRepIndex(TNode node) { TNode index2 = d_infoMap.getWeakEquivIndex(secondary); Node reason; TNode next; - while (index2.isNull() || !d_equalityEngine.areEqual(index, index2)) { + while (index2.isNull() || !d_equalityEngine->areEqual(index, index2)) + { next = d_infoMap.getWeakEquivPointer(secondary); d_infoMap.setWeakEquivSecondary(node, next); reason = d_infoMap.getWeakEquivSecondaryReason(node); @@ -590,13 +610,13 @@ void TheoryArrays::weakEquivAddSecondary(TNode index, TNode arrayFrom, TNode arr TNode pointer, indexRep; if (!index.isNull()) { index_trail.push_back(index); - marked.insert(d_equalityEngine.getRepresentative(index)); + marked.insert(d_equalityEngine->getRepresentative(index)); } while (arrayFrom != arrayTo) { index = d_infoMap.getWeakEquivIndex(arrayFrom); pointer = d_infoMap.getWeakEquivPointer(arrayFrom); if (!index.isNull()) { - indexRep = d_equalityEngine.getRepresentative(index); + indexRep = d_equalityEngine->getRepresentative(index); if (marked.find(indexRep) == marked.end() && weakEquivGetRepIndex(arrayFrom, index) != arrayTo) { weakEquivMakeRepIndex(arrayFrom); d_infoMap.setWeakEquivSecondary(arrayFrom, arrayTo); @@ -639,7 +659,7 @@ void TheoryArrays::checkWeakEquiv(bool arraysMerged) { || !secondary.isNull()); if (!pointer.isNull()) { if (index.isNull()) { - Assert(d_equalityEngine.areEqual(n, pointer)); + Assert(d_equalityEngine->areEqual(n, pointer)); } else { Assert( @@ -677,16 +697,17 @@ void TheoryArrays::preRegisterTermInternal(TNode node) case kind::EQUAL: // Add the trigger for equality // NOTE: note that if the equality is true or false already, it might not be added - d_equalityEngine.addTriggerEquality(node); + d_equalityEngine->addTriggerEquality(node); break; case kind::SELECT: { // Invariant: array terms should be preregistered before being added to the equality engine - if (d_equalityEngine.hasTerm(node)) { + if (d_equalityEngine->hasTerm(node)) + { Assert(d_isPreRegistered.find(node) != d_isPreRegistered.end()); return; } // Reads - TNode store = d_equalityEngine.getRepresentative(node[0]); + TNode store = d_equalityEngine->getRepresentative(node[0]); // The may equal needs the store d_mayEqualEqualityEngine.addTerm(store); @@ -694,15 +715,15 @@ void TheoryArrays::preRegisterTermInternal(TNode node) if (node.getType().isArray()) { d_mayEqualEqualityEngine.addTerm(node); - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } Assert((d_isPreRegistered.insert(node), true)); - Assert(d_equalityEngine.getRepresentative(store) == store); + Assert(d_equalityEngine->getRepresentative(store) == store); d_infoMap.addIndex(store, node[1]); // Synchronize d_constReadsContext with SAT context @@ -712,7 +733,7 @@ void TheoryArrays::preRegisterTermInternal(TNode node) } // Record read in sharing data structure - TNode index = d_equalityEngine.getRepresentative(node[1]); + TNode index = d_equalityEngine->getRepresentative(node[1]); if (!options::arraysWeakEquivalence() && index.isConst()) { CTNodeList* temp; CNodeNListMap::iterator it = d_constReads.find(index); @@ -734,12 +755,13 @@ void TheoryArrays::preRegisterTermInternal(TNode node) break; } case kind::STORE: { - if (d_equalityEngine.hasTerm(node)) { + if (d_equalityEngine->hasTerm(node)) + { break; } - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); - TNode a = d_equalityEngine.getRepresentative(node[0]); + TNode a = d_equalityEngine->getRepresentative(node[0]); if (node.isConst()) { // Can't use d_mayEqualEqualityEngine to merge node with a because they are both constants, @@ -761,12 +783,13 @@ void TheoryArrays::preRegisterTermInternal(TNode node) TNode v = node[2]; NodeManager* nm = NodeManager::currentNM(); Node ni = nm->mkNode(kind::SELECT, node, i); - if (!d_equalityEngine.hasTerm(ni)) { + if (!d_equalityEngine->hasTerm(ni)) + { preRegisterTermInternal(ni); } // Apply RIntro1 Rule - d_equalityEngine.assertEquality(ni.eqNode(v), true, d_true, d_reasonRow1); + d_equalityEngine->assertEquality(ni.eqNode(v), true, d_true, d_reasonRow1); d_infoMap.addStore(node, node); d_infoMap.addInStore(a, node); @@ -787,7 +810,8 @@ void TheoryArrays::preRegisterTermInternal(TNode node) break; } case kind::STORE_ALL: { - if (d_equalityEngine.hasTerm(node)) { + if (d_equalityEngine->hasTerm(node)) + { break; } ArrayStoreAll storeAll = node.getConst(); @@ -798,7 +822,7 @@ void TheoryArrays::preRegisterTermInternal(TNode node) d_infoMap.setConstArr(node, node); d_mayEqualEqualityEngine.addTerm(node); Assert(d_mayEqualEqualityEngine.getRepresentative(node) == node); - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); d_defValues[node] = defaultValue; break; } @@ -807,19 +831,19 @@ void TheoryArrays::preRegisterTermInternal(TNode node) if (node.getType().isArray()) { // The may equal needs the node d_mayEqualEqualityEngine.addTerm(node); - d_equalityEngine.addTriggerTerm(node, THEORY_ARRAYS); - Assert(d_equalityEngine.getSize(node) == 1); + d_equalityEngine->addTriggerTerm(node, THEORY_ARRAYS); + Assert(d_equalityEngine->getSize(node) == 1); } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } break; } // Invariant: preregistered terms are exactly the terms in the equality engine // Disabled, see comment above for kind::EQUAL - // Assert(d_equalityEngine.hasTerm(node) || - // !d_equalityEngine.consistent()); + // Assert(d_equalityEngine->hasTerm(node) || + // !d_equalityEngine->consistent()); } @@ -830,7 +854,7 @@ void TheoryArrays::preRegisterTerm(TNode node) // Note: do this here instead of in preRegisterTermInternal to prevent internal select // terms from being propagated out (as this results in an assertion failure). if (node.getKind() == kind::SELECT && node.getType().isBoolean()) { - d_equalityEngine.addTriggerPredicate(node); + d_equalityEngine->addTriggerPredicate(node); } } @@ -862,7 +886,7 @@ Node TheoryArrays::explain(TNode literal, eq::EqProof* proof) { void TheoryArrays::addSharedTerm(TNode t) { Debug("arrays::sharing") << spaces(getSatContext()->getLevel()) << "TheoryArrays::addSharedTerm(" << t << ")" << std::endl; - d_equalityEngine.addTriggerTerm(t, THEORY_ARRAYS); + d_equalityEngine->addTriggerTerm(t, THEORY_ARRAYS); if (t.getType().isArray()) { d_sharedArrays.insert(t); } @@ -876,12 +900,14 @@ void TheoryArrays::addSharedTerm(TNode t) { EqualityStatus TheoryArrays::getEqualityStatus(TNode a, TNode b) { - Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)); - if (d_equalityEngine.areEqual(a, b)) { + Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)); + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - else if (d_equalityEngine.areDisequal(a, b, false)) { + else if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -895,16 +921,19 @@ void TheoryArrays::checkPair(TNode r1, TNode r2) TNode x = r1[1]; TNode y = r2[1]; - Assert(d_equalityEngine.isTriggerTerm(x, THEORY_ARRAYS)); + Assert(d_equalityEngine->isTriggerTerm(x, THEORY_ARRAYS)); - if (d_equalityEngine.hasTerm(x) && d_equalityEngine.hasTerm(y) && - (d_equalityEngine.areEqual(x,y) || d_equalityEngine.areDisequal(x,y,false))) { + if (d_equalityEngine->hasTerm(x) && d_equalityEngine->hasTerm(y) + && (d_equalityEngine->areEqual(x, y) + || d_equalityEngine->areDisequal(x, y, false))) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): equality known, skipping" << std::endl; return; } // If the terms are already known to be equal, we are also in good shape - if (d_equalityEngine.areEqual(r1, r2)) { + if (d_equalityEngine->areEqual(r1, r2)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): equal, skipping" << std::endl; return; } @@ -913,8 +942,9 @@ void TheoryArrays::checkPair(TNode r1, TNode r2) // If arrays are known to be disequal, or cannot become equal, we can continue Assert(d_mayEqualEqualityEngine.hasTerm(r1[0]) && d_mayEqualEqualityEngine.hasTerm(r2[0])); - if (r1[0].getType() != r2[0].getType() || - d_equalityEngine.areDisequal(r1[0], r2[0], false)) { + if (r1[0].getType() != r2[0].getType() + || d_equalityEngine->areDisequal(r1[0], r2[0], false)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): arrays can't be equal, skipping" << std::endl; return; } @@ -923,14 +953,17 @@ void TheoryArrays::checkPair(TNode r1, TNode r2) } } - if (!d_equalityEngine.isTriggerTerm(y, THEORY_ARRAYS)) { + if (!d_equalityEngine->isTriggerTerm(y, THEORY_ARRAYS)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): not connected to shared terms, skipping" << std::endl; return; } // Get representative trigger terms - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_ARRAYS); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_ARRAYS); + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_ARRAYS); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_ARRAYS); EqualityStatus eqStatusDomain = d_valuation.getEqualityStatus(x_shared, y_shared); switch (eqStatusDomain) { case EQUALITY_TRUE_AND_PROPAGATED: @@ -999,14 +1032,16 @@ void TheoryArrays::computeCareGraph() TNode r1 = d_reads[i]; Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): checking read " << r1 << std::endl; - Assert(d_equalityEngine.hasTerm(r1)); + Assert(d_equalityEngine->hasTerm(r1)); TNode x = r1[1]; - if (!d_equalityEngine.isTriggerTerm(x, THEORY_ARRAYS)) { + if (!d_equalityEngine->isTriggerTerm(x, THEORY_ARRAYS)) + { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): not connected to shared terms, skipping" << std::endl; continue; } - Node x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_ARRAYS); + Node x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_ARRAYS); // Get the model value of index and find all reads that read from that same model value: these are the pairs we have to check // Also, insert this read in the list at the proper index @@ -1034,12 +1069,12 @@ void TheoryArrays::computeCareGraph() // We don't know the model value for x. Just do brute force examination of all pairs of reads for (unsigned j = 0; j < size; ++j) { TNode r2 = d_reads[j]; - Assert(d_equalityEngine.hasTerm(r2)); + Assert(d_equalityEngine->hasTerm(r2)); checkPair(r1,r2); } for (unsigned j = 0; j < d_constReadsList.size(); ++j) { TNode r2 = d_constReadsList[j]; - Assert(d_equalityEngine.hasTerm(r2)); + Assert(d_equalityEngine->hasTerm(r2)); checkPair(r1,r2); } } @@ -1064,7 +1099,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) NodeManager* nm = NodeManager::currentNM(); std::vector arrays; bool computeRep, isArray; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); for (; !eqcs_i.isFinished(); ++eqcs_i) { Node eqc = (*eqcs_i); isArray = eqc.getType().isArray(); @@ -1072,7 +1107,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) continue; } computeRep = false; - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); for (; !eqc_i.isFinished(); ++eqc_i) { Node n = *eqc_i; // If this EC is an array type and it contains something other than STORE nodes, we have to compute a representative explicitly @@ -1095,30 +1130,36 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) bool changed; do { changed = false; - eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eqcs_i = eq::EqClassesIterator(d_equalityEngine); for (; !eqcs_i.isFinished(); ++eqcs_i) { Node eqc = (*eqcs_i); - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); for (; !eqc_i.isFinished(); ++eqc_i) { Node n = *eqc_i; if (n.getKind() == kind::SELECT && termSet.find(n) != termSet.end()) { // Find all terms equivalent to n[0] and get corresponding read terms - Node array_eqc = d_equalityEngine.getRepresentative(n[0]); - eq::EqClassIterator array_eqc_i = eq::EqClassIterator(array_eqc, &d_equalityEngine); + Node array_eqc = d_equalityEngine->getRepresentative(n[0]); + eq::EqClassIterator array_eqc_i = + eq::EqClassIterator(array_eqc, d_equalityEngine); for (; !array_eqc_i.isFinished(); ++array_eqc_i) { Node arr = *array_eqc_i; - if (arr.getKind() == kind::STORE && - termSet.find(arr) != termSet.end() && - !d_equalityEngine.areEqual(arr[1],n[1])) { + if (arr.getKind() == kind::STORE + && termSet.find(arr) != termSet.end() + && !d_equalityEngine->areEqual(arr[1], n[1])) + { Node r = nm->mkNode(kind::SELECT, arr, n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(a) read: " << r << endl; termSet.insert(r); changed = true; } r = nm->mkNode(kind::SELECT, arr[0], n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(b) read: " << r << endl; termSet.insert(r); changed = true; @@ -1132,16 +1173,21 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) for(; it < instores->size(); ++it) { TNode instore = (*instores)[it]; Assert(instore.getKind() == kind::STORE); - if (termSet.find(instore) != termSet.end() && - !d_equalityEngine.areEqual(instore[1],n[1])) { + if (termSet.find(instore) != termSet.end() + && !d_equalityEngine->areEqual(instore[1], n[1])) + { Node r = nm->mkNode(kind::SELECT, instore, n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(c) read: " << r << endl; termSet.insert(r); changed = true; } r = nm->mkNode(kind::SELECT, instore[0], n[1]); - if (termSet.find(r) == termSet.end() && d_equalityEngine.hasTerm(r)) { + if (termSet.find(r) == termSet.end() + && d_equalityEngine->hasTerm(r)) + { Trace("arrays::collectModelInfo") << "TheoryArrays::collectModelInfo, adding RIntro2(d) read: " << r << endl; termSet.insert(r); changed = true; @@ -1154,7 +1200,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) } while (changed); // Send the equality engine information to the model - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { return false; } @@ -1166,7 +1212,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) Node n = *set_it; // If this term is a select, record that the EC rep of its store parameter is being read from using this term if (n.getKind() == kind::SELECT) { - selects[d_equalityEngine.getRepresentative(n[0])].push_back(n); + selects[d_equalityEngine->getRepresentative(n[0])].push_back(n); } } @@ -1177,7 +1223,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) // Compute all default values already in use //if (fullModel) { for (size_t i=0; igetRepresentative(arrays[i]); d_mayEqualEqualityEngine.addTerm(nrep); // add the term in case it isn't there already TNode mayRep = d_mayEqualEqualityEngine.getRepresentative(nrep); it = d_defValues.find(mayRep); @@ -1190,7 +1236,7 @@ bool TheoryArrays::collectModelInfo(TheoryModel* m) // Loop through all array equivalence classes that need a representative computed for (size_t i=0; igetRepresentative(n); //if (fullModel) { // Compute default value for this array - there is one default value for every mayEqual equivalence class @@ -1280,9 +1326,9 @@ Node TheoryArrays::getSkolem(TNode ref, const string& name, const TypeNode& type } else { skolem = (*it).second; - if (d_equalityEngine.hasTerm(ref) && - d_equalityEngine.hasTerm(skolem) && - d_equalityEngine.areEqual(ref, skolem)) { + if (d_equalityEngine->hasTerm(ref) && d_equalityEngine->hasTerm(skolem) + && d_equalityEngine->areEqual(ref, skolem)) + { makeEqual = false; } } @@ -1294,7 +1340,7 @@ Node TheoryArrays::getSkolem(TNode ref, const string& name, const TypeNode& type if (makeEqual) { Node d = skolem.eqNode(ref); Debug("arrays-model-based") << "Asserting skolem equality " << d << endl; - d_equalityEngine.assertEquality(d, true, d_true); + d_equalityEngine->assertEquality(d, true, d_true); Assert(!d_conflict); d_skolemAssertions.push_back(d); d_skolemIndex = d_skolemIndex + 1; @@ -1328,13 +1374,15 @@ void TheoryArrays::check(Effort e) { if (!assertion.d_isPreregistered) { if (atom.getKind() == kind::EQUAL) { - if (!d_equalityEngine.hasTerm(atom[0])) { + if (!d_equalityEngine->hasTerm(atom[0])) + { Assert(atom[0].isConst()); - d_equalityEngine.addTerm(atom[0]); + d_equalityEngine->addTerm(atom[0]); } - if (!d_equalityEngine.hasTerm(atom[1])) { + if (!d_equalityEngine->hasTerm(atom[1])) + { Assert(atom[1].isConst()); - d_equalityEngine.addTerm(atom[1]); + d_equalityEngine->addTerm(atom[1]); } } } @@ -1342,17 +1390,19 @@ void TheoryArrays::check(Effort e) { // Do the work switch (fact.getKind()) { case kind::EQUAL: - d_equalityEngine.assertEquality(fact, true, fact); + d_equalityEngine->assertEquality(fact, true, fact); break; case kind::SELECT: - d_equalityEngine.assertPredicate(fact, true, fact); + d_equalityEngine->assertPredicate(fact, true, fact); break; case kind::NOT: if (fact[0].getKind() == kind::SELECT) { - d_equalityEngine.assertPredicate(fact[0], false, fact); - } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1], false)) { + d_equalityEngine->assertPredicate(fact[0], false, fact); + } + else if (!d_equalityEngine->areDisequal(fact[0][0], fact[0][1], false)) + { // Assert the dis-equality - d_equalityEngine.assertEquality(fact[0], false, fact); + d_equalityEngine->assertEquality(fact[0], false, fact); // Apply ArrDiseq Rule if diseq is between arrays if(fact[0][0].getType().isArray() && !d_conflict) { @@ -1396,18 +1446,26 @@ void TheoryArrays::check(Effort e) { // when we output the lemma. However, in replay need the lemma to be propagated, and so we // preregister manually. if (d_proofsEnabled) { - if (!d_equalityEngine.hasTerm(ak)) { preRegisterTermInternal(ak); } - if (!d_equalityEngine.hasTerm(bk)) { preRegisterTermInternal(bk); } + if (!d_equalityEngine->hasTerm(ak)) + { + preRegisterTermInternal(ak); + } + if (!d_equalityEngine->hasTerm(bk)) + { + preRegisterTermInternal(bk); + } } - if (options::arraysPropagate() > 0 && d_equalityEngine.hasTerm(ak) && d_equalityEngine.hasTerm(bk)) { + if (options::arraysPropagate() > 0 && d_equalityEngine->hasTerm(ak) + && d_equalityEngine->hasTerm(bk)) + { // Propagate witness disequality - might produce a conflict d_permRef.push_back(lemma); Debug("pf::array") << "Asserting to the equality engine:" << std::endl << "\teq = " << eq << std::endl << "\treason = " << fact << std::endl; - d_equalityEngine.assertEquality(eq, false, fact, d_reasonExt); + d_equalityEngine->assertEquality(eq, false, fact, d_reasonExt); ++d_numProp; } @@ -1465,7 +1523,7 @@ void TheoryArrays::check(Effort e) { // Find the bucket for this read. mayRep = d_mayEqualEqualityEngine.getRepresentative(r[0]); - iRep = d_equalityEngine.getRepresentative(r[1]); + iRep = d_equalityEngine->getRepresentative(r[1]); std::pair key(mayRep, iRep); ReadBucketMap::iterator rbm_it = d_readBucketTable.find(key); if (rbm_it == d_readBucketTable.end()) @@ -1484,20 +1542,21 @@ void TheoryArrays::check(Effort e) { const TNode& r2 = *ctnl_it; Assert(r2.getKind() == kind::SELECT); Assert(mayRep == d_mayEqualEqualityEngine.getRepresentative(r2[0])); - Assert(iRep == d_equalityEngine.getRepresentative(r2[1])); - if (d_equalityEngine.areEqual(r, r2)) { + Assert(iRep == d_equalityEngine->getRepresentative(r2[1])); + if (d_equalityEngine->areEqual(r, r2)) + { continue; } if (weakEquivGetRepIndex(r[0], r[1]) == weakEquivGetRepIndex(r2[0], r[1])) { // add lemma: r[1] = r2[1] /\ cond(r[0],r2[0]) => r = r2 vector conjunctions; - Assert(d_equalityEngine.areEqual(r, Rewriter::rewrite(r))); - Assert(d_equalityEngine.areEqual(r2, Rewriter::rewrite(r2))); + Assert(d_equalityEngine->areEqual(r, Rewriter::rewrite(r))); + Assert(d_equalityEngine->areEqual(r2, Rewriter::rewrite(r2))); Node lemma = Rewriter::rewrite(r).eqNode(Rewriter::rewrite(r2)).negate(); d_permRef.push_back(lemma); conjunctions.push_back(lemma); if (r[1] != r2[1]) { - d_equalityEngine.explainEquality(r[1], r2[1], true, conjunctions); + d_equalityEngine->explainEquality(r[1], r2[1], true, conjunctions); } // TODO: get smaller lemmas by eliminating shared parts of path weakEquivBuildCond(r[0], r[1], conjunctions); @@ -1648,8 +1707,8 @@ void TheoryArrays::mergeArrays(TNode a, TNode b) // Normally, a is its own representative, but it's possible for a to have // been merged with another array after it got queued up by the equality engine, // so we take its representative to be safe. - a = d_equalityEngine.getRepresentative(a); - Assert(d_equalityEngine.getRepresentative(b) == a); + a = d_equalityEngine->getRepresentative(a); + Assert(d_equalityEngine->getRepresentative(b) == a); Trace("arrays-merge") << spaces(getSatContext()->getLevel()) << "Arrays::merge: (" << a << ", " << b << ")\n"; if (options::arraysOptimizeLinear() && !options::arraysWeakEquivalence()) { @@ -1759,7 +1818,7 @@ void TheoryArrays::checkStore(TNode a) { TNode b = a[0]; TNode i = a[1]; - TNode brep = d_equalityEngine.getRepresentative(b); + TNode brep = d_equalityEngine->getRepresentative(b); if (!options::arraysOptimizeLinear() || d_infoMap.isNonLinear(brep)) { const CTNodeList* js = d_infoMap.getIndices(brep); @@ -1786,17 +1845,18 @@ void TheoryArrays::checkRowForIndex(TNode i, TNode a) d_infoMap.getInfo(a)->print(); } Assert(a.getType().isArray()); - Assert(d_equalityEngine.getRepresentative(a) == a); + Assert(d_equalityEngine->getRepresentative(a) == a); TNode constArr = d_infoMap.getConstArr(a); if (!constArr.isNull()) { ArrayStoreAll storeAll = constArr.getConst(); Node defValue = storeAll.getValue(); Node selConst = NodeManager::currentNM()->mkNode(kind::SELECT, constArr, i); - if (!d_equalityEngine.hasTerm(selConst)) { + if (!d_equalityEngine->hasTerm(selConst)) + { preRegisterTermInternal(selConst); } - d_equalityEngine.assertEquality(selConst.eqNode(defValue), true, d_true); + d_equalityEngine->assertEquality(selConst.eqNode(defValue), true, d_true); } const CTNodeList* stores = d_infoMap.getStores(a); @@ -1848,7 +1908,8 @@ void TheoryArrays::checkRowLemmas(TNode a, TNode b) for( ; it < i_a->size(); ++it) { TNode i = (*i_a)[it]; Node selConst = NodeManager::currentNM()->mkNode(kind::SELECT, constArr, i); - if (!d_equalityEngine.hasTerm(selConst)) { + if (!d_equalityEngine->hasTerm(selConst)) + { preRegisterTermInternal(selConst); } } @@ -1901,8 +1962,8 @@ void TheoryArrays::propagate(RowLemmaType lem) std::tie(a, b, i, j) = lem; Assert(a.getType().isArray() && b.getType().isArray()); - if (d_equalityEngine.areEqual(a,b) || - d_equalityEngine.areEqual(i,j)) { + if (d_equalityEngine->areEqual(a, b) || d_equalityEngine->areEqual(i, j)) + { return; } @@ -1911,14 +1972,15 @@ void TheoryArrays::propagate(RowLemmaType lem) Node bj = nm->mkNode(kind::SELECT, b, j); // Try to avoid introducing new read terms: track whether these already exist - bool ajExists = d_equalityEngine.hasTerm(aj); - bool bjExists = d_equalityEngine.hasTerm(bj); + bool ajExists = d_equalityEngine->hasTerm(aj); + bool bjExists = d_equalityEngine->hasTerm(bj); bool bothExist = ajExists && bjExists; // If propagating, check propagations int prop = options::arraysPropagate(); if (prop > 0) { - if (d_equalityEngine.areDisequal(i,j,true) && (bothExist || prop > 1)) { + if (d_equalityEngine->areDisequal(i, j, true) && (bothExist || prop > 1)) + { Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating aj = bj ("<assertEquality(aj_eq_bj, true, reason, d_reasonRow); ++d_numProp; return; } - if (bothExist && d_equalityEngine.areDisequal(aj,bj,true)) { + if (bothExist && d_equalityEngine->areDisequal(aj, bj, true)) + { Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating i = j ("<assertEquality(i_eq_j, true, reason, d_reasonRow); ++d_numProp; return; } @@ -1958,8 +2021,8 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) std::tie(a, b, i, j) = lem; Assert(a.getType().isArray() && b.getType().isArray()); - if (d_equalityEngine.areEqual(a,b) || - d_equalityEngine.areEqual(i,j)) { + if (d_equalityEngine->areEqual(a, b) || d_equalityEngine->areEqual(i, j)) + { return; } @@ -1968,8 +2031,8 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) Node bj = nm->mkNode(kind::SELECT, b, j); // Try to avoid introducing new read terms: track whether these already exist - bool ajExists = d_equalityEngine.hasTerm(aj); - bool bjExists = d_equalityEngine.hasTerm(bj); + bool ajExists = d_equalityEngine->hasTerm(aj); + bool bjExists = d_equalityEngine->hasTerm(bj); bool bothExist = ajExists && bjExists; // If propagating, check propagations @@ -1981,13 +2044,16 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) // If equivalent lemma already exists, don't enqueue this one if (d_useArrTable) { Node tableEntry = NodeManager::currentNM()->mkNode(kind::ARR_TABLE_FUN, a, b, i, j); - if (d_equalityEngine.getSize(tableEntry) != 1) { + if (d_equalityEngine->getSize(tableEntry) != 1) + { return; } } // Prefer equality between indexes so as not to introduce new read terms - if (options::arraysEagerIndexSplitting() && !bothExist && !d_equalityEngine.areDisequal(i,j, false)) { + if (options::arraysEagerIndexSplitting() && !bothExist + && !d_equalityEngine->areDisequal(i, j, false)) + { Node i_eq_j; if (!d_proofsEnabled) { i_eq_j = d_valuation.ensureLiteral(i.eqNode(j)); // TODO: think about this @@ -2008,20 +2074,22 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) if (!ajExists) { preRegisterTermInternal(aj); } - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - d_equalityEngine.assertEquality(aj.eqNode(aj2), true, d_true); + d_equalityEngine->assertEquality(aj.eqNode(aj2), true, d_true); } Node bj2 = Rewriter::rewrite(bj); if (bj != bj2) { if (!bjExists) { preRegisterTermInternal(bj); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(bj.eqNode(bj2), true, d_true); + d_equalityEngine->assertEquality(bj.eqNode(bj2), true, d_true); } if (aj2 == bj2) { return; @@ -2031,20 +2099,22 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) Node eq1 = aj2.eqNode(bj2); Node eq1_r = Rewriter::rewrite(eq1); if (eq1_r == d_true) { - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(eq1, true, d_true); + d_equalityEngine->assertEquality(eq1, true, d_true); return; } Node eq2 = i.eqNode(j); Node eq2_r = Rewriter::rewrite(eq2); if (eq2_r == d_true) { - d_equalityEngine.assertEquality(eq2, true, d_true); + d_equalityEngine->assertEquality(eq2, true, d_true); return; } @@ -2089,14 +2159,16 @@ bool TheoryArrays::dischargeLemmas() NodeManager* nm = NodeManager::currentNM(); Node aj = nm->mkNode(kind::SELECT, a, j); Node bj = nm->mkNode(kind::SELECT, b, j); - bool ajExists = d_equalityEngine.hasTerm(aj); - bool bjExists = d_equalityEngine.hasTerm(bj); + bool ajExists = d_equalityEngine->hasTerm(aj); + bool bjExists = d_equalityEngine->hasTerm(bj); // Check for redundant lemma // TODO: more checks possible (i.e. check d_RowAlreadyAdded in context) - if (!d_equalityEngine.hasTerm(i) || !d_equalityEngine.hasTerm(j) || d_equalityEngine.areEqual(i,j) || - !d_equalityEngine.hasTerm(a) || !d_equalityEngine.hasTerm(b) || d_equalityEngine.areEqual(a,b) || - (ajExists && bjExists && d_equalityEngine.areEqual(aj,bj))) { + if (!d_equalityEngine->hasTerm(i) || !d_equalityEngine->hasTerm(j) + || d_equalityEngine->areEqual(i, j) || !d_equalityEngine->hasTerm(a) + || !d_equalityEngine->hasTerm(b) || d_equalityEngine->areEqual(a, b) + || (ajExists && bjExists && d_equalityEngine->areEqual(aj, bj))) + { continue; } @@ -2114,21 +2186,22 @@ bool TheoryArrays::dischargeLemmas() if (!ajExists) { preRegisterTermInternal(aj); } - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - d_equalityEngine.assertEquality(aj.eqNode(aj2), true, d_true); + d_equalityEngine->assertEquality(aj.eqNode(aj2), true, d_true); } Node bj2 = Rewriter::rewrite(bj); if (bj != bj2) { if (!bjExists) { preRegisterTermInternal(bj); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(bj.eqNode(bj2), true, d_true); - + d_equalityEngine->assertEquality(bj.eqNode(bj2), true, d_true); } if (aj2 == bj2) { continue; @@ -2138,20 +2211,22 @@ bool TheoryArrays::dischargeLemmas() Node eq1 = aj2.eqNode(bj2); Node eq1_r = Rewriter::rewrite(eq1); if (eq1_r == d_true) { - if (!d_equalityEngine.hasTerm(aj2)) { + if (!d_equalityEngine->hasTerm(aj2)) + { preRegisterTermInternal(aj2); } - if (!d_equalityEngine.hasTerm(bj2)) { + if (!d_equalityEngine->hasTerm(bj2)) + { preRegisterTermInternal(bj2); } - d_equalityEngine.assertEquality(eq1, true, d_true); + d_equalityEngine->assertEquality(eq1, true, d_true); continue; } Node eq2 = i.eqNode(j); Node eq2_r = Rewriter::rewrite(eq2); if (eq2_r == d_true) { - d_equalityEngine.assertEquality(eq2, true, d_true); + d_equalityEngine->assertEquality(eq2, true, d_true); continue; } diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index 116b0f43b..f1cd2ea14 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -148,9 +148,18 @@ class TheoryArrays : public Theory { std::string name = ""); ~TheoryArrays(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization std::string identify() const override { return std::string("TheoryArrays"); } @@ -353,9 +362,6 @@ class TheoryArrays : public Theory { /** The notify class for d_equalityEngine */ NotifyClass d_notify; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; - /** Are we in conflict? */ context::CDO d_conflict; @@ -460,7 +466,7 @@ class TheoryArrays : public Theory { int d_topLevel; /** An equality-engine callback for proof reconstruction */ - ArrayProofReconstruction d_proofReconstruction; + std::unique_ptr d_proofReconstruction; /** * The decision strategy for the theory of arrays, which calls the @@ -493,9 +499,6 @@ class TheoryArrays : public Theory { */ Node getNextDecisionRequest(); - public: - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } - };/* class TheoryArrays */ }/* CVC4::theory::arrays namespace */ diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index c49909fe6..48ec81a1e 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -35,55 +35,65 @@ using namespace CVC4::theory::bv::utils; CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt) : SubtheorySolver(c, bv), d_notify(*this), - d_equalityEngine(d_notify, c, "theory::bv::ee", true), d_slicer(new Slicer()), d_isComplete(c, true), d_lemmaThreshold(16), d_useSlicer(false), d_preregisterCalled(false), d_checkCalled(false), + d_bv(bv), d_extTheory(extt), d_reasons(c) { - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT, true); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_AND); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_OR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NAND); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XNOR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_COMP); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_MULT, true); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_PLUS, true); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_EXTRACT, true); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SUB); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NEG); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UDIV); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UREM); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SDIV); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SREM); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SMOD); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SHL); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_LSHR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ASHR); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLE); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGT); - // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGE); - d_equalityEngine.addFunctionKind(kind::BITVECTOR_TO_NAT); - d_equalityEngine.addFunctionKind(kind::INT_TO_BITVECTOR); } CoreSolver::~CoreSolver() {} -void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +bool CoreSolver::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::bv::ee"; + return true; +} + +void CoreSolver::finishInit() +{ + // use the parent's equality engine, which may be the one we allocated above + d_equalityEngine = d_bv->getEqualityEngine(); + + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::BITVECTOR_CONCAT, true); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_AND); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_OR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_XOR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NOT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NAND); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NOR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_XNOR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_COMP); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_MULT, true); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_PLUS, true); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_EXTRACT, true); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SUB); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_NEG); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UDIV); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UREM); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SDIV); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SREM); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SMOD); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SHL); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_LSHR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ASHR); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ULT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_ULE); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UGT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_UGE); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SLT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SLE); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SGT); + // d_equalityEngine->addFunctionKind(kind::BITVECTOR_SGE); + d_equalityEngine->addFunctionKind(kind::BITVECTOR_TO_NAT); + d_equalityEngine->addFunctionKind(kind::INT_TO_BITVECTOR); } void CoreSolver::enableSlicer() { @@ -95,13 +105,14 @@ void CoreSolver::enableSlicer() { void CoreSolver::preRegister(TNode node) { d_preregisterCalled = true; if (node.getKind() == kind::EQUAL) { - d_equalityEngine.addTriggerEquality(node); - if (d_useSlicer) { - d_slicer->processEquality(node); - AlwaysAssert(!d_checkCalled); + d_equalityEngine->addTriggerEquality(node); + if (d_useSlicer) + { + d_slicer->processEquality(node); + AlwaysAssert(!d_checkCalled); } } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); // Register with the extended theory, for context-dependent simplification. // Notice we do this for registered terms but not internally generated // equivalence classes. The two should roughly cooincide. Since ExtTheory is @@ -115,9 +126,9 @@ void CoreSolver::explain(TNode literal, std::vector& assumptions) { bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions); + d_equalityEngine->explainPredicate(atom, polarity, assumptions); } } @@ -224,14 +235,14 @@ void CoreSolver::buildModel() TNodeSet constants; TNodeSet constants_in_eq_engine; // collect constants in equality engine - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while (!eqcs_i.isFinished()) { TNode repr = *eqcs_i; if (repr.getKind() == kind::CONST_BITVECTOR) { // must check if it's just the constant - eq::EqClassIterator it(repr, &d_equalityEngine); + eq::EqClassIterator it(repr, d_equalityEngine); if (!(++it).isFinished() || true) { constants.insert(repr); @@ -243,7 +254,7 @@ void CoreSolver::buildModel() // build repr to value map - eqcs_i = eq::EqClassesIterator(&d_equalityEngine); + eqcs_i = eq::EqClassesIterator(d_equalityEngine); while (!eqcs_i.isFinished()) { TNode repr = *eqcs_i; @@ -351,15 +362,16 @@ bool CoreSolver::assertFactToEqualityEngine(TNode fact, TNode reason) { if (predicate.getKind() == kind::EQUAL) { if (negated) { // dis-equality - d_equalityEngine.assertEquality(predicate, false, reason); + d_equalityEngine->assertEquality(predicate, false, reason); } else { // equality - d_equalityEngine.assertEquality(predicate, true, reason); + d_equalityEngine->assertEquality(predicate, true, reason); } } else { // Adding predicate if the congruence over it is turned on - if (d_equalityEngine.isFunctionKind(predicate.getKind())) { - d_equalityEngine.assertPredicate(predicate, !negated, reason); + if (d_equalityEngine->isFunctionKind(predicate.getKind())) + { + d_equalityEngine->assertPredicate(predicate, !negated, reason); } } } @@ -408,7 +420,7 @@ bool CoreSolver::storePropagation(TNode literal) { void CoreSolver::conflict(TNode a, TNode b) { std::vector assumptions; - d_equalityEngine.explainEquality(a, b, true, assumptions); + d_equalityEngine->explainEquality(a, b, true, assumptions); Node conflict = flattenAnd(assumptions); d_bv->setConflict(conflict); } @@ -434,7 +446,7 @@ bool CoreSolver::collectModelInfo(TheoryModel* m, bool fullModel) } set termSet; d_bv->computeRelevantTerms(termSet); - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { return false; } @@ -457,7 +469,7 @@ bool CoreSolver::collectModelInfo(TheoryModel* m, bool fullModel) Node CoreSolver::getModelValue(TNode var) { Debug("bitvector-model") << "CoreSolver::getModelValue (" << var <<")"; Assert(isComplete()); - TNode repr = d_equalityEngine.getRepresentative(var); + TNode repr = d_equalityEngine->getRepresentative(var); Node result = Node(); if (repr.getKind() == kind::CONST_BITVECTOR) { result = repr; @@ -472,6 +484,35 @@ Node CoreSolver::getModelValue(TNode var) { return result; } +void CoreSolver::addSharedTerm(TNode t) +{ + d_equalityEngine->addTriggerTerm(t, THEORY_BV); +} + +EqualityStatus CoreSolver::getEqualityStatus(TNode a, TNode b) +{ + if (d_equalityEngine->areEqual(a, b)) + { + // The terms are implied to be equal + return EQUALITY_TRUE; + } + if (d_equalityEngine->areDisequal(a, b, false)) + { + // The terms are implied to be dis-equal + return EQUALITY_FALSE; + } + return EQUALITY_UNKNOWN; +} + +bool CoreSolver::hasTerm(TNode node) const +{ + return d_equalityEngine->hasTerm(node); +} +void CoreSolver::addTermToEqualityEngine(TNode node) +{ + d_equalityEngine->addTerm(node); +} + CoreSolver::Statistics::Statistics() : d_numCallstoCheck("theory::bv::CoreSolver::NumCallsToCheck", 0) , d_slicerEnabled("theory::bv::CoreSolver::SlicerEnabled", false) diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index ea652e7cd..33f119e5f 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -70,9 +70,6 @@ class CoreSolver : public SubtheorySolver { /** The notify class for d_equalityEngine */ NotifyClass d_notify; - /** Equality engine */ - eq::EqualityEngine d_equalityEngine; - /** Store a propagation to the bv solver */ bool storePropagation(TNode literal); @@ -88,6 +85,10 @@ class CoreSolver : public SubtheorySolver { bool d_preregisterCalled; bool d_checkCalled; + /** Pointer to the parent theory solver that owns this */ + TheoryBV* d_bv; + /** Pointer to the equality engine of the parent */ + eq::EqualityEngine* d_equalityEngine; /** Pointer to the extended theory module. */ ExtTheory* d_extTheory; @@ -100,36 +101,23 @@ class CoreSolver : public SubtheorySolver { Node getBaseDecomposition(TNode a); bool isCompleteForTerm(TNode term, TNodeBoolMap& seen); Statistics d_statistics; -public: - CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt); - ~CoreSolver(); - bool isComplete() override { return d_isComplete; } - void setMasterEqualityEngine(eq::EqualityEngine* eq); - void preRegister(TNode node) override; - bool check(Theory::Effort e) override; - void explain(TNode literal, std::vector& assumptions) override; - bool collectModelInfo(TheoryModel* m, bool fullModel) override; - Node getModelValue(TNode var) override; - void addSharedTerm(TNode t) override - { - d_equalityEngine.addTriggerTerm(t, THEORY_BV); - } - EqualityStatus getEqualityStatus(TNode a, TNode b) override - { - if (d_equalityEngine.areEqual(a, b)) { - // The terms are implied to be equal - return EQUALITY_TRUE; - } - if (d_equalityEngine.areDisequal(a, b, false)) { - // The terms are implied to be dis-equal - return EQUALITY_FALSE; - } - return EQUALITY_UNKNOWN; - } - bool hasTerm(TNode node) const { return d_equalityEngine.hasTerm(node); } - void addTermToEqualityEngine(TNode node) { d_equalityEngine.addTerm(node); } + + public: + CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt); + ~CoreSolver(); + bool needsEqualityEngine(EeSetupInfo& esi); + void finishInit(); + bool isComplete() override { return d_isComplete; } + void preRegister(TNode node) override; + bool check(Theory::Effort e) override; + void explain(TNode literal, std::vector& assumptions) override; + bool collectModelInfo(TheoryModel* m, bool fullModel) override; + Node getModelValue(TNode var) override; + void addSharedTerm(TNode t) override; + EqualityStatus getEqualityStatus(TNode a, TNode b) override; + bool hasTerm(TNode node) const; + void addTermToEqualityEngine(TNode node); void enableSlicer(); - eq::EqualityEngine * getEqualityEngine() { return &d_equalityEngine; } }; diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 0a4499c11..ced320d92 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -113,13 +113,31 @@ TheoryBV::TheoryBV(context::Context* c, TheoryBV::~TheoryBV() {} -void TheoryBV::setMasterEqualityEngine(eq::EqualityEngine* eq) { - if (options::bitblastMode() == options::BitblastMode::EAGER) +TheoryRewriter* TheoryBV::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryBV::needsEqualityEngine(EeSetupInfo& esi) +{ + CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; + if (core) { - return; + return core->needsEqualityEngine(esi); } - if (options::bitvectorEqualitySolver()) { - dynamic_cast(d_subtheoryMap[SUB_CORE])->setMasterEqualityEngine(eq); + // otherwise we don't use an equality engine + return false; +} + +void TheoryBV::finishInit() +{ + // these kinds are semi-evaluated in getModelValue (applications of this + // kind are treated as variables) + d_valuation.setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UDIV); + d_valuation.setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UREM); + + CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; + if (core) + { + // must finish initialization in the core solver + core->finishInit(); } } @@ -185,16 +203,6 @@ Node TheoryBV::getBVDivByZero(Kind k, unsigned width) { Unreachable(); } -void TheoryBV::finishInit() -{ - // these kinds are semi-evaluated in getModelValue (applications of this - // kind are treated as variables) - TheoryModel* tm = d_valuation.getModel(); - Assert(tm != nullptr); - tm->setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UDIV); - tm->setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UREM); -} - TrustNode TheoryBV::expandDefinition(Node node) { Debug("bitvector-expandDefinition") << "TheoryBV::expandDefinition(" << node << ")" << std::endl; @@ -582,16 +590,6 @@ void TheoryBV::propagate(Effort e) { } } - -eq::EqualityEngine * TheoryBV::getEqualityEngine() { - CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE]; - if( core ){ - return core->getEqualityEngine(); - }else{ - return NULL; - } -} - bool TheoryBV::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) { eq::EqualityEngine * ee = getEqualityEngine(); if( ee ){ diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index b0991c8b0..0e8877359 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -77,11 +77,18 @@ class TheoryBV : public Theory { ~TheoryBV(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ void finishInit() override; + //--------------------------------- end initialization TrustNode expandDefinition(Node node) override; @@ -99,8 +106,6 @@ class TheoryBV : public Theory { std::string identify() const override { return std::string("TheoryBV"); } - /** equality engine */ - eq::EqualityEngine* getEqualityEngine() override; bool getCurrentSubstitution(int effort, std::vector& vars, std::vector& subs, diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 832324d4b..4b38ad6bd 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -51,7 +51,6 @@ TheoryDatatypes::TheoryDatatypes(Context* c, d_infer_exp(c), d_term_sk(u), d_notify(*this), - d_equalityEngine(d_notify, c, "theory::datatypes", true), d_labels(c), d_selector_apps(c), d_conflict(c, false), @@ -64,13 +63,6 @@ TheoryDatatypes::TheoryDatatypes(Context* c, d_lemmas_produced_c(u), d_sygusExtension(nullptr) { - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::APPLY_CONSTRUCTOR); - d_equalityEngine.addFunctionKind(kind::APPLY_SELECTOR_TOTAL); - //d_equalityEngine.addFunctionKind(kind::DT_SIZE); - //d_equalityEngine.addFunctionKind(kind::DT_HEIGHT_BOUND); - d_equalityEngine.addFunctionKind(kind::APPLY_TESTER); - //d_equalityEngine.addFunctionKind(kind::APPLY_UF); d_true = NodeManager::currentNM()->mkConst( true ); d_zero = NodeManager::currentNM()->mkConst( Rational(0) ); @@ -86,8 +78,32 @@ TheoryDatatypes::~TheoryDatatypes() { } } -void TheoryDatatypes::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheoryDatatypes::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryDatatypes::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::datatypes::ee"; + return true; +} + +void TheoryDatatypes::finishInit() +{ + Assert(d_equalityEngine != nullptr); + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::APPLY_CONSTRUCTOR); + d_equalityEngine->addFunctionKind(kind::APPLY_SELECTOR_TOTAL); + d_equalityEngine->addFunctionKind(kind::APPLY_TESTER); + // We could but don't do congruence for DT_SIZE and DT_HEIGHT_BOUND here. + // It also could make sense in practice to do congruence for APPLY_UF, but + // this is not done. + if (getQuantifiersEngine() && options::sygus()) + { + d_sygusExtension.reset( + new SygusExtension(this, getQuantifiersEngine(), getSatContext())); + // do congruence on evaluation functions + d_equalityEngine->addFunctionKind(kind::DT_SYGUS_EVAL); + } } TheoryDatatypes::EqcInfo* TheoryDatatypes::getOrMakeEqcInfo( TNode n, bool doMake ){ @@ -193,7 +209,7 @@ void TheoryDatatypes::check(Effort e) { do { d_addedFact = false; std::map< TypeNode, Node > rec_singletons; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ Node n = (*eqcs_i); //TODO : avoid irrelevant (pre-registered but not asserted) terms here? @@ -479,9 +495,9 @@ void TheoryDatatypes::assertFact( Node fact, Node exp ){ bool polarity = fact.getKind() != kind::NOT; TNode atom = polarity ? fact : fact[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.assertEquality( atom, polarity, exp ); + d_equalityEngine->assertEquality(atom, polarity, exp); }else{ - d_equalityEngine.assertPredicate( atom, polarity, exp ); + d_equalityEngine->assertPredicate(atom, polarity, exp); } doPendingMerges(); // could be sygus-specific @@ -527,37 +543,27 @@ void TheoryDatatypes::preRegisterTerm(TNode n) { switch (n.getKind()) { case kind::EQUAL: // Add the trigger for equality - d_equalityEngine.addTriggerEquality(n); + d_equalityEngine->addTriggerEquality(n); break; case kind::APPLY_TESTER: // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerPredicate(n); + d_equalityEngine->addTriggerPredicate(n); break; default: // Function applications/predicates - d_equalityEngine.addTerm(n); + d_equalityEngine->addTerm(n); if (d_sygusExtension) { std::vector< Node > lemmas; d_sygusExtension->preRegisterTerm(n, lemmas); doSendLemmas( lemmas ); } - //d_equalityEngine.addTriggerTerm(n, THEORY_DATATYPES); + // d_equalityEngine->addTriggerTerm(n, THEORY_DATATYPES); break; } flushPendingFacts(); } -void TheoryDatatypes::finishInit() { - if (getQuantifiersEngine() && options::sygus()) - { - d_sygusExtension.reset( - new SygusExtension(this, getQuantifiersEngine(), getSatContext())); - // do congruence on evaluation functions - d_equalityEngine.addFunctionKind(kind::DT_SYGUS_EVAL); - } -} - TrustNode TheoryDatatypes::expandDefinition(Node n) { NodeManager* nm = NodeManager::currentNM(); @@ -727,7 +733,7 @@ TrustNode TheoryDatatypes::ppRewrite(TNode in) void TheoryDatatypes::addSharedTerm(TNode t) { Debug("datatypes") << "TheoryDatatypes::addSharedTerm(): " << t << " " << t.getType().isBoolean() << endl; - d_equalityEngine.addTriggerTerm(t, THEORY_DATATYPES); + d_equalityEngine->addTriggerTerm(t, THEORY_DATATYPES); Debug("datatypes") << "TheoryDatatypes::addSharedTerm() finished" << std::endl; } @@ -776,14 +782,14 @@ void TheoryDatatypes::addAssumptions( std::vector& assumptions, std::vect void TheoryDatatypes::explainEquality( TNode a, TNode b, bool polarity, std::vector& assumptions ) { if( a!=b ){ std::vector tassumptions; - d_equalityEngine.explainEquality(a, b, polarity, tassumptions); + d_equalityEngine->explainEquality(a, b, polarity, tassumptions); addAssumptions( assumptions, tassumptions ); } } void TheoryDatatypes::explainPredicate( TNode p, bool polarity, std::vector& assumptions ) { std::vector tassumptions; - d_equalityEngine.explainPredicate(p, polarity, tassumptions); + d_equalityEngine->explainPredicate(p, polarity, tassumptions); addAssumptions( assumptions, tassumptions ); } @@ -1367,12 +1373,14 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) { } EqualityStatus TheoryDatatypes::getEqualityStatus(TNode a, TNode b){ - Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)); - if (d_equalityEngine.areEqual(a, b)) { + Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)); + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - if (d_equalityEngine.areDisequal(a, b, false)) { + if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -1395,15 +1403,20 @@ void TheoryDatatypes::addCarePairs(TNodeTrie* t1, for (unsigned k = 0; k < f1.getNumChildren(); ++ k) { TNode x = f1[k]; TNode y = f2[k]; - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); Assert(!areDisequal(x, y)); Assert(!areCareDisequal(x, y)); - if( !d_equalityEngine.areEqual( x, y ) ){ + if (!d_equalityEngine->areEqual(x, y)) + { Trace("dt-cg") << "Arg #" << k << " is " << x << " " << y << std::endl; - if( d_equalityEngine.isTriggerTerm(x, THEORY_DATATYPES) && d_equalityEngine.isTriggerTerm(y, THEORY_DATATYPES) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_DATATYPES); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_DATATYPES); + if (d_equalityEngine->isTriggerTerm(x, THEORY_DATATYPES) + && d_equalityEngine->isTriggerTerm(y, THEORY_DATATYPES)) + { + TNode x_shared = d_equalityEngine->getTriggerTermRepresentative( + x, THEORY_DATATYPES); + TNode y_shared = d_equalityEngine->getTriggerTermRepresentative( + y, THEORY_DATATYPES); currentPairs.push_back(make_pair(x_shared, y_shared)); } } @@ -1432,7 +1445,8 @@ void TheoryDatatypes::addCarePairs(TNodeTrie* t1, std::map::iterator it2 = it; ++it2; for( ; it2 != t1->d_data.end(); ++it2 ){ - if( !d_equalityEngine.areDisequal(it->first, it2->first, false) ){ + if (!d_equalityEngine->areDisequal(it->first, it2->first, false)) + { if( !areCareDisequal(it->first, it2->first) ){ addCarePairs( &it->second, &it2->second, arity, depth+1, n_pairs ); } @@ -1445,7 +1459,7 @@ void TheoryDatatypes::addCarePairs(TNodeTrie* t1, { for (std::pair& tt2 : t2->d_data) { - if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false)) + if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false)) { if (!areCareDisequal(tt1.first, tt2.first)) { @@ -1468,7 +1482,7 @@ void TheoryDatatypes::computeCareGraph(){ unsigned functionTerms = d_functionTerms.size(); for( unsigned i=0; ihasTerm(f1)); Trace("dt-cg-debug") << "...build for " << f1 << std::endl; //break into index based on operator, and type of first argument (since some operators are parametric) Node op = f1.getOperator(); @@ -1476,8 +1490,9 @@ void TheoryDatatypes::computeCareGraph(){ std::vector< TNode > reps; bool has_trigger_arg = false; for( unsigned j=0; jgetRepresentative(f1[j])); + if (d_equalityEngine->isTriggerTerm(f1[j], THEORY_DATATYPES)) + { has_trigger_arg = true; } } @@ -1502,7 +1517,8 @@ void TheoryDatatypes::computeCareGraph(){ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) { - Trace("dt-cmi") << "Datatypes : Collect model info " << d_equalityEngine.consistent() << std::endl; + Trace("dt-cmi") << "Datatypes : Collect model info " + << d_equalityEngine->consistent() << std::endl; Trace("dt-model") << std::endl; printModelDebug( "dt-model" ); Trace("dt-model") << std::endl; @@ -1513,13 +1529,13 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) getRelevantTerms(termSet); //combine the equality engine - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { return false; } //get all constructors - eq::EqClassesIterator eqccs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqccs_i = eq::EqClassesIterator(d_equalityEngine); std::vector< Node > cons; std::vector< Node > nodes; std::map< Node, Node > eqc_cons; @@ -1558,7 +1574,8 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) bool addCons = false; TypeNode tt = eqc.getType(); const DType& dt = tt.getDType(); - if( !d_equalityEngine.hasTerm( eqc ) ){ + if (!d_equalityEngine->hasTerm(eqc)) + { Assert(false); }else{ Trace("dt-cmi") << "NOTICE : Datatypes: no constructor in equivalence class " << eqc << std::endl; @@ -1578,12 +1595,6 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) bool cfinite = dt[ i ].isInterpretedFinite( tt ); if( pcons[i] && (r==1)==cfinite ){ neqc = utils::getInstCons(eqc, dt, i); - //for( unsigned j=0; jaddTerm(n_ic); Debug("dt-enum") << "Made instantiate cons " << n_ic << std::endl; } d_inst_map[n][index] = n_ic; @@ -1824,7 +1835,7 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){ void TheoryDatatypes::checkCycles() { Trace("datatypes-cycle-check") << "Check acyclicity" << std::endl; std::vector< Node > cdt_eqc; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ Node eqc = (*eqcs_i); TypeNode tn = eqc.getType(); @@ -2115,15 +2126,13 @@ bool TheoryDatatypes::mustCommunicateFact( Node n, Node exp ){ } } -bool TheoryDatatypes::hasTerm( TNode a ){ - return d_equalityEngine.hasTerm( a ); -} +bool TheoryDatatypes::hasTerm(TNode a) { return d_equalityEngine->hasTerm(a); } bool TheoryDatatypes::areEqual( TNode a, TNode b ){ if( a==b ){ return true; }else if( hasTerm( a ) && hasTerm( b ) ){ - return d_equalityEngine.areEqual( a, b ); + return d_equalityEngine->areEqual(a, b); }else{ return false; } @@ -2133,7 +2142,7 @@ bool TheoryDatatypes::areDisequal( TNode a, TNode b ){ if( a==b ){ return false; }else if( hasTerm( a ) && hasTerm( b ) ){ - return d_equalityEngine.areDisequal( a, b, false ); + return d_equalityEngine->areDisequal(a, b, false); }else{ //TODO : constants here? return false; @@ -2141,11 +2150,15 @@ bool TheoryDatatypes::areDisequal( TNode a, TNode b ){ } bool TheoryDatatypes::areCareDisequal( TNode x, TNode y ) { - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - if( d_equalityEngine.isTriggerTerm(x, THEORY_DATATYPES) && d_equalityEngine.isTriggerTerm(y, THEORY_DATATYPES) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_DATATYPES); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_DATATYPES); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + if (d_equalityEngine->isTriggerTerm(x, THEORY_DATATYPES) + && d_equalityEngine->isTriggerTerm(y, THEORY_DATATYPES)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_DATATYPES); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_DATATYPES); EqualityStatus eqStatus = d_valuation.getEqualityStatus(x_shared, y_shared); if( eqStatus==EQUALITY_FALSE_AND_PROPAGATED || eqStatus==EQUALITY_FALSE || eqStatus==EQUALITY_FALSE_IN_MODEL ){ return true; @@ -2156,7 +2169,7 @@ bool TheoryDatatypes::areCareDisequal( TNode x, TNode y ) { TNode TheoryDatatypes::getRepresentative( TNode a ){ if( hasTerm( a ) ){ - return d_equalityEngine.getRepresentative( a ); + return d_equalityEngine->getRepresentative(a); }else{ return a; } @@ -2172,7 +2185,7 @@ void TheoryDatatypes::printModelDebug( const char* c ){ } Trace( c ) << "Datatypes model : " << std::endl; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ Node eqc = (*eqcs_i); //if( !eqc.getType().isBoolean() ){ @@ -2182,7 +2195,7 @@ void TheoryDatatypes::printModelDebug( const char* c ){ Trace( c ) << eqc << " : " << eqc.getType() << " : " << std::endl; Trace( c ) << " { "; //add terms to model - eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine ); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); while( !eqc_i.isFinished() ){ if( (*eqc_i)!=eqc ){ Trace( c ) << (*eqc_i) << " "; @@ -2248,7 +2261,7 @@ void TheoryDatatypes::getRelevantTerms( std::set& termSet ) { << std::endl; //also include non-singleton equivalence classes TODO : revisit this - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine); while( !eqcs_i.isFinished() ){ TNode r = (*eqcs_i); bool addedFirst = false; @@ -2256,7 +2269,7 @@ void TheoryDatatypes::getRelevantTerms( std::set& termSet ) { TypeNode rtn = r.getType(); if (!rtn.isBoolean()) { - eq::EqClassIterator eqc_i = eq::EqClassIterator(r, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(r, d_equalityEngine); while (!eqc_i.isFinished()) { TNode n = (*eqc_i); @@ -2296,7 +2309,7 @@ std::pair TheoryDatatypes::entailmentCheck(TNode lit) if( atom.getKind()==APPLY_TESTER ){ Node n = atom[0]; if( hasTerm( n ) ){ - Node r = d_equalityEngine.getRepresentative( n ); + Node r = d_equalityEngine->getRepresentative(n); EqcInfo * ei = getOrMakeEqcInfo( r, false ); int l_index = getLabelIndex( ei, r ); int t_index = static_cast(utils::indexOf(atom.getOperator())); diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index 422a01f07..a68caca94 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -145,8 +145,6 @@ private: private: /** The notify class */ NotifyClass d_notify; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; /** information necessary for equivalence classes */ std::map< Node, EqcInfo* > d_eqc_info; /** map from nodes to their instantiated equivalent for each constructor type */ @@ -269,9 +267,18 @@ private: ProofNodeManager* pnm = nullptr); ~TheoryDatatypes(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization /** propagate */ void propagate(Effort effort) override; @@ -295,7 +302,6 @@ private: void check(Effort e) override; bool needsCheckLastEffort() override; void preRegisterTerm(TNode n) override; - void finishInit() override; TrustNode expandDefinition(Node n) override; TrustNode ppRewrite(TNode n) override; void presolve() override; @@ -307,8 +313,6 @@ private: { return std::string("TheoryDatatypes"); } - /** equality engine */ - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } bool getCurrentSubstitution(int effort, std::vector& vars, std::vector& subs, diff --git a/src/theory/ee_manager_distributed.cpp b/src/theory/ee_manager_distributed.cpp index 21237816f..eb12ce893 100644 --- a/src/theory/ee_manager_distributed.cpp +++ b/src/theory/ee_manager_distributed.cpp @@ -61,6 +61,7 @@ void EqEngineManagerDistributed::finishInit() } // allocate the equality engine eet.d_allocEe.reset(allocateEqualityEngine(esi, c)); + eet.d_usedEe = eet.d_allocEe.get(); } const LogicInfo& logicInfo = d_te.getLogicInfo(); diff --git a/src/theory/ee_manager_distributed.h b/src/theory/ee_manager_distributed.h index 3de1898d7..8cac225be 100644 --- a/src/theory/ee_manager_distributed.h +++ b/src/theory/ee_manager_distributed.h @@ -41,6 +41,9 @@ namespace theory { */ struct EeTheoryInfo { + EeTheoryInfo() : d_usedEe(nullptr) {} + /** The equality engine that the theory uses (if it exists) */ + eq::EqualityEngine* d_usedEe; /** The equality engine allocated by this theory (if it exists) */ std::unique_ptr d_allocEe; }; diff --git a/src/theory/fp/theory_fp.cpp b/src/theory/fp/theory_fp.cpp index a4cff8c95..f5cc16ea9 100644 --- a/src/theory/fp/theory_fp.cpp +++ b/src/theory/fp/theory_fp.cpp @@ -107,7 +107,6 @@ TheoryFp::TheoryFp(context::Context* c, ProofNodeManager* pnm) : Theory(THEORY_FP, c, u, out, valuation, logicInfo, pnm), d_notification(*this), - d_equalityEngine(d_notification, c, "theory::fp::ee", true), d_registeredTerms(u), d_conv(u), d_expansionRequested(false), @@ -122,60 +121,74 @@ TheoryFp::TheoryFp(context::Context* c, floatToRealMap(u), abstractionMap(u) { +} /* TheoryFp::TheoryFp() */ + +TheoryRewriter* TheoryFp::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryFp::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notification; + esi.d_name = "theory::fp::ee"; + return true; +} + +void TheoryFp::finishInit() +{ + Assert(d_equalityEngine != nullptr); // Kinds that are to be handled in the congruence closure - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ABS); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_NEG); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_PLUS); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_SUB); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MULT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_DIV); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_FMA); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_SQRT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_REM); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_RTI); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MIN); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MAX); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MIN_TOTAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_MAX_TOTAL); - - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_EQ); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_LEQ); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_LT); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_GEQ); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_GT); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISSN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISZ); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISINF); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISNAN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISNEG); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_ISPOS); - - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_REAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR); - d_equalityEngine.addFunctionKind( + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ABS); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_NEG); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_PLUS); + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_SUB); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MULT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_DIV); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_FMA); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_SQRT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_REM); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_RTI); + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MIN); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MAX); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MIN_TOTAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_MAX_TOTAL); + + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_EQ); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_LEQ); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_LT); + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_GEQ); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_GT); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISSN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISZ); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISINF); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISNAN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISNEG); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_ISPOS); + + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_IEEE_BITVECTOR); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_FLOATINGPOINT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_REAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_SIGNED_BITVECTOR); + d_equalityEngine->addFunctionKind( kind::FLOATINGPOINT_TO_FP_UNSIGNED_BITVECTOR); - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_FP_GENERIC); // + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_FP_GENERIC); // // Needed in parsing, should be rewritten away - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_UBV); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_SBV); // Removed - // d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_REAL); // Removed - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_UBV_TOTAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_SBV_TOTAL); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_TO_REAL_TOTAL); - - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_NAN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_INF); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_ZERO); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGN); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_EXPONENT); - d_equalityEngine.addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND); - d_equalityEngine.addFunctionKind(kind::ROUNDINGMODE_BITBLAST); -} /* TheoryFp::TheoryFp() */ + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_UBV); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_SBV); // Removed + // d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_REAL); // Removed + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_UBV_TOTAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_SBV_TOTAL); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_TO_REAL_TOTAL); + + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_NAN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_INF); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_ZERO); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGN); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_EXPONENT); + d_equalityEngine->addFunctionKind(kind::FLOATINGPOINT_COMPONENT_SIGNIFICAND); + d_equalityEngine->addFunctionKind(kind::ROUNDINGMODE_BITBLAST); +} Node TheoryFp::minUF(Node node) { Assert(node.getKind() == kind::FLOATINGPOINT_MIN); @@ -803,11 +816,11 @@ void TheoryFp::registerTerm(TNode node) { // Add to the equality engine if (k == kind::EQUAL) { - d_equalityEngine.addTriggerEquality(node); + d_equalityEngine->addTriggerEquality(node); } else { - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } // Give the expansion of classifications in terms of equalities @@ -961,22 +974,22 @@ void TheoryFp::check(Effort level) { if (negated) { Debug("fp-eq") << "TheoryFp::check(): adding dis-equality " << fact[0] << std::endl; - d_equalityEngine.assertEquality(predicate, false, fact); - + d_equalityEngine->assertEquality(predicate, false, fact); } else { Debug("fp-eq") << "TheoryFp::check(): adding equality " << fact << std::endl; - d_equalityEngine.assertEquality(predicate, true, fact); + d_equalityEngine->assertEquality(predicate, true, fact); } } else { // A system-wide invariant; predicates are registered before they are // asserted Assert(isRegistered(predicate)); - if (d_equalityEngine.isFunctionKind(predicate.getKind())) { + if (d_equalityEngine->isFunctionKind(predicate.getKind())) + { Debug("fp-eq") << "TheoryFp::check(): adding predicate " << predicate << " is " << !negated << std::endl; - d_equalityEngine.assertPredicate(predicate, !negated, fact); + d_equalityEngine->assertPredicate(predicate, !negated, fact); } } } @@ -1007,10 +1020,6 @@ void TheoryFp::check(Effort level) { } /* TheoryFp::check() */ -void TheoryFp::setMasterEqualityEngine(eq::EqualityEngine *eq) { - d_equalityEngine.setMasterEqualityEngine(eq); -} - TrustNode TheoryFp::explain(TNode n) { Trace("fp") << "TheoryFp::explain(): explain " << n << std::endl; @@ -1022,9 +1031,9 @@ TrustNode TheoryFp::explain(TNode n) bool polarity = n.getKind() != kind::NOT; TNode atom = polarity ? n : n[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions); + d_equalityEngine->explainPredicate(atom, polarity, assumptions); } Node exp = helper::buildConjunct(assumptions); @@ -1177,7 +1186,7 @@ void TheoryFp::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) { << " = " << t2 << std::endl; std::vector assumptions; - d_theorySolver.d_equalityEngine.explainEquality(t1, t2, true, assumptions); + d_theorySolver.d_equalityEngine->explainEquality(t1, t2, true, assumptions); Node conflict = helper::buildConjunct(assumptions); diff --git a/src/theory/fp/theory_fp.h b/src/theory/fp/theory_fp.h index a1dd8a731..02e7e4232 100644 --- a/src/theory/fp/theory_fp.h +++ b/src/theory/fp/theory_fp.h @@ -42,8 +42,18 @@ class TheoryFp : public Theory { Valuation valuation, const LogicInfo& logicInfo, ProofNodeManager* pnm = nullptr); - - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization TrustNode expandDefinition(Node node) override; @@ -60,8 +70,6 @@ class TheoryFp : public Theory { std::string identify() const override { return "THEORY_FP"; } - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; - TrustNode explain(TNode n) override; protected: @@ -86,7 +94,6 @@ class TheoryFp : public Theory { friend NotifyClass; NotifyClass d_notification; - eq::EqualityEngine d_equalityEngine; /** General utility **/ void registerTerm(TNode node); diff --git a/src/theory/quantifiers/theory_quantifiers.cpp b/src/theory/quantifiers/theory_quantifiers.cpp index 1475446fe..04e83032b 100644 --- a/src/theory/quantifiers/theory_quantifiers.cpp +++ b/src/theory/quantifiers/theory_quantifiers.cpp @@ -64,6 +64,7 @@ TheoryQuantifiers::TheoryQuantifiers(Context* c, TheoryQuantifiers::~TheoryQuantifiers() { } +TheoryRewriter* TheoryQuantifiers::getTheoryRewriter() { return &d_rewriter; } void TheoryQuantifiers::finishInit() { // quantifiers are not evaluated in getModelValue diff --git a/src/theory/quantifiers/theory_quantifiers.h b/src/theory/quantifiers/theory_quantifiers.h index 3168af195..c378f3537 100644 --- a/src/theory/quantifiers/theory_quantifiers.h +++ b/src/theory/quantifiers/theory_quantifiers.h @@ -42,10 +42,13 @@ class TheoryQuantifiers : public Theory { ProofNodeManager* pnm = nullptr); ~TheoryQuantifiers(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; /** finish initialization */ void finishInit() override; + //--------------------------------- end initialization + void preRegisterTerm(TNode n) override; void presolve() override; void ppNotifyAssertions(const std::vector& assertions) override; diff --git a/src/theory/sep/theory_sep.cpp b/src/theory/sep/theory_sep.cpp index 4dfdb9fa5..edb5dd0ae 100644 --- a/src/theory/sep/theory_sep.cpp +++ b/src/theory/sep/theory_sep.cpp @@ -48,7 +48,6 @@ TheorySep::TheorySep(context::Context* c, : Theory(THEORY_SEP, c, u, out, valuation, logicInfo, pnm), d_lemmas_produced_c(u), d_notify(*this), - d_equalityEngine(d_notify, c, "theory::sep::ee", true), d_conflict(c, false), d_reduce(u), d_infer(c), @@ -58,10 +57,6 @@ TheorySep::TheorySep(context::Context* c, d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); d_bounds_init = false; - - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::SEP_PTO); - //d_equalityEngine.addFunctionKind(kind::SEP_STAR); } TheorySep::~TheorySep() { @@ -70,8 +65,21 @@ TheorySep::~TheorySep() { } } -void TheorySep::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheorySep::getTheoryRewriter() { return &d_rewriter; } + +bool TheorySep::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::sep::ee"; + return true; +} + +void TheorySep::finishInit() +{ + Assert(d_equalityEngine != nullptr); + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::SEP_PTO); + // we could but don't do congruence on SEP_STAR here. } Node TheorySep::mkAnd( std::vector< TNode >& assumptions ) { @@ -126,9 +134,10 @@ void TheorySep::explain(TNode literal, std::vector& assumptions) { bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality( atom[0], atom[1], polarity, assumptions, NULL ); + d_equalityEngine->explainEquality( + atom[0], atom[1], polarity, assumptions, NULL); } else { - d_equalityEngine.explainPredicate( atom, polarity, assumptions ); + d_equalityEngine->explainPredicate(atom, polarity, assumptions); } } } @@ -155,17 +164,19 @@ TrustNode TheorySep::explain(TNode literal) void TheorySep::addSharedTerm(TNode t) { Debug("sep") << "TheorySep::addSharedTerm(" << t << ")" << std::endl; - d_equalityEngine.addTriggerTerm(t, THEORY_SEP); + d_equalityEngine->addTriggerTerm(t, THEORY_SEP); } EqualityStatus TheorySep::getEqualityStatus(TNode a, TNode b) { - Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b)); - if (d_equalityEngine.areEqual(a, b)) { + Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)); + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - else if (d_equalityEngine.areDisequal(a, b, false)) { + else if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -211,7 +222,7 @@ bool TheorySep::collectModelInfo(TheoryModel* m) computeRelevantTerms(termSet); // Send the equality engine information to the model - return m->assertEqualityEngine(&d_equalityEngine, &termSet); + return m->assertEqualityEngine(d_equalityEngine, &termSet); } void TheorySep::postProcessModel( TheoryModel* m ){ @@ -490,16 +501,16 @@ void TheorySep::check(Effort e) { if( !is_spatial ){ Trace("sep-assert") << "Asserting " << atom << ", pol = " << polarity << " to EE..." << std::endl; if( s_atom.getKind()==kind::EQUAL ){ - d_equalityEngine.assertEquality(atom, polarity, fact); + d_equalityEngine->assertEquality(atom, polarity, fact); }else{ - d_equalityEngine.assertPredicate(atom, polarity, fact); + d_equalityEngine->assertPredicate(atom, polarity, fact); } Trace("sep-assert") << "Done asserting " << atom << " to EE." << std::endl; }else if( s_atom.getKind()==kind::SEP_PTO ){ Node pto_lbl = NodeManager::currentNM()->mkNode( kind::SINGLETON, s_atom[0] ); Assert(s_lbl == pto_lbl); Trace("sep-assert") << "Asserting " << s_atom << std::endl; - d_equalityEngine.assertPredicate(s_atom, polarity, fact); + d_equalityEngine->assertPredicate(s_atom, polarity, fact); //associate the equivalence class of the lhs with this pto Node r = getRepresentative( s_lbl ); HeapAssertInfo * ei = getOrMakeEqcInfo( r, true ); @@ -619,11 +630,11 @@ void TheorySep::check(Effort e) { Trace("sep-process") << "---" << std::endl; } if(Trace.isOn("sep-eqc")) { - eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator(d_equalityEngine); Trace("sep-eqc") << "EQC:" << std::endl; while( !eqcs2_i.isFinished() ){ Node eqc = (*eqcs2_i); - eq::EqClassIterator eqc2_i = eq::EqClassIterator( eqc, &d_equalityEngine ); + eq::EqClassIterator eqc2_i = eq::EqClassIterator(eqc, d_equalityEngine); Trace("sep-eqc") << "Eqc( " << eqc << " ) : { "; while( !eqc2_i.isFinished() ) { if( (*eqc2_i)!=eqc ){ @@ -1552,22 +1563,21 @@ void TheorySep::computeLabelModel( Node lbl ) { } Node TheorySep::getRepresentative( Node t ) { - if( d_equalityEngine.hasTerm( t ) ){ - return d_equalityEngine.getRepresentative( t ); + if (d_equalityEngine->hasTerm(t)) + { + return d_equalityEngine->getRepresentative(t); }else{ return t; } } -bool TheorySep::hasTerm( Node a ){ - return d_equalityEngine.hasTerm( a ); -} +bool TheorySep::hasTerm(Node a) { return d_equalityEngine->hasTerm(a); } bool TheorySep::areEqual( Node a, Node b ){ if( a==b ){ return true; }else if( hasTerm( a ) && hasTerm( b ) ){ - return d_equalityEngine.areEqual( a, b ); + return d_equalityEngine->areEqual(a, b); }else{ return false; } @@ -1577,7 +1587,8 @@ bool TheorySep::areDisequal( Node a, Node b ){ if( a==b ){ return false; }else if( hasTerm( a ) && hasTerm( b ) ){ - if( d_equalityEngine.areDisequal( a, b, false ) ){ + if (d_equalityEngine->areDisequal(a, b, false)) + { return true; } } @@ -1743,9 +1754,9 @@ void TheorySep::doPendingFacts() { bool pol = d_pending[i].getKind()!=kind::NOT; Trace("sep-pending") << "Sep : Assert to EE : " << atom << ", pol = " << pol << std::endl; if( atom.getKind()==kind::EQUAL ){ - d_equalityEngine.assertEquality(atom, pol, d_pending_exp[i]); + d_equalityEngine->assertEquality(atom, pol, d_pending_exp[i]); }else{ - d_equalityEngine.assertPredicate(atom, pol, d_pending_exp[i]); + d_equalityEngine->assertPredicate(atom, pol, d_pending_exp[i]); } } }else{ diff --git a/src/theory/sep/theory_sep.h b/src/theory/sep/theory_sep.h index 7c6ce38c4..84a7025f0 100644 --- a/src/theory/sep/theory_sep.h +++ b/src/theory/sep/theory_sep.h @@ -74,9 +74,18 @@ class TheorySep : public Theory { ProofNodeManager* pnm = nullptr); ~TheorySep(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ + void finishInit() override; + //--------------------------------- end initialization std::string identify() const override { return std::string("TheorySep"); } @@ -202,9 +211,6 @@ class TheorySep : public Theory { /** The notify class for d_equalityEngine */ NotifyClass d_notify; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; - /** Are we in conflict? */ context::CDO d_conflict; std::vector< Node > d_pending_exp; @@ -326,7 +332,6 @@ class TheorySep : public Theory { void doPendingFacts(); public: - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } void initializeBounds(); };/* class TheorySep */ diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index bf81099a7..fd9af488f 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -35,8 +35,7 @@ TheorySets::TheorySets(context::Context* c, ProofNodeManager* pnm) : Theory(THEORY_SETS, c, u, out, valuation, logicInfo, pnm), d_internal(new TheorySetsPrivate(*this, c, u)), - d_notify(*d_internal.get()), - d_equalityEngine(d_notify, c, "theory::sets::ee", true) + d_notify(*d_internal.get()) { // Do not move me to the header. // The constructor + destructor are not in the header as d_internal is a @@ -54,29 +53,38 @@ TheoryRewriter* TheorySets::getTheoryRewriter() return d_internal->getTheoryRewriter(); } +bool TheorySets::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = "theory::sets::ee"; + return true; +} + void TheorySets::finishInit() { + Assert(d_equalityEngine != nullptr); + d_valuation.setUnevaluatedKind(COMPREHENSION); // choice is used to eliminate witness d_valuation.setUnevaluatedKind(WITNESS); // functions we are doing congruence over - d_equalityEngine.addFunctionKind(kind::SINGLETON); - d_equalityEngine.addFunctionKind(kind::UNION); - d_equalityEngine.addFunctionKind(kind::INTERSECTION); - d_equalityEngine.addFunctionKind(kind::SETMINUS); - d_equalityEngine.addFunctionKind(kind::MEMBER); - d_equalityEngine.addFunctionKind(kind::SUBSET); + d_equalityEngine->addFunctionKind(kind::SINGLETON); + d_equalityEngine->addFunctionKind(kind::UNION); + d_equalityEngine->addFunctionKind(kind::INTERSECTION); + d_equalityEngine->addFunctionKind(kind::SETMINUS); + d_equalityEngine->addFunctionKind(kind::MEMBER); + d_equalityEngine->addFunctionKind(kind::SUBSET); // relation operators - d_equalityEngine.addFunctionKind(PRODUCT); - d_equalityEngine.addFunctionKind(JOIN); - d_equalityEngine.addFunctionKind(TRANSPOSE); - d_equalityEngine.addFunctionKind(TCLOSURE); - d_equalityEngine.addFunctionKind(JOIN_IMAGE); - d_equalityEngine.addFunctionKind(IDEN); - d_equalityEngine.addFunctionKind(APPLY_CONSTRUCTOR); + d_equalityEngine->addFunctionKind(PRODUCT); + d_equalityEngine->addFunctionKind(JOIN); + d_equalityEngine->addFunctionKind(TRANSPOSE); + d_equalityEngine->addFunctionKind(TCLOSURE); + d_equalityEngine->addFunctionKind(JOIN_IMAGE); + d_equalityEngine->addFunctionKind(IDEN); + d_equalityEngine->addFunctionKind(APPLY_CONSTRUCTOR); // we do congruence over cardinality - d_equalityEngine.addFunctionKind(CARD); + d_equalityEngine->addFunctionKind(CARD); // finish initialization internally d_internal->finishInit(); @@ -198,16 +206,6 @@ bool TheorySets::isEntailed( Node n, bool pol ) { return d_internal->isEntailed( n, pol ); } -eq::EqualityEngine* TheorySets::getEqualityEngine() -{ - return &d_equalityEngine; -} - -void TheorySets::setMasterEqualityEngine(eq::EqualityEngine* eq) -{ - d_equalityEngine.setMasterEqualityEngine(eq); -} - /**************************** eq::NotifyClass *****************************/ bool TheorySets::NotifyClass::eqNotifyTriggerEquality(TNode equality, diff --git a/src/theory/sets/theory_sets.h b/src/theory/sets/theory_sets.h index 84291346b..cb8fdfbc3 100644 --- a/src/theory/sets/theory_sets.h +++ b/src/theory/sets/theory_sets.h @@ -48,6 +48,12 @@ class TheorySets : public Theory //--------------------------------- initialization /** get the official theory rewriter of this theory */ TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; /** finish initialization */ void finishInit() override; //--------------------------------- end initialization @@ -65,10 +71,7 @@ class TheorySets : public Theory PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override; void presolve() override; void propagate(Effort) override; - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; bool isEntailed(Node n, bool pol); - /* equality engine */ - virtual eq::EqualityEngine* getEqualityEngine() override; private: /** Functions to handle callbacks from equality engine */ class NotifyClass : public eq::EqualityEngineNotify @@ -92,9 +95,7 @@ class TheorySets : public Theory /** The internal theory */ std::unique_ptr d_internal; /** Instance of the above class */ - NotifyClass d_notify; - /** Equality engine */ - eq::EqualityEngine d_equalityEngine; + NotifyClass d_notify; }; /* class TheorySets */ }/* CVC4::theory::sets namespace */ diff --git a/src/theory/strings/solver_state.cpp b/src/theory/strings/solver_state.cpp index a554ac595..8634478fd 100644 --- a/src/theory/strings/solver_state.cpp +++ b/src/theory/strings/solver_state.cpp @@ -27,11 +27,10 @@ namespace strings { SolverState::SolverState(context::Context* c, context::UserContext* u, - eq::EqualityEngine& ee, Valuation& v) : d_context(c), d_ucontext(u), - d_ee(ee), + d_ee(nullptr), d_eeDisequalities(c), d_valuation(v), d_conflict(c, false), @@ -48,19 +47,25 @@ SolverState::~SolverState() } } +void SolverState::finishInit(eq::EqualityEngine* ee) +{ + Assert(ee != nullptr); + d_ee = ee; +} + context::Context* SolverState::getSatContext() const { return d_context; } context::UserContext* SolverState::getUserContext() const { return d_ucontext; } Node SolverState::getRepresentative(Node t) const { - if (d_ee.hasTerm(t)) + if (d_ee->hasTerm(t)) { - return d_ee.getRepresentative(t); + return d_ee->getRepresentative(t); } return t; } -bool SolverState::hasTerm(Node a) const { return d_ee.hasTerm(a); } +bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); } bool SolverState::areEqual(Node a, Node b) const { @@ -70,7 +75,7 @@ bool SolverState::areEqual(Node a, Node b) const } else if (hasTerm(a) && hasTerm(b)) { - return d_ee.areEqual(a, b); + return d_ee->areEqual(a, b); } return false; } @@ -83,17 +88,17 @@ bool SolverState::areDisequal(Node a, Node b) const } else if (hasTerm(a) && hasTerm(b)) { - Node ar = d_ee.getRepresentative(a); - Node br = d_ee.getRepresentative(b); + Node ar = d_ee->getRepresentative(a); + Node br = d_ee->getRepresentative(b); return (ar != br && ar.isConst() && br.isConst()) - || d_ee.areDisequal(ar, br, false); + || d_ee->areDisequal(ar, br, false); } Node ar = getRepresentative(a); Node br = getRepresentative(b); return ar != br && ar.isConst() && br.isConst(); } -eq::EqualityEngine* SolverState::getEqualityEngine() const { return &d_ee; } +eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; } const context::CDList& SolverState::getDisequalityList() const { @@ -105,7 +110,7 @@ void SolverState::eqNotifyNewClass(TNode t) Kind k = t.getKind(); if (k == STRING_LENGTH || k == STRING_TO_CODE) { - Node r = d_ee.getRepresentative(t[0]); + Node r = d_ee->getRepresentative(t[0]); EqcInfo* ei = getOrMakeEqcInfo(r); if (k == STRING_LENGTH) { @@ -317,14 +322,14 @@ void SolverState::separateByLength( NodeManager* nm = NodeManager::currentNM(); for (const Node& eqc : n) { - Assert(d_ee.getRepresentative(eqc) == eqc); + Assert(d_ee->getRepresentative(eqc) == eqc); TypeNode tnEqc = eqc.getType(); EqcInfo* ei = getOrMakeEqcInfo(eqc, false); Node lt = ei ? ei->d_lengthTerm : Node::null(); if (!lt.isNull()) { lt = nm->mkNode(STRING_LENGTH, lt); - Node r = d_ee.getRepresentative(lt); + Node r = d_ee->getRepresentative(lt); std::pair lkey(r, tnEqc); if (eqc_to_leqc.find(lkey) == eqc_to_leqc.end()) { diff --git a/src/theory/strings/solver_state.h b/src/theory/strings/solver_state.h index 8d3162b38..0322abdb7 100644 --- a/src/theory/strings/solver_state.h +++ b/src/theory/strings/solver_state.h @@ -46,9 +46,13 @@ class SolverState public: SolverState(context::Context* c, context::UserContext* u, - eq::EqualityEngine& ee, Valuation& v); ~SolverState(); + /** + * Finish initialize, ee is a pointer to the official equality engine + * of theory of strings. + */ + void finishInit(eq::EqualityEngine* ee); /** Get the SAT context */ context::Context* getSatContext() const; /** Get the user context */ @@ -186,8 +190,8 @@ class SolverState context::Context* d_context; /** Pointer to the user context object used by the theory of strings. */ context::UserContext* d_ucontext; - /** Reference to equality engine of the theory of strings. */ - eq::EqualityEngine& d_ee; + /** Pointer to equality engine of the theory of strings. */ + eq::EqualityEngine* d_ee; /** * The (SAT-context-dependent) list of disequalities that have been asserted * to the equality engine above. diff --git a/src/theory/strings/term_registry.cpp b/src/theory/strings/term_registry.cpp index f28db4c35..71b45915f 100644 --- a/src/theory/strings/term_registry.cpp +++ b/src/theory/strings/term_registry.cpp @@ -37,12 +37,10 @@ typedef expr::Attribute StringsProxyVarAttribute; TermRegistry::TermRegistry(SolverState& s, - eq::EqualityEngine& ee, OutputChannel& out, SequencesStatistics& statistics, ProofNodeManager* pnm) : d_state(s), - d_ee(ee), d_out(out), d_statistics(statistics), d_hasStrCode(false), @@ -129,6 +127,7 @@ void TermRegistry::preRegisterTerm(TNode n) { return; } + eq::EqualityEngine* ee = d_state.getEqualityEngine(); d_preregisteredTerms.insert(n); Trace("strings-preregister") << "TheoryString::preregister : " << n << std::endl; @@ -156,15 +155,15 @@ void TermRegistry::preRegisterTerm(TNode n) ss << "Equality between regular expressions is not supported"; throw LogicException(ss.str()); } - d_ee.addTriggerEquality(n); + ee->addTriggerEquality(n); return; } else if (k == STRING_IN_REGEXP) { d_out.requirePhase(n, true); - d_ee.addTriggerPredicate(n); - d_ee.addTerm(n[0]); - d_ee.addTerm(n[1]); + ee->addTriggerPredicate(n); + ee->addTerm(n[0]); + ee->addTerm(n[1]); return; } else if (k == STRING_TO_CODE) @@ -196,17 +195,17 @@ void TermRegistry::preRegisterTerm(TNode n) } } } - d_ee.addTerm(n); + ee->addTerm(n); } else if (tn.isBoolean()) { // Get triggered for both equal and dis-equal - d_ee.addTriggerPredicate(n); + ee->addTriggerPredicate(n); } else { // Function applications/predicates - d_ee.addTerm(n); + ee->addTerm(n); } // Set d_functionsTerms stores all function applications that are // relevant to theory combination. Notice that this is a subset of @@ -216,7 +215,7 @@ void TermRegistry::preRegisterTerm(TNode n) // Concatenation terms do not need to be considered here because // their arguments have string type and do not introduce any shared // terms. - if (n.hasOperator() && d_ee.isFunctionKind(k) && k != STRING_CONCAT) + if (n.hasOperator() && ee->isFunctionKind(k) && k != STRING_CONCAT) { d_functionsTerms.push_back(n); } @@ -313,7 +312,7 @@ void TermRegistry::registerType(TypeNode tn) { // preregister the empty word for the type Node emp = Word::mkEmptyWord(tn); - if (!d_ee.hasTerm(emp)) + if (!d_state.hasTerm(emp)) { preRegisterTerm(emp); } diff --git a/src/theory/strings/term_registry.h b/src/theory/strings/term_registry.h index 2048abec1..45fb40073 100644 --- a/src/theory/strings/term_registry.h +++ b/src/theory/strings/term_registry.h @@ -50,7 +50,6 @@ class TermRegistry public: TermRegistry(SolverState& s, - eq::EqualityEngine& ee, OutputChannel& out, SequencesStatistics& statistics, ProofNodeManager* pnm); @@ -220,8 +219,6 @@ class TermRegistry uint32_t d_cardSize; /** Reference to the solver state of the theory of strings. */ SolverState& d_state; - /** Reference to equality engine of the theory of strings. */ - eq::EqualityEngine& d_ee; /** Reference to the output channel of the theory of strings. */ OutputChannel& d_out; /** Reference to the statistics for the theory of strings/sequences. */ diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 0ad887d2f..b23765313 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -43,9 +43,8 @@ TheoryStrings::TheoryStrings(context::Context* c, : Theory(THEORY_STRINGS, c, u, out, valuation, logicInfo, pnm), d_notify(*this), d_statistics(), - d_equalityEngine(d_notify, c, "theory::strings::ee", true), - d_state(c, u, d_equalityEngine, d_valuation), - d_termReg(d_state, d_equalityEngine, out, d_statistics, nullptr), + d_state(c, u, d_valuation), + d_termReg(d_state, out, d_statistics, nullptr), d_extTheory(this), d_im(c, u, d_state, d_termReg, d_extTheory, out, d_statistics), d_rewriter(&d_statistics.d_rewrites), @@ -67,30 +66,6 @@ TheoryStrings::TheoryStrings(context::Context* c, d_statistics), d_stringsFmf(c, u, valuation, d_termReg) { - bool eagerEval = options::stringEagerEval(); - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::STRING_LENGTH, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_CONCAT, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_IN_REGEXP, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_TO_CODE, eagerEval); - d_equalityEngine.addFunctionKind(kind::SEQ_UNIT, eagerEval); - - // extended functions - d_equalityEngine.addFunctionKind(kind::STRING_STRCTN, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_LEQ, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_SUBSTR, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_UPDATE, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_ITOS, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STOI, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRIDOF, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRREPL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRREPLALL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_REPLACE_RE, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_REPLACE_RE_ALL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_STRREPLALL, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_TOLOWER, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_TOUPPER, eagerEval); - d_equalityEngine.addFunctionKind(kind::STRING_REV, eagerEval); d_zero = NodeManager::currentNM()->mkConst( Rational( 0 ) ); d_one = NodeManager::currentNM()->mkConst( Rational( 1 ) ); @@ -113,26 +88,63 @@ TheoryStrings::~TheoryStrings() { } TheoryRewriter* TheoryStrings::getTheoryRewriter() { return &d_rewriter; } -std::string TheoryStrings::identify() const -{ - return std::string("TheoryStrings"); -} -eq::EqualityEngine* TheoryStrings::getEqualityEngine() + +bool TheoryStrings::needsEqualityEngine(EeSetupInfo& esi) { - return &d_equalityEngine; + esi.d_notify = &d_notify; + esi.d_name = "theory::strings::ee"; + return true; } + void TheoryStrings::finishInit() { + Assert(d_equalityEngine != nullptr); + // witness is used to eliminate str.from_code d_valuation.setUnevaluatedKind(WITNESS); + + bool eagerEval = options::stringEagerEval(); + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::STRING_LENGTH, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_CONCAT, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_IN_REGEXP, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_TO_CODE, eagerEval); + d_equalityEngine->addFunctionKind(kind::SEQ_UNIT, eagerEval); + // extended functions + d_equalityEngine->addFunctionKind(kind::STRING_STRCTN, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_LEQ, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_SUBSTR, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_UPDATE, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_ITOS, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STOI, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRIDOF, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRREPL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRREPLALL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_REPLACE_RE, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_REPLACE_RE_ALL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_STRREPLALL, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_TOLOWER, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_TOUPPER, eagerEval); + d_equalityEngine->addFunctionKind(kind::STRING_REV, eagerEval); + + d_state.finishInit(d_equalityEngine); +} + +std::string TheoryStrings::identify() const +{ + return std::string("TheoryStrings"); } bool TheoryStrings::areCareDisequal( TNode x, TNode y ) { - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - if( d_equalityEngine.isTriggerTerm(x, THEORY_STRINGS) && d_equalityEngine.isTriggerTerm(y, THEORY_STRINGS) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_STRINGS); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_STRINGS); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + if (d_equalityEngine->isTriggerTerm(x, THEORY_STRINGS) + && d_equalityEngine->isTriggerTerm(y, THEORY_STRINGS)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_STRINGS); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_STRINGS); EqualityStatus eqStatus = d_valuation.getEqualityStatus(x_shared, y_shared); if( eqStatus==EQUALITY_FALSE_AND_PROPAGATED || eqStatus==EQUALITY_FALSE || eqStatus==EQUALITY_FALSE_IN_MODEL ){ return true; @@ -141,14 +153,10 @@ bool TheoryStrings::areCareDisequal( TNode x, TNode y ) { return false; } -void TheoryStrings::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); -} - void TheoryStrings::addSharedTerm(TNode t) { Debug("strings") << "TheoryStrings::addSharedTerm(): " << t << " " << t.getType().isBoolean() << endl; - d_equalityEngine.addTriggerTerm(t, THEORY_STRINGS); + d_equalityEngine->addTriggerTerm(t, THEORY_STRINGS); if (options::stringExp()) { d_esolver.addSharedTerm(t); @@ -157,12 +165,15 @@ void TheoryStrings::addSharedTerm(TNode t) { } EqualityStatus TheoryStrings::getEqualityStatus(TNode a, TNode b) { - if( d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b) ){ - if (d_equalityEngine.areEqual(a, b)) { + if (d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b)) + { + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } - if (d_equalityEngine.areDisequal(a, b, false)) { + if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -251,7 +262,7 @@ bool TheoryStrings::collectModelInfo(TheoryModel* m) // Compute terms appearing in assertions and shared terms computeRelevantTerms(termSet); // assert the (relevant) portion of the equality engine to the model - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { Unreachable() << "TheoryStrings::collectModelInfo: failed to assert equality engine" @@ -670,14 +681,15 @@ void TheoryStrings::check(Effort e) { << "Theory of strings " << e << " effort check " << std::endl; if(Trace.isOn("strings-eqc")) { for( unsigned t=0; t<2; t++ ) { - eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator( &d_equalityEngine ); + eq::EqClassesIterator eqcs2_i = eq::EqClassesIterator(d_equalityEngine); Trace("strings-eqc") << (t==0 ? "STRINGS:" : "OTHER:") << std::endl; while( !eqcs2_i.isFinished() ){ Node eqc = (*eqcs2_i); bool print = (t == 0 && eqc.getType().isStringLike()) || (t == 1 && !eqc.getType().isStringLike()); if (print) { - eq::EqClassIterator eqc2_i = eq::EqClassIterator( eqc, &d_equalityEngine ); + eq::EqClassIterator eqc2_i = + eq::EqClassIterator(eqc, d_equalityEngine); Trace("strings-eqc") << "Eqc( " << eqc << " ) : { "; while( !eqc2_i.isFinished() ) { if( (*eqc2_i)!=eqc && (*eqc2_i).getKind()!=kind::EQUAL ){ @@ -779,20 +791,26 @@ void TheoryStrings::addCarePairs(TNodeTrie* t1, if( t2!=NULL ){ Node f1 = t1->getData(); Node f2 = t2->getData(); - if( !d_equalityEngine.areEqual( f1, f2 ) ){ + if (!d_equalityEngine->areEqual(f1, f2)) + { Trace("strings-cg-debug") << "TheoryStrings::computeCareGraph(): checking function " << f1 << " and " << f2 << std::endl; vector< pair > currentPairs; for (unsigned k = 0; k < f1.getNumChildren(); ++ k) { TNode x = f1[k]; TNode y = f2[k]; - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - Assert(!d_equalityEngine.areDisequal(x, y, false)); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + Assert(!d_equalityEngine->areDisequal(x, y, false)); Assert(!areCareDisequal(x, y)); - if( !d_equalityEngine.areEqual( x, y ) ){ - if( d_equalityEngine.isTriggerTerm(x, THEORY_STRINGS) && d_equalityEngine.isTriggerTerm(y, THEORY_STRINGS) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_STRINGS); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_STRINGS); + if (!d_equalityEngine->areEqual(x, y)) + { + if (d_equalityEngine->isTriggerTerm(x, THEORY_STRINGS) + && d_equalityEngine->isTriggerTerm(y, THEORY_STRINGS)) + { + TNode x_shared = d_equalityEngine->getTriggerTermRepresentative( + x, THEORY_STRINGS); + TNode y_shared = d_equalityEngine->getTriggerTermRepresentative( + y, THEORY_STRINGS); currentPairs.push_back(make_pair(x_shared, y_shared)); } } @@ -820,7 +838,8 @@ void TheoryStrings::addCarePairs(TNodeTrie* t1, std::map::iterator it2 = it; ++it2; for( ; it2 != t1->d_data.end(); ++it2 ){ - if( !d_equalityEngine.areDisequal(it->first, it2->first, false) ){ + if (!d_equalityEngine->areDisequal(it->first, it2->first, false)) + { if( !areCareDisequal(it->first, it2->first) ){ addCarePairs( &it->second, &it2->second, arity, depth+1 ); } @@ -833,7 +852,7 @@ void TheoryStrings::addCarePairs(TNodeTrie* t1, { for (std::pair& tt2 : t2->d_data) { - if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false)) + if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false)) { if (!areCareDisequal(tt1.first, tt2.first)) { @@ -862,8 +881,9 @@ void TheoryStrings::computeCareGraph(){ std::vector< TNode > reps; bool has_trigger_arg = false; for( unsigned j=0; jgetRepresentative(f1[j])); + if (d_equalityEngine->isTriggerTerm(f1[j], THEORY_STRINGS)) + { has_trigger_arg = true; } } @@ -889,7 +909,7 @@ void TheoryStrings::checkRegisterTermsPreNormalForm() const std::vector& seqc = d_bsolver.getStringEqc(); for (const Node& eqc : seqc) { - eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine); + eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine); while (!eqc_i.isFinished()) { Node n = (*eqc_i); diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index 2fb827429..500daac1f 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -70,20 +70,24 @@ class TheoryStrings : public Theory { const LogicInfo& logicInfo, ProofNodeManager* pnm); ~TheoryStrings(); + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; /** finish initialization */ void finishInit() override; - /** Get the theory rewriter of this class */ - TheoryRewriter* getTheoryRewriter() override; - /** Set the master equality engine */ - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- end initialization /** Identify this theory */ std::string identify() const override; /** Propagate */ void propagate(Effort e) override; /** Explain */ TrustNode explain(TNode literal) override; - /** Get the equality engine */ - eq::EqualityEngine* getEqualityEngine() override; /** Get current substitution */ bool getCurrentSubstitution(int effort, std::vector& vars, @@ -268,8 +272,6 @@ class TheoryStrings : public Theory { * theories is collected in this object. */ SequencesStatistics d_statistics; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; /** The solver state object */ SolverState d_state; /** The term registry for this theory */ diff --git a/src/theory/theory.cpp b/src/theory/theory.cpp index f1bfd052d..4f0cbdb6a 100644 --- a/src/theory/theory.cpp +++ b/src/theory/theory.cpp @@ -63,7 +63,6 @@ Theory::Theory(TheoryId id, ProofNodeManager* pnm, std::string name) : d_id(id), - d_instanceName(name), d_satContext(satContext), d_userContext(userContext), d_logicInfo(logicInfo), @@ -74,12 +73,15 @@ Theory::Theory(TheoryId id, d_careGraph(NULL), d_quantEngine(NULL), d_decManager(nullptr), + d_instanceName(name), d_checkTime(getStatsPrefix(id) + name + "::checkTime"), d_computeCareGraphTime(getStatsPrefix(id) + name + "::computeCareGraphTime"), d_sharedTerms(satContext), d_out(&out), d_valuation(valuation), + d_equalityEngine(nullptr), + d_allocEqualityEngine(nullptr), d_proofsEnabled(false) { smtStatisticsRegistry()->registerStat(&d_checkTime); @@ -91,7 +93,43 @@ Theory::~Theory() { smtStatisticsRegistry()->unregisterStat(&d_computeCareGraphTime); } -bool Theory::needsEqualityEngine(EeSetupInfo& esi) { return false; } +bool Theory::needsEqualityEngine(EeSetupInfo& esi) +{ + // by default, this theory does not use an (official) equality engine + return false; +} + +void Theory::setEqualityEngine(eq::EqualityEngine* ee) +{ + // set the equality engine pointer + d_equalityEngine = ee; +} +void Theory::setQuantifiersEngine(QuantifiersEngine* qe) +{ + Assert(d_quantEngine == nullptr); + d_quantEngine = qe; +} + +void Theory::setDecisionManager(DecisionManager* dm) +{ + Assert(d_decManager == nullptr); + Assert(dm != nullptr); + d_decManager = dm; +} + +void Theory::finishInitStandalone() +{ + EeSetupInfo esi; + if (needsEqualityEngine(esi)) + { + // always associated with the same SAT context as the theory (d_satContext) + d_allocEqualityEngine.reset(new eq::EqualityEngine( + *esi.d_notify, d_satContext, esi.d_name, esi.d_constantsAreTriggers)); + // use it as the official equality engine + d_equalityEngine = d_allocEqualityEngine.get(); + } + finishInit(); +} TheoryId Theory::theoryOf(options::TheoryOfMode mode, TNode node) { @@ -410,17 +448,10 @@ void Theory::getCareGraph(CareGraph* careGraph) { d_careGraph = NULL; } -void Theory::setQuantifiersEngine(QuantifiersEngine* qe) { - Assert(d_quantEngine == NULL); - Assert(qe != NULL); - d_quantEngine = qe; -} - -void Theory::setDecisionManager(DecisionManager* dm) +eq::EqualityEngine* Theory::getEqualityEngine() { - Assert(d_decManager == nullptr); - Assert(dm != nullptr); - d_decManager = dm; + // get the assigned equality engine, which is a pointer stored in this class + return d_equalityEngine; } }/* CVC4::theory namespace */ diff --git a/src/theory/theory.h b/src/theory/theory.h index ef06732fb..4feeac394 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -77,11 +77,35 @@ namespace eq { * RegisteredAttr works. (If you need multiple instances of the same * theory, you'll have to write a multiplexed theory that dispatches * all calls to them.) + * + * NOTE: A Theory has a special way of being initialized. The owner of a Theory + * is either: + * + * (A) Using Theory as a standalone object, not associated with a TheoryEngine. + * In this case, simply call the public initialization method + * (Theory::finishInitStandalone). + * + * (B) TheoryEngine, which determines how the Theory acts in accordance with + * its theory combination policy. We require the following steps in order: + * (B.1) Get information about whether the theory wishes to use an equality + * eninge, and more specifically which equality engine notifications the Theory + * would like to be notified of (Theory::needsEqualityEngine). + * (B.2) Set the equality engine of the theory (Theory::setEqualityEngine), + * which we refer to as the "official equality engine" of this Theory. The + * equality engine passed to the theory must respect the contract(s) specified + * by the equality engine setup information (EeSetupInfo) returned in the + * previous step. + * (B.3) Set the other required utilities including setQuantifiersEngine and + * setDecisionManager. + * (B.4) Call the private initialization method (Theory::finishInit). + * + * Initialization of the second form happens during TheoryEngine::finishInit, + * after the quantifiers engine and model objects have been set up. */ class Theory { - private: friend class ::CVC4::TheoryEngine; + private: // Disallow default construction, copy, assignment. Theory() = delete; Theory(const Theory&) = delete; @@ -90,11 +114,6 @@ class Theory { /** An integer identifying the type of the theory. */ TheoryId d_id; - /** Name of this theory instance. Along with the TheoryId this should provide - * an unique string identifier for each instance of a Theory class. We need - * this to ensure unique statistics names over multiple theory instances. */ - std::string d_instanceName; - /** The SAT search context for the Theory. */ context::Context* d_satContext; @@ -137,6 +156,10 @@ class Theory { DecisionManager* d_decManager; protected: + /** Name of this theory instance. Along with the TheoryId this should provide + * an unique string identifier for each instance of a Theory class. We need + * this to ensure unique statistics names over multiple theory instances. */ + std::string d_instanceName; // === STATISTICS === /** time spent in check calls */ @@ -222,7 +245,15 @@ class Theory { * theory engine (and other theories). */ Valuation d_valuation; - + /** + * Pointer to the official equality engine of this theory, which is owned by + * the equality engine manager of TheoryEngine. + */ + eq::EqualityEngine* d_equalityEngine; + /** + * The official equality engine, if we allocated it. + */ + std::unique_ptr d_allocEqualityEngine; /** * Whether proofs are enabled * @@ -264,17 +295,33 @@ class Theory { * its value must be computed (approximated) by the non-linear solver. */ bool isLegalElimination(TNode x, TNode val); + //--------------------------------- private initialization + /** + * Called to set the official equality engine. This should be done by + * TheoryEngine only. + */ + void setEqualityEngine(eq::EqualityEngine* ee); + /** Called to set the quantifiers engine. */ + void setQuantifiersEngine(QuantifiersEngine* qe); + /** Called to set the decision manager. */ + void setDecisionManager(DecisionManager* dm); + /** + * Finish theory initialization. At this point, options and the logic + * setting are final, the master equality engine and quantifiers + * engine (if any) are initialized, and the official equality engine of this + * theory has been assigned. This base class implementation + * does nothing. This should be called by TheoryEngine only. + */ + virtual void finishInit() {} + //--------------------------------- end private initialization public: //--------------------------------- initialization /** - * @return The theory rewriter associated with this theory. This is primarily - * called for the purposes of initializing the rewriter. + * @return The theory rewriter associated with this theory. */ virtual TheoryRewriter* getTheoryRewriter() = 0; /** - * !!!! TODO: use this method (https://github.com/orgs/CVC4/projects/39). - * * Returns true if this theory needs an equality engine for checking * satisfiability. * @@ -288,6 +335,13 @@ class Theory { * a notifications class (eq::EqualityEngineNotify). */ virtual bool needsEqualityEngine(EeSetupInfo& esi); + /** + * Finish theory initialization, standalone version. This is used to + * initialize this class if it is not associated with a theory engine. + * This allocates the official equality engine of this Theory and then + * calls the finishInit method above. + */ + void finishInitStandalone(); //--------------------------------- end initialization /** @@ -450,14 +504,6 @@ class Theory { /** Get the decision manager associated to this theory. */ DecisionManager* getDecisionManager() { return d_decManager; } - /** - * Finish theory initialization. At this point, options and the logic - * setting are final, and the master equality engine and quantifiers - * engine (if any) are initialized. This base class implementation - * does nothing. - */ - virtual void finishInit() { } - /** * Expand definitions in the term node. This returns a term that is * equivalent to node. It wraps this term in a TrustNode of kind @@ -513,14 +559,9 @@ class Theory { virtual void addSharedTerm(TNode n) { } /** - * Called to set the master equality engine. + * Get the official equality engine of this theory. */ - virtual void setMasterEqualityEngine(eq::EqualityEngine* eq) { } - - /** Called to set the quantifiers engine. */ - void setQuantifiersEngine(QuantifiersEngine* qe); - /** Called to set the decision manager. */ - void setDecisionManager(DecisionManager* dm); + eq::EqualityEngine* getEqualityEngine(); /** * Return the current theory care graph. Theories should overload @@ -855,9 +896,6 @@ class Theory { */ virtual std::pair entailmentCheck(TNode lit); - /* equality engine TODO: use? */ - virtual eq::EqualityEngine* getEqualityEngine() { return NULL; } - /* get current substitution at an effort * input : vars * output : subs, exp diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index a88db4494..07c160058 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -42,6 +42,7 @@ #include "theory/bv/theory_bv_utils.h" #include "theory/care_graph.h" #include "theory/decision_manager.h" +#include "theory/ee_manager_distributed.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/fmf/model_engine.h" #include "theory/quantifiers/theory_quantifiers.h" @@ -129,20 +130,21 @@ std::string getTheoryString(theory::TheoryId id) } void TheoryEngine::finishInit() { - //initialize the quantifiers engine, master equality engine, model, model builder - if( d_logicInfo.isQuantified() ) { + // initialize the quantifiers engine + if (d_logicInfo.isQuantified()) + { // initialize the quantifiers engine d_quantEngine = new QuantifiersEngine(d_context, d_userContext, this); - Assert(d_masterEqualityEngine == 0); - d_masterEqualityEngine = new eq::EqualityEngine(d_masterEENotify,getSatContext(), "theory::master", false); + } - for(TheoryId theoryId = theory::THEORY_FIRST; theoryId != theory::THEORY_LAST; ++ theoryId) { - if (d_theoryTable[theoryId]) { - d_theoryTable[theoryId]->setQuantifiersEngine(d_quantEngine); - d_theoryTable[theoryId]->setMasterEqualityEngine(d_masterEqualityEngine); - } - } + // Initialize the equality engine architecture for all theories, which + // includes the master equality engine. + d_eeDistributed.reset(new EqEngineManagerDistributed(*this)); + d_eeDistributed->finishInit(); + // Initialize the model and model builder. + if (d_logicInfo.isQuantified()) + { d_curr_model_builder = d_quantEngine->getModelBuilder(); d_curr_model = d_quantEngine->getModel(); } else { @@ -150,25 +152,32 @@ void TheoryEngine::finishInit() { d_userContext, "DefaultModel", options::assignFunctionValues()); d_aloc_curr_model = true; } + //make the default builder, e.g. in the case that the quantifiers engine does not have a model builder if( d_curr_model_builder==NULL ){ d_curr_model_builder = new theory::TheoryEngineModelBuilder(this); d_aloc_curr_model_builder = true; } + // finish initializing the theories for(TheoryId theoryId = theory::THEORY_FIRST; theoryId != theory::THEORY_LAST; ++ theoryId) { - if (d_theoryTable[theoryId]) { - // set the decision manager for the theory - d_theoryTable[theoryId]->setDecisionManager(d_decManager.get()); - // finish initializing the theory - d_theoryTable[theoryId]->finishInit(); + Theory* t = d_theoryTable[theoryId]; + if (t == nullptr) + { + continue; } - } -} - -void TheoryEngine::eqNotifyNewClass(TNode t){ - if (d_logicInfo.isQuantified()) { - d_quantEngine->eqNotifyNewClass( t ); + // setup the pointers to the utilities + const EeTheoryInfo* eeti = d_eeDistributed->getEeTheoryInfo(theoryId); + Assert(eeti != nullptr); + // the theory's official equality engine is the one specified by the + // equality engine manager + t->setEqualityEngine(eeti->d_usedEe); + // set the quantifiers engine + t->setQuantifiersEngine(d_quantEngine); + // set the decision manager for the theory + t->setDecisionManager(d_decManager.get()); + // finish initializing the theory + t->finishInit(); } } @@ -182,8 +191,7 @@ TheoryEngine::TheoryEngine(context::Context* context, d_userContext(userContext), d_logicInfo(logicInfo), d_sharedTerms(this, context), - d_masterEqualityEngine(nullptr), - d_masterEENotify(*this), + d_eeDistributed(nullptr), d_quantEngine(nullptr), d_decManager(new DecisionManager(userContext)), d_curr_model(nullptr), @@ -252,8 +260,6 @@ TheoryEngine::~TheoryEngine() { delete d_quantEngine; - delete d_masterEqualityEngine; - smtStatisticsRegistry()->unregisterStat(&d_combineTheoriesTime); smtStatisticsRegistry()->unregisterStat(&d_arithSubstitutionsAdded); } @@ -537,9 +543,12 @@ void TheoryEngine::check(Theory::Effort effort) { Debug("theory") << ", need check = " << (needCheck() ? "YES" : "NO") << endl; if( Theory::fullEffort(effort) && !d_inConflict && !needCheck()) { - // case where we are about to answer SAT - if( d_masterEqualityEngine != NULL ){ - AlwaysAssert(d_masterEqualityEngine->consistent()); + // case where we are about to answer SAT, the master equality engine, + // if it exists, must be consistent. + eq::EqualityEngine* mee = getMasterEqualityEngine(); + if (mee != NULL) + { + AlwaysAssert(mee->consistent()); } if (d_curr_model->isBuilt()) { @@ -1793,6 +1802,17 @@ void TheoryEngine::staticInitializeBVOptions( } } +SharedTermsDatabase* TheoryEngine::getSharedTermsDatabase() +{ + return &d_sharedTerms; +} + +theory::eq::EqualityEngine* TheoryEngine::getMasterEqualityEngine() +{ + Assert(d_eeDistributed != nullptr); + return d_eeDistributed->getMasterEqualityEngine(); +} + void TheoryEngine::getExplanation(std::vector& explanationVector, LemmaProofRecipe* proofRecipe) { Assert(explanationVector.size() > 0); diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index 081d53098..aa23aa29b 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -90,6 +90,7 @@ struct NodeTheoryPairHashFunction { namespace theory { class TheoryModel; class TheoryEngineModelBuilder; + class EqEngineManagerDistributed; namespace eq { class EqualityEngine; @@ -148,43 +149,13 @@ class TheoryEngine { SharedTermsDatabase d_sharedTerms; /** - * Master equality engine, to share with theories. + * The distributed equality manager. This class is responsible for + * configuring the theories of this class for handling equalties + * in a "distributed" fashion, i.e. each theory maintains a unique + * instance of an equality engine. These equality engines are memory + * managed by this class. */ - theory::eq::EqualityEngine* d_masterEqualityEngine; - - /** notify class for master equality engine */ - class NotifyClass : public theory::eq::EqualityEngineNotify { - TheoryEngine& d_te; - public: - NotifyClass(TheoryEngine& te): d_te(te) {} - bool eqNotifyTriggerEquality(TNode equality, bool value) override - { - return true; - } - bool eqNotifyTriggerPredicate(TNode predicate, bool value) override - { - return true; - } - bool eqNotifyTriggerTermEquality(theory::TheoryId tag, - TNode t1, - TNode t2, - bool value) override - { - return true; - } - void eqNotifyConstantTermMerge(TNode t1, TNode t2) override {} - void eqNotifyNewClass(TNode t) override { d_te.eqNotifyNewClass(t); } - void eqNotifyMerge(TNode t1, TNode t2) override {} - void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override - { - } - };/* class TheoryEngine::NotifyClass */ - NotifyClass d_masterEENotify; - - /** - * notification methods - */ - void eqNotifyNewClass(TNode t); + std::unique_ptr d_eeDistributed; /** * The quantifiers engine @@ -389,7 +360,13 @@ class TheoryEngine { d_propEngine = propEngine; } - /** Called when all initialization of options/logic is done */ + /** + * Called when all initialization of options/logic is done, after theory + * objects have been created. + * + * This initializes the quantifiers engine, the "official" equality engines + * of each theory as required, and the model and model builder utilities. + */ void finishInit(); /** @@ -759,13 +736,9 @@ public: public: void staticInitializeBVOptions(const std::vector& assertions); - Node ppSimpITE(TNode assertion); - /** Returns false if an assertion simplified to false. */ - bool donePPSimpITE(std::vector& assertions); - - SharedTermsDatabase* getSharedTermsDatabase() { return &d_sharedTerms; } + SharedTermsDatabase* getSharedTermsDatabase(); - theory::eq::EqualityEngine* getMasterEqualityEngine() { return d_masterEqualityEngine; } + theory::eq::EqualityEngine* getMasterEqualityEngine(); SortInference* getSortInference() { return &d_sortInfer; } diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index 862a906a0..4f9c3bed5 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -54,16 +54,12 @@ TheoryUF::TheoryUF(context::Context* c, * so make sure it's initialized first. */ d_thss(nullptr), d_ho(nullptr), - d_equalityEngine(d_notify, c, instanceName + "theory::uf::ee", true), d_conflict(c, false), d_functionsTerms(c), d_symb(u, instanceName) { d_true = NodeManager::currentNM()->mkConst( true ); - // The kinds we are treating as function application in congruence - d_equalityEngine.addFunctionKind(kind::APPLY_UF, false, options::ufHo()); - ProofChecker* pc = pnm != nullptr ? pnm->getChecker() : nullptr; if (pc != nullptr) { @@ -74,11 +70,17 @@ TheoryUF::TheoryUF(context::Context* c, TheoryUF::~TheoryUF() { } -void TheoryUF::setMasterEqualityEngine(eq::EqualityEngine* eq) { - d_equalityEngine.setMasterEqualityEngine(eq); +TheoryRewriter* TheoryUF::getTheoryRewriter() { return &d_rewriter; } + +bool TheoryUF::needsEqualityEngine(EeSetupInfo& esi) +{ + esi.d_notify = &d_notify; + esi.d_name = d_instanceName + "theory::uf::ee"; + return true; } void TheoryUF::finishInit() { + Assert(d_equalityEngine != nullptr); // combined cardinality constraints are not evaluated in getModelValue d_valuation.setUnevaluatedKind(kind::COMBINED_CARDINALITY_CONSTRAINT); // Initialize the cardinality constraints solver if the logic includes UF, @@ -90,9 +92,11 @@ void TheoryUF::finishInit() { d_thss.reset(new CardinalityExtension( getSatContext(), getUserContext(), *d_out, this)); } + // The kinds we are treating as function application in congruence + d_equalityEngine->addFunctionKind(kind::APPLY_UF, false, options::ufHo()); if (options::ufHo()) { - d_equalityEngine.addFunctionKind(kind::HO_APPLY); + d_equalityEngine->addFunctionKind(kind::HO_APPLY); d_ho.reset(new HoExtension(*this, getSatContext(), getUserContext())); } } @@ -148,7 +152,7 @@ void TheoryUF::check(Effort level) { bool polarity = fact.getKind() != kind::NOT; TNode atom = polarity ? fact : fact[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.assertEquality(atom, polarity, fact); + d_equalityEngine->assertEquality(atom, polarity, fact); if( options::ufHo() && options::ufHoExt() ){ if( !polarity && !d_conflict && atom[0].getType().isFunction() ){ // apply extensionality eagerly using the ho extension @@ -169,10 +173,10 @@ void TheoryUF::check(Effort level) { } //needed for models if( options::produceModels() ){ - d_equalityEngine.assertPredicate(atom, polarity, fact); + d_equalityEngine->assertPredicate(atom, polarity, fact); } } else { - d_equalityEngine.assertPredicate(atom, polarity, fact); + d_equalityEngine->assertPredicate(atom, polarity, fact); } } @@ -198,7 +202,7 @@ Node TheoryUF::getOperatorForApplyTerm( TNode node ) { if( node.getKind()==kind::APPLY_UF ){ return node.getOperator(); }else{ - return d_equalityEngine.getRepresentative( node[0] ); + return d_equalityEngine->getRepresentative(node[0]); } } @@ -242,17 +246,17 @@ void TheoryUF::preRegisterTerm(TNode node) { switch (node.getKind()) { case kind::EQUAL: // Add the trigger for equality - d_equalityEngine.addTriggerEquality(node); + d_equalityEngine->addTriggerEquality(node); break; case kind::APPLY_UF: case kind::HO_APPLY: // Maybe it's a predicate if (node.getType().isBoolean()) { // Get triggered for both equal and dis-equal - d_equalityEngine.addTriggerPredicate(node); + d_equalityEngine->addTriggerPredicate(node); } else { // Function applications/predicates - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); } // Remember the function and predicate terms d_functionsTerms.push_back(node); @@ -263,7 +267,7 @@ void TheoryUF::preRegisterTerm(TNode node) { break; default: // Variables etc - d_equalityEngine.addTerm(node); + d_equalityEngine->addTerm(node); break; } }/* TheoryUF::preRegisterTerm() */ @@ -294,9 +298,10 @@ void TheoryUF::explain(TNode literal, std::vector& assumptions, eq::EqPro bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; if (atom.getKind() == kind::EQUAL) { - d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions, pf); + d_equalityEngine->explainEquality( + atom[0], atom[1], polarity, assumptions, pf); } else { - d_equalityEngine.explainPredicate(atom, polarity, assumptions, pf); + d_equalityEngine->explainPredicate(atom, polarity, assumptions, pf); } if( pf ){ Debug("pf::uf") << std::endl; @@ -331,7 +336,7 @@ bool TheoryUF::collectModelInfo(TheoryModel* m) // Compute terms appearing in assertions and shared terms computeRelevantTerms(termSet); - if (!m->assertEqualityEngine(&d_equalityEngine, &termSet)) + if (!m->assertEqualityEngine(d_equalityEngine, &termSet)) { Trace("uf") << "Collect model info fail UF" << std::endl; return false; @@ -495,13 +500,15 @@ void TheoryUF::ppStaticLearn(TNode n, NodeBuilder<>& learned) { EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) { // Check for equality (simplest) - if (d_equalityEngine.areEqual(a, b)) { + if (d_equalityEngine->areEqual(a, b)) + { // The terms are implied to be equal return EQUALITY_TRUE; } // Check for disequality - if (d_equalityEngine.areDisequal(a, b, false)) { + if (d_equalityEngine->areDisequal(a, b, false)) + { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -512,15 +519,19 @@ EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) { void TheoryUF::addSharedTerm(TNode t) { Debug("uf::sharing") << "TheoryUF::addSharedTerm(" << t << ")" << std::endl; - d_equalityEngine.addTriggerTerm(t, THEORY_UF); + d_equalityEngine->addTriggerTerm(t, THEORY_UF); } bool TheoryUF::areCareDisequal(TNode x, TNode y){ - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - if( d_equalityEngine.isTriggerTerm(x, THEORY_UF) && d_equalityEngine.isTriggerTerm(y, THEORY_UF) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_UF); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_UF); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + if (d_equalityEngine->isTriggerTerm(x, THEORY_UF) + && d_equalityEngine->isTriggerTerm(y, THEORY_UF)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_UF); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_UF); EqualityStatus eqStatus = d_valuation.getEqualityStatus(x_shared, y_shared); if( eqStatus==EQUALITY_FALSE_AND_PROPAGATED || eqStatus==EQUALITY_FALSE || eqStatus==EQUALITY_FALSE_IN_MODEL ){ return true; @@ -538,21 +549,27 @@ void TheoryUF::addCarePairs(TNodeTrie* t1, if( t2!=NULL ){ Node f1 = t1->getData(); Node f2 = t2->getData(); - if( !d_equalityEngine.areEqual( f1, f2 ) ){ + if (!d_equalityEngine->areEqual(f1, f2)) + { Debug("uf::sharing") << "TheoryUf::computeCareGraph(): checking function " << f1 << " and " << f2 << std::endl; vector< pair > currentPairs; unsigned arg_start_index = getArgumentStartIndexForApplyTerm( f1 ); for (unsigned k = arg_start_index; k < f1.getNumChildren(); ++ k) { TNode x = f1[k]; TNode y = f2[k]; - Assert(d_equalityEngine.hasTerm(x)); - Assert(d_equalityEngine.hasTerm(y)); - Assert(!d_equalityEngine.areDisequal(x, y, false)); + Assert(d_equalityEngine->hasTerm(x)); + Assert(d_equalityEngine->hasTerm(y)); + Assert(!d_equalityEngine->areDisequal(x, y, false)); Assert(!areCareDisequal(x, y)); - if( !d_equalityEngine.areEqual( x, y ) ){ - if( d_equalityEngine.isTriggerTerm(x, THEORY_UF) && d_equalityEngine.isTriggerTerm(y, THEORY_UF) ){ - TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_UF); - TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_UF); + if (!d_equalityEngine->areEqual(x, y)) + { + if (d_equalityEngine->isTriggerTerm(x, THEORY_UF) + && d_equalityEngine->isTriggerTerm(y, THEORY_UF)) + { + TNode x_shared = + d_equalityEngine->getTriggerTermRepresentative(x, THEORY_UF); + TNode y_shared = + d_equalityEngine->getTriggerTermRepresentative(y, THEORY_UF); currentPairs.push_back(make_pair(x_shared, y_shared)); } } @@ -580,7 +597,8 @@ void TheoryUF::addCarePairs(TNodeTrie* t1, std::map::iterator it2 = it; ++it2; for( ; it2 != t1->d_data.end(); ++it2 ){ - if( !d_equalityEngine.areDisequal(it->first, it2->first, false) ){ + if (!d_equalityEngine->areDisequal(it->first, it2->first, false)) + { if( !areCareDisequal(it->first, it2->first) ){ addCarePairs( &it->second, &it2->second, arity, depth+1 ); } @@ -593,7 +611,7 @@ void TheoryUF::addCarePairs(TNodeTrie* t1, { for (std::pair& tt2 : t2->d_data) { - if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false)) + if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false)) { if (!areCareDisequal(tt1.first, tt2.first)) { @@ -621,8 +639,9 @@ void TheoryUF::computeCareGraph() { std::vector< TNode > reps; bool has_trigger_arg = false; for( unsigned j=arg_start_index; jgetRepresentative(f1[j])); + if (d_equalityEngine->isTriggerTerm(f1[j], THEORY_UF)) + { has_trigger_arg = true; } } diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index 345547301..001c947e9 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -116,9 +116,6 @@ private: /** the higher-order solver extension (or nullptr if it does not exist) */ std::unique_ptr d_ho; - /** Equaltity engine */ - eq::EqualityEngine d_equalityEngine; - /** Are we in conflict */ context::CDO d_conflict; @@ -186,10 +183,18 @@ private: ~TheoryUF(); - TheoryRewriter* getTheoryRewriter() override { return &d_rewriter; } - - void setMasterEqualityEngine(eq::EqualityEngine* eq) override; + //--------------------------------- initialization + /** get the official theory rewriter of this theory */ + TheoryRewriter* getTheoryRewriter() override; + /** + * Returns true if we need an equality engine. If so, we initialize the + * information regarding how it should be setup. For details, see the + * documentation in Theory::needsEqualityEngine. + */ + bool needsEqualityEngine(EeSetupInfo& esi) override; + /** finish initialization */ void finishInit() override; + //--------------------------------- end initialization void check(Effort) override; TrustNode expandDefinition(Node node) override; @@ -210,8 +215,6 @@ private: std::string identify() const override { return "THEORY_UF"; } - eq::EqualityEngine* getEqualityEngine() override { return &d_equalityEngine; } - /** get a pointer to the uf with cardinality */ CardinalityExtension* getCardinalityExtension() const { return d_thss.get(); } /** are we in conflict? */