From: Dejan Jovanović Date: Sun, 27 May 2012 05:44:13 +0000 (+0000) Subject: Committing the work on equality engine, I need to see how it does on the regressions... X-Git-Tag: cvc5-1.0.0~8140 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=b390cfa8f095048472cb3dd0b9ccc22fbd51f411;p=cvc5.git Committing the work on equality engine, I need to see how it does on the regressions. New additions: * areDisequal(x, y) -> areDisequal(x, y, needProof): when asking for a disequality you must say needProof if you will ask for an explanation later. * propagation of shared dis-equalities (not yet complete, once case missing) * changes to the theories that use it, authors should check up on the changes --- diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 44e362f90..15636fc72 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -86,12 +86,19 @@ TheoryArrays::TheoryArrays(context::Context* c, context::UserContext* u, OutputC d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); + // The preprocessing congruence kinds + d_ppEqualityEngine.addFunctionKind(kind::SELECT); + d_ppEqualityEngine.addFunctionKind(kind::STORE); + + // The mayequal congruence kinds + d_mayEqualEqualityEngine.addFunctionKind(kind::SELECT); + d_mayEqualEqualityEngine.addFunctionKind(kind::STORE); + // The kinds we are treating as function application in congruence d_equalityEngine.addFunctionKind(kind::SELECT); if (d_ccStore) { d_equalityEngine.addFunctionKind(kind::STORE); } - d_equalityEngine.addFunctionKind(kind::EQUAL); if (d_useArrTable) { d_equalityEngine.addFunctionKind(kind::ARR_TABLE_FUN); } @@ -118,12 +125,13 @@ TheoryArrays::~TheoryArrays() { Node TheoryArrays::ppRewrite(TNode term) { if (!d_preprocess) return term; + d_ppEqualityEngine.addTerm(term); switch (term.getKind()) { case kind::SELECT: { // select(store(a,i,v),j) = select(a,j) // IF i != j if (term[0].getKind() == kind::STORE && - (d_ppEqualityEngine.areDisequal(term[0][1], term[1]) || + (d_ppEqualityEngine.areDisequal(term[0][1], term[1], false) || (term[0][1].isConst() && term[1].isConst() && term[0][1] != term[1]))) { return NodeBuilder<2>(kind::SELECT) << term[0][0] << term[1]; } @@ -134,7 +142,7 @@ Node TheoryArrays::ppRewrite(TNode term) { // IF i != j and j comes before i in the ordering if (term[0].getKind() == kind::STORE && (term[1] < term[0][1]) && - (d_ppEqualityEngine.areDisequal(term[1], term[0][1]) || + (d_ppEqualityEngine.areDisequal(term[1], term[0][1], false) || (term[0][1].isConst() && term[1].isConst() && term[0][1] != term[1]))) { Node inner = NodeBuilder<3>(kind::STORE) << term[0][0] << term[1] << term[2]; Node outer = NodeBuilder<3>(kind::STORE) << inner << term[0][1] << term[0][2]; @@ -198,7 +206,7 @@ Node TheoryArrays::ppRewrite(TNode term) { NodeBuilder<> hyp(kind::AND); for (j = leftWrites - 1; j > i; --j) { index_j = write_j[1]; - if (d_ppEqualityEngine.areDisequal(index_i, index_j) || + if (d_ppEqualityEngine.areDisequal(index_i, index_j, false) || (index_i.isConst() && index_j.isConst() && index_i != index_j)) { continue; } @@ -374,6 +382,7 @@ void TheoryArrays::preRegisterTerm(TNode node) switch (node.getKind()) { case kind::EQUAL: // Add the trigger for equality + // NOTE: note that if the equality is true or false already, it might not be added d_equalityEngine.addTriggerEquality(node); break; case kind::SELECT: { @@ -385,6 +394,9 @@ void TheoryArrays::preRegisterTerm(TNode node) // Reads TNode store = d_equalityEngine.getRepresentative(node[0]); + // The may equal needs the store + d_mayEqualEqualityEngine.addTerm(store); + // Apply RIntro1 rule to any stores equal to store if not done already const CTNodeList* stores = d_infoMap.getStores(store); CTNodeList::const_iterator it = stores->begin(); @@ -460,7 +472,8 @@ void TheoryArrays::preRegisterTerm(TNode node) break; } // Invariant: preregistered terms are exactly the terms in the equality engine - Assert(d_equalityEngine.hasTerm(node)); + // Disabled, see comment above for kind::EQUAL + // Assert(d_equalityEngine.hasTerm(node) || !d_equalityEngine.consistent()); } @@ -525,7 +538,7 @@ EqualityStatus TheoryArrays::getEqualityStatus(TNode a, TNode b) { // The terms are implied to be equal return EQUALITY_TRUE; } - if (d_equalityEngine.areDisequal(a, b)) { + if (d_equalityEngine.areDisequal(a, b, false)) { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -576,10 +589,10 @@ void TheoryArrays::computeCareGraph() if (r1[0] != r2[0]) { // If arrays are known to be disequal, or cannot become equal, we can continue - Assert(d_equalityEngine.hasTerm(r1[0]) && d_equalityEngine.hasTerm(r2[0])); + Assert(d_mayEqualEqualityEngine.hasTerm(r1[0]) && d_mayEqualEqualityEngine.hasTerm(r2[0])); if (r1[0].getType() != r2[0].getType() || (!d_mayEqualEqualityEngine.areEqual(r1[0], r2[0])) || - d_equalityEngine.areDisequal(r1[0], r2[0])) { + d_equalityEngine.areDisequal(r1[0], r2[0], false)) { Debug("arrays::sharing") << "TheoryArrays::computeCareGraph(): arrays can't be equal, skipping" << std::endl; continue; } @@ -704,7 +717,7 @@ void TheoryArrays::check(Effort e) { case kind::NOT: if (fact[0].getKind() == kind::SELECT) { d_equalityEngine.assertPredicate(fact[0], false, fact); - } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1])) { + } else if (!d_equalityEngine.areDisequal(fact[0][0], fact[0][1], false)) { // Assert the dis-equality d_equalityEngine.assertEquality(fact[0], false, fact); @@ -1141,7 +1154,7 @@ void TheoryArrays::queueRowLemma(RowLemmaType lem) // If propagating, check propagations if (d_propagateLemmas) { - if (d_equalityEngine.areDisequal(i,j)) { + if (d_equalityEngine.areDisequal(i,j,true)) { Trace("arrays-lem") << spaces(getSatContext()->getLevel()) <<"Arrays::queueRowLemma: propagating aj = bj ("<getLevel()) <<"Arrays::queueRowLemma: propagating i = j ("<propagateAsDecision(split); } diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index 639b03df8..03d7e7d8d 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -228,7 +228,7 @@ class TheoryArrays : public Theory { NotifyClass(TheoryArrays& arrays): d_arrays(arrays) {} bool eqNotifyTriggerEquality(TNode equality, bool value) { - Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; + Debug("arrays::propagate") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false") << ")" << std::endl; // Just forward to arrays if (value) { return d_arrays.propagate(equality); @@ -242,7 +242,7 @@ class TheoryArrays : public Theory { } bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) { - Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ")" << std::endl; + Debug("arrays::propagate") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ")" << std::endl; if (value) { if (t1.getType().isArray()) { d_arrays.mergeArrays(t1, t2); @@ -259,7 +259,7 @@ class TheoryArrays : public Theory { } bool eqNotifyConstantTermMerge(TNode t1, TNode t2) { - Debug("arrays") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl; + Debug("arrays::propagate") << spaces(d_arrays.getSatContext()->getLevel()) << "NotifyClass::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << ")" << std::endl; if (Theory::theoryOf(t1) == THEORY_BOOL) { return d_arrays.propagate(t1.iffNode(t2)); } else { diff --git a/src/theory/bv/bv_subtheory_eq.h b/src/theory/bv/bv_subtheory_eq.h index 241dd5919..356d12a06 100644 --- a/src/theory/bv/bv_subtheory_eq.h +++ b/src/theory/bv/bv_subtheory_eq.h @@ -64,7 +64,7 @@ public: // The terms are implied to be equal return EQUALITY_TRUE; } - if (d_equalityEngine.areDisequal(a, b)) { + if (d_equalityEngine.areDisequal(a, b, false)) { // The terms are implied to be dis-equal return EQUALITY_FALSE; } diff --git a/src/theory/shared_terms_database.cpp b/src/theory/shared_terms_database.cpp index 90037b90b..0c893482a 100644 --- a/src/theory/shared_terms_database.cpp +++ b/src/theory/shared_terms_database.cpp @@ -297,7 +297,7 @@ bool SharedTermsDatabase::areEqual(TNode a, TNode b) { bool SharedTermsDatabase::areDisequal(TNode a, TNode b) { - return d_equalityEngine.areDisequal(a,b); + return d_equalityEngine.areDisequal(a,b,false); } void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason) @@ -305,7 +305,7 @@ void SharedTermsDatabase::processSharedLiteral(TNode literal, TNode reason) bool negated = literal.getKind() == kind::NOT; TNode atom = negated ? literal[0] : literal; if (negated) { - Assert(!d_equalityEngine.areDisequal(atom[0], atom[1])); + Assert(!d_equalityEngine.areDisequal(atom[0], atom[1],false)); d_equalityEngine.assertEquality(atom, false, reason); // !!! need to send this out } diff --git a/src/theory/term_registration_visitor.cpp b/src/theory/term_registration_visitor.cpp index 099871ceb..22b87c32f 100644 --- a/src/theory/term_registration_visitor.cpp +++ b/src/theory/term_registration_visitor.cpp @@ -32,7 +32,6 @@ bool PreRegisterVisitor::alreadyVisited(TNode current, TNode parent) { Debug("register::internal") << "PreRegisterVisitor::alreadyVisited(" << current << "," << parent << ")" << std::endl; - TheoryId currentTheoryId = Theory::theoryOf(current); TheoryId parentTheoryId = Theory::theoryOf(parent); diff --git a/src/theory/uf/equality_engine.cpp b/src/theory/uf/equality_engine.cpp index e60d52c7a..72966d0b3 100644 --- a/src/theory/uf/equality_engine.cpp +++ b/src/theory/uf/equality_engine.cpp @@ -90,6 +90,7 @@ EqualityEngine::EqualityEngine(context::Context* context, std::string name) , d_stats(name) , d_triggerDatabaseSize(context, 0) , d_triggerTermSetUpdatesSize(context, 0) +, d_deducedDisequalitiesSize(context, 0) { init(); } @@ -107,6 +108,7 @@ EqualityEngine::EqualityEngine(EqualityEngineNotify& notify, context::Context* c , d_stats(name) , d_triggerDatabaseSize(context, 0) , d_triggerTermSetUpdatesSize(context, 0) +, d_deducedDisequalitiesSize(context, 0) { init(); } @@ -152,8 +154,8 @@ EqualityNodeId EqualityEngine::newApplicationNode(TNode original, EqualityNodeId } // Add to the use lists - d_equalityNodes[t1ClassId].usedIn(funId, d_useListNodes); - d_equalityNodes[t2ClassId].usedIn(funId, d_useListNodes); + d_equalityNodes[t1ClassId].usedIn(funId, d_useListNodes); + d_equalityNodes[t2ClassId].usedIn(funId, d_useListNodes); // Return the new id Debug("equality") << "EqualityEngine::newApplicationNode(" << original << ", " << t1 << ", " << t2 << ") => " << funId << std::endl; @@ -205,6 +207,10 @@ void EqualityEngine::addTerm(TNode t) { return; } + if (d_done) { + return; + } + EqualityNodeId result; if (t.getKind() == kind::EQUAL) { @@ -270,6 +276,10 @@ void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) { Debug("equality") << "EqualityEngine::addEqualityInternal(" << t1 << "," << t2 << ")" << std::endl; + if (d_done) { + return; + } + // Add the terms if they are not already in the database addTerm(t1); addTerm(t2); @@ -278,27 +288,85 @@ void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason) { EqualityNodeId t1Id = getNodeId(t1); EqualityNodeId t2Id = getNodeId(t2); enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason)); - - propagate(); } void EqualityEngine::assertPredicate(TNode t, bool polarity, TNode reason) { Debug("equality") << "EqualityEngine::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl; Assert(t.getKind() != kind::EQUAL, "Use assertEquality instead"); assertEqualityInternal(t, polarity ? d_true : d_false, reason); + propagate(); } void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason) { Debug("equality") << "EqualityEngine::addEquality(" << eq << "," << (polarity ? "true" : "false") << std::endl; if (polarity) { + // If two terms are already equal, don't assert anything + if (hasTerm(eq[0]) && hasTerm(eq[1]) && areEqual(eq[0], eq[1])) { + return; + } // Add equality between terms assertEqualityInternal(eq[0], eq[1], reason); + propagate(); // Add eq = true for dis-equality propagation assertEqualityInternal(eq, d_true, reason); + propagate(); } else { + // If two terms are already dis-equal, don't assert anything + if (hasTerm(eq[0]) && hasTerm(eq[1]) && areDisequal(eq[0], eq[1], false)) { + return; + } + assertEqualityInternal(eq, d_false, reason); - Node eqSymm = eq[1].eqNode(eq[0]); - assertEqualityInternal(eqSymm, d_false, reason); + propagate(); + assertEqualityInternal(eq[1].eqNode(eq[0]), d_false, reason); + propagate(); + + if (d_done) { + return; + } + + // If we are adding a disequality, notify of the shared term representatives + EqualityNodeId a = getNodeId(eq[0]); + EqualityNodeId b = getNodeId(eq[1]); + EqualityNodeId eqId = getNodeId(eq); + TriggerTermSetRef aTriggerRef = d_nodeIndividualTrigger[a]; + TriggerTermSetRef bTriggerRef = d_nodeIndividualTrigger[b]; + if (aTriggerRef != +null_set_id && bTriggerRef != +null_set_id) { + // The sets of trigger terms + TriggerTermSet& aTriggerTerms = getTriggerTermSet(aTriggerRef); + TriggerTermSet& bTriggerTerms = getTriggerTermSet(bTriggerRef); + // Go through and notify the shared dis-equalities + Theory::Set aTags = aTriggerTerms.tags; + Theory::Set bTags = bTriggerTerms.tags; + TheoryId aTag = Theory::setPop(aTags); + TheoryId bTag = Theory::setPop(bTags); + int a_i = 0, b_i = 0; + while (aTag != THEORY_LAST && bTag != THEORY_LAST) { + if (aTag < bTag) { + aTag = Theory::setPop(aTags); + ++ a_i; + } else if (aTag > bTag) { + bTag = Theory::setPop(bTags); + ++ b_i; + } else { + // Same tags, notify + EqualityNodeId aSharedId = aTriggerTerms.triggers[a_i++]; + EqualityNodeId bSharedId = bTriggerTerms.triggers[b_i++]; + d_deducedDisequalityReasons.push_back(EqualityPair(aSharedId, a)); + d_deducedDisequalityReasons.push_back(EqualityPair(bSharedId, b)); + d_deducedDisequalityReasons.push_back(EqualityPair(eqId, d_falseId)); + storePropagatedDisequality(d_nodes[aSharedId], d_nodes[bSharedId], 3); + // We notify even if the it's already been sent (they are not + // disequal at assertion, and we need to notify for each tag) + if (!d_notify.eqNotifyTriggerTermEquality(aTag, d_nodes[aSharedId], d_nodes[bSharedId], false)) { + break; + } + // Pop the next tags + aTag = Theory::setPop(aTags); + bTag = Theory::setPop(bTags); + } + } + } } } @@ -310,26 +378,11 @@ TNode EqualityEngine::getRepresentative(TNode t) const { return d_nodes[representativeId]; } -bool EqualityEngine::areEqual(TNode t1, TNode t2) const { - Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; - - Assert(hasTerm(t1)); - Assert(hasTerm(t2)); - - // Both following commands are semantically const - EqualityNodeId rep1 = getEqualityNode(t1).getFind(); - EqualityNodeId rep2 = getEqualityNode(t2).getFind(); - - Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ") => " << (rep1 == rep2 ? "true" : "false") << std::endl; - - return rep1 == rep2; -} - -bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggers) { +bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vector& triggersFired) { Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << ")" << std::endl; - Assert(triggers.empty()); + Assert(triggersFired.empty()); ++ d_stats.mergesCount; @@ -338,15 +391,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect // Check for constant merges bool isConstant = d_isConstant[class1Id]; - if (isConstant && d_isConstant[class2Id]) { - if (d_performNotify) { - if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) { - // Now merge the so that backtracking is OK - class1.merge(class2); - return false; - } - } - } + // Update class2 representative information Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating class " << class2Id << std::endl; EqualityNodeId currentId = class2Id; @@ -369,8 +414,19 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect trigger.classId = class1Id; // If they became the same, call the trigger if (otherTrigger.classId == class1Id) { + const TriggerInfo& triggerInfo = d_equalityTriggersOriginal[currentTrigger]; + if (triggerInfo.trigger.getKind() == kind::EQUAL && !triggerInfo.polarity) { + TNode equality = triggerInfo.trigger; + EqualityNodeId original = getNodeId(equality); + d_deducedDisequalityReasons.push_back(EqualityPair(original, d_falseId)); + if (!storePropagatedDisequality(equality[0], equality[1], 1)) { + // Go to the next trigger + currentTrigger = trigger.nextTrigger; + continue; + } + } // Id of the real trigger is half the internal one - triggers.push_back(currentTrigger); + triggersFired.push_back(currentTrigger); } } @@ -393,7 +449,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect Debug("equality") << "EqualityEngine::merge(" << class1.getFind() << "," << class2.getFind() << "): updating lookups of node " << currentId << std::endl; // Go through the uselist and check for congruences - UseListNodeId currentUseId = currentNode.getUseList(); + UseListNodeId currentUseId = currentNode.getUseList(); while (currentUseId != null_uselist_id) { // Get the node of the use list UseListNode& useNode = d_useListNodes[currentUseId]; @@ -422,7 +478,7 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect enqueue(MergeCandidate(funId, d_falseId, MERGED_THROUGH_CONSTANTS, TNode::null())); } } - } + } } // Go to the next one in the use list @@ -502,6 +558,15 @@ bool EqualityEngine::merge(EqualityNode& class1, EqualityNode& class2, std::vect } } + // Notify of the constants merge + if (isConstant && d_isConstant[class2Id]) { + if (d_performNotify) { + if (!d_notify.eqNotifyConstantTermMerge(d_nodes[class1Id], d_nodes[class2Id])) { + return false; + } + } + } + // Everything fine return true; } @@ -610,9 +675,9 @@ void EqualityEngine::backtrack() { const FunctionApplication& app = d_applications[i].normalized; if (app.isApplication()) { // Remove b from use-list - getEqualityNode(app.b).removeTopFromUseList(d_useListNodes); + getEqualityNode(app.b).removeTopFromUseList(d_useListNodes); // Remove a from use-list - getEqualityNode(app.a).removeTopFromUseList(d_useListNodes); + getEqualityNode(app.a).removeTopFromUseList(d_useListNodes); } } @@ -626,6 +691,20 @@ void EqualityEngine::backtrack() { d_equalityGraph.resize(d_nodesCount); d_equalityNodes.resize(d_nodesCount); } + + if (d_deducedDisequalities.size() > d_deducedDisequalitiesSize) { + for(int i = d_deducedDisequalities.size() - 1, i_end = (int)d_deducedDisequalitiesSize; i >= i_end; -- i) { + EqualityPair pair = d_deducedDisequalities[i]; + DisequalityReasonRef reason = d_disequalityReasonsMap[pair]; + // Remove from the map + d_disequalityReasonsMap.erase(pair); + std::swap(pair.first, pair.second); + d_disequalityReasonsMap.erase(pair); + // Resize the reasons vector + d_deducedDisequalityReasons.resize(reason.mergesStart); + } + d_deducedDisequalities.resize(d_deducedDisequalitiesSize); + } } void EqualityEngine::addGraphEdge(EqualityNodeId t1, EqualityNodeId t2, MergeReasonType type, TNode reason) { @@ -658,42 +737,35 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const { return out.str(); } -void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector& equalities) { - Debug("equality") << "EqualityEngine::explainEquality(" << t1 << "," << t2 << ")" << std::endl; +void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity, std::vector& equalities) const { + Debug("equality") << "EqualityEngine::explainEquality(" << t1 << ", " << t2 << ", " << (polarity ? "true" : "false") << ")" << std::endl; - // Don't notify during this check - ScopedBool turnOffNotify(d_performNotify, false); + // The terms must be there already + Assert(hasTerm(t1) && hasTerm(t2));; - // Add the terms (they might not be there) - addTerm(t1); - addTerm(t2); + // Get the ids + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t2Id = getNodeId(t2); if (polarity) { // Get the explanation - EqualityNodeId t1Id = getNodeId(t1); - EqualityNodeId t2Id = getNodeId(t2); getExplanation(t1Id, t2Id, equalities); } else { - // Add the equality - Node equality = t1.eqNode(t2); - addTerm(equality); - - // Get the explanation - EqualityNodeId equalityId = getNodeId(equality); - EqualityNodeId falseId = getNodeId(d_false); - getExplanation(equalityId, falseId, equalities); + // Get the reason for this disequality + EqualityPair pair(t1Id, t2Id); + Assert(d_disequalityReasonsMap.find(pair) != d_disequalityReasonsMap.end(), "Don't ask for stuff I didn't notify you about"); + DisequalityReasonRef reasonRef = d_disequalityReasonsMap.find(pair)->second; + for (unsigned i = reasonRef.mergesStart; i < reasonRef.mergesEnd; ++ i) { + EqualityPair toExplain = d_deducedDisequalityReasons[i]; + getExplanation(toExplain.first, toExplain.second, equalities); + } } } -void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector& assertions) { - Debug("equality") << "EqualityEngine::explainEquality(" << p << ")" << std::endl; - - // Don't notify during this check - ScopedBool turnOffNotify(d_performNotify, false); - - // Add the terms - addTerm(p); - +void EqualityEngine::explainPredicate(TNode p, bool polarity, std::vector& assertions) const { + Debug("equality") << "EqualityEngine::explainPredicate(" << p << ")" << std::endl; + // Must have the term + Assert(hasTerm(p)); // Get the explanation getExplanation(getNodeId(p), polarity ? d_trueId : d_falseId, assertions); } @@ -702,11 +774,16 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st Debug("equality") << "EqualityEngine::getExplanation(" << d_nodes[t1Id] << "," << d_nodes[t2Id] << ")" << std::endl; - // We can only explain the nodes that got merged (or between - // constants since they didn't get merged but we stil added the - // edge in the graph equality - Assert(getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind() || - (d_isConstant[getEqualityNode(t1Id).getFind()] && d_isConstant[getEqualityNode(t2Id).getFind()])); + // We can only explain the nodes that got merged +#ifdef CVC4_ASSERTIONS + bool canExplain = getEqualityNode(t1Id).getFind() == getEqualityNode(t2Id).getFind(); + if (!canExplain) { + Warning() << "Can't explain equality:" << std::endl; + Warning() << d_nodes[t1Id] << " with find " << d_nodes[getEqualityNode(t1Id).getFind()] << std::endl; + Warning() << d_nodes[t2Id] << " with find " << d_nodes[getEqualityNode(t2Id).getFind()] << std::endl; + } + Assert(canExplain); +#endif // If the nodes are the same, we're done if (t1Id == t2Id) return; @@ -778,7 +855,7 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st equalities.push_back(d_equalityEdges[currentEdge].getReason()); break; case MERGED_THROUGH_CONSTANTS: { - // (a = b) == false bacause a and b are different constants + // (a = b) == false because a and b are different constants Debug("equality") << "EqualityEngine::getExplanation(): due to constants, going deeper" << std::endl; EqualityNodeId eqId = currentNode == d_falseId ? edgeNode : currentNode; const FunctionApplication& eq = d_applications[eqId].original; @@ -822,8 +899,34 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNodeId t2Id, st void EqualityEngine::addTriggerEquality(TNode eq) { Assert(eq.getKind() == kind::EQUAL); + + if (d_done) { + return; + } + // Add the terms + addTerm(eq[0]); + addTerm(eq[1]); + + bool skipTrigger = false; + + // If they are equal or disequal already, no need for the trigger + if (areEqual(eq[0], eq[1])) { + d_notify.eqNotifyTriggerEquality(eq, true); + skipTrigger = true; + } + if (areDisequal(eq[0], eq[1], true)) { + d_notify.eqNotifyTriggerEquality(eq, false); + skipTrigger = true; + } + + if (skipTrigger) { + return; + } + + // Add the equality addTerm(eq); + // Positive trigger addTriggerEqualityInternal(eq[0], eq[1], eq, true); // Negative trigger @@ -833,8 +936,30 @@ void EqualityEngine::addTriggerEquality(TNode eq) { void EqualityEngine::addTriggerPredicate(TNode predicate) { Assert(predicate.getKind() != kind::NOT && predicate.getKind() != kind::EQUAL); Assert(d_congruenceKinds.tst(predicate.getKind()), "No point in adding non-congruence predicates"); + + if (d_done) { + return; + } + // Add the term addTerm(predicate); + + bool skipTrigger = false; + + // If it's know already, no need for the trigger + if (areEqual(predicate, d_true)) { + d_notify.eqNotifyTriggerPredicate(predicate, true); + skipTrigger = true; + } + if (areEqual(predicate, d_false)) { + d_notify.eqNotifyTriggerPredicate(predicate, false); + skipTrigger = true; + } + + if (skipTrigger) { + return; + } + // Positive trigger addTriggerEqualityInternal(predicate, d_true, predicate, true); // Negative trigger @@ -848,41 +973,41 @@ void EqualityEngine::addTriggerEqualityInternal(TNode t1, TNode t2, TNode trigge Assert(hasTerm(t1)); Assert(hasTerm(t2)); + if (d_done) { + return; + } + // Get the information about t1 EqualityNodeId t1Id = getNodeId(t1); EqualityNodeId t1classId = getEqualityNode(t1Id).getFind(); + // We will attach it to the class representative, since then we know how to backtrack it TriggerId t1TriggerId = d_nodeTriggers[t1classId]; // Get the information about t2 EqualityNodeId t2Id = getNodeId(t2); EqualityNodeId t2classId = getEqualityNode(t2Id).getFind(); + // We will attach it to the class representative, since then we know how to backtrack it TriggerId t2TriggerId = d_nodeTriggers[t2classId]; Debug("equality") << "EqualityEngine::addTrigger(" << trigger << "): " << t1Id << " (" << t1classId << ") = " << t2Id << " (" << t2classId << ")" << std::endl; // Create the triggers TriggerId t1NewTriggerId = d_equalityTriggers.size(); - TriggerId t2NewTriggerId = t1NewTriggerId | 1; d_equalityTriggers.push_back(Trigger(t1classId, t1TriggerId)); d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); + TriggerId t2NewTriggerId = d_equalityTriggers.size(); d_equalityTriggers.push_back(Trigger(t2classId, t2TriggerId)); d_equalityTriggersOriginal.push_back(TriggerInfo(trigger, polarity)); // Update the counters - d_equalityTriggersCount = d_equalityTriggersCount + 2; + d_equalityTriggersCount = d_equalityTriggers.size(); + Assert(d_equalityTriggers.size() == d_equalityTriggersOriginal.size()); + Assert(d_equalityTriggers.size() % 2 == 0); // Add the trigger to the trigger graph d_nodeTriggers[t1classId] = t1NewTriggerId; d_nodeTriggers[t2classId] = t2NewTriggerId; - // If the trigger is already on, we propagate - if (t1classId == t2classId) { - Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl; - if (d_performNotify) { - d_notify.eqNotifyTriggerEquality(trigger, polarity); // Don't care about the return value - } - } - if (Debug.isOn("equality::internal")) { debugPrintGraph(); } @@ -978,39 +1103,105 @@ void EqualityEngine::debugPrintGraph() const { } } -bool EqualityEngine::areEqual(TNode t1, TNode t2) -{ - // Don't notify during this check - ScopedBool turnOffNotify(d_performNotify, false); +bool EqualityEngine::areEqual(TNode t1, TNode t2) const { + Debug("equality") << "EqualityEngine::areEqual(" << t1 << "," << t2 << ")" << std::endl; - // Add the terms - addTerm(t1); - addTerm(t2); - bool equal = getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind(); + Assert(hasTerm(t1)); + Assert(hasTerm(t2)); - // Return whether the two terms are equal - return equal; + return getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind(); } -bool EqualityEngine::areDisequal(TNode t1, TNode t2) +bool EqualityEngine::areDisequal(TNode t1, TNode t2, bool ensureProof) const { - // Don't notify during this check - ScopedBool turnOffNotify(d_performNotify, false); + Debug("equality") << "EqualityEngine::areDisequal(" << t1 << "," << t2 << ")" << std::endl; // Add the terms - addTerm(t1); - addTerm(t2); + Assert(hasTerm(t1)); + Assert(hasTerm(t2)); + + // Get ids + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t2Id = getNodeId(t2); - // Check (t1 = t2) = false - // No need to check the symmetric version: we can only deduce a disequality from an existing - // diseqality, and each of those is asserted in the symmetric version also - Node equality = t1.eqNode(t2); - addTerm(equality); - if (getEqualityNode(equality).getFind() == getEqualityNode(d_false).getFind()) { + // Get equivalence classes + EqualityNodeId t1ClassId = getEqualityNode(t1Id).getFind(); + EqualityNodeId t2ClassId = getEqualityNode(t2Id).getFind(); + + // We are semantically const, for remembering stuff + EqualityEngine* nonConst = const_cast(this); + + // Check for constants + if (d_isConstant[t1ClassId] && d_isConstant[t2ClassId] && t1ClassId != t2ClassId) { + if (ensureProof) { + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t1Id, t1ClassId)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t2Id, t2ClassId)); + storePropagatedDisequality(t1, t2, 2); + } return true; } - // Return whether the terms are disequal + // Check the equality itself if it exists + Node eq = t1.eqNode(t2); + if (hasTerm(eq)) { + if (getEqualityNode(eq).getFind() == getEqualityNode(d_falseId).getFind()) { + if (ensureProof) { + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(getNodeId(eq), d_falseId)); + storePropagatedDisequality(t1, t2, 1); + } + return true; + } + } + + // Check the other equality itself if it exists + eq = t2.eqNode(t1); + if (hasTerm(eq)) { + if (getEqualityNode(eq).getFind() == getEqualityNode(d_false).getFind()) { + if (ensureProof) { + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(getNodeId(eq), d_falseId)); + storePropagatedDisequality(t1, t2, 1); + } + return true; + } + } + + // Create the equality + FunctionApplication eqNormalized(true, t1ClassId, t2ClassId); + ApplicationIdsMap::const_iterator find = d_applicationLookup.find(eqNormalized); + if (find != d_applicationLookup.end()) { + if (getEqualityNode(find->second).getFind() == getEqualityNode(d_falseId).getFind()) { + if (ensureProof) { + const FunctionApplication original = d_applications[find->second].original; + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t1Id, original.a)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t2Id, original.b)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.a, t1ClassId)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.b, t2ClassId)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(find->second, d_falseId)); + storePropagatedDisequality(t1, t2, 5); + } + return true; + } + } + + // Check the symmetric disequality + std::swap(eqNormalized.a, eqNormalized.b); + find = d_applicationLookup.find(eqNormalized); + if (find != d_applicationLookup.end()) { + if (getEqualityNode(find->second).getFind() == getEqualityNode(d_falseId).getFind()) { + if (ensureProof) { + const FunctionApplication original = d_applications[find->second].original; + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t2Id, original.a)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(t1Id, original.b)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.a, t2ClassId)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(original.b, t1ClassId)); + nonConst->d_deducedDisequalityReasons.push_back(EqualityPair(find->second, d_falseId)); + storePropagatedDisequality(t1, t2, 5); + } + return true; + } + } + + // Couldn't deduce dis-equalityReturn whether the terms are disequal return false; } @@ -1026,6 +1217,10 @@ void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag) Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ", " << tag << ")" << std::endl; Assert(tag != THEORY_LAST); + if (d_done) { + return; + } + // Add the term if it's not already there addTerm(t); @@ -1040,7 +1235,9 @@ void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag) // If the term already is in the equivalence class that a tagged representative, just notify if (d_performNotify) { EqualityNodeId triggerId = getTriggerTermSet(triggerSetRef).getTrigger(tag); - d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true); + if (!d_notify.eqNotifyTriggerTermEquality(tag, t, d_nodes[triggerId], true)) { + d_done = true; + } } } else { @@ -1117,7 +1314,7 @@ void EqualityEngine::getUseListTerms(TNode t, std::set& output) { // Get the current node EqualityNode& currentNode = getEqualityNode(currentId); // Go through the use-list - UseListNodeId currentUseId = currentNode.getUseList(); + UseListNodeId currentUseId = currentNode.getUseList(); while (currentUseId != null_uselist_id) { // Get the node of the use list UseListNode& useNode = d_useListNodes[currentUseId]; @@ -1157,6 +1354,41 @@ EqualityEngine::TriggerTermSetRef EqualityEngine::newTriggerTermSet() { return newTriggerSetRef; } +bool EqualityEngine::storePropagatedDisequality(TNode lhs, TNode rhs, unsigned reasonsCount) const { + + Assert(reasonsCount > 0); + + + EqualityNodeId lhsId = getNodeId(lhs); + EqualityNodeId rhsId = getNodeId(rhs); + + // We are semantically const, just remembering stuff for later + EqualityEngine* nonConst = const_cast(this); + + Assert(d_deducedDisequalityReasons.size() >= reasonsCount); + DisequalityReasonRef ref(d_deducedDisequalityReasons.size() - reasonsCount, d_deducedDisequalityReasons.size()); + +#ifdef CVC4_ASSERTIONS + for (unsigned i = ref.mergesStart; i < ref.mergesEnd; ++ i) { + Assert(getEqualityNode(d_deducedDisequalityReasons[i].first).getFind() == getEqualityNode(d_deducedDisequalityReasons[i].second).getFind()); + } +#endif + + EqualityPair pair(lhsId, rhsId); + DisequalityReasonsMap::const_iterator find = d_disequalityReasonsMap.find(pair); + if (find == d_disequalityReasonsMap.end()) { + nonConst->d_disequalityReasonsMap[pair] = ref; + nonConst->d_deducedDisequalities.push_back(pair); + nonConst->d_deducedDisequalitiesSize = d_deducedDisequalities.size(); + std::swap(pair.first, pair.second); + nonConst->d_disequalityReasonsMap[pair] = ref; + return true; + } else { + nonConst->d_deducedDisequalities.resize(d_deducedDisequalitiesSize); + return false; + } +} + } // Namespace uf } // Namespace theory diff --git a/src/theory/uf/equality_engine.h b/src/theory/uf/equality_engine.h index 8fc57eb48..5ff8ee4dc 100644 --- a/src/theory/uf/equality_engine.h +++ b/src/theory/uf/equality_engine.h @@ -187,20 +187,6 @@ private: /** A context-dependents count of nodes */ context::CDO d_nodesCount; - /** - * At time of addition a function application can already normalize to something, so - * we keep both the original, and the normalized version. - */ - struct FunctionApplicationPair { - FunctionApplication original; - FunctionApplication normalized; - FunctionApplicationPair() {} - FunctionApplicationPair(const FunctionApplication& original, const FunctionApplication& normalized) - : original(original), normalized(normalized) {} - bool isNull() const { - return !original.isApplication(); - } - }; /** Map from ids to the applications */ std::vector d_applications; @@ -336,16 +322,6 @@ private: */ std::vector d_equalityTriggers; - struct TriggerInfo { - /** The trigger itself */ - Node trigger; - /** Polarity of the trigger */ - bool polarity; - TriggerInfo() {} - TriggerInfo(Node trigger, bool polarity) - : trigger(trigger), polarity(polarity) {} - }; - /** * Vector of original equalities of the triggers. */ @@ -515,6 +491,31 @@ private: */ std::vector d_nodeIndividualTrigger; + typedef std::hash_map DisequalityReasonsMap; + + /** + * A map from pairs of disequal terms, to the reason why we deduced they are disequal. + */ + DisequalityReasonsMap d_disequalityReasonsMap; + + /** + * A list of all the disequalities we deduced. + */ + std::vector d_deducedDisequalities; + + /** + * Context dependent size of the deduced disequalities + */ + context::CDO d_deducedDisequalitiesSize; + + /** + * For each disequality deduced, we add the pairs of equivalences needed to explain it. + */ + std::vector d_deducedDisequalityReasons; + + + bool storePropagatedDisequality(TNode lhs, TNode rhs, unsigned reasonsCount) const; + public: /** @@ -593,24 +594,19 @@ public: */ void getUseListTerms(TNode t, std::set& output); - /** - * Returns true if the two nodes are in the same equivalence class. - */ - bool areEqual(TNode t1, TNode t2) const; - /** * Get an explanation of the equality t1 = t2 begin true of false. * Returns the reasons (added when asserting) that imply it * in the assertions vector. */ - void explainEquality(TNode t1, TNode t2, bool polarity, std::vector& assertions); + void explainEquality(TNode t1, TNode t2, bool polarity, std::vector& assertions) const; /** * Get an explanation of the predicate being true or false. * Returns the reasons (added when asserting) that imply imply it * in the assertions vector. */ - void explainPredicate(TNode p, bool polarity, std::vector& assertions); + void explainPredicate(TNode p, bool polarity, std::vector& assertions) const; /** * Add term to the set of trigger terms with a corresponding tag. The notify class will get @@ -652,14 +648,14 @@ public: void addTriggerPredicate(TNode predicate); /** - * Check whether the two terms are equal. + * Returns true if the two are currently in the database and equal. */ - bool areEqual(TNode t1, TNode t2); + bool areEqual(TNode t1, TNode t2) const; /** * Check whether the two term are dis-equal. */ - bool areDisequal(TNode t1, TNode t2); + bool areDisequal(TNode t1, TNode t2, bool ensureProof) const; /** * Return the number of nodes in the equivalence class containing t @@ -667,6 +663,11 @@ public: */ size_t getSize(TNode t); + /** + * Returns true if the engine is in a consistents state. + */ + bool consistent() const { return !d_done; } + }; } // Namespace uf diff --git a/src/theory/uf/equality_engine_types.h b/src/theory/uf/equality_engine_types.h index a0d84a1ed..0baf70fcf 100644 --- a/src/theory/uf/equality_engine_types.h +++ b/src/theory/uf/equality_engine_types.h @@ -98,6 +98,16 @@ struct MergeCandidate { t1Id(x), t2Id(y), type(type), reason(reason) {} }; +/** + * Just an index into the reasons array, and the number of merges to consume. + */ +struct DisequalityReasonRef { + DefaultSizeType mergesStart; + DefaultSizeType mergesEnd; + DisequalityReasonRef(DefaultSizeType mergesStart = 0, DefaultSizeType mergesEnd = 0) + : mergesStart(mergesStart), mergesEnd(mergesEnd) {} +}; + /** * We mantaint uselist where a node appears in, and this is the node * of such a list. @@ -135,6 +145,15 @@ public: } }; +/** Main types of uselists */ +enum UseListType { + /** Use list of functions where the term appears in */ + USE_LIST_FUNCTIONS, + /** Use list of asserted disequalities */ + USE_LIST_DISEQUALITIES +}; + + /** * Main class for representing nodes in the equivalence class. The * nodes are a circular list, with the representative carrying the @@ -159,23 +178,28 @@ private: /** The use list of this node */ UseListNodeId d_useList; + /** The list of asserted disequalities that this node appears in */ + UseListNodeId d_diseqList; + public: /** * Creates a new node, which is in a list of it's own. */ EqualityNode(EqualityNodeId nodeId = null_id) - : d_size(1), - d_findId(nodeId), - d_nextId(nodeId), - d_useList(null_uselist_id) + : d_size(1) + , d_findId(nodeId) + , d_nextId(nodeId) + , d_useList(null_uselist_id) + , d_diseqList(null_uselist_id) {} /** - * Returns the function uselist. + * Returns the requested uselist. */ + template UseListNodeId getUseList() const { - return d_useList; + return type == USE_LIST_FUNCTIONS ? d_useList : d_diseqList; } /** @@ -220,24 +244,38 @@ public: * Note that this node is used in a function application funId, or * a negatively asserted equality (dis-equality) with funId. */ - template + template void usedIn(EqualityNodeId funId, memory_class& memory) { + UseListNodeId& useList = type == USE_LIST_FUNCTIONS ? d_useList : d_diseqList; UseListNodeId newUseId = memory.size(); - memory.push_back(UseListNode(funId, d_useList)); - d_useList = newUseId; + memory.push_back(UseListNode(funId, useList)); + useList = newUseId; } /** * For backtracking: remove the first element from the uselist and pop the memory. */ - template + template void removeTopFromUseList(memory_class& memory) { - Assert ((int)d_useList == (int)memory.size() - 1); - d_useList = memory.back().getNext(); + UseListNodeId& useList = type == USE_LIST_FUNCTIONS ? d_useList : d_diseqList; + Assert ((int) useList == (int)memory.size() - 1); + useList = memory.back().getNext(); memory.pop_back(); } }; +/** A pair of ids */ +typedef std::pair EqualityPair; + +struct EqualityPairHashFunction { + size_t operator () (const EqualityPair& pair) const { + size_t hash = 0; + hash = 0x9e3779b9 + pair.first; + hash ^= 0x9e3779b9 + pair.second + (hash << 6) + (hash >> 2); + return hash; + } +}; + /** * Represents the function APPLY a b. If isEquality is true then it * represents the predicate (a = b). Note that since one can not @@ -266,6 +304,35 @@ struct FunctionApplicationHashFunction { } }; +/** + * At time of addition a function application can already normalize to something, so + * we keep both the original, and the normalized version. + */ +struct FunctionApplicationPair { + FunctionApplication original; + FunctionApplication normalized; + FunctionApplicationPair() {} + FunctionApplicationPair(const FunctionApplication& original, const FunctionApplication& normalized) + : original(original), normalized(normalized) {} + bool isNull() const { + return !original.isApplication(); + } +}; + +/** + * Information about the added triggers. + */ +struct TriggerInfo { + /** The trigger itself */ + Node trigger; + /** Polarity of the trigger */ + bool polarity; + TriggerInfo() {} + TriggerInfo(Node trigger, bool polarity) + : trigger(trigger), polarity(polarity) {} + }; + + } // namespace eq } // namespace theory } // namespace CVC4 diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index 9c1229f80..ae8bdc8da 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -332,7 +332,7 @@ EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) { } // Check for disequality - if (d_equalityEngine.areDisequal(a, b)) { + if (d_equalityEngine.areDisequal(a, b, false)) { // The terms are implied to be dis-equal return EQUALITY_FALSE; } @@ -388,16 +388,14 @@ void TheoryUF::computeCareGraph() { Debug("uf::sharing") << "TheoryUf::computeCareGraph(): checking arguments " << x << " and " << y << std::endl; - EqualityStatus eqStatusUf = getEqualityStatus(x, y); - - if (eqStatusUf == EQUALITY_FALSE) { + if (d_equalityEngine.areDisequal(x, y, false)) { // Mark that there is a dis-equal pair and break somePairIsDisequal = true; Debug("uf::sharing") << "TheoryUf::computeCareGraph(): disequal, disregarding all" << std::endl; break; } - if (eqStatusUf == EQUALITY_TRUE) { + if (d_equalityEngine.areEqual(x, y)) { // We don't need this one Debug("uf::sharing") << "TheoryUf::computeCareGraph(): equal" << std::endl; continue;