From e792bb8628ea7010fa9c452bf1aa7ba1b60291a3 Mon Sep 17 00:00:00 2001 From: Tim King Date: Tue, 29 Jun 2010 20:53:47 +0000 Subject: [PATCH] Merging the unate-propagator branch into the trunk. This is a big update so expect a little turbulence. This commit will not compile. There will be a second commit that fixes this in a moment. I am delaying a change to avoid svn whining about a conflict. --- src/prop/cnf_stream.cpp | 41 +-- src/prop/cnf_stream.h | 36 ++- src/prop/minisat/core/Solver.C | 111 ++++++-- src/prop/minisat/core/Solver.h | 23 +- src/prop/minisat/simp/SimpSolver.C | 6 +- src/prop/sat.cpp | 45 +++- src/prop/sat.h | 14 +- src/theory/Makefile.am | 1 + src/theory/arith/Makefile.am | 3 + src/theory/arith/arith_propagator.cpp | 347 +++++++++++++++++++++++++ src/theory/arith/arith_propagator.h | 111 ++++++++ src/theory/arith/ordered_bounds_list.h | 212 +++++++++++++++ src/theory/arith/theory_arith.cpp | 93 ++++++- src/theory/arith/theory_arith.h | 7 +- src/theory/theory.cpp | 16 ++ src/theory/theory.h | 2 + src/theory/theory_engine.h | 46 +++- src/theory/theory_test_utils.h | 81 ++++++ src/theory/uf/theory_uf.cpp | 2 +- test/unit/Makefile.am | 1 + test/unit/theory/theory_arith_white.h | 312 ++++++++++++++++++++++ test/unit/theory/theory_uf_white.h | 56 +--- 22 files changed, 1430 insertions(+), 136 deletions(-) create mode 100644 src/theory/arith/arith_propagator.cpp create mode 100644 src/theory/arith/arith_propagator.h create mode 100644 src/theory/arith/ordered_bounds_list.h create mode 100644 src/theory/theory_test_utils.h create mode 100644 test/unit/theory/theory_arith_white.h diff --git a/src/prop/cnf_stream.cpp b/src/prop/cnf_stream.cpp index 45f7ab398..9136a73c3 100644 --- a/src/prop/cnf_stream.cpp +++ b/src/prop/cnf_stream.cpp @@ -102,26 +102,10 @@ Node CnfStream::getNode(const SatLiteral& literal) { return node; } -SatLiteral CnfStream::getLiteral(TNode node) { - TranslationCache::iterator find = d_translationCache.find(node); - Assert(find != d_translationCache.end(), "Literal not in the CNF Cache"); - SatLiteral literal = find->second; - Debug("cnf") << "CnfStream::getLiteral(" << node << ") => " << literal << std::endl; - return literal; -} - -const CnfStream::NodeCache& CnfStream::getNodeCache() const { - return d_nodeCache; -} - -const CnfStream::TranslationCache& CnfStream::getTranslationCache() const { - return d_translationCache; -} - -SatLiteral TseitinCnfStream::handleAtom(TNode node) { +SatLiteral CnfStream::convertAtom(TNode node) { Assert(!isCached(node), "atom already mapped!"); - Debug("cnf") << "handleAtom(" << node << ")" << endl; + Debug("cnf") << "convertAtom(" << node << ")" << endl; bool theoryLiteral = node.getKind() != kind::VARIABLE; SatLiteral lit = newLiteral(node, theoryLiteral); @@ -137,6 +121,23 @@ SatLiteral TseitinCnfStream::handleAtom(TNode node) { return lit; } +SatLiteral CnfStream::getLiteral(TNode node, bool create /* = false */) { + TranslationCache::iterator find = d_translationCache.find(node); + SatLiteral literal; + if(create) { + if(find == d_translationCache.end()) { + literal = convertAtom(node); + } else { + literal = find->second; + } + } else { + Assert(find != d_translationCache.end(), "Literal not in the CNF Cache"); + literal = find->second; + } + Debug("cnf") << "CnfStream::getLiteral(" << node << ", create = " << create << ") => " << literal << std::endl; + return literal; +} + SatLiteral TseitinCnfStream::handleXor(TNode xorNode) { Assert(!isCached(xorNode), "Atom already mapped!"); Assert(xorNode.getKind() == XOR, "Expecting an XOR expression!"); @@ -366,10 +367,10 @@ SatLiteral TseitinCnfStream::toCNF(TNode node, bool negated) { default: { //TODO make sure this does not contain any boolean substructure - nodeLit = handleAtom(node); + nodeLit = convertAtom(node); //Unreachable(); //Node atomic = handleNonAtomicNode(node); - //return isCached(atomic) ? lookupInCache(atomic) : handleAtom(atomic); + //return isCached(atomic) ? lookupInCache(atomic) : convertAtom(atomic); } } } diff --git a/src/prop/cnf_stream.h b/src/prop/cnf_stream.h index abb69f590..ba87cf269 100644 --- a/src/prop/cnf_stream.h +++ b/src/prop/cnf_stream.h @@ -127,6 +127,16 @@ protected: */ SatLiteral newLiteral(TNode node, bool theoryLiteral = false); + /** + * Constructs a new literal for an atom and returns it. Calls + * newLiteral(). + * + * @param node the node to convert; there should be no boolean + * structure in this expression. Assumed to not be in the + * translation cache. + */ + SatLiteral convertAtom(TNode node); + public: /** @@ -161,14 +171,25 @@ public: /** * Returns the literal that represents the given node in the SAT CNF - * representation. [Presumably there are some constraints on the kind - * of node? E.g., it needs to be a boolean? -Chris] - * + * representation. + * @param node [Presumably there are some constraints on the kind of + * node? E.g., it needs to be a boolean? -Chris] + * @param create Controls whether or not to create a new SAT literal + * mapping for Node if it does not exist. This exists to break + * circular dependencies, where an atom is converted and asserted to + * the SAT solver, which propagates it immediately since it's a + * unit, which can theory-propagate additional literals that don't + * yet have a SAT literal mapping. */ - SatLiteral getLiteral(TNode node); + SatLiteral getLiteral(TNode node, bool create = false); + + const TranslationCache& getTranslationCache() const { + return d_translationCache; + } - const TranslationCache& getTranslationCache() const; - const NodeCache& getNodeCache() const; + const NodeCache& getNodeCache() const { + return d_nodeCache; + } }; /* class CnfStream */ /** @@ -178,7 +199,7 @@ public: * will be equivalent to each subexpression in the constructed equi-satisfiable * formula, then substitute the new literal for the formula, and so on, * recursively. - * + * * This implementation does this in a single recursive pass. [??? -Chris] */ class TseitinCnfStream : public CnfStream { @@ -211,7 +232,6 @@ private: // - returning l // // handleX( n ) can assume that n is not in d_translationCache - SatLiteral handleAtom(TNode node); SatLiteral handleNot(TNode node); SatLiteral handleXor(TNode node); SatLiteral handleImplies(TNode node); diff --git a/src/prop/minisat/core/Solver.C b/src/prop/minisat/core/Solver.C index 8533e191b..1667af20d 100644 --- a/src/prop/minisat/core/Solver.C +++ b/src/prop/minisat/core/Solver.C @@ -29,6 +29,28 @@ namespace CVC4 { namespace prop { namespace minisat { +Clause* Solver::lazy_reason = reinterpret_cast(1); + +Clause* Solver::getReason(Lit l) +{ + if (reason[var(l)] != lazy_reason) return reason[var(l)]; + // Get the explanation from the theory + SatClause explanation; + if (value(l) == l_True) { + proxy->explainPropagation(l, explanation); + assert(explanation[0] == l); + } else { + proxy->explainPropagation(~l, explanation); + assert(explanation[0] == ~l); + } + Clause* real_reason = Clause_new(explanation, true); + reason[var(l)] = real_reason; + // Add it to the database + learnts.push(real_reason); + attachClause(*real_reason); + return real_reason; +} + Solver::Solver(SatSolver* proxy, context::Context* context) : // SMT stuff @@ -122,7 +144,7 @@ bool Solver::addClause(vec& ps, ClauseType type) assert(type != CLAUSE_LEMMA); assert(value(ps[0]) == l_Undef); uncheckedEnqueue(ps[0]); - return ok = (propagate() == NULL); + return ok = (propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) == NULL); }else{ Clause* c = Clause_new(ps, false); clauses.push(c); @@ -282,7 +304,7 @@ void Solver::analyze(Clause* confl, vec& out_learnt, int& out_btlevel) // Select next clause to look at: while (!seen[var(trail[index--])]); p = trail[index+1]; - confl = reason[var(p)]; + confl = getReason(p); seen[var(p)] = 0; pathC--; @@ -299,12 +321,12 @@ void Solver::analyze(Clause* confl, vec& out_learnt, int& out_btlevel) out_learnt.copyTo(analyze_toclear); for (i = j = 1; i < out_learnt.size(); i++) - if (reason[var(out_learnt[i])] == NULL || !litRedundant(out_learnt[i], abstract_level)) + if (getReason(out_learnt[i]) == NULL || !litRedundant(out_learnt[i], abstract_level)) out_learnt[j++] = out_learnt[i]; }else{ out_learnt.copyTo(analyze_toclear); for (i = j = 1; i < out_learnt.size(); i++){ - Clause& c = *reason[var(out_learnt[i])]; + Clause& c = *getReason(out_learnt[i]); for (int k = 1; k < c.size(); k++) if (!seen[var(c[k])] && level[var(c[k])] > 0){ out_learnt[j++] = out_learnt[i]; @@ -342,13 +364,13 @@ bool Solver::litRedundant(Lit p, uint32_t abstract_levels) analyze_stack.clear(); analyze_stack.push(p); int top = analyze_toclear.size(); while (analyze_stack.size() > 0){ - assert(reason[var(analyze_stack.last())] != NULL); + assert(getReason(analyze_stack.last()) != NULL); Clause& c = *reason[var(analyze_stack.last())]; analyze_stack.pop(); for (int i = 1; i < c.size(); i++){ Lit p = c[i]; if (!seen[var(p)] && level[var(p)] > 0){ - if (reason[var(p)] != NULL && (abstractLevel(var(p)) & abstract_levels) != 0){ + if (getReason(p) != NULL && (abstractLevel(var(p)) & abstract_levels) != 0){ seen[var(p)] = 1; analyze_stack.push(p); analyze_toclear.push(p); @@ -415,42 +437,74 @@ void Solver::uncheckedEnqueue(Lit p, Clause* from) polarity [var(p)] = sign(p); trail.push(p); - if (theory[var(p)]) { + if (theory[var(p)] && from != lazy_reason) { // Enqueue to the theory proxy->enqueueTheoryLiteral(p); } } -Clause* Solver::propagate() +Clause* Solver::propagate(TheoryCheckType type) { Clause* confl = NULL; - while(qhead < trail.size()) { - confl = propagateBool(); - if (confl != NULL) break; - confl = propagateTheory(); - if (confl != NULL) break; + // If this is the final check, no need for Boolean propagation and + // theory propagation + if (type == CHECK_WITHOUTH_PROPAGATION_FINAL) { + return theoryCheck(theory::Theory::FULL_EFFORT); } + // The effort we will be using to theory check + theory::Theory::Effort effort = type == CHECK_WITHOUTH_PROPAGATION_QUICK ? + theory::Theory::QUICK_CHECK : theory::Theory::STANDARD; + + // Keep running until we have checked everything, we + // have no conflict and no new literals have been asserted + bool new_assertions; + do { + new_assertions = false; + while(qhead < trail.size()) { + confl = propagateBool(); + if (confl != NULL) break; + confl = theoryCheck(effort); + if (confl != NULL) break; + } + + if (confl == NULL && type == CHECK_WITH_PROPAGATION_STANDARD) { + new_assertions = propagateTheory(); + if (!new_assertions) break; + } + } while (new_assertions); + return confl; } +bool Solver::propagateTheory() { + std::vector propagatedLiterals; + proxy->theoryPropagate(propagatedLiterals); + const unsigned i_end = propagatedLiterals.size(); + for (unsigned i = 0; i < i_end; ++ i) { + uncheckedEnqueue(propagatedLiterals[i], lazy_reason); + } + proxy->clearPropagatedLiterals(); + return propagatedLiterals.size() > 0; +} + /*_________________________________________________________________________________________________ | -| propagateTheory : [void] -> [Clause*] +| theoryCheck: [void] -> [Clause*] | | Description: -| Propagates all enqueued theory facts. If a conflict arises, the conflicting clause is returned, -| otherwise NULL. +| Checks all enqueued theory facts for satisfiability. If a conflict arises, the conflicting +| clause is returned, otherwise NULL. | | Note: the propagation queue might be NOT empty |________________________________________________________________________________________________@*/ -Clause* Solver::propagateTheory() +Clause* Solver::theoryCheck(theory::Theory::Effort effort) { Clause* c = NULL; SatClause clause; - proxy->theoryCheck(clause); + proxy->theoryCheck(effort, clause); int clause_size = clause.size(); Assert(clause_size != 1, "Can't handle unit clause explanations"); if(clause_size > 0) { @@ -598,7 +652,7 @@ bool Solver::simplify() { assert(decisionLevel() == 0); - if (!ok || propagate() != NULL) + if (!ok || propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) != NULL) return ok = false; if (nAssigns() == simpDB_assigns || (simpDB_props > 0)) @@ -643,9 +697,9 @@ lbool Solver::search(int nof_conflicts, int nof_learnts) starts++; bool first = true; - + TheoryCheckType check_type = CHECK_WITH_PROPAGATION_STANDARD; for (;;){ - Clause* confl = propagate(); + Clause* confl = propagate(check_type); if (confl != NULL){ // CONFLICT conflicts++; conflictC++; @@ -671,9 +725,16 @@ lbool Solver::search(int nof_conflicts, int nof_learnts) varDecayActivity(); claDecayActivity(); + // We have a conflict so, we are going back to standard checks + check_type = CHECK_WITH_PROPAGATION_STANDARD; + }else{ // NO CONFLICT + // If this was a final check, we are satisfiable + if (check_type == CHECK_WITHOUTH_PROPAGATION_FINAL) + return l_True; + if (nof_conflicts >= 0 && conflictC >= nof_conflicts){ // Reached bound on number of conflicts: progress_estimate = progressEstimate(); @@ -709,9 +770,11 @@ lbool Solver::search(int nof_conflicts, int nof_learnts) decisions++; next = pickBranchLit(polarity_mode, random_var_freq); - if (next == lit_Undef) - // Model found: - return l_True; + if (next == lit_Undef) { + // We need to do a full theory check to confirm + check_type = CHECK_WITHOUTH_PROPAGATION_FINAL; + continue; + } } // Increase decision level and enqueue 'next' diff --git a/src/prop/minisat/core/Solver.h b/src/prop/minisat/core/Solver.h index 312fe44d5..2e44803e9 100644 --- a/src/prop/minisat/core/Solver.h +++ b/src/prop/minisat/core/Solver.h @@ -23,6 +23,7 @@ OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWA #define __CVC4__PROP__MINISAT__SOLVER_H #include "context/context.h" +#include "theory/theory.h" #include #include @@ -161,7 +162,11 @@ protected: vec trail_lim; // Separator indices for different decision levels in 'trail'. vec lemmas; // List of lemmas we added (context dependent) vec lemmas_lim; // Separator indices for different decision levels in 'lemmas'. - vec reason; // 'reason[var]' is the clause that implied the variables current value, or 'NULL' if none. + static Clause* lazy_reason; // The mark when we need to ask the theory engine for a reason + vec reason; // 'reason[var]' is the clause that implied the variables current value, lazy_reason if theory propagated, or 'NULL' if none. + + Clause* getReason(Lit l); // Returns the reason, or asks the theory for an explanation + vec level; // 'level[var]' contains the level at which the assignment was made. int qhead; // Head of queue (as index into the trail -- no more explicit propagation queue in MiniSat). int lhead; // Head of the lemma stack (for backtracking) @@ -181,6 +186,15 @@ protected: vec analyze_toclear; vec add_tmp; + enum TheoryCheckType { + // Quick check, but don't perform theory propagation + CHECK_WITHOUTH_PROPAGATION_QUICK, + // Check and perform theory propagation + CHECK_WITH_PROPAGATION_STANDARD, + // The SAT problem is satisfiable, perform a full theory check + CHECK_WITHOUTH_PROPAGATION_FINAL + }; + // Main internal methods: // void insertVarOrder (Var x); // Insert a variable in the decision order priority queue. @@ -188,9 +202,10 @@ protected: void newDecisionLevel (); // Begins a new decision level. void uncheckedEnqueue (Lit p, Clause* from = NULL); // Enqueue a literal. Assumes value of literal is undefined. bool enqueue (Lit p, Clause* from = NULL); // Test if fact 'p' contradicts current state, enqueue otherwise. - Clause* propagate (); // Perform Boolean and Theory. Returns possibly conflicting clause. + Clause* propagate (TheoryCheckType type); // Perform Boolean and Theory. Returns possibly conflicting clause. Clause* propagateBool (); // Perform Boolean propagation. Returns possibly conflicting clause. - Clause* propagateTheory (); // Perform Theory propagation. Returns possibly conflicting clause. + bool propagateTheory (); // Perform Theory propagation. Return true if any literals were asserted. + Clause* theoryCheck (theory::Theory::Effort effort); // Perform a theory satisfiability check. Returns possibly conflicting clause. void cancelUntil (int level); // Backtrack until a certain level. void analyze (Clause* confl, vec& out_learnt, int& out_btlevel); // (bt = backtrack) void analyzeFinal (Lit p, vec& out_conflict); // COULD THIS BE IMPLEMENTED BY THE ORDINARIY "analyze" BY SOME REASONABLE GENERALIZATION? @@ -216,7 +231,7 @@ protected: // Misc: // - int decisionLevel () const; // Gives the current decisionlevel. + int decisionLevel () const; // Gives the current decision level. uint32_t abstractLevel (Var x) const; // Used to represent an abstraction of sets of decision levels. double progressEstimate () const; // DELETE THIS ?? IT'S NOT VERY USEFUL ... diff --git a/src/prop/minisat/simp/SimpSolver.C b/src/prop/minisat/simp/SimpSolver.C index 9aad6aea7..00f93402f 100644 --- a/src/prop/minisat/simp/SimpSolver.C +++ b/src/prop/minisat/simp/SimpSolver.C @@ -212,7 +212,7 @@ bool SimpSolver::strengthenClause(Clause& c, Lit l) updateElimHeap(var(l)); } - return c.size() == 1 ? enqueue(c[0]) && propagate() == NULL : true; + return c.size() == 1 ? enqueue(c[0]) && propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) == NULL : true; } @@ -312,7 +312,7 @@ bool SimpSolver::implied(const vec& c) uncheckedEnqueue(~c[i]); } - bool result = propagate() != NULL; + bool result = propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) != NULL; cancelUntil(0); return result; } @@ -394,7 +394,7 @@ bool SimpSolver::asymm(Var v, Clause& c) else l = c[i]; - if (propagate() != NULL){ + if (propagate(CHECK_WITHOUTH_PROPAGATION_QUICK) != NULL){ cancelUntil(0); asymm_lits++; if (!strengthenClause(c, l)) diff --git a/src/prop/sat.cpp b/src/prop/sat.cpp index 207bda4db..a7b536a57 100644 --- a/src/prop/sat.cpp +++ b/src/prop/sat.cpp @@ -26,9 +26,9 @@ namespace CVC4 { namespace prop { -void SatSolver::theoryCheck(SatClause& conflict) { +void SatSolver::theoryCheck(theory::Theory::Effort effort, SatClause& conflict) { // Try theory propagation - bool ok = d_theoryEngine->check(theory::Theory::FULL_EFFORT); + bool ok = d_theoryEngine->check(effort); // If in conflict construct the conflict clause if (!ok) { // We have a conflict, get it @@ -47,6 +47,47 @@ void SatSolver::theoryCheck(SatClause& conflict) { } } +void SatSolver::theoryPropagate(std::vector& output) { + // Propagate + d_theoryEngine->propagate(); + // Get the propagated literals + const std::vector& outputNodes = d_theoryEngine->getPropagatedLiterals(); + // If any literals, make a clause + const unsigned i_end = outputNodes.size(); + for (unsigned i = 0; i < i_end; ++ i) { + Debug("prop-explain") << "theoryPropagate() => " << outputNodes[i].toString() << endl; + // The second argument ("true") instructs the CNF stream to create + // a new literal mapping if it doesn't exist. This can happen due + // to a circular dependence, if a SAT literal "a" is asserted as a + // unit to the SAT solver, a round of theory propagation can occur + // before all Nodes have SAT variable mappings. + SatLiteral l = d_cnfStream->getLiteral(outputNodes[i], true); + output.push_back(l); + } +} + +void SatSolver::explainPropagation(SatLiteral l, SatClause& explanation) { + TNode lNode = d_cnfStream->getNode(l); + Debug("prop-explain") << "explainPropagation(" << lNode.toString() << ")" << endl; + Node theoryExplanation = d_theoryEngine->getExplanation(lNode); + Debug("prop-explain") << "explainPropagation() => " << theoryExplanation.toString() << endl; + if (lNode.getKind() == kind::AND) { + Node::const_iterator it = theoryExplanation.begin(); + Node::const_iterator it_end = theoryExplanation.end(); + explanation.push(l); + for (; it != it_end; ++ it) { + explanation.push(~d_cnfStream->getLiteral(*it)); + } + } else { + explanation.push(l); + explanation.push(~d_cnfStream->getLiteral(theoryExplanation)); + } +} + +void SatSolver::clearPropagatedLiterals() { + d_theoryEngine->clearPropagatedLiterals(); +} + void SatSolver::enqueueTheoryLiteral(const SatLiteral& l) { Node literalNode = d_cnfStream->getNode(l); Debug("prop") << "enqueueing theory literal " << l << " " << literalNode << std::endl; diff --git a/src/prop/sat.h b/src/prop/sat.h index f64697d7b..992d8ecd2 100644 --- a/src/prop/sat.h +++ b/src/prop/sat.h @@ -27,6 +27,7 @@ #include "util/options.h" #include "util/stats.h" +#include "theory/theory.h" #ifdef __CVC4_USE_MINISAT @@ -199,7 +200,13 @@ public: SatVariable newVar(bool theoryAtom = false); - void theoryCheck(SatClause& conflict); + void theoryCheck(theory::Theory::Effort effort, SatClause& conflict); + + void explainPropagation(SatLiteral l, SatClause& explanation); + + void theoryPropagate(std::vector& output); + + void clearPropagatedLiterals(); void enqueueTheoryLiteral(const SatLiteral& l); @@ -229,6 +236,11 @@ inline SatSolver::SatSolver(PropEngine* propEngine, TheoryEngine* theoryEngine, // Make minisat reuse the literal values d_minisat->polarity_mode = minisat::SimpSolver::polarity_user; + // No random choices + if(debugTagIsOn("no_rnd_decisions")){ + d_minisat->random_var_freq = 0; + } + d_statistics.init(d_minisat); } diff --git a/src/theory/Makefile.am b/src/theory/Makefile.am index 7cfc1571b..d0d2f23d7 100644 --- a/src/theory/Makefile.am +++ b/src/theory/Makefile.am @@ -9,6 +9,7 @@ libtheory_la_SOURCES = \ @srcdir@/theoryof_table.h \ theory_engine.h \ theory_engine.cpp \ + theory_test_utils.h \ theory.h \ theory.cpp diff --git a/src/theory/arith/Makefile.am b/src/theory/arith/Makefile.am index 83d44e285..37df73edd 100644 --- a/src/theory/arith/Makefile.am +++ b/src/theory/arith/Makefile.am @@ -15,10 +15,13 @@ libarith_la_SOURCES = \ delta_rational.cpp \ partial_model.h \ partial_model.cpp \ + ordered_bounds_list.h \ basic.h \ normal.h \ slack.h \ tableau.h \ + arith_propagator.h \ + arith_propagator.cpp \ theory_arith.h \ theory_arith.cpp diff --git a/src/theory/arith/arith_propagator.cpp b/src/theory/arith/arith_propagator.cpp new file mode 100644 index 000000000..e40575054 --- /dev/null +++ b/src/theory/arith/arith_propagator.cpp @@ -0,0 +1,347 @@ +#include "theory/arith/arith_propagator.h" +#include "theory/arith/arith_utilities.h" + +#include + +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::arith; +using namespace CVC4::theory::arith::propagator; + +using namespace CVC4::kind; + +using namespace std; + +ArithUnatePropagator::ArithUnatePropagator(context::Context* cxt) : + d_assertions(cxt), d_pendingAssertions(cxt,0) +{ } + + +bool acceptedKinds(Kind k){ + switch(k){ + case EQUAL: + case LEQ: + case GEQ: + return true; + default: + return false; + } +} + +void ArithUnatePropagator::addAtom(TNode atom){ + Assert(acceptedKinds(atom.getKind())); + + TNode left = atom[0]; + TNode right = atom[1]; + + if(!leftIsSetup(left)){ + setupLefthand(left); + } + + switch(atom.getKind()){ + case EQUAL: + { + OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList()); + Assert(!eqList->contains(atom)); + eqList->append(atom); + break; + } + case LEQ: + { + OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList()); + Assert(! leqList->contains(atom)); + leqList->append(atom); + break; + } + break; + case GEQ: + { + OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList()); + Assert(! geqList->contains(atom)); + geqList->append(atom); + break; + } + default: + Unreachable(); + } +} +bool ArithUnatePropagator::leftIsSetup(TNode left){ + return left.hasAttribute(propagator::PropagatorEqList()); +} + +void ArithUnatePropagator::setupLefthand(TNode left){ + Assert(!leftIsSetup(left)); + + OrderedBoundsList* eqList = new OrderedBoundsList(); + OrderedBoundsList* geqList = new OrderedBoundsList(); + OrderedBoundsList* leqList = new OrderedBoundsList(); + + left.setAttribute(propagator::PropagatorEqList(), eqList); + left.setAttribute(propagator::PropagatorLeqList(), leqList); + left.setAttribute(propagator::PropagatorGeqList(), geqList); +} + +void ArithUnatePropagator::assertLiteral(TNode lit){ + + if(lit.getKind() == NOT){ + Assert(!lit[0].getAttribute(propagator::PropagatorMarked())); + lit[0].setAttribute(propagator::PropagatorMarked(), true); + }else{ + Assert(!lit.getAttribute(propagator::PropagatorMarked())); + lit.setAttribute(propagator::PropagatorMarked(), true); + } + d_assertions.push_back(lit); +} + +std::vector ArithUnatePropagator::getImpliedLiterals(){ + std::vector impliedButNotAsserted; + + while(d_pendingAssertions < d_assertions.size()){ + TNode assertion = d_assertions[d_pendingAssertions]; + d_pendingAssertions = d_pendingAssertions + 1; + + enqueueImpliedLiterals(assertion, impliedButNotAsserted); + } + + if(debugTagIsOn("arith::propagator")){ + for(std::vector::iterator i = impliedButNotAsserted.begin(), + endIter = impliedButNotAsserted.end(); i != endIter; ++i){ + Node imp = *i; + Debug("arith::propagator") << explain(imp) << " (prop)-> " << imp << endl; + } + } + + return impliedButNotAsserted; +} + +/** This function is effectively a case split. */ +void ArithUnatePropagator::enqueueImpliedLiterals(TNode lit, std::vector& buffer){ + switch(lit.getKind()){ + case EQUAL: + enqueueEqualityImplications(lit, buffer); + break; + case LEQ: + enqueueUpperBoundImplications(lit, lit, buffer); + break; + case GEQ: + enqueueLowerBoundImplications(lit, lit, buffer); + break; + case NOT: + { + TNode under = lit[0]; + switch(under.getKind()){ + case EQUAL: + //Do nothing + break;; + case LEQ: + enqueueLowerBoundImplications(under, lit, buffer); + break; + case GEQ: + enqueueUpperBoundImplications(under, lit, buffer); + break; + default: + Unreachable(); + } + break; + } + default: + Unreachable(); + } +} + +/** + * An equality (x = c) has been asserted. + * In this case we can propagate everything by comparing against the other constants. + */ +void ArithUnatePropagator::enqueueEqualityImplications(TNode orig, std::vector& buffer){ + TNode left = orig[0]; + TNode right = orig[1]; + const Rational& c = right.getConst(); + + OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList()); + OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList()); + OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList()); + + + /* (x = c) /\ (c !=d) => (x != d) */ + for(OrderedBoundsList::iterator i = eqList->begin(); i != eqList->end(); ++i){ + TNode eq = *i; + Assert(eq.getKind() == EQUAL); + if(!eq.getAttribute(propagator::PropagatorMarked())){ //Note that (x = c) is marked + Assert(eq[1].getConst() != c); + + eq.setAttribute(propagator::PropagatorMarked(), true); + + Node neq = NodeManager::currentNM()->mkNode(NOT, eq); + neq.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(neq); + } + } + for(OrderedBoundsList::iterator i = leqList->begin(); i != leqList->end(); ++i){ + TNode leq = *i; + Assert(leq.getKind() == LEQ); + if(!leq.getAttribute(propagator::PropagatorMarked())){ + leq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = leq[1].getConst(); + if(c <= d){ + /* (x = c) /\ (c <= d) => (x <= d) */ + leq.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(leq); + }else{ + /* (x = c) /\ (c > d) => (x > d) */ + Node gt = NodeManager::currentNM()->mkNode(NOT, leq); + gt.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(gt); + } + } + } + + for(OrderedBoundsList::iterator i = geqList->begin(); i != geqList->end(); ++i){ + TNode geq = *i; + Assert(geq.getKind() == GEQ); + if(!geq.getAttribute(propagator::PropagatorMarked())){ + geq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = geq[1].getConst(); + if(c >= d){ + /* (x = c) /\ (c >= d) => (x >= d) */ + geq.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(geq); + }else{ + /* (x = c) /\ (c >= d) => (x >= d) */ + Node lt = NodeManager::currentNM()->mkNode(NOT, geq); + lt.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(lt); + } + } + } +} + +void ArithUnatePropagator::enqueueUpperBoundImplications(TNode atom, TNode orig, std::vector& buffer){ + + Assert(atom.getKind() == LEQ || (orig.getKind() == NOT && atom.getKind() == GEQ)); + + TNode left = atom[0]; + TNode right = atom[1]; + const Rational& c = right.getConst(); + + OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList()); + OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList()); + OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList()); + + + //For every node (x <= d), we will restrict ourselves to look at the cases when (d >= c) + for(OrderedBoundsList::iterator i = leqList->lower_bound(atom); i != leqList->end(); ++i){ + TNode leq = *i; + Assert(leq.getKind() == LEQ); + if(!leq.getAttribute(propagator::PropagatorMarked())){ + leq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = leq[1].getConst(); + Assert( c <= d ); + + leq.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(leq); // (x<=c) /\ (c <= d) => (x <= d) + //Note that if c=d, that at the node is not marked this can only be reached when (x < c) + //So we do not have to worry about a circular dependency + }else if(leq != atom){ + break; //No need to examine the rest, this atom implies the rest of the possible propagataions + } + } + + for(OrderedBoundsList::iterator i = geqList->upper_bound(atom); i != geqList->end(); ++i){ + TNode geq = *i; + Assert(geq.getKind() == GEQ); + if(!geq.getAttribute(propagator::PropagatorMarked())){ + geq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = geq[1].getConst(); + Assert( c < d ); + + Node lt = NodeManager::currentNM()->mkNode(NOT, geq); + lt.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(lt); // x<=c /\ d > c => x < d + }else{ + break; //No need to examine this atom implies the rest + } + } + + for(OrderedBoundsList::iterator i = eqList->upper_bound(atom); i != eqList->end(); ++i){ + TNode eq = *i; + Assert(eq.getKind() == EQUAL); + if(!eq.getAttribute(propagator::PropagatorMarked())){ + eq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = eq[1].getConst(); + Assert( c < d ); + + Node neq = NodeManager::currentNM()->mkNode(NOT, eq); + neq.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(neq); // x<=c /\ c < d => x != d + } + } +} + +void ArithUnatePropagator::enqueueLowerBoundImplications(TNode atom, TNode orig, std::vector& buffer){ + + Assert(atom.getKind() == GEQ || (orig.getKind() == NOT && atom.getKind() == LEQ)); + + TNode left = atom[0]; + TNode right = atom[1]; + const Rational& c = right.getConst(); + + OrderedBoundsList* eqList = left.getAttribute(propagator::PropagatorEqList()); + OrderedBoundsList* leqList = left.getAttribute(propagator::PropagatorLeqList()); + OrderedBoundsList* geqList = left.getAttribute(propagator::PropagatorGeqList()); + + + for(OrderedBoundsList::reverse_iterator i = geqList->reverse_lower_bound(atom); + i != geqList->rend(); i++){ + TNode geq = *i; + Assert(geq.getKind() == GEQ); + if(!geq.getAttribute(propagator::PropagatorMarked())){ + geq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = geq[1].getConst(); + Assert( c >= d ); + + geq.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(geq); // x>=c /\ c >= d => x >= d + }else if(geq != atom){ + break; //No need to examine the rest, this atom implies the rest of the possible propagataions + } + } + + for(OrderedBoundsList::reverse_iterator i = leqList->reverse_upper_bound(atom); + i != leqList->rend(); ++i){ + TNode leq = *i; + Assert(leq.getKind() == LEQ); + if(!leq.getAttribute(propagator::PropagatorMarked())){ + leq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = leq[1].getConst(); + Assert( c > d ); + + Node gt = NodeManager::currentNM()->mkNode(NOT, leq); + gt.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(gt); // x>=c /\ d < c => x > d + }else{ + break; //No need to examine this atom implies the rest + } + } + + for(OrderedBoundsList::reverse_iterator i = eqList->reverse_upper_bound(atom); + i != eqList->rend(); ++i){ + TNode eq = *i; + Assert(eq.getKind() == EQUAL); + if(!eq.getAttribute(propagator::PropagatorMarked())){ + eq.setAttribute(propagator::PropagatorMarked(), true); + const Rational& d = eq[1].getConst(); + Assert( c > d ); + + Node neq = NodeManager::currentNM()->mkNode(NOT, eq); + neq.setAttribute(propagator::PropagatorExplanation(), orig); + buffer.push_back(neq); // x>=c /\ c > d => x != d + } + } + +} + +Node ArithUnatePropagator::explain(TNode lit){ + Assert(lit.hasAttribute(propagator::PropagatorExplanation())); + return lit.getAttribute(propagator::PropagatorExplanation()); +} diff --git a/src/theory/arith/arith_propagator.h b/src/theory/arith/arith_propagator.h new file mode 100644 index 000000000..a623517fb --- /dev/null +++ b/src/theory/arith/arith_propagator.h @@ -0,0 +1,111 @@ + + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__ARITH__ARITH_PROPAGATOR_H +#define __CVC4__THEORY__ARITH__ARITH_PROPAGATOR_H + +#include "expr/node.h" +#include "context/cdlist.h" +#include "context/context.h" +#include "context/cdo.h" +#include "theory/arith/ordered_bounds_list.h" + +#include +#include + +namespace CVC4 { +namespace theory { +namespace arith { + +class ArithUnatePropagator { +private: + /** Index of assertions. */ + context::CDList d_assertions; + + /** Index of the last assertion in d_assertions to be asserted. */ + context::CDO d_pendingAssertions; + +public: + ArithUnatePropagator(context::Context* cxt); + + /** + * Adds a new atom for the propagator to watch. + * Atom is assumed to have been rewritten by TheoryArith::rewrite(). + */ + void addAtom(TNode atom); + + /** + * Informs the propagator that a literal has been asserted to the theory. + */ + void assertLiteral(TNode lit); + + + /** + * returns a vector of literals that are + */ + std::vector getImpliedLiterals(); + + /** Explains a literal that was asserted in the current context. */ + Node explain(TNode lit); + +private: + /** returns true if the left hand side side left has been setup. */ + bool leftIsSetup(TNode left); + + /** + * Sets up a left hand side. + * This initializes the attributes PropagatorEqList, PropagatorGeqList, and PropagatorLeqList for left. + */ + void setupLefthand(TNode left); + + /** + * Given that the literal lit is now asserted, + * enqueue additional entailed assertions in buffer. + */ + void enqueueImpliedLiterals(TNode lit, std::vector& buffer); + + void enqueueEqualityImplications(TNode original, std::vector& buffer); + void enqueueLowerBoundImplications(TNode atom, TNode original, std::vector& buffer); + /** + * Given that the literal original is now asserted, which is either (<= x c) or (not (>= x c)), + * enqueue additional entailed assertions in buffer. + */ + void enqueueUpperBoundImplications(TNode atom, TNode original, std::vector& buffer); +}; + + + +namespace propagator { + +/** Basic memory management wrapper for deleting PropagatorEqList, PropagatorGeqList, and PropagatorLeqList.*/ +struct ListCleanupStrategy{ + static void cleanup(OrderedBoundsList* l){ + Debug("arithgc") << "cleaning up " << l << "\n"; + delete l; + } +}; + + +struct PropagatorEqListID {}; +typedef expr::Attribute PropagatorEqList; + +struct PropagatorGeqListID {}; +typedef expr::Attribute PropagatorGeqList; + +struct PropagatorLeqListID {}; +typedef expr::Attribute PropagatorLeqList; + + +struct PropagatorMarkedID {}; +typedef expr::CDAttribute PropagatorMarked; + +struct PropagatorExplanationID {}; +typedef expr::CDAttribute PropagatorExplanation; +}/* CVC4::theory::arith::propagator */ + +}/* CVC4::theory::arith namespace */ +}/* CVC4::theory namespace */ +}/* CVC4 namespace */ + +#endif /* __CVC4__THEORY__ARITH__THEORY_ARITH_H */ diff --git a/src/theory/arith/ordered_bounds_list.h b/src/theory/arith/ordered_bounds_list.h new file mode 100644 index 000000000..d21283afa --- /dev/null +++ b/src/theory/arith/ordered_bounds_list.h @@ -0,0 +1,212 @@ + + +#include "cvc4_private.h" + + +#ifndef __CVC4__THEORY__ARITH__ORDERED_BOUNDS_LIST_H +#define __CVC4__THEORY__ARITH__ORDERED_BOUNDS_LIST_H + +#include "expr/node.h" +#include "util/rational.h" +#include "expr/kind.h" + +#include +#include + +namespace CVC4 { +namespace theory { +namespace arith { + +struct RightHandRationalLT +{ + bool operator()(TNode s1, TNode s2) const + { + Assert(s1.getNumChildren() >= 2); + Assert(s2.getNumChildren() >= 2); + + Assert(s1[1].getKind() == kind::CONST_RATIONAL); + Assert(s2[1].getKind() == kind::CONST_RATIONAL); + + TNode rh1 = s1[1]; + TNode rh2 = s2[1]; + const Rational& c1 = rh1.getConst(); + const Rational& c2 = rh2.getConst(); + return c1.cmp(c2) < 0; + } +}; + +struct RightHandRationalGT +{ + bool operator()(TNode s1, TNode s2) const + { + Assert(s1.getNumChildren() >= 2); + Assert(s2.getNumChildren() >= 2); + + Assert(s1[1].getKind() == kind::CONST_RATIONAL); + Assert(s2[1].getKind() == kind::CONST_RATIONAL); + + TNode rh1 = s1[1]; + TNode rh2 = s2[1]; + const Rational& c1 = rh1.getConst(); + const Rational& c2 = rh2.getConst(); + return c1.cmp(c2) > 0; + } +}; + +/** + * An OrderedBoundsList is a lazily sorted vector of Arithmetic constraints. + * The intended use is for a list of rewriting arithmetic atoms. + * An example of such a list would be [(<= x 5);(= y 78); (>= x 9)]. + * + * Nodes are required to have a CONST_RATIONAL child as their second node. + * Nodes are sorted in increasing order according to RightHandRationalLT. + * + * The lists are lazily sorted in the sense that the list is not sorted until + * an operation to access the element is attempted. + * + * An append() may make the list no longer sorted. + * After an append() operation all iterators for the list become invalid. + */ +class OrderedBoundsList { +private: + bool d_isSorted; + std::vector d_list; + +public: + typedef std::vector::const_iterator iterator; + typedef std::vector::const_reverse_iterator reverse_iterator; + + /** + * Constucts a new and empty OrderBoundsList. + * The empty list is initially sorted. + */ + OrderedBoundsList() : d_isSorted(true){} + + /** + * Appends a node onto the back of the list. + * The list may no longer be sorted. + */ + void append(TNode n){ + Assert(n.getNumChildren() >= 2); + Assert(n[1].getKind() == kind::CONST_RATIONAL); + d_isSorted = false; + d_list.push_back(n); + } + + /** returns the size of the list */ + unsigned int size(){ + return d_list.size(); + } + + /** returns the i'th element in the sort list. This may sort the list.*/ + TNode at(unsigned int idx){ + sortIfNeeded(); + return d_list.at(idx); + } + + /** returns true if the list is known to be sorted. */ + bool isSorted() const{ + return d_isSorted; + } + + /** sorts the list. */ + void sort(){ + d_isSorted = true; + std::sort(d_list.begin(), d_list.end(), RightHandRationalLT()); + } + + /** + * returns an iterator to the list that iterates in ascending order. + * This may sort the list. + */ + iterator begin(){ + sortIfNeeded(); + return d_list.begin(); + } + /** + * returns an iterator to the end of the list when interating in ascending order. + */ + iterator end() const{ + return d_list.end(); + } + + /** + * returns an iterator to the list that iterates in descending order. + * This may sort the list. + */ + reverse_iterator rbegin(){ + sortIfNeeded(); + return d_list.rend(); + } + /** + * returns an iterator to the end of the list when interating in descending order. + */ + reverse_iterator rend() const{ + return d_list.rend(); + } + + /** + * returns an iterator to the least strict upper bound of value. + * if the list is [(<= x 2);(>= x 80);(< y 70)] + * then *upper_bound((< z 70)) == (>= x 80) + * + * This may sort the list. + * see stl::upper_bound for more information. + */ + iterator upper_bound(TNode value){ + sortIfNeeded(); + return std::upper_bound(begin(), end(), value, RightHandRationalLT()); + } + /** + * returns an iterator to the greatest lower bound of value. + * This is bound is not strict. + * if the list is [(<= x 2);(>= x 80);(< y 70)] + * then *lower_bound((< z 70)) == (< y 70) + * + * This may sort the list. + * see stl::upper_bound for more information. + */ + iterator lower_bound(TNode value){ + sortIfNeeded(); + return std::lower_bound(begin(), end(), value, RightHandRationalLT()); + } + /** + * see OrderedBoundsList::upper_bound for more information. + * The difference is that the iterator goes in descending order. + */ + reverse_iterator reverse_upper_bound(TNode value){ + sortIfNeeded(); + return std::upper_bound(rbegin(), rend(), value, RightHandRationalGT()); + } + /** + * see OrderedBoundsList::lower_bound for more information. + * The difference is that the iterator goes in descending order. + */ + reverse_iterator reverse_lower_bound(TNode value){ + sortIfNeeded(); + return std::lower_bound(rbegin(), rend(), value, RightHandRationalGT()); + } + + /** + * This is an O(n) method for searching the array to check if it contains n. + */ + bool contains(TNode n) const { + for(std::vector::const_iterator i = d_list.begin(); i != d_list.end(); ++i){ + if(*i == n) return true; + } + return false; + } +private: + /** Sorts the list if it is not already sorted. */ + void sortIfNeeded(){ + if(!d_isSorted){ + sort(); + } + } +}; + +}/* CVC4::theory::arith namespace */ +}/* CVC4::theory namespace */ +}/* CVC4 namespace */ + +#endif /* __CVC4__THEORY__ARITH__ORDERED_BOUNDS_LIST_H */ diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index b3b7f58be..bd35e0797 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -34,6 +34,7 @@ #include "theory/arith/basic.h" #include "theory/arith/arith_rewriter.h" +#include "theory/arith/arith_propagator.h" #include "theory/arith/theory_arith.h" #include @@ -55,6 +56,7 @@ TheoryArith::TheoryArith(context::Context* c, OutputChannel& out) : d_partialModel(c), d_diseq(c), d_rewriter(&d_constants), + d_propagator(c), d_statistics() { uint64_t ass_id = partial_model::Assignment::getId(); @@ -81,6 +83,15 @@ TheoryArith::Statistics::Statistics(): StatisticsRegistry::registerStat(&d_statUpdateConflicts); } +TheoryArith::Statistics::~Statistics(){ + StatisticsRegistry::unregisterStat(&d_statPivots); + StatisticsRegistry::unregisterStat(&d_statUpdates); + StatisticsRegistry::unregisterStat(&d_statAssertUpperConflicts); + StatisticsRegistry::unregisterStat(&d_statAssertLowerConflicts); + StatisticsRegistry::unregisterStat(&d_statUpdateConflicts); +} + + bool isBasicSum(TNode n){ if(n.getKind() != kind::PLUS) return false; @@ -143,6 +154,8 @@ void TheoryArith::preRegisterTerm(TNode n) { Assert(isNormalAtom(n)); + d_propagator.addAtom(n); + TNode left = n[0]; TNode right = n[1]; if(left.getKind() == PLUS){ @@ -206,6 +219,10 @@ void TheoryArith::setupVariable(TNode x){ //lower bound. This is done to strongly enforce the notion that basic //variables should not be changed without begin checked. + //Strictly speaking checking x is unnessecary as it cannot have an upper or + //lower bound. This is done to strongly enforce the notion that basic + //variables should not be changed without begin checked. + } Debug("arithgc") << "setupVariable("< " << assertion << ")" << std::endl; + d_propagator.assertLiteral(original); bool conflictDuringAnAssert = assertionCases(original, assertion); + if(conflictDuringAnAssert){ - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } d_partialModel.revertAssignmentChanges(); - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - return; } } - if(fullEffort(level)){ - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } + //TODO This must be done everytime for the time being + if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - Node possibleConflict = updateInconsistentVars(); - if(possibleConflict != Node::null()){ + Node possibleConflict = updateInconsistentVars(); + if(possibleConflict != Node::null()){ - d_partialModel.revertAssignmentChanges(); + d_partialModel.revertAssignmentChanges(); - d_out->conflict(possibleConflict, true); + if(debugTagIsOn("arith::print-conflict")) + Debug("arith_conflict") << (possibleConflict) << std::endl; - Debug("arith_conflict") <<"Found a conflict "<< possibleConflict << endl; - }else{ - d_partialModel.commitAssignmentChanges(); - } - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } + d_out->conflict(possibleConflict); + + Debug("arith_conflict") <<"Found a conflict "<< possibleConflict << endl; + }else{ + d_partialModel.commitAssignmentChanges(); } + if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } + Debug("arith") << "TheoryArith::check end" << std::endl; + if(debugTagIsOn("arith::print_model")) { + Debug("arith::print_model") << "Model:" << endl; + + for (unsigned i = 0; i < d_variables.size(); ++ i) { + Debug("arith::print_model") << d_variables[i] << " : " << + d_partialModel.getAssignment(d_variables[i]); + if(isBasic(d_variables[i])) + Debug("arith::print_model") << " (basic)"; + Debug("arith::print_model") << endl; + } + } + if(debugTagIsOn("arith::print_assertions")) { + Debug("arith::print_assertions") << "Assertions:" << endl; + for (unsigned i = 0; i < d_variables.size(); ++ i) { + Node x = d_variables[i]; + if (x.hasAttribute(partial_model::LowerConstraint())) { + Node constr = d_partialModel.getLowerConstraint(x); + Debug("arith::print_assertions") << constr.toString() << endl; + } + if (x.hasAttribute(partial_model::UpperConstraint())) { + Node constr = d_partialModel.getUpperConstraint(x); + Debug("arith::print_assertions") << constr.toString() << endl; + } + } + } } /** @@ -750,3 +795,23 @@ void TheoryArith::checkTableau(){ Assert(sum == shouldBe); } } + + +void TheoryArith::explain(TNode n, Effort e) { + Node explanation = d_propagator.explain(n); + Debug("arith") << "arith::explain("<" + << explanation << endl; + d_out->explanation(explanation, true); +} + +void TheoryArith::propagate(Effort e) { + + if(quickCheckOrMore(e)){ + std::vector implied = d_propagator.getImpliedLiterals(); + for(std::vector::iterator i = implied.begin(); + i != implied.end(); + ++i){ + d_out->propagate(*i); + } + } +} diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index aff60f651..c76923bee 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -30,6 +30,7 @@ #include "theory/arith/tableau.h" #include "theory/arith/arith_rewriter.h" #include "theory/arith/partial_model.h" +#include "theory/arith/arith_propagator.h" #include "util/stats.h" @@ -96,6 +97,7 @@ private: */ ArithRewriter d_rewriter; + ArithUnatePropagator d_propagator; public: TheoryArith(context::Context* c, OutputChannel& out); @@ -115,8 +117,8 @@ public: void registerTerm(TNode n); void check(Effort e); - void propagate(Effort e) { Unimplemented(); } - void explain(TNode n, Effort e) { Unimplemented(); } + void propagate(Effort e); + void explain(TNode n, Effort e); void shutdown(){ } @@ -242,6 +244,7 @@ private: IntStat d_statAssertLowerConflicts, d_statUpdateConflicts; Statistics(); + ~Statistics(); }; Statistics d_statistics; diff --git a/src/theory/theory.cpp b/src/theory/theory.cpp index e06c9594c..5e83d3728 100644 --- a/src/theory/theory.cpp +++ b/src/theory/theory.cpp @@ -103,5 +103,21 @@ Node Theory::get() { return fact; } +std::ostream& operator<<(std::ostream& os, Theory::Effort level){ + switch(level){ + case Theory::MIN_EFFORT: + os << "MIN_EFFORT"; break; + case Theory::QUICK_CHECK: + os << "QUICK_CHECK:"; break; + case Theory::STANDARD: + os << "STANDARD"; break; + case Theory::FULL_EFFORT: + os << "FULL_EFFORT"; break; + default: + Unreachable(); + } + return os; +} + }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/theory.h b/src/theory/theory.h index 1bf6f660c..6f4effe78 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -331,6 +331,8 @@ protected: };/* class Theory */ +std::ostream& operator<<(std::ostream& os, Theory::Effort level); + }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/theory_engine.h b/src/theory/theory_engine.h index c2511f4e6..15b406cdd 100644 --- a/src/theory/theory_engine.h +++ b/src/theory/theory_engine.h @@ -64,13 +64,22 @@ class TheoryEngine { TheoryEngine* d_engine; context::Context* d_context; context::CDO d_conflictNode; + context::CDO d_explanationNode; + + /** + * Literals that are propagated by the theory. Note that these are TNodes. + * The theory can only propagate nodes that have an assigned literal in the + * sat solver and are hence referenced in the SAT solver. + */ + std::vector d_propagatedLiterals; public: EngineOutputChannel(TheoryEngine* engine, context::Context* context) : d_engine(engine), d_context(context), - d_conflictNode(context) { + d_conflictNode(context), + d_explanationNode(context){ } void conflict(TNode conflictNode, bool safe) throw(theory::Interrupted, AssertionException) { @@ -82,7 +91,9 @@ class TheoryEngine { } } - void propagate(TNode, bool) throw(theory::Interrupted, AssertionException) { + void propagate(TNode lit, bool) throw(theory::Interrupted, AssertionException) { + d_propagatedLiterals.push_back(lit); + ++(d_engine->d_statistics.d_statPropagate); ++(d_engine->d_statistics.d_statPropagate); } @@ -94,7 +105,9 @@ class TheoryEngine { ++(d_engine->d_statistics.d_statAugLemma); d_engine->newAugmentingLemma(node); } - void explanation(TNode, bool) throw(theory::Interrupted, AssertionException) { + void explanation(TNode explanationNode, bool) throw(theory::Interrupted, AssertionException) { + d_explanationNode = explanationNode; + ++(d_engine->d_statistics.d_statExplanatation); ++(d_engine->d_statistics.d_statExplanatation); } }; @@ -302,6 +315,7 @@ public: inline bool check(theory::Theory::Effort effort) { d_theoryOut.d_conflictNode = Node::null(); + d_theoryOut.d_propagatedLiterals.clear(); // Do the checking try { //d_bool.check(effort); @@ -316,13 +330,23 @@ public: return d_theoryOut.d_conflictNode.get().isNull(); } + inline const std::vector& getPropagatedLiterals() const { + return d_theoryOut.d_propagatedLiterals; + } + + void clearPropagatedLiterals() { + d_theoryOut.d_propagatedLiterals.clear(); + } + inline void newLemma(TNode node) { d_propEngine->assertLemma(node); } + inline void newAugmentingLemma(TNode node) { Node preprocessed = preprocess(node); d_propEngine->assertFormula(preprocessed); } + /** * Returns the last conflict (if any). */ @@ -330,6 +354,21 @@ public: return d_theoryOut.d_conflictNode; } + inline void propagate() { + d_theoryOut.d_propagatedLiterals.clear(); + // Do the propagation + d_uf.propagate(theory::Theory::FULL_EFFORT); + d_arith.propagate(theory::Theory::FULL_EFFORT); + } + + inline Node getExplanation(TNode node){ + d_theoryOut.d_explanationNode = Node::null(); + theory::Theory* theory = + node.getKind() == kind::NOT ? theoryOf(node[0]) : theoryOf(node); + theory->explain(node); + return d_theoryOut.d_explanationNode; + } + private: class Statistics { public: @@ -350,6 +389,7 @@ private: }; Statistics d_statistics; + };/* class TheoryEngine */ }/* CVC4 namespace */ diff --git a/src/theory/theory_test_utils.h b/src/theory/theory_test_utils.h new file mode 100644 index 000000000..dc08788f3 --- /dev/null +++ b/src/theory/theory_test_utils.h @@ -0,0 +1,81 @@ + + +#include "cvc4_public.h" + + +#ifndef __CVC4__THEORY__THEORY_TEST_UTILS_H +#define __CVC4__THEORY__ITHEORY_TEST_UTILS_H + +#include "util/Assert.h" +#include "expr/node.h" +#include "theory/output_channel.h" +#include "theory/interrupted.h" + +#include + +namespace CVC4{ + +namespace theory { + +/** + * Very basic OutputChannel for testing simple Theory Behaviour. + * Stores a call sequence for the output channel + */ +enum OutputChannelCallType { CONFLICT, PROPOGATE, AUG_LEMMA, LEMMA, EXPLANATION }; + + +class TestOutputChannel : public theory::OutputChannel { +public: + std::vector< pair > d_callHistory; + + TestOutputChannel() {} + + ~TestOutputChannel() {} + + void safePoint() throw(Interrupted, AssertionException) {} + + void conflict(TNode n, bool safe = false) throw(Interrupted, AssertionException) { + push(CONFLICT, n); + } + + void propagate(TNode n, bool safe = false) throw(Interrupted, AssertionException) { + push(PROPOGATE, n); + } + + void lemma(TNode n, bool safe = false) throw(Interrupted, AssertionException) { + push(LEMMA, n); + } + void augmentingLemma(TNode n, bool safe = false) throw(Interrupted, AssertionException){ + push(AUG_LEMMA, n); + } + void explanation(TNode n, bool safe = false) throw(Interrupted, AssertionException) { + push(EXPLANATION, n); + } + + void clear() { + d_callHistory.clear(); + } + + Node getIthNode(int i) { + Node tmp = (d_callHistory[i]).second; + return tmp; + } + + OutputChannelCallType getIthCallType(int i) { + return (d_callHistory[i]).first; + } + + unsigned getNumCalls() { + return d_callHistory.size(); + } + +private: + void push(OutputChannelCallType call, TNode n) { + d_callHistory.push_back(make_pair(call,n)); + } +};/* class TestOutputChannel */ + +}/* namespace theory */ +}/* namespace CVC4 */ + +#endif /* __CVC4__THEORY__THEORY_TEST_UTILS_H */ diff --git a/src/theory/uf/theory_uf.cpp b/src/theory/uf/theory_uf.cpp index d13baf6a9..f440c3d0f 100644 --- a/src/theory/uf/theory_uf.cpp +++ b/src/theory/uf/theory_uf.cpp @@ -307,7 +307,7 @@ void TheoryUF::check(Effort level) { merge(); } - if(fullEffort(level)) { + if(standardEffortOrMore(level)) { for(CDList::const_iterator diseqIter = d_disequality.begin(); diseqIter != d_disequality.end(); ++diseqIter) { diff --git a/test/unit/Makefile.am b/test/unit/Makefile.am index 9f8379d54..ddab915bf 100644 --- a/test/unit/Makefile.am +++ b/test/unit/Makefile.am @@ -23,6 +23,7 @@ UNIT_TESTS = \ context/cdmap_white \ theory/theory_black \ theory/theory_uf_white \ + theory/theory_arith_white \ util/assert_white \ util/bitvector_black \ util/configuration_black \ diff --git a/test/unit/theory/theory_arith_white.h b/test/unit/theory/theory_arith_white.h new file mode 100644 index 000000000..fe9cbb388 --- /dev/null +++ b/test/unit/theory/theory_arith_white.h @@ -0,0 +1,312 @@ + +#include + +#include "theory/theory.h" +#include "theory/arith/theory_arith.h" +#include "expr/node.h" +#include "expr/node_manager.h" +#include "context/context.h" +#include "util/rational.h" + +#include "theory/theory_test_utils.h" + +#include + +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::arith; +using namespace CVC4::expr; +using namespace CVC4::context; +using namespace CVC4::kind; + +using namespace std; + +class TheoryArithWhite : public CxxTest::TestSuite { + + Context* d_ctxt; + NodeManager* d_nm; + NodeManagerScope* d_scope; + + TestOutputChannel d_outputChannel; + Theory::Effort d_level; + + TheoryArith* d_arith; + + TypeNode* d_booleanType; + TypeNode* d_realType; + + const Rational d_zero; + const Rational d_one; + + std::set* preregistered; + + bool debug; + +public: + + TheoryArithWhite() : d_level(Theory::FULL_EFFORT), d_zero(0), d_one(1), debug(false) {} + + void setUp() { + d_ctxt = new Context; + d_nm = new NodeManager(d_ctxt); + d_scope = new NodeManagerScope(d_nm); + d_outputChannel.clear(); + d_arith = new TheoryArith(d_ctxt, d_outputChannel); + + preregistered = new std::set(); + + d_booleanType = new TypeNode(d_nm->booleanType()); + d_realType = new TypeNode(d_nm->realType()); + + } + + void tearDown() { + delete d_realType; + delete d_booleanType; + + delete preregistered; + + delete d_arith; + d_outputChannel.clear(); + delete d_scope; + delete d_nm; + delete d_ctxt; + } + + Node fakeTheoryEnginePreprocess(TNode inp){ + Node rewrite = d_arith->rewrite(inp); + + if(debug) cout << rewrite << inp << endl; + + std::list toPreregister; + + toPreregister.push_back(rewrite); + for(std::list::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){ + Node n = *i; + preregistered->insert(n); + + for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){ + Node c = *citer; + if(preregistered->find(c) == preregistered->end()){ + toPreregister.push_back(c); + } + } + } + for(std::list::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){ + Node n = *i; + if(debug) cout << n.getId() << " "<< n << endl; + d_arith->preRegisterTerm(n); + } + + return rewrite; + } + + void testAssert() { + Node x = d_nm->mkVar(*d_realType); + Node c = d_nm->mkConst(d_zero); + + Node leq = d_nm->mkNode(LEQ, x, c); + Node rLeq = fakeTheoryEnginePreprocess(leq); + + d_arith->assertFact(rLeq); + + d_arith->check(d_level); + + TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u); + } + + Node simulateSplit(TNode l, TNode r){ + Node eq = d_nm->mkNode(EQUAL, l, r); + Node lt = d_nm->mkNode(LT, l, r); + Node gt = d_nm->mkNode(GT, l, r); + + Node dis = d_nm->mkNode(OR, eq, lt, gt); + return dis; + } + + void testAssertEqualityEagerSplit() { + Node x = d_nm->mkVar(*d_realType); + Node c = d_nm->mkConst(d_zero); + + Node eq = d_nm->mkNode(EQUAL, x, c); + Node expectedDisjunct = simulateSplit(x,c); + + Node rEq = fakeTheoryEnginePreprocess(eq); + + d_arith->assertFact(rEq); + + d_arith->check(d_level); + + TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 1u); + + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA); + + } + void testLtRewrite() { + Node x = d_nm->mkVar(*d_realType); + Node c = d_nm->mkConst(d_zero); + + Node lt = d_nm->mkNode(LT, x, c); + Node geq = d_nm->mkNode(GEQ, x, c); + Node expectedRewrite = d_nm->mkNode(NOT, geq); + + Node rewrite = d_arith->rewrite(lt); + + TS_ASSERT_EQUALS(expectedRewrite, rewrite); + } + + void testBasicConflict() { + Node x = d_nm->mkVar(*d_realType); + Node c = d_nm->mkConst(d_zero); + + Node eq = d_nm->mkNode(EQUAL, x, c); + Node lt = d_nm->mkNode(LT, x, c); + Node expectedDisjunct = simulateSplit(x,c); + + Node rEq = fakeTheoryEnginePreprocess(eq); + Node rLt = fakeTheoryEnginePreprocess(lt); + + d_arith->assertFact(rEq); + d_arith->assertFact(rLt); + + + d_arith->check(d_level); + + TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA); + + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), CONFLICT); + + Node expectedClonflict = d_nm->mkNode(AND, rEq, rLt); + + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedClonflict); + } + + void testBasicPropagate() { + Node x = d_nm->mkVar(*d_realType); + Node c = d_nm->mkConst(d_zero); + + Node eq = d_nm->mkNode(EQUAL, x, c); + Node lt = d_nm->mkNode(LT, x, c); + Node expectedDisjunct = simulateSplit(x,c); + + Node rEq = fakeTheoryEnginePreprocess(eq); + Node rLt = fakeTheoryEnginePreprocess(lt); + + d_arith->assertFact(rEq); + + + d_arith->check(d_level); + d_arith->propagate(d_level); + + TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA); + + + Node expectedProp = d_nm->mkNode(GEQ, x, c); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), PROPOGATE); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedProp); + + } + void testTPLt1() { + Node x = d_nm->mkVar(*d_realType); + Node c0 = d_nm->mkConst(d_zero); + Node c1 = d_nm->mkConst(d_one); + + Node leq0 = d_nm->mkNode(LEQ, x, c0); + Node leq1 = d_nm->mkNode(LEQ, x, c1); + Node lt1 = d_nm->mkNode(LT, x, c1); + + Node rLeq0 = fakeTheoryEnginePreprocess(leq0); + Node rLt1 = fakeTheoryEnginePreprocess(lt1); + Node rLeq1 = fakeTheoryEnginePreprocess(leq1); + + d_arith->assertFact(rLt1); + + + d_arith->check(d_level); + d_arith->propagate(d_level); + +#ifdef CVC4_ASSERTIONS + TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException ); +#endif + d_arith->explain(rLeq1, d_level); + + TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPOGATE); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), EXPLANATION); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1); + } + + + void testTPLeq0() { + Node x = d_nm->mkVar(*d_realType); + Node c0 = d_nm->mkConst(d_zero); + Node c1 = d_nm->mkConst(d_one); + + Node leq0 = d_nm->mkNode(LEQ, x, c0); + Node leq1 = d_nm->mkNode(LEQ, x, c1); + Node lt1 = d_nm->mkNode(LT, x, c1); + + Node rLeq0 = fakeTheoryEnginePreprocess(leq0); + Node rLt1 = fakeTheoryEnginePreprocess(lt1); + Node rLeq1 = fakeTheoryEnginePreprocess(leq1); + + d_arith->assertFact(rLeq0); + + + d_arith->check(d_level); + d_arith->propagate(d_level); + + + d_arith->explain(rLt1, d_level); +#ifdef CVC4_ASSERTIONS + TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException ); +#endif + d_arith->explain(rLeq1, d_level); + + TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 4u); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPOGATE); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), PROPOGATE); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(2), EXPLANATION); + TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(3), EXPLANATION); + + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), rLeq1); + + + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), rLeq0); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), rLeq0); + } + void testTPLeq1() { + Node x = d_nm->mkVar(*d_realType); + Node c0 = d_nm->mkConst(d_zero); + Node c1 = d_nm->mkConst(d_one); + + Node leq0 = d_nm->mkNode(LEQ, x, c0); + Node leq1 = d_nm->mkNode(LEQ, x, c1); + Node lt1 = d_nm->mkNode(LT, x, c1); + + Node rLeq0 = fakeTheoryEnginePreprocess(leq0); + Node rLt1 = fakeTheoryEnginePreprocess(lt1); + Node rLeq1 = fakeTheoryEnginePreprocess(leq1); + + d_arith->assertFact(rLeq1); + + + d_arith->check(d_level); + d_arith->propagate(d_level); + +#ifdef CVC4_ASSERTIONS + TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(rLeq1, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException ); +#endif + + TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u); + } +}; diff --git a/test/unit/theory/theory_uf_white.h b/test/unit/theory/theory_uf_white.h index 50c201606..203d669b7 100644 --- a/test/unit/theory/theory_uf_white.h +++ b/test/unit/theory/theory_uf_white.h @@ -24,6 +24,8 @@ #include "expr/node_manager.h" #include "context/context.h" +#include "theory/theory_test_utils.h" + #include using namespace CVC4; @@ -34,60 +36,6 @@ using namespace CVC4::context; using namespace std; -/** - * Very basic OutputChannel for testing simple Theory Behaviour. - * Stores a call sequence for the output channel - */ -enum OutputChannelCallType { CONFLICT, PROPOGATE, LEMMA, EXPLANATION }; -class TestOutputChannel : public OutputChannel { -private: - void push(OutputChannelCallType call, TNode n) { - d_callHistory.push_back(make_pair(call,n)); - } -public: - vector< pair > d_callHistory; - - TestOutputChannel() {} - - ~TestOutputChannel() {} - - void safePoint() throw(Interrupted, AssertionException) {} - - void conflict(TNode n, bool safe = false) throw(Interrupted, AssertionException) { - push(CONFLICT, n); - } - - void propagate(TNode n, bool safe = false) throw(Interrupted, AssertionException) { - push(PROPOGATE, n); - } - - void lemma(TNode n, bool safe = false) throw(Interrupted, AssertionException) { - push(LEMMA, n); - } - void augmentingLemma(TNode n, bool safe = false) throw(Interrupted, AssertionException){ - Unreachable(); - } - void explanation(TNode n, bool safe = false) throw(Interrupted, AssertionException) { - push(EXPLANATION, n); - } - - void clear() { - d_callHistory.clear(); - } - - Node getIthNode(int i) { - Node tmp = (d_callHistory[i]).second; - return tmp; - } - - OutputChannelCallType getIthCallType(int i) { - return (d_callHistory[i]).first; - } - - unsigned getNumCalls() { - return d_callHistory.size(); - } -}; class TheoryUFWhite : public CxxTest::TestSuite { -- 2.30.2