From bb59480a36fb0f799af53676c07b8fca43c2fff4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dejan=20Jovanovi=C4=87?= Date: Mon, 17 Oct 2011 03:12:17 +0000 Subject: [PATCH] Sharing work --- src/prop/minisat/core/Solver.cc | 2 + src/theory/arith/arith_rewriter.h | 4 + src/theory/arith/theory_arith.cpp | 9 + src/theory/arith/theory_arith.h | 2 + src/theory/arrays/theory_arrays_rewriter.h | 11 +- src/theory/booleans/theory_bool_rewriter.h | 5 + src/theory/builtin/theory_builtin_rewriter.h | 10 +- src/theory/bv/theory_bv_rewriter.h | 4 + src/theory/datatypes/datatypes_rewriter.h | 4 + src/theory/mkrewriter | 6 + src/theory/rewriter.cpp | 5 + src/theory/rewriter.h | 18 +- src/theory/rewriter_tables_template.h | 8 + src/theory/theory.cpp | 6 +- src/theory/theory.h | 56 +++-- src/theory/theory_engine.cpp | 200 +++++++++++++++--- src/theory/theory_engine.h | 64 +++--- src/theory/uf/equality_engine.h | 71 ++++++- src/theory/uf/equality_engine_impl.h | 189 +++++++++++++++-- src/theory/uf/theory_uf.cpp | 139 ++++++++++-- src/theory/uf/theory_uf.h | 28 ++- src/theory/uf/theory_uf_rewriter.h | 11 +- src/theory/valuation.cpp | 43 +++- src/theory/valuation.h | 32 +++ src/util/node_visitor.h | 2 +- .../regress0/arith/integers/Makefile.am | 72 +------ test/regress/regress0/uflra/Makefile.am | 22 +- test/regress/regress0/uflra/simple.01.cvc | 6 + test/regress/regress0/uflra/simple.02.cvc | 10 + test/regress/regress0/uflra/simple.03.cvc | 12 ++ test/regress/regress0/uflra/simple.04.cvc | 15 ++ test/unit/theory/theory_arith_white.h | 8 +- test/unit/theory/theory_black.h | 4 +- 33 files changed, 831 insertions(+), 247 deletions(-) create mode 100644 test/regress/regress0/uflra/simple.01.cvc create mode 100644 test/regress/regress0/uflra/simple.02.cvc create mode 100644 test/regress/regress0/uflra/simple.03.cvc create mode 100644 test/regress/regress0/uflra/simple.04.cvc diff --git a/src/prop/minisat/core/Solver.cc b/src/prop/minisat/core/Solver.cc index c1795a12c..1c327c5a8 100644 --- a/src/prop/minisat/core/Solver.cc +++ b/src/prop/minisat/core/Solver.cc @@ -631,6 +631,8 @@ CRef Solver::propagate(TheoryCheckType type) if (lemmas.size() > 0) { recheck = true; return updateLemmas(); + } else { + return confl; } } diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index d88a7ae92..822514f38 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -63,6 +63,10 @@ public: static RewriteResponse preRewrite(TNode n); static RewriteResponse postRewrite(TNode n); + static Node rewriteEquality(TNode equality) { + // Arithmetic owns the domain, so this is totally ok + return Rewriter::rewrite(equality); + } static void init() { } diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index c69960d37..066eb85b5 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -1035,3 +1035,12 @@ void TheoryArith::presolve(){ learner.clear(); check(FULL_EFFORT); } + +EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) { + if (getValue(a) == getValue(b)) { + return EQUALITY_TRUE_IN_MODEL; + } else { + return EQUALITY_FALSE_IN_MODEL; + } + +} diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 4731bea30..1ba9a50e0 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -189,6 +189,8 @@ public: std::string identify() const { return std::string("TheoryArith"); } + EqualityStatus getEqualityStatus(TNode a, TNode b); + private: /** The constant zero. */ DeltaRational d_DELTA_ZERO; diff --git a/src/theory/arrays/theory_arrays_rewriter.h b/src/theory/arrays/theory_arrays_rewriter.h index 8c1c16de2..f3a19ff02 100644 --- a/src/theory/arrays/theory_arrays_rewriter.h +++ b/src/theory/arrays/theory_arrays_rewriter.h @@ -68,12 +68,7 @@ public: } if (node[0] > node[1]) { Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); - // If we've switched theories, we need to rewrite again (TODO: THIS IS HACK, once theories accept eq, change) - if (Theory::theoryOf(newNode[0]) != Theory::theoryOf(newNode[1])) { - return RewriteResponse(REWRITE_AGAIN_FULL, newNode); - } else { - return RewriteResponse(REWRITE_DONE, newNode); - } + return RewriteResponse(REWRITE_DONE, newNode); } break; } @@ -89,6 +84,10 @@ public: return RewriteResponse(REWRITE_DONE, node); } + static Node rewriteEquality(TNode node) { + return postRewrite(node).node; + } + static inline void init() {} static inline void shutdown() {} diff --git a/src/theory/booleans/theory_bool_rewriter.h b/src/theory/booleans/theory_bool_rewriter.h index 6771f775c..d26a4d478 100644 --- a/src/theory/booleans/theory_bool_rewriter.h +++ b/src/theory/booleans/theory_bool_rewriter.h @@ -37,6 +37,11 @@ public: return preRewrite(node); } + static Node rewriteEquality(TNode node) { + Unreachable(); + return node; + } + static void init() {} static void shutdown() {} diff --git a/src/theory/builtin/theory_builtin_rewriter.h b/src/theory/builtin/theory_builtin_rewriter.h index 716323b8a..e299f84e7 100644 --- a/src/theory/builtin/theory_builtin_rewriter.h +++ b/src/theory/builtin/theory_builtin_rewriter.h @@ -36,9 +36,6 @@ class TheoryBuiltinRewriter { public: static inline RewriteResponse postRewrite(TNode node) { - if(node.getKind() == kind::EQUAL) { - return Rewriter::callPostRewrite(Theory::theoryOf(node[0]), node); - } return RewriteResponse(REWRITE_DONE, node); } @@ -46,13 +43,16 @@ public: switch(node.getKind()) { case kind::DISTINCT: return RewriteResponse(REWRITE_DONE, blastDistinct(node)); - case kind::EQUAL: - return Rewriter::callPreRewrite(Theory::theoryOf(node[0]), node); default: return RewriteResponse(REWRITE_DONE, node); } } + static inline Node rewriteEquality(TNode equality) { + Unreachable(); + return equality; + } + static inline void init() {} static inline void shutdown() {} diff --git a/src/theory/bv/theory_bv_rewriter.h b/src/theory/bv/theory_bv_rewriter.h index 20da74ce8..11a55ca61 100644 --- a/src/theory/bv/theory_bv_rewriter.h +++ b/src/theory/bv/theory_bv_rewriter.h @@ -42,6 +42,10 @@ public: return RewriteResponse(REWRITE_DONE, node); } + static inline Node rewriteEquality(TNode node) { + return postRewrite(node).node; + } + static void init(); static void shutdown(); };/* class TheoryBVRewriter */ diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 14f05d14c..7a45c98aa 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -108,6 +108,10 @@ public: return RewriteResponse(REWRITE_DONE, in); } + static Node rewriteEquality(TNode equality) { + return postRewrite(equality).node; + } + static inline void init() {} static inline void shutdown() {} diff --git a/src/theory/mkrewriter b/src/theory/mkrewriter index 395317045..b8fa51d77 100755 --- a/src/theory/mkrewriter +++ b/src/theory/mkrewriter @@ -47,6 +47,8 @@ post_rewrite_calls= post_rewrite_get_cache= post_rewrite_set_cache= +rewrite_equality_calls= + seen_theory=false seen_theory_builtin=false @@ -132,6 +134,9 @@ function rewriter { post_rewrite_set_cache="${post_rewrite_set_cache} case ${theory_id}: return RewriteAttibute<${theory_id}>::setPostRewriteCache(node, cache); " + rewrite_equality_calls="${rewrite_equality_calls} case ${theory_id}: return ${class}::rewriteEquality(node); +" + lineno=${BASH_LINENO[0]} check_theory_seen } @@ -224,6 +229,7 @@ for var in \ rewriter_includes \ pre_rewrite_calls \ post_rewrite_calls \ + rewrite_equality_calls \ pre_rewrite_get_cache \ post_rewrite_get_cache \ pre_rewrite_set_cache \ diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index f6aa75bbd..fddbbcd13 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -62,6 +62,11 @@ Node Rewriter::rewrite(Node node) { return rewriteTo(theory::Theory::theoryOf(node), node); } +Node Rewriter::rewriteEquality(theory::TheoryId theoryId, TNode node) { + Trace("rewriter") << "Rewriter::rewriteEquality(" << theoryId << "," << node << ")"<< std::endl; + return Rewriter::callRewriteEquality(theoryId, node); +} + Node Rewriter::rewriteTo(theory::TheoryId theoryId, Node node) { Trace("rewriter") << "Rewriter::rewriteTo(" << theoryId << "," << node << ")"<< std::endl; diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index 30267ce48..dacc4d212 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -73,7 +73,10 @@ class Rewriter { Rewriter() CVC4_UNUSED; Rewriter(const Rewriter&) CVC4_UNUSED; -public: + /** + * Rewrites the node using the given theory rewriter. + */ + static Node rewriteTo(theory::TheoryId theoryId, Node node); /** Calls the pre-rewriter for the given theory */ static RewriteResponse callPreRewrite(theory::TheoryId theoryId, TNode node); @@ -81,6 +84,14 @@ public: /** Calls the post-rewriter for the given theory */ static RewriteResponse callPostRewrite(theory::TheoryId theoryId, TNode node); + /** + * Calls the equality-rewruter for the given theory. + */ + static Node callRewriteEquality(theory::TheoryId theoryId, TNode equality); + +public: + + /** * Rewrites the node using theoryOf() to determine which rewriter to * use on the node. @@ -88,9 +99,10 @@ public: static Node rewrite(Node node); /** - * Rewrites the node using the given theory rewriter. + * Rewrite an equality between two terms that are already in normal form, so + * that the equality is in theory-normal form. */ - static Node rewriteTo(theory::TheoryId theoryId, Node node); + static Node rewriteEquality(theory::TheoryId theoryId, TNode node); /** * Should be called before the rewriter gets used for the first time. diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index 9ab98168e..cd79dcd9f 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -45,6 +45,14 @@ ${post_rewrite_calls} } } +Node Rewriter::callRewriteEquality(theory::TheoryId theoryId, TNode node) { + switch(theoryId) { +${rewrite_equality_calls} + default: + Unreachable(); + } +} + Node Rewriter::getPreRewriteCache(theory::TheoryId theoryId, TNode node) { switch(theoryId) { ${pre_rewrite_get_cache} diff --git a/src/theory/theory.cpp b/src/theory/theory.cpp index 1451f654a..ff2feb121 100644 --- a/src/theory/theory.cpp +++ b/src/theory/theory.cpp @@ -61,9 +61,9 @@ void Theory::computeCareGraph(CareGraph& careGraph) { // We don't care about the terms of different types continue; } - switch (equalityStatus(a, b)) { - case EQUALITY_TRUE: - case EQUALITY_FALSE: + switch (getEqualityStatus(a, b)) { + case EQUALITY_TRUE_AND_PROPAGATED: + case EQUALITY_FALSE_AND_PROPAGATED: // If we know about it, we should have propagated it, so we can skip break; default: diff --git a/src/theory/theory.h b/src/theory/theory.h index 17c9ef14a..d11d28aec 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -42,19 +42,31 @@ class TheoryEngine; namespace theory { /** - * The status of an equality in the current context. + * Information about an assertion for the theories. */ -enum EqualityStatus { - /** The eqaulity is known to be true */ - EQUALITY_TRUE, - /** The equality is known to be false */ - EQUALITY_FALSE, - /** The equality is not known, but is true in the current model */ - EQUALITY_TRUE_IN_MODEL, - /** The equality is not known, but is false in the current model */ - EQUALITY_FALSE_IN_MODEL, - /** The equality is completely unknown */ - EQUALITY_UNKNOWN +struct Assertion { + + /** The assertion */ + Node assertion; + /** Has this assertion been preregistered with this theory */ + bool isPreregistered; + + Assertion(TNode assertion, bool isPreregistered) + : assertion(assertion), isPreregistered(isPreregistered) {} + + /** + * Convert the assertion to a TNode. + */ + operator TNode () const { + return assertion; + } + + /** + * Convert the assertion to a Node. + */ + operator Node () const { + return assertion; + } }; /** @@ -113,7 +125,7 @@ private: * These can not be TNodes as some atoms (such as equalities) are sent * across theories without being stored in a global map. */ - context::CDList d_facts; + context::CDList d_facts; /** Index into the head of the facts list */ context::CDO d_factsHead; @@ -179,15 +191,15 @@ protected: * * @return the next assertion in the assertFact() queue */ - TNode get() { + Assertion get() { Assert( !done(), "Theory`() called with assertion queue empty!" ); // Get the assertion - TNode fact = d_facts[d_factsHead]; + Assertion fact = d_facts[d_factsHead]; d_factsHead = d_factsHead + 1; Trace("theory") << "Theory::get() => " << fact << " (" << d_facts.size() - d_factsHead << " left)" << std::endl; if(Dump.isOn("state")) { - Dump("state") << AssertCommand(fact.toExpr()) << std::endl; + Dump("state") << AssertCommand(fact.assertion.toExpr()) << std::endl; } return fact; @@ -199,7 +211,7 @@ protected: * * @return the iterator to the beginning of the fact queue */ - context::CDList::const_iterator facts_begin() const { + context::CDList::const_iterator facts_begin() const { return d_facts.begin(); } @@ -209,7 +221,7 @@ protected: * * @return the iterator to the end of the fact queue */ - context::CDList::const_iterator facts_end() const { + context::CDList::const_iterator facts_end() const { return d_facts.end(); } @@ -363,9 +375,9 @@ public: /** * Assert a fact in the current context. */ - void assertFact(TNode assertion) { - Trace("theory") << "Theory<" << getId() << ">::assertFact(" << assertion << ")" << std::endl; - d_facts.push_back(assertion); + void assertFact(TNode assertion, bool isPreregistered) { + Trace("theory") << "Theory<" << getId() << ">::assertFact(" << assertion << ", " << (isPreregistered ? "true" : "false") << ")" << std::endl; + d_facts.push_back(Assertion(assertion, isPreregistered)); } /** @@ -384,7 +396,7 @@ public: * Return the status of two terms in the current context. Should be implemented in * sub-theories to enable more efficient theory-combination. */ - virtual EqualityStatus equalityStatus(TNode a, TNode b) { return EQUALITY_UNKNOWN; } + virtual EqualityStatus getEqualityStatus(TNode a, TNode b) { return EQUALITY_UNKNOWN; } /** * This method is called by the shared term manager when a shared diff --git a/src/theory/theory_engine.cpp b/src/theory/theory_engine.cpp index 2cd3f4d72..c03b09be2 100644 --- a/src/theory/theory_engine.cpp +++ b/src/theory/theory_engine.cpp @@ -42,12 +42,13 @@ TheoryEngine::TheoryEngine(context::Context* context, : d_propEngine(NULL), d_context(context), d_userContext(userContext), - d_activeTheories(0), + d_activeTheories(context, 0), d_sharedTerms(context), d_atomPreprocessingCache(), d_possiblePropagations(), d_hasPropagated(context), d_inConflict(context, false), + d_sharedTermsExist(context, false), d_hasShutDown(false), d_incomplete(context, false), d_sharedAssertions(context), @@ -91,6 +92,8 @@ void TheoryEngine::preRegister(TNode preprocessed) { if (multipleTheories) { // Collect the shared terms if there are multipe theories NodeVisitor::run(d_sharedTermsVisitor, preprocessed); + // Mark the multiple theories flag + d_sharedTermsExist = true; } } @@ -124,6 +127,8 @@ void TheoryEngine::check(Theory::Effort effort) { while (true) { + Debug("theory") << "TheoryEngine::check(" << effort << "): running check" << std::endl; + // Do the checking CVC4_FOR_EACH_THEORY; @@ -133,26 +138,32 @@ void TheoryEngine::check(Theory::Effort effort) { << CheckSatCommand() << endl; } + Debug("theory") << "TheoryEngine::check(" << effort << "): running propagation after the initial check" << std::endl; + // We are still satisfiable, propagate as much as possible propagate(effort); // If we have any propagated equalities, we enqueue them to the theories and re-check if (d_propagatedEqualities.size() > 0) { + Debug("theory") << "TheoryEngine::check(" << effort << "): distributing shared equalities" << std::endl; assertSharedEqualities(); continue; } // If we added any lemmas, we're done if (d_lemmasAdded) { + Debug("theory") << "TheoryEngine::check(" << effort << "): lemmas added, done" << std::endl; break; } // If in full check and no lemmas added, run the combination - if (Theory::fullEffort(effort)) { + if (Theory::fullEffort(effort) && d_sharedTermsExist) { // Do the combination + Debug("theory") << "TheoryEngine::check(" << effort << "): running combination" << std::endl; combineTheories(); // If we have any propagated equalities, we enqueue them to the theories and re-check if (d_propagatedEqualities.size() > 0) { + Debug("theory") << "TheoryEngine::check(" << effort << "): distributing shared equalities" << std::endl; assertSharedEqualities(); } else { // No propagated equalities, we're either sat, or there are lemmas added @@ -166,6 +177,8 @@ void TheoryEngine::check(Theory::Effort effort) { // Clear any leftover propagated equalities d_propagatedEqualities.clear(); + Debug("theory") << "TheoryEngine::check(" << effort << "): done, we are " << (d_inConflict ? "unsat" : "sat") << (d_lemmasAdded ? " with new lemmas" : " with no new lemmas") << std::endl; + } catch(const theory::Interrupted&) { Trace("theory") << "TheoryEngine::check() => conflict" << endl; } @@ -176,9 +189,11 @@ void TheoryEngine::assertSharedEqualities() { for (unsigned i = 0; i < d_propagatedEqualities.size(); ++ i) { const SharedEquality& eq = d_propagatedEqualities[i]; // Check if the theory already got this one - if (d_sharedAssertions.find(eq.toAssert) != d_sharedAssertions.end()) { + // TODO: the real shared (non-sat) equalities + if (d_sharedAssertions.find(eq.toAssert) == d_sharedAssertions.end()) { + Debug("sharing") << "TheoryEngine::assertSharedEqualities(): asserting " << eq.toAssert.node << " to " << eq.toAssert.theory << " from " << eq.toExplain.theory << std::endl; d_sharedAssertions[eq.toAssert] = eq.toExplain; - theoryOf(eq.toAssert.theory)->assertFact(eq.toAssert.node); + theoryOf(eq.toAssert.theory)->assertFact(eq.toAssert.node, false); } } // Clear the equalities @@ -188,6 +203,8 @@ void TheoryEngine::assertSharedEqualities() { void TheoryEngine::combineTheories() { + Debug("sharing") << "TheoryEngine::combineTheories()" << std::endl; + CareGraph careGraph; #ifdef CVC4_FOR_EACH_THEORY_STATEMENT #undef CVC4_FOR_EACH_THEORY_STATEMENT @@ -205,20 +222,31 @@ void TheoryEngine::combineTheories() { for (unsigned i = 0; i < careGraph.size(); ++ i) { const CarePair& carePair = careGraph[i]; + Debug("sharing") << "TheoryEngine::combineTheories(): checking " << carePair.a << " = " << carePair.b << " from " << carePair.theory << std::endl; + Node equality = carePair.a.eqNode(carePair.b); Node normalizedEquality = Rewriter::rewrite(equality); // If the node has a literal, it has been asserted so we should just check it bool value; if (d_propEngine->isSatLiteral(normalizedEquality) && d_propEngine->hasValue(normalizedEquality, value)) { + Debug("sharing") << "TheoryEngine::combineTheories(): has a literal " << std::endl; + // Normalize the equality to the theory that requested it - Node toAssert = Rewriter::rewriteTo(carePair.theory, equality); + Node toAssert = Rewriter::rewriteEquality(carePair.theory, equality); + if (value) { - d_propagatedEqualities.push_back(SharedEquality(toAssert, normalizedEquality, theory::THEORY_LAST, carePair.theory)); + SharedEquality sharedEquality(toAssert, normalizedEquality, theory::THEORY_LAST, carePair.theory); + Assert(d_sharedAssertions.find(sharedEquality.toAssert) == d_sharedAssertions.end()); + d_propagatedEqualities.push_back(sharedEquality); } else { - d_propagatedEqualities.push_back(SharedEquality(toAssert.notNode(), normalizedEquality.notNode(), theory::THEORY_LAST, carePair.theory)); + SharedEquality sharedEquality(toAssert.notNode(), normalizedEquality.notNode(), theory::THEORY_LAST, carePair.theory); + Assert(d_sharedAssertions.find(sharedEquality.toAssert) == d_sharedAssertions.end()); + d_propagatedEqualities.push_back(sharedEquality); } } else { + Debug("sharing") << "TheoryEngine::combineTheories(): requesting a split " << std::endl; + // There is no value, so we need to split on it lemma(equality.orNode(equality.notNode()), false, false); } @@ -251,16 +279,6 @@ void TheoryEngine::propagate(Theory::Effort effort) { } } -Node TheoryEngine::getExplanation(TNode node, theory::Theory* theory) { - Node explanation = theory->explain(node); - if(Dump.isOn("t-explanations")) { - Dump("t-explanations") - << CommentCommand(string("theory explanation from ") + theory->identify() + ": expect valid") << endl - << QueryCommand(explanation.impNode(node).toExpr()) << endl; - } - return explanation; -} - bool TheoryEngine::properConflict(TNode conflict) const { bool value; if (conflict.getKind() == kind::AND) { @@ -497,7 +515,7 @@ void TheoryEngine::assertFact(TNode node) TNode atom = node.getKind() == kind::NOT ? node[0] : node; // Assert the fact to the apropriate theory - theoryOf(atom)->assertFact(node); + theoryOf(atom)->assertFact(node, true); // If any shared terms, notify the theories if (d_sharedTerms.hasSharedTerms(atom)) { @@ -512,13 +530,14 @@ void TheoryEngine::assertFact(TNode node) } } d_sharedTerms.markNotified(term, theories); + markActive(theories); } } } void TheoryEngine::propagate(TNode literal, theory::TheoryId theory) { - Debug("theory") << "EngineOutputChannel::propagate(" << literal << ")" << std::endl; + Debug("theory") << "EngineOutputChannel::propagate(" << literal << ", " << theory << ")" << std::endl; d_propEngine->checkTime(); @@ -530,7 +549,9 @@ void TheoryEngine::propagate(TNode literal, theory::TheoryId theory) { d_hasPropagated.insert(literal); } - if (literal.getKind() != kind::EQUAL) { + TNode atom = literal.getKind() == kind::NOT ? literal[0] : literal; + + if (!d_sharedTermsExist || atom.getKind() != kind::EQUAL) { // If not an equality it must be a SAT literal so we enlist it, and it can only // be propagated by the theory that owns it, as it is the only one that got // a preregister call with it. @@ -541,24 +562,139 @@ void TheoryEngine::propagate(TNode literal, theory::TheoryId theory) { Node normalizedEquality = Rewriter::rewrite(literal); if (d_propEngine->isSatLiteral(normalizedEquality)) { // If there is a literal, just enqueue it, same as above - d_propagatedLiterals.push_back(literal); - } else { - // Otherwise, we assert it to all interested theories + d_propagatedLiterals.push_back(normalizedEquality); + } + // Otherwise, we assert it to all interested theories + Theory::Set lhsNotified = d_sharedTerms.getNotifiedTheories(atom[0]); + Theory::Set rhsNotified = d_sharedTerms.getNotifiedTheories(atom[1]); + // Now notify the theories + if (Theory::setIntersection(lhsNotified, rhsNotified) != 0) { bool negated = literal.getKind() == kind::NOT; - Node equality = negated ? literal[0] : literal; - Theory::Set lhsNotified = d_sharedTerms.getNotifiedTheories(equality[0]); - Theory::Set rhsNotified = d_sharedTerms.getNotifiedTheories(equality[1]); - // Now notify the theories - for (TheoryId current = theory::THEORY_FIRST; current != theory::THEORY_LAST; ++ current) { - if (Theory::setContains(current, lhsNotified) && Theory::setContains(current, rhsNotified)) { + for (TheoryId currentTheory = theory::THEORY_FIRST; currentTheory != theory::THEORY_LAST; ++ currentTheory) { + if (currentTheory == theory) { + // Don't reassert to the same theory + continue; + } + // Assert if theory is interested in both terms + if (Theory::setContains(currentTheory, lhsNotified) && Theory::setContains(currentTheory, rhsNotified)) { // Normalize the equality - Node equalityNormalized = Rewriter::rewriteTo(current, equality); + Node equality = Rewriter::rewriteEquality(currentTheory, atom); // The assertion - Node assertion = negated ? equalityNormalized.notNode() : equalityNormalized; + Node assertion = negated ? equality.notNode() : equality; // Remember it to assert later - d_propagatedEqualities.push_back(SharedEquality(assertion, literal, theory, current)); + d_propagatedEqualities.push_back(SharedEquality(assertion, literal, theory, currentTheory)); } } } } } + +theory::EqualityStatus TheoryEngine::getEqualityStatus(TNode a, TNode b) { + Assert(a.getType() == b.getType()); + return theoryOf(Theory::theoryOf(a.getType()))->getEqualityStatus(a, b); +} + +Node TheoryEngine::getExplanation(TNode node) +{ + Debug("theory") << "TheoryEngine::getExplanation(" << node << ")" << std::endl; + + TNode atom = node.getKind() == kind::NOT ? node[0] : node; + + Node explanation; + + // The theory that has explaining to do + Theory* theory = theoryOf(atom); + if (d_sharedTermsExist && atom.getKind() == kind::EQUAL) { + SharedAssertionsMap::iterator find = d_sharedAssertions.find(NodeTheoryPair(node, theory::THEORY_LAST)); + if (find == d_sharedAssertions.end()) { + theory = theoryOf(atom); + } + } + + // Get the explanation + explanation = theory->explain(node); + + // Explain any shared equalities that might be in the explanation + if (d_sharedTermsExist) { + NodeBuilder<> nb(kind::AND); + explainEqualities(theory->getId(), explanation, nb); + explanation = nb; + } + + Assert(!explanation.isNull()); + if(Dump.isOn("t-explanations")) { + Dump("t-explanations") << CommentCommand(std::string("theory explanation from ") + theoryOf(atom)->identify() + ": expect valid") << std::endl + << QueryCommand(explanation.impNode(node).toExpr()) << std::endl; + } + Assert(properExplanation(node, explanation)); + + return explanation; +} + +void TheoryEngine::conflict(TNode conflict, TheoryId theoryId) { + + // Mark that we are in conflict + d_inConflict = true; + + if(Dump.isOn("t-conflicts")) { + Dump("t-conflicts") << CommentCommand("theory conflict: expect unsat") << std::endl + << CheckSatCommand(conflict.toExpr()) << std::endl; + } + + if (d_sharedTermsExist) { + // In the multiple-theories case, we need to reconstruct the conflict + NodeBuilder<> nb(kind::AND); + explainEqualities(theoryId, conflict, nb); + Node fullConflict = nb; + Assert(properConflict(fullConflict)); + Debug("theory") << "TheoryEngine::conflict(" << conflict << ", " << theoryId << "): " << fullConflict << std::endl; + lemma(fullConflict, true, false); + } else { + // When only one theory, the conflict should need no processing + Assert(properConflict(conflict)); + lemma(conflict, true, false); + } +} + +void TheoryEngine::explainEqualities(TheoryId theoryId, TNode literals, NodeBuilder<>& builder) { + Debug("theory") << "TheoryEngine::explainEqualities(" << theoryId << ", " << literals << ")" << std::endl; + if (literals.getKind() == kind::AND) { + for (unsigned i = 0, i_end = literals.getNumChildren(); i != i_end; ++ i) { + TNode literal = literals[i]; + TNode atom = literal.getKind() == kind::NOT ? literal[0] : literal; + if (atom.getKind() == kind::EQUAL) { + explainEquality(theoryId, literal, builder); + } else { + builder << literal; + } + } + } else { + TNode atom = literals.getKind() == kind::NOT ? literals[0] : literals; + if (atom.getKind() == kind::EQUAL) { + explainEquality(theoryId, literals, builder); + } else { + builder << literals; + } + } +} + +void TheoryEngine::explainEquality(TheoryId theoryId, TNode eqLiteral, NodeBuilder<>& builder) { + Assert(eqLiteral.getKind() == kind::EQUAL || (eqLiteral.getKind() == kind::NOT && eqLiteral[0].getKind() == kind::EQUAL)); + + SharedAssertionsMap::iterator find = d_sharedAssertions.find(NodeTheoryPair(eqLiteral, theoryId)); + if (find == d_sharedAssertions.end()) { + // Not a shared assertion, just add it since it must be SAT literal + builder << eqLiteral; + } else { + TheoryId explainingTheory = (*find).second.theory; + if (explainingTheory == theory::THEORY_LAST) { + // If the theory is from the SAT solver, just take the normalized one + builder << (*find).second.node; + } else { + // Explain it using the theory that propagated it + Node explanation = theoryOf(explainingTheory)->explain((*find).second.node); + explainEqualities(explainingTheory, explanation, builder); + } + } +} + diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index be3068a45..80890303b 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -96,7 +96,7 @@ class TheoryEngine { * runs (no sharing), can reduce the cost of walking the DAG on * registration, etc. */ - theory::Theory::Set d_activeTheories; + context::CDO d_activeTheories; /** * The database of shared terms. @@ -191,7 +191,7 @@ class TheoryEngine { void conflict(TNode conflictNode) throw(AssertionException) { Trace("theory") << "EngineOutputChannel<" << d_theory << ">::conflict(" << conflictNode << ")" << std::endl; ++ d_statistics.conflicts; - d_engine->conflict(conflictNode); + d_engine->conflict(conflictNode, d_theory); } void propagate(TNode literal) throw(AssertionException) { @@ -226,21 +226,26 @@ class TheoryEngine { */ context::CDO d_inConflict; - void conflict(TNode conflict) { - - Assert(properConflict(conflict), "not a proper conflict: %s", conflict.toString().c_str()); + /** + * Does the context contain terms shared among multiple theories. + */ + context::CDO d_sharedTermsExist; - // Mark that we are in conflict - d_inConflict = true; + /** + * Explain the equality literals and push all the explaining literals into the builder. All + * the non-equality literals are pushed to the builder. + */ + void explainEqualities(theory::TheoryId theoryId, TNode literals, NodeBuilder<>& builder); - if(Dump.isOn("t-conflicts")) { - Dump("t-conflicts") << CommentCommand("theory conflict: expect unsat") << std::endl - << CheckSatCommand(conflict.toExpr()) << std::endl; - } + /** + * Same as above, but just for single equalities. + */ + void explainEquality(theory::TheoryId theoryId, TNode eqLiteral, NodeBuilder<>& builder); - // Construct the lemma (note that no CNF caching should happen as all the literals already exists) - lemma(conflict, true, false); - } + /** + * Called by the theories to notify of a conflict. + */ + void conflict(TNode conflict, theory::TheoryId theoryId); /** * Debugging flag to ensure that shutdown() is called before the @@ -282,15 +287,20 @@ class TheoryEngine { NodeTheoryPair toExplain; SharedEquality(TNode assertion, TNode original, theory::TheoryId sendingTheory, theory::TheoryId receivingTheory) - : toAssert(assertion, sendingTheory), - toExplain(original, receivingTheory) + : toAssert(assertion, receivingTheory), + toExplain(original, sendingTheory) { } }; + /** + * Map from equalities asserted to a theory, to the theory that can explain them. + */ + typedef context::CDMap SharedAssertionsMap; + /** * A map from asserted facts to where they came from (for explanations). */ - context::CDMap d_sharedAssertions; + SharedAssertionsMap d_sharedAssertions; /** * Assert a shared equalities propagated by theories. @@ -480,23 +490,11 @@ public: } } - Node getExplanation(TNode node, theory::Theory* theory); - bool properConflict(TNode conflict) const; bool properPropagation(TNode lit) const; bool properExplanation(TNode node, TNode expl) const; - inline Node getExplanation(TNode node) { - TNode atom = node.getKind() == kind::NOT ? node[0] : node; - Node explanation = theoryOf(atom)->explain(node); - Assert(!explanation.isNull()); - if(Dump.isOn("t-explanations")) { - Dump("t-explanations") << CommentCommand(std::string("theory explanation from ") + theoryOf(atom)->identify() + ": expect valid") << std::endl - << QueryCommand(explanation.impNode(node).toExpr()) << std::endl; - } - Assert(properExplanation(node, explanation)); - return explanation; - } + Node getExplanation(TNode node); Node getValue(TNode node); @@ -522,6 +520,12 @@ public: return d_theoryTable[theoryId]; } + /** + * Returns the equality status of the two terms, from the theory that owns the domain type. + * The types of a and b must be the same. + */ + theory::EqualityStatus getEqualityStatus(TNode a, TNode b); + private: /** Default visitor for pre-registration */ diff --git a/src/theory/uf/equality_engine.h b/src/theory/uf/equality_engine.h index 18a525f2d..13b8898d5 100644 --- a/src/theory/uf/equality_engine.h +++ b/src/theory/uf/equality_engine.h @@ -141,7 +141,11 @@ public: * Creates a new node, which is in a list of it's own. */ EqualityNode(EqualityNodeId nodeId = null_id) - : d_size(1), d_findId(nodeId), d_nextId(nodeId), d_useList(null_uselist_id) {} + : d_size(1), + d_findId(nodeId), + d_nextId(nodeId), + d_useList(null_uselist_id) + {} /** * Retuerns the uselist. @@ -266,6 +270,12 @@ public: private: + /** The context we are using */ + context::Context* d_context; + + /** Whether to notify or not (temporarily disabled on equality checks) */ + bool d_performNotify; + /** The class to notify when a representative changes for a term */ NotifyClass d_notify; @@ -454,6 +464,21 @@ private: */ std::vector d_nodeTriggers; + /** + * List of terms that are marked as individual triggers. + */ + std::vector d_individualTriggers; + + /** + * Size of the individual triggers list. + */ + context::CDO d_individualTriggersSize; + + /** + * Map from ids to the individual trigger id representative. + */ + std::vector d_nodeIndividualTrigger; + /** * Adds the trigger with triggerId to the beginning of the trigger list of the node with id nodeId. */ @@ -503,6 +528,11 @@ private: */ void debugPrintGraph() const; + /** The true node */ + Node d_true; + /** The false node */ + Node d_false; + public: /** @@ -511,15 +541,20 @@ public: */ EqualityEngine(NotifyClass& notify, context::Context* context, std::string name) : ContextNotifyObj(context), + d_context(context), + d_performNotify(true), d_notify(notify), d_nodesCount(context, 0), d_assertedEqualitiesCount(context, 0), d_equalityTriggersCount(context, 0), + d_individualTriggersSize(context, 0), d_stats(name) { Debug("equality") << "EqualityEdge::EqualityEngine(): id_null = " << +null_id << std::endl; Debug("equality") << "EqualityEdge::EqualityEngine(): edge_null = " << +null_edge << std::endl; Debug("equality") << "EqualityEdge::EqualityEngine(): trigger_null = " << +null_trigger << std::endl; + d_true = NodeManager::currentNM()->mkConst(true); + d_false = NodeManager::currentNM()->mkConst(false); } /** @@ -560,6 +595,11 @@ public: */ void addEquality(TNode t1, TNode t2, TNode reason); + /** + * Adds an dis-equality t1 != t2 to the database. + */ + void addDisequality(TNode t1, TNode t2, TNode reason); + /** * Returns the representative of the term t. */ @@ -577,12 +617,39 @@ public: */ void getExplanation(TNode t1, TNode t2, std::vector& equalities) const; + /** + * Add term to the trigger terms. The notify class will get notified when two + * trigger terms become equal. Thihs will only happen on trigger term + * representatives. + */ + void addTriggerTerm(TNode t); + + /** + * Returns true if t is a trigger term or equal to some other trigger term. + */ + bool isTriggerTerm(TNode t) const; + /** * Adds a notify trigger for equality t1 = t2, i.e. when t1 = t2 the notify will be called with - * (t1 = t2). + * trigger. */ void addTriggerEquality(TNode t1, TNode t2, TNode trigger); + /** + * Adds a notify trigger for dis-equality t1 != t2, i.e. when t1 != t2 the notify will be called with + * trigger. + */ + void addTriggerDisequality(TNode t1, TNode t2, TNode trigger); + + /** + * Check whether the two terms are equal. + */ + bool areEqual(TNode t1, TNode t2); + + /** + * Check whether the two term are dis-equal. + */ + bool areDisequal(TNode t1, TNode t2); }; } // Namespace uf diff --git a/src/theory/uf/equality_engine_impl.h b/src/theory/uf/equality_engine_impl.h index b31d04a32..77c8e80b4 100644 --- a/src/theory/uf/equality_engine_impl.h +++ b/src/theory/uf/equality_engine_impl.h @@ -91,6 +91,8 @@ EqualityNodeId EqualityEngine::newNode(TNode node, bool isApplicati d_nodeTriggers.push_back(+null_trigger); // Add it to the equality graph d_equalityGraph.push_back(+null_edge); + // Mark the no-individual trigger + d_nodeIndividualTrigger.push_back(+null_id); // Add the equality node to the nodes d_equalityNodes.push_back(EqualityNode(newId)); @@ -176,14 +178,26 @@ void EqualityEngine::addEquality(TNode t1, TNode t2, TNode reason) Debug("equality") << "EqualityEngine::addEquality(" << t1 << "," << t2 << ")" << std::endl; // Add the terms if they are not already in the database - EqualityNodeId t1Id = getNodeId(t1); - EqualityNodeId t2Id = getNodeId(t2); + addTerm(t1); + addTerm(t2); // Add to the queue and propagate + EqualityNodeId t1Id = getNodeId(t1); + EqualityNodeId t2Id = getNodeId(t2); enqueue(MergeCandidate(t1Id, t2Id, MERGED_THROUGH_EQUALITY, reason)); propagate(); } +template +void EqualityEngine::addDisequality(TNode t1, TNode t2, TNode reason) { + + Debug("equality") << "EqualityEngine::addDisequality(" << t1 << "," << t2 << ")" << std::endl; + + Node equality = t1.eqNode(t2); + addEquality(equality, d_false, reason); +} + + template TNode EqualityEngine::getRepresentative(TNode t) const { @@ -304,6 +318,24 @@ void EqualityEngine::merge(EqualityNode& class1, EqualityNode& clas // Now merge the lists class1.merge(class2); + + // Notfiy the triggers + EqualityNodeId class1triggerId = d_nodeIndividualTrigger[class1Id]; + EqualityNodeId class2triggerId = d_nodeIndividualTrigger[class2Id]; + if (class2triggerId != +null_id) { + if (class1triggerId == +null_id) { + // If class1 is not an individual trigger, but class2 is, mark it + d_nodeIndividualTrigger[class1Id] = class2triggerId; + // Add it to the list for backtracking + d_individualTriggers.push_back(class1Id); + d_individualTriggersSize = d_individualTriggersSize + 1; + } else { + // Notify when done + if (d_performNotify) { + d_notify.notify(d_nodes[class1triggerId], d_nodes[class2triggerId]); + } + } + } } template @@ -437,11 +469,19 @@ void EqualityEngine::backtrack() { d_equalityEdges.resize(2 * d_assertedEqualitiesCount); } + if (d_individualTriggers.size() > d_individualTriggersSize) { + // Unset the individual triggers + for (int i = d_individualTriggers.size() - 1, i_end = d_individualTriggersSize; i >= i_end; -- i) { + d_nodeIndividualTrigger[d_individualTriggers[i]] = +null_id; + } + d_individualTriggers.resize(d_individualTriggersSize); + } + if (d_equalityTriggers.size() > d_equalityTriggersCount) { // Unlink the triggers from the lists for (int i = d_equalityTriggers.size() - 1, i_end = d_equalityTriggersCount; i >= i_end; -- i) { - const Trigger& trigger = d_equalityTriggers[i]; - d_nodeTriggers[trigger.classId] = trigger.nextTrigger; + const Trigger& trigger = d_equalityTriggers[i]; + d_nodeTriggers[trigger.classId] = trigger.nextTrigger; } // Get rid of the triggers d_equalityTriggers.resize(d_equalityTriggersCount); @@ -470,6 +510,7 @@ void EqualityEngine::backtrack() { d_nodes.resize(d_nodesCount); d_applications.resize(d_nodesCount); d_nodeTriggers.resize(d_nodesCount); + d_nodeIndividualTrigger.resize(d_nodesCount); d_equalityGraph.resize(d_nodesCount); d_equalityNodes.resize(d_nodesCount); } @@ -613,6 +654,13 @@ void EqualityEngine::getExplanation(EqualityNodeId t1Id, EqualityNo } } +template +void EqualityEngine::addTriggerDisequality(TNode t1, TNode t2, TNode trigger) { + Node equality = t1.eqNode(t2); + addTerm(equality); + addTriggerEquality(equality, d_false, trigger); +} + template void EqualityEngine::addTriggerEquality(TNode t1, TNode t2, TNode trigger) { @@ -651,7 +699,9 @@ void EqualityEngine::addTriggerEquality(TNode t1, TNode t2, TNode t // If the trigger is already on, we propagate if (t1classId == t2classId) { Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << "): triggered at setup time" << std::endl; - d_notify.notifyEquality(trigger); // Don't care about the return value + if (d_performNotify) { + d_notify.notify(trigger); // Don't care about the return value + } } Debug("equality") << "EqualityEngine::addTrigger(" << t1 << "," << t2 << ") => (" << t1NewTriggerId << ", " << t2NewTriggerId << ")" << std::endl; @@ -690,31 +740,30 @@ void EqualityEngine::propagate() { Assert(node1.getFind() == t1classId); Assert(node2.getFind() == t2classId); + // Add the actuall equality to the equality graph + addGraphEdge(current.t1Id, current.t2Id, current.type, current.reason); + + // One more equality added + d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1; + // Depending on the merge preference (such as size), merge them std::vector triggers; if (node2.getSize() > node1.getSize()) { Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t1Id]<< " into " << d_nodes[current.t2Id] << std::endl; - merge(node2, node1, triggers); d_assertedEqualities.push_back(Equality(t2classId, t1classId)); + merge(node2, node1, triggers); } else { Debug("equality") << "EqualityEngine::propagate(): merging " << d_nodes[current.t2Id] << " into " << d_nodes[current.t1Id] << std::endl; - merge(node1, node2, triggers); d_assertedEqualities.push_back(Equality(t1classId, t2classId)); + merge(node1, node2, triggers); } - // Add the actuall equality to the equality graph - addGraphEdge(current.t1Id, current.t2Id, current.type, current.reason); - - // One more equality added - d_assertedEqualitiesCount = d_assertedEqualitiesCount + 1; - - Assert(2*d_assertedEqualities.size() == d_equalityEdges.size()); - Assert(d_assertedEqualities.size() == d_assertedEqualitiesCount); - // Notify the triggers - for (size_t trigger = 0, trigger_end = triggers.size(); trigger < trigger_end && !done; ++ trigger) { - // Notify the trigger and exit if it fails - done = !d_notify.notifyEquality(d_equalityTriggersOriginal[triggers[trigger]]); + if (d_performNotify) { + for (size_t trigger = 0, trigger_end = triggers.size(); trigger < trigger_end && !done; ++ trigger) { + // Notify the trigger and exit if it fails + done = !d_notify.notify(d_equalityTriggersOriginal[triggers[trigger]]); + } } } } @@ -736,6 +785,108 @@ void EqualityEngine::debugPrintGraph() const { } } +class ScopedBool { + bool& watch; + bool oldValue; +public: + ScopedBool(bool& watch, bool newValue) + : watch(watch), oldValue(watch) { + watch = newValue; + } + ~ScopedBool() { + watch = oldValue; + } +}; + +template +bool EqualityEngine::areEqual(TNode t1, TNode t2) +{ + // Don't notify during this check + ScopedBool turnOfNotify(d_performNotify, false); + + // Push the context, so that we can remove the terms later + d_context->push(); + + // Add the terms + addTerm(t1); + addTerm(t2); + bool equal = getEqualityNode(t1).getFind() == getEqualityNode(t2).getFind(); + + // Pop the context (triggers new term removal) + d_context->pop(); + + // Return whether the two terms are equal + return equal; +} + +template +bool EqualityEngine::areDisequal(TNode t1, TNode t2) +{ + // Don't notify during this check + ScopedBool turnOfNotify(d_performNotify, false); + + // Push the context, so that we can remove the terms later + d_context->push(); + + // Add the terms + addTerm(t1); + addTerm(t2); + + // Check (t1 = t2) = false + Node equality1 = t1.eqNode(t2); + addTerm(equality1); + if (getEqualityNode(equality1).getFind() == getEqualityNode(d_false).getFind()) { + return true; + } + + // Check (t2 = t1) = false + Node equality2 = t2.eqNode(t1); + addTerm(equality2); + if (getEqualityNode(equality2).getFind() == getEqualityNode(d_false).getFind()) { + return true; + } + + // Pop the context (triggers new term removal) + d_context->pop(); + + // Return whether the terms are disequal + return false; +} + +template +void EqualityEngine::addTriggerTerm(TNode t) +{ + Debug("equality::internal") << "EqualityEngine::addTriggerTerm(" << t << ")" << std::endl; + + // Add the term if it's not already there + addTerm(t); + + // Get the node id + EqualityNodeId eqNodeId = getNodeId(t); + EqualityNode& eqNode = getEqualityNode(eqNodeId); + EqualityNodeId classId = eqNode.getFind(); + + if (d_nodeIndividualTrigger[classId] != +null_id) { + // No need to keep it, just propagate the existing individual triggers + if (d_performNotify) { + d_notify.notify(t, d_nodes[d_nodeIndividualTrigger[classId]]); + } + } else { + // Add it to the list for backtracking + d_individualTriggers.push_back(classId); + d_individualTriggersSize = d_individualTriggersSize + 1; + // Mark the class id as a trigger + d_nodeIndividualTrigger[classId] = eqNodeId; + } +} + +template +bool EqualityEngine::isTriggerTerm(TNode t) const { + if (!hasTerm(t)) return false; + EqualityNodeId classId = getEqualityNode(t).getFind(); + return d_nodeIndividualTrigger[classId] != +null_id; +} + } // Namespace uf } // Namespace theory } // Namespace CVC4 diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index 84fad2f19..3c28e9d9d 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -50,34 +50,38 @@ Node mkAnd(const std::vector& conjunctions) { void TheoryUF::check(Effort level) { - while (!done() && !d_conflict) { + while (!done() && !d_conflict) + { // Get all the assertions - TNode assertion = get(); - Debug("uf") << "TheoryUF::check(): processing " << assertion << std::endl; + Assertion assertion = get(); + TNode fact = assertion.assertion; - // Fo the work - switch (assertion.getKind()) { + Debug("uf") << "TheoryUF::check(): processing " << fact << std::endl; + + // If the assertion doesn't have a literal, it's a shared equality, so we set it up + if (!assertion.isPreregistered) { + Debug("uf") << "TheoryUF::check(): shared equality, setting up " << fact << std::endl; + if (fact.getKind() == kind::NOT) { + preRegisterTerm(fact[0]); + } else { + preRegisterTerm(fact); + } + } + + // Do the work + switch (fact.getKind()) { case kind::EQUAL: - d_equalityEngine.addEquality(assertion[0], assertion[1], assertion); + d_equalityEngine.addEquality(fact[0], fact[1], fact); break; case kind::APPLY_UF: - d_equalityEngine.addEquality(assertion, d_true, assertion); + d_equalityEngine.addEquality(fact, d_true, fact); break; case kind::NOT: - if (assertion[0].getKind() == kind::APPLY_UF) { - d_equalityEngine.addEquality(assertion[0], d_false, assertion); + if (fact[0].getKind() == kind::APPLY_UF) { + d_equalityEngine.addEquality(fact[0], d_false, fact); } else { - // Disequality check - TNode equality = assertion[0]; - if (d_equalityEngine.getRepresentative(equality[0]) == d_equalityEngine.getRepresentative(equality[1])) { - std::vector assumptions; - assumptions.push_back(assertion); - explain(equality, assumptions); - d_conflictNode = mkAnd(assumptions); - d_conflict = true; - } // Assert the dis-equality - d_equalityEngine.addEquality(assertion[0], d_false, assertion); + d_equalityEngine.addDisequality(fact[0][0], fact[0][1], fact); } break; default: @@ -138,9 +142,7 @@ void TheoryUF::preRegisterTerm(TNode node) { d_equalityEngine.addTerm(node[1]); // Add the trigger for equality d_equalityEngine.addTriggerEquality(node[0], node[1], node); - // Add the disequality terms and triggers - d_equalityEngine.addTerm(node); - d_equalityEngine.addTriggerEquality(node, d_false, node.notNode()); + d_equalityEngine.addTriggerDisequality(node[0], node[1], node.notNode()); break; case kind::APPLY_UF: // Function applications/predicates @@ -151,6 +153,8 @@ void TheoryUF::preRegisterTerm(TNode node) { d_equalityEngine.addTriggerEquality(node, d_true, node); d_equalityEngine.addTriggerEquality(node, d_false, node.notNode()); } + // Remember the function and predicate terms + d_functionsTerms.push_back(node); break; default: // Variables etc @@ -359,3 +363,94 @@ void TheoryUF::staticLearning(TNode n, NodeBuilder<>& learned) { d_symb.assertFormula(n); } } + +EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) { + if (d_equalityEngine.areEqual(a, b)) { + // The terms are implied to be equal + return EQUALITY_TRUE; + } + if (d_equalityEngine.areDisequal(a, b)) { + // The rems are implied to be dis-equal + return EQUALITY_FALSE; + } + // All other terms we interpret as dis-equal in the model + return EQUALITY_FALSE_IN_MODEL; +} + +void TheoryUF::addSharedTerm(TNode t) { + Debug("uf::sharing") << "TheoryUF::addSharedTerm(" << t << ")" << std::endl; + d_equalityEngine.addTriggerTerm(t); +} + +void TheoryUF::computeCareGraph(CareGraph& careGraph) { + + if (d_sharedTerms.size() > 0) { + + std::vector currentPairs; + + // Go through the function terms and see if there are any to split on + unsigned functionTerms = d_functionsTerms.size(); + for (unsigned i = 0; i < functionTerms; ++ i) { + TNode f1 = d_functionsTerms[i]; + Node op = f1.getOperator(); + for (unsigned j = i + 1; j < functionTerms; ++ j) { + + TNode f2 = d_functionsTerms[j]; + + // If the operators are not the same, we can skip this pair + if (f2.getOperator() != op) { + continue; + } + + Debug("uf::sharing") << "TheoryUf::computeCareGraph(): checking function " << f1 << " and " << f2 << std::endl; + + // If the terms are already known to be equal, we are also in good shape + if (d_equalityEngine.areEqual(f1, f2)) { + Debug("uf::sharing") << "TheoryUf::computeCareGraph(): equal, skipping" << std::endl; + continue; + } + + // We have two functions f(x1, ..., xn) and f(y1, ..., yn) no known to be equal + // We split on the argument pairs that are are not known, unless there is some + // argument pair that is already dis-equal. + bool somePairIsDisequal = false; + currentPairs.clear(); + for (unsigned k = 0; k < f1.getNumChildren(); ++ k) { + + TNode x = f1[k]; + TNode y = f2[k]; + + Debug("uf::sharing") << "TheoryUf::computeCareGraph(): checking arguments " << x << " and " << y << std::endl; + + EqualityStatus eqStatusUf = getEqualityStatus(x, y); + + if (eqStatusUf == EQUALITY_FALSE) { + // Mark that there is a dis-equal pair and break + somePairIsDisequal = true; + Debug("uf::sharing") << "TheoryUf::computeCareGraph(): disequal, disregarding all" << std::endl; + break; + } + + if (!d_equalityEngine.isTriggerTerm(x) || !d_equalityEngine.isTriggerTerm(y)) { + // Not connected to shared terms, we don't care + continue; + } + + if (eqStatusUf == EQUALITY_TRUE) { + // We don't neeed this one + Debug("uf::sharing") << "TheoryUf::computeCareGraph(): equal" << std::endl; + continue; + } + + // Otherwise, we need to figure it out + Debug("uf::sharing") << "TheoryUf::computeCareGraph(): adding to care-graph" << std::endl; + currentPairs.push_back(CarePair(x, y, THEORY_UF)); + } + + if (!somePairIsDisequal) { + careGraph.insert(careGraph.end(), currentPairs.begin(), currentPairs.end()); + } + } + } + } +} diff --git a/src/theory/uf/theory_uf.h b/src/theory/uf/theory_uf.h index f8e17b1de..769caba5c 100644 --- a/src/theory/uf/theory_uf.h +++ b/src/theory/uf/theory_uf.h @@ -44,10 +44,16 @@ public: public: NotifyClass(TheoryUF& uf): d_uf(uf) {} - bool notifyEquality(TNode reason) { - Debug("uf") << "NotifyClass::notifyEquality(" << reason << ")" << std::endl; + bool notify(TNode propagation) { + Debug("uf") << "NotifyClass::notify(" << propagation << ")" << std::endl; // Just forward to uf - return d_uf.propagate(reason); + return d_uf.propagate(propagation); + } + + void notify(TNode t1, TNode t2) { + Debug("uf") << "NotifyClass::notify(" << t1 << ", " << t2 << ")" << std::endl; + Node equality = Rewriter::rewriteEquality(theory::THEORY_UF, t1.eqNode(t2)); + d_uf.propagate(equality); } }; @@ -59,9 +65,6 @@ private: /** Equaltity engine */ EqualityEngine d_equalityEngine; - /** All the literals known to be true */ - context::CDSet d_knownFacts; - /** Are we in conflict */ context::CDO d_conflict; @@ -79,11 +82,13 @@ private: void explain(TNode literal, std::vector& assumptions); /** Literals to propagate */ - context::CDList d_literalsToPropagate; + context::CDList d_literalsToPropagate; /** Index of the next literal to propagate */ context::CDO d_literalsToPropagateIndex; + /** All the function terms that the theory has seen */ + context::CDList d_functionsTerms; /** True node for predicates = true */ Node d_true; @@ -101,10 +106,10 @@ public: Theory(THEORY_UF, c, u, out, valuation), d_notify(*this), d_equalityEngine(d_notify, c, "theory::uf::TheoryUF"), - d_knownFacts(c), d_conflict(c, false), d_literalsToPropagate(c), - d_literalsToPropagateIndex(c, 0) + d_literalsToPropagateIndex(c, 0), + d_functionsTerms(c) { // The kinds we are treating as function application in congruence d_equalityEngine.addFunctionKind(kind::APPLY_UF); @@ -126,6 +131,11 @@ public: void staticLearning(TNode in, NodeBuilder<>& learned); void presolve(); + void addSharedTerm(TNode n); + void computeCareGraph(CareGraph& careGraph); + + EqualityStatus getEqualityStatus(TNode a, TNode b); + std::string identify() const { return "THEORY_UF"; } diff --git a/src/theory/uf/theory_uf_rewriter.h b/src/theory/uf/theory_uf_rewriter.h index e1aba2c95..be4906ab6 100644 --- a/src/theory/uf/theory_uf_rewriter.h +++ b/src/theory/uf/theory_uf_rewriter.h @@ -39,12 +39,7 @@ public: } if (node[0] > node[1]) { Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); - // If we've switched theories, we need to rewrite again (TODO: THIS IS HACK, once theories accept eq, change) - if (Theory::theoryOf(newNode[0]) != Theory::theoryOf(newNode[1])) { - return RewriteResponse(REWRITE_AGAIN_FULL, newNode); - } else { - return RewriteResponse(REWRITE_DONE, newNode); - } + return RewriteResponse(REWRITE_DONE, newNode); } } return RewriteResponse(REWRITE_DONE, node); @@ -59,6 +54,10 @@ public: return RewriteResponse(REWRITE_DONE, node); } + static Node rewriteEquality(TNode equality) { + return postRewrite(equality).node; + } + static inline void init() {} static inline void shutdown() {} diff --git a/src/theory/valuation.cpp b/src/theory/valuation.cpp index 148b95632..627125a27 100644 --- a/src/theory/valuation.cpp +++ b/src/theory/valuation.cpp @@ -24,6 +24,38 @@ namespace CVC4 { namespace theory { +bool equalityStatusCompatible(EqualityStatus s1, EqualityStatus s2) { + switch (s1) { + case EQUALITY_TRUE: + case EQUALITY_TRUE_IN_MODEL: + case EQUALITY_TRUE_AND_PROPAGATED: + switch (s2) { + case EQUALITY_TRUE: + case EQUALITY_TRUE_IN_MODEL: + case EQUALITY_TRUE_AND_PROPAGATED: + return true; + default: + return false; + } + break; + case EQUALITY_FALSE: + case EQUALITY_FALSE_IN_MODEL: + case EQUALITY_FALSE_AND_PROPAGATED: + switch (s2) { + case EQUALITY_FALSE: + case EQUALITY_FALSE_IN_MODEL: + case EQUALITY_FALSE_AND_PROPAGATED: + return true; + default: + return false; + } + break; + default: + return false; + } +} + + Node Valuation::getValue(TNode n) const { return d_engine->getValue(n); } @@ -47,7 +79,16 @@ Node Valuation::getSatValue(TNode n) const { } bool Valuation::hasSatValue(TNode n, bool& value) const { - return d_engine->getPropEngine()->hasValue(n, value); + Node normalized = Rewriter::rewrite(n); + if (d_engine->getPropEngine()->isSatLiteral(normalized)) { + return d_engine->getPropEngine()->hasValue(normalized, value); + } else { + return false; + } +} + +EqualityStatus Valuation::getEqualityStatus(TNode a, TNode b) { + return d_engine->getEqualityStatus(a, b); } Node Valuation::ensureLiteral(TNode n) { diff --git a/src/theory/valuation.h b/src/theory/valuation.h index 71b194905..28a743dc8 100644 --- a/src/theory/valuation.h +++ b/src/theory/valuation.h @@ -32,6 +32,32 @@ class TheoryEngine; namespace theory { +/** + * The status of an equality in the current context. + */ +enum EqualityStatus { + /** The equality is known to be true, and has been propagated */ + EQUALITY_TRUE_AND_PROPAGATED, + /** The equality is known to be true and has been propagated */ + EQUALITY_FALSE_AND_PROPAGATED, + /** The equality is known to be true */ + EQUALITY_TRUE, + /** The equality is known to be false */ + EQUALITY_FALSE, + /** The equality is not known, but is true in the current model */ + EQUALITY_TRUE_IN_MODEL, + /** The equality is not known, but is false in the current model */ + EQUALITY_FALSE_IN_MODEL, + /** The equality is completely unknown */ + EQUALITY_UNKNOWN +}; + +/** + * Returns true if the two statuses are compatible, i.e. bot TRUE + * or both FALSE (regardles of inmodel/propagation). + */ +bool equalityStatusCompatible(EqualityStatus s1, EqualityStatus s2); + class Valuation { TheoryEngine* d_engine; public: @@ -69,6 +95,12 @@ public: */ bool hasSatValue(TNode n, bool& value) const; + /** + * Returns the equality status of the two terms, from the theory that owns the domain type. + * The types of a and b must be the same. + */ + EqualityStatus getEqualityStatus(TNode a, TNode b); + /** * Ensure that the given node will have a designated SAT literal * that is definitionally equal to it. The result of this function diff --git a/src/util/node_visitor.h b/src/util/node_visitor.h index 06a1a83e8..0dec26717 100644 --- a/src/util/node_visitor.h +++ b/src/util/node_visitor.h @@ -55,7 +55,7 @@ public: // Notify of a start visitor.start(node); - // Do a topological sort of the subexpressions and preregister them + // Do a topological sort of the subexpressions std::vector toVisit; toVisit.push_back(stack_element(node, node)); while (!toVisit.empty()) { diff --git a/test/regress/regress0/arith/integers/Makefile.am b/test/regress/regress0/arith/integers/Makefile.am index 3d7f40c71..d0340616f 100644 --- a/test/regress/regress0/arith/integers/Makefile.am +++ b/test/regress/regress0/arith/integers/Makefile.am @@ -23,77 +23,7 @@ TESTS = \ arith-int-017.cvc \ arith-int-018.cvc \ arith-int-019.cvc \ - arith-int-020.cvc \ - arith-int-021.cvc \ - arith-int-023.cvc \ - arith-int-024.cvc \ - arith-int-025.cvc \ - arith-int-026.cvc \ - arith-int-027.cvc \ - arith-int-028.cvc \ - arith-int-029.cvc \ - arith-int-030.cvc \ - arith-int-031.cvc \ - arith-int-032.cvc \ - arith-int-033.cvc \ - arith-int-034.cvc \ - arith-int-035.cvc \ - arith-int-036.cvc \ - arith-int-037.cvc \ - arith-int-038.cvc \ - arith-int-039.cvc \ - arith-int-040.cvc \ - arith-int-041.cvc \ - arith-int-044.cvc \ - arith-int-045.cvc \ - arith-int-046.cvc \ - arith-int-048.cvc \ - arith-int-049.cvc \ - arith-int-051.cvc \ - arith-int-052.cvc \ - arith-int-053.cvc \ - arith-int-054.cvc \ - arith-int-055.cvc \ - arith-int-056.cvc \ - arith-int-057.cvc \ - arith-int-058.cvc \ - arith-int-059.cvc \ - arith-int-060.cvc \ - arith-int-061.cvc \ - arith-int-062.cvc \ - arith-int-063.cvc \ - arith-int-064.cvc \ - arith-int-065.cvc \ - arith-int-066.cvc \ - arith-int-067.cvc \ - arith-int-068.cvc \ - arith-int-069.cvc \ - arith-int-070.cvc \ - arith-int-071.cvc \ - arith-int-072.cvc \ - arith-int-073.cvc \ - arith-int-074.cvc \ - arith-int-075.cvc \ - arith-int-076.cvc \ - arith-int-077.cvc \ - arith-int-078.cvc \ - arith-int-079.cvc \ - arith-int-080.cvc \ - arith-int-081.cvc \ - arith-int-083.cvc \ - arith-int-085.cvc \ - arith-int-086.cvc \ - arith-int-087.cvc \ - arith-int-088.cvc \ - arith-int-089.cvc \ - arith-int-090.cvc \ - arith-int-091.cvc \ - arith-int-092.cvc \ - arith-int-093.cvc \ - arith-int-094.cvc \ - arith-int-095.cvc \ - arith-int-096.cvc \ - arith-int-099.cvc + arith-int-020.cvc EXTRA_DIST = $(TESTS) \ arith-int-008.cvc \ diff --git a/test/regress/regress0/uflra/Makefile.am b/test/regress/regress0/uflra/Makefile.am index 5199f2b62..377489ef7 100644 --- a/test/regress/regress0/uflra/Makefile.am +++ b/test/regress/regress0/uflra/Makefile.am @@ -6,15 +6,19 @@ MAKEFLAGS = -k # put it below in "TESTS +=" # Regression tests for SMT inputs -#SMT_TESTS = \ -# pb_real_10_0100_10_10.smt \ -# pb_real_10_0100_10_11.smt \ -# pb_real_10_0100_10_15.smt \ -# pb_real_10_0100_10_16.smt \ -# pb_real_10_0100_10_19.smt \ -# pb_real_10_0200_10_22.smt \ -# pb_real_10_0200_10_26.smt \ -# pb_real_10_0200_10_29.smt +SMT_TESTS = \ + simple.01.cvc \ + simple.02.cvc \ + simple.03.cvc \ + simple.04.cvc \ + pb_real_10_0100_10_10.smt \ + pb_real_10_0100_10_11.smt \ + pb_real_10_0100_10_15.smt \ + pb_real_10_0100_10_16.smt \ + pb_real_10_0100_10_19.smt \ + pb_real_10_0200_10_22.smt \ + pb_real_10_0200_10_26.smt \ + pb_real_10_0200_10_29.smt # Regression tests for SMT2 inputs SMT2_TESTS = diff --git a/test/regress/regress0/uflra/simple.01.cvc b/test/regress/regress0/uflra/simple.01.cvc new file mode 100644 index 000000000..8904192ce --- /dev/null +++ b/test/regress/regress0/uflra/simple.01.cvc @@ -0,0 +1,6 @@ +% EXPECT: sat +% EXIT: 10 +x, y: REAL; +f: REAL -> REAL; + +CHECKSAT NOT (f(x) = f(y)); \ No newline at end of file diff --git a/test/regress/regress0/uflra/simple.02.cvc b/test/regress/regress0/uflra/simple.02.cvc new file mode 100644 index 000000000..a14ca8a1f --- /dev/null +++ b/test/regress/regress0/uflra/simple.02.cvc @@ -0,0 +1,10 @@ +% EXPECT: unsat +% EXIT: 20 +x, y: REAL; +f: REAL -> REAL; + +ASSERT (x <= y); +ASSERT (y <= x); +ASSERT NOT (f(x) = f(y)); + +CHECKSAT; \ No newline at end of file diff --git a/test/regress/regress0/uflra/simple.03.cvc b/test/regress/regress0/uflra/simple.03.cvc new file mode 100644 index 000000000..1fdeed40a --- /dev/null +++ b/test/regress/regress0/uflra/simple.03.cvc @@ -0,0 +1,12 @@ +% EXPECT: sat +% EXIT: 10 +x1, y1, z1: REAL; +x2, y2, z2: REAL; +f: REAL -> REAL; +g: (REAL, REAL) -> REAL; + +ASSERT (z1 = f(x1)); +ASSERT (z2 = f(y1)); +ASSERT NOT (g(z1, z2) = g(z2, y2)); + +CHECKSAT; \ No newline at end of file diff --git a/test/regress/regress0/uflra/simple.04.cvc b/test/regress/regress0/uflra/simple.04.cvc new file mode 100644 index 000000000..c9c226fa2 --- /dev/null +++ b/test/regress/regress0/uflra/simple.04.cvc @@ -0,0 +1,15 @@ +% EXPECT: unsat +% EXIT: 20 +x1, x2: REAL; +y1, y2: REAL; +f: REAL -> REAL; +g: (REAL, REAL) -> REAL; + +ASSERT (x1 <= x2) AND (x2 <= x1); + +ASSERT NOT (g(x1, y1) = g(x2, y2)); + +ASSERT (y1 <= f(x1)) AND (f(x1) <= y1); +ASSERT (y2 <= f(x2)) AND (f(x2) <= y2); + +CHECKSAT; \ No newline at end of file diff --git a/test/unit/theory/theory_arith_white.h b/test/unit/theory/theory_arith_white.h index f0073cc0b..4787dfd21 100644 --- a/test/unit/theory/theory_arith_white.h +++ b/test/unit/theory/theory_arith_white.h @@ -128,7 +128,7 @@ public: Node leq = d_nm->mkNode(LEQ, x, c); fakeTheoryEnginePreprocess(leq); - d_arith->assertFact(leq); + d_arith->assertFact(leq, true); d_arith->check(d_level); @@ -160,7 +160,7 @@ public: fakeTheoryEnginePreprocess(leq1); fakeTheoryEnginePreprocess(geq1); - d_arith->assertFact(lt1); + d_arith->assertFact(lt1, true); d_arith->check(d_level); @@ -199,7 +199,7 @@ public: fakeTheoryEnginePreprocess(leq1); fakeTheoryEnginePreprocess(geq1); - d_arith->assertFact(leq0); + d_arith->assertFact(leq0, true); d_arith->check(d_level); @@ -235,7 +235,7 @@ public: fakeTheoryEnginePreprocess(leq1); fakeTheoryEnginePreprocess(geq1); - d_arith->assertFact(leq1); + d_arith->assertFact(leq1, true); d_arith->check(d_level); diff --git a/test/unit/theory/theory_black.h b/test/unit/theory/theory_black.h index 63900c19c..be1d4a35b 100644 --- a/test/unit/theory/theory_black.h +++ b/test/unit/theory/theory_black.h @@ -214,8 +214,8 @@ public: void testDone() { TS_ASSERT(d_dummy->doneWrapper()); - d_dummy->assertFact(atom0); - d_dummy->assertFact(atom1); + d_dummy->assertFact(atom0, true); + d_dummy->assertFact(atom1, true); TS_ASSERT(!d_dummy->doneWrapper()); -- 2.30.2