From: Tim King Date: Wed, 16 Jun 2010 21:19:34 +0000 (+0000) Subject: More assorted changes to arithmetic in preparation for the code review. X-Git-Tag: cvc5-1.0.0~8984 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=2ceddc34920376bcb181c5fbbe2a9c0f4b87f436;p=cvc5.git More assorted changes to arithmetic in preparation for the code review. --- diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index 2f62c8bc1..514dce3f7 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -47,13 +47,10 @@ using namespace CVC4::kind; using namespace CVC4::theory; using namespace CVC4::theory::arith; -struct EagerSplittingTag {}; -typedef expr::Attribute EagerlySplitUpon; TheoryArith::TheoryArith(context::Context* c, OutputChannel& out) : Theory(c, out), - d_preprocessed(c), d_constants(NodeManager::currentNM()), d_partialModel(c), d_diseq(c), @@ -105,25 +102,20 @@ bool isNormalAtom(TNode n){ } void TheoryArith::preRegisterTerm(TNode n) { - Debug("arith_preregister") << "arith: begin TheoryArith::preRegisterTerm(" - << n << ")" << endl; - + Debug("arith_preregister") <<"begin arith::preRegisterTerm("<< n <<")"<< endl; Kind k = n.getKind(); - if(n.getKind() == EQUAL){ - if(!n.getAttribute(EagerlySplitUpon())){ - TNode left = n[0]; - TNode right = n[1]; + if(k == EQUAL){ + TNode left = n[0]; + TNode right = n[1]; - Node lt = NodeManager::currentNM()->mkNode(LT, left,right); - Node gt = NodeManager::currentNM()->mkNode(GT, left,right); - Node eagerSplit = NodeManager::currentNM()->mkNode(OR, n, lt, gt); + Node lt = NodeManager::currentNM()->mkNode(LT, left,right); + Node gt = NodeManager::currentNM()->mkNode(GT, left,right); + Node eagerSplit = NodeManager::currentNM()->mkNode(OR, n, lt, gt); - d_splits.push_back(eagerSplit); + d_splits.push_back(eagerSplit); - n.setAttribute(EagerlySplitUpon(), true); - d_out->augmentingLemma(eagerSplit); - } + d_out->augmentingLemma(eagerSplit); } if(n.getMetaKind() == metakind::VARIABLE){ @@ -148,13 +140,10 @@ void TheoryArith::preRegisterTerm(TNode n) { } } } - - Debug("arith_preregister") << "arith: end TheoryArith::preRegisterTerm(" - << n << ")" << endl; + Debug("arith_preregister") << "end arith::preRegisterTerm("<< n <<")"<< endl; } void TheoryArith::setupSlack(TNode left){ - //TODO TypeNode real_type = NodeManager::currentNM()->realType(); Node slack = NodeManager::currentNM()->mkVar(real_type); @@ -179,7 +168,6 @@ void TheoryArith::checkBasicVariable(TNode basic){ } /* Requirements: - * Variable must have been set to be basic. * For basic variables the row must have been added to the tableau. */ void TheoryArith::setupVariable(TNode x){ @@ -194,20 +182,16 @@ void TheoryArith::setupVariable(TNode x){ //This can go away if the tableau creation is done at preregister //time instead of register - - DeltaRational q = computeRowValueUsingSavedAssignment(x); - if(!(q == d_constants.d_ZERO_DELTA)){ - Debug("arith_setup") << "setup("<::iterator i = row->begin(); i != row->end();++i){ + for(Row::iterator i = row->begin(); i != row->end();++i){ TNode nonbasic = *i; const Rational& coeff = row->lookup(nonbasic); - DeltaRational assignment = d_partialModel.getAssignment(nonbasic); + const DeltaRational& assignment = d_partialModel.getAssignment(nonbasic); sum = sum + (assignment * coeff); } return sum; @@ -237,10 +221,10 @@ DeltaRational TheoryArith::computeRowValueUsingSavedAssignment(TNode x){ DeltaRational sum = d_constants.d_ZERO_DELTA; Row* row = d_tableau.lookup(x); - for(std::set::iterator i = row->begin(); i != row->end();++i){ + for(Row::iterator i = row->begin(); i != row->end();++i){ TNode nonbasic = *i; const Rational& coeff = row->lookup(nonbasic); - DeltaRational assignment = d_partialModel.getSafeAssignment(nonbasic); + const DeltaRational& assignment = d_partialModel.getSafeAssignment(nonbasic); sum = sum + (assignment * coeff); } return sum; @@ -250,19 +234,13 @@ Node TheoryArith::rewrite(TNode n){ Debug("arith") << "rewrite(" << n << ")" << endl; Node result = d_rewriter.rewrite(n); - Debug("arith-rewrite") << "rewrite(" - << n << " -> " << result - << ")" << endl; + Debug("arith-rewrite") << "rewrite(" << n << ") -> " << result << endl; return result; } void TheoryArith::registerTerm(TNode tn){ Debug("arith") << "registerTerm(" << tn << ")" << endl; - - if(tn.getKind() == kind::BUILTIN) return; - - } /* procedure AssertUpper( x_i <= c_i) */ @@ -329,7 +307,6 @@ bool TheoryArith::AssertLower(TNode n, TNode original){ }else{ checkBasicVariable(x_i); } - //d_partialModel.printModel(x_i); return false; } @@ -351,7 +328,7 @@ void TheoryArith::update(TNode x_i, DeltaRational& v){ if(row_j->has(x_i)){ const Rational& a_ji = row_j->lookup(x_i); - DeltaRational assignment = d_partialModel.getAssignment(x_j); + const DeltaRational& assignment = d_partialModel.getAssignment(x_j); DeltaRational nAssignment = assignment+(diff * a_ji); d_partialModel.setAssignment(x_j, nAssignment); checkBasicVariable(x_j); @@ -372,7 +349,7 @@ void TheoryArith::pivotAndUpdate(TNode x_i, TNode x_j, DeltaRational& v){ const Rational& a_ij = row_i->lookup(x_j); - DeltaRational betaX_i = d_partialModel.getAssignment(x_i); + const DeltaRational& betaX_i = d_partialModel.getAssignment(x_i); Rational inv_aij = a_ij.inverse(); DeltaRational theta = (v - betaX_i)*inv_aij; @@ -390,7 +367,7 @@ void TheoryArith::pivotAndUpdate(TNode x_i, TNode x_j, DeltaRational& v){ Row* row_k = d_tableau.lookup(x_k); if(x_k != x_i && row_k->has(x_j)){ - Rational a_kj = row_k->lookup(x_j); + const Rational& a_kj = row_k->lookup(x_j); DeltaRational nextAssignment = d_partialModel.getAssignment(x_k) + (theta * a_kj); d_partialModel.setAssignment(x_k, nextAssignment); checkBasicVariable(x_k); @@ -438,40 +415,28 @@ TNode TheoryArith::selectSmallestInconsistentVar(){ return TNode::null(); } -TNode TheoryArith::selectSlackBelow(TNode x_i){ //beta(x_i) < l_i - Row* row_i = d_tableau.lookup(x_i); +template +TNode TheoryArith::selectSlack(TNode x_i){ + Row* row_i = d_tableau.lookup(x_i); - typedef std::set::iterator NonBasicIter; - - for(NonBasicIter nbi = row_i->begin(); nbi != row_i->end(); ++nbi){ + for(Row::iterator nbi = row_i->begin(); nbi != row_i->end(); ++nbi){ TNode nonbasic = *nbi; - - Rational a_ij = row_i->lookup(nonbasic); - if(a_ij > d_constants.d_ZERO && d_partialModel.strictlyBelowUpperBound(nonbasic)){ - return nonbasic; - }else if(a_ij < d_constants.d_ZERO && d_partialModel.strictlyAboveLowerBound(nonbasic)){ - return nonbasic; - } - } - return TNode::null(); -} - -TNode TheoryArith::selectSlackAbove(TNode x_i){ // beta(x_i) > u_i - Row* row_i = d_tableau.lookup(x_i); - - typedef std::set::iterator NonBasicIter; - - for(NonBasicIter nbi = row_i->begin(); nbi != row_i->end(); ++nbi){ - TNode nonbasic = *nbi; - - Rational a_ij = row_i->lookup(nonbasic); - if(a_ij < d_constants.d_ZERO && d_partialModel.strictlyBelowUpperBound(nonbasic)){ - return nonbasic; - }else if(a_ij > d_constants.d_ZERO && d_partialModel.strictlyAboveLowerBound(nonbasic)){ - return nonbasic; + const Rational& a_ij = row_i->lookup(nonbasic); + int cmp = a_ij.cmp(d_constants.d_ZERO); + if(above){ // beta(x_i) > u_i + if( cmp < 0 && d_partialModel.strictlyBelowUpperBound(nonbasic)){ + return nonbasic; + }else if( cmp > 0 && d_partialModel.strictlyAboveLowerBound(nonbasic)){ + return nonbasic; + } + }else{ //beta(x_i) < l_i + if(cmp > 0 && d_partialModel.strictlyBelowUpperBound(nonbasic)){ + return nonbasic; + }else if(cmp < 0 && d_partialModel.strictlyAboveLowerBound(nonbasic)){ + return nonbasic; + } } } - return TNode::null(); } @@ -522,9 +487,7 @@ Node TheoryArith::generateConflictAbove(TNode conflictVar){ nb << bound; - typedef std::set::iterator NonBasicIter; - - for(NonBasicIter nbi = row_i->begin(); nbi != row_i->end(); ++nbi){ + for(Row::iterator nbi = row_i->begin(); nbi != row_i->end(); ++nbi){ TNode nonbasic = *nbi; const Rational& a_ij = row_i->lookup(nonbasic); @@ -560,9 +523,7 @@ Node TheoryArith::generateConflictBelow(TNode conflictVar){ << " " << bound << endl; nb << bound; - typedef std::set::iterator NonBasicIter; - - for(NonBasicIter nbi = row_i->begin(); nbi != row_i->end(); ++nbi){ + for(Row::iterator nbi = row_i->begin(); nbi != row_i->end(); ++nbi){ TNode nonbasic = *nbi; const Rational& a_ij = row_i->lookup(nonbasic); @@ -593,11 +554,9 @@ Node TheoryArith::generateConflictBelow(TNode conflictVar){ Node TheoryArith::simulatePreprocessing(TNode n){ if(n.getKind() == NOT){ Node sub = simulatePreprocessing(n[0]); - if(sub.getKind() == NOT){ - return sub[0]; - }else{ - return NodeManager::currentNM()->mkNode(NOT,sub); - } + Assert(sub.getKind() != NOT); + return NodeManager::currentNM()->mkNode(NOT,sub); + }else{ Assert(isNormalAtom(n)); Kind k = n.getKind(); @@ -615,132 +574,83 @@ Node TheoryArith::simulatePreprocessing(TNode n){ } } +bool TheoryArith::assertionCases(TNode original, TNode assertion){ + switch(assertion.getKind()){ + case LEQ: + return AssertUpper(assertion, original); + case GEQ: + return AssertLower(assertion, original); + case EQUAL: + if(AssertUpper(assertion, original)){ + return true; + }else{ + return AssertLower(assertion, original); + } + case NOT: + { + TNode atom = assertion[0]; + switch(atom.getKind()){ + case LEQ: //(not (LEQ x c)) <=> (GT x c) + { + Node pushedin = pushInNegation(assertion); + return AssertLower(pushedin,original); + } + case GEQ: //(not (GEQ x c) <=> (LT x c) + { + Node pushedin = pushInNegation(assertion); + return AssertUpper(pushedin,original); + } + case EQUAL: + d_diseq.push_back(assertion); + return false; + default: + Unreachable(); + return false; + } + } + default: + Unreachable(); + return false; + } +} + void TheoryArith::check(Effort level){ Debug("arith") << "TheoryArith::check begun" << std::endl; - - bool conflictDuringAnAssert = false; - - while(!done() && !conflictDuringAnAssert){ - //checkTableau(); + while(!done()){ Node original = get(); Node assertion = simulatePreprocessing(original); Debug("arith_assertions") << "arith assertion(" << original << " \\-> " << assertion << ")" << std::endl; - d_preprocessed.push_back(assertion); - - switch(assertion.getKind()){ - case LEQ: - conflictDuringAnAssert = AssertUpper(assertion, original); - break; - case GEQ: - conflictDuringAnAssert = AssertLower(assertion, original); - break; - case EQUAL: - conflictDuringAnAssert = AssertUpper(assertion, original); - if(!conflictDuringAnAssert){ - conflictDuringAnAssert = AssertLower(assertion, original); - } - break; - case NOT: - { - TNode atom = assertion[0]; - switch(atom.getKind()){ - case LEQ: //(not (LEQ x c)) <=> (GT x c) - { - Node pushedin = pushInNegation(assertion); - conflictDuringAnAssert = AssertLower(pushedin,original); - break; - } - case GEQ: //(not (GEQ x c) <=> (LT x c) - { - Node pushedin = pushInNegation(assertion); - conflictDuringAnAssert = AssertUpper(pushedin,original); - break; - } - case EQUAL: - d_diseq.push_back(assertion); - break; - default: - Unhandled(); - } - break; - } - default: - Unhandled(); - } - } - if(conflictDuringAnAssert){ - while(!done()) { get(); } - - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - d_partialModel.revertAssignmentChanges(); - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - + bool conflictDuringAnAssert = assertionCases(original, assertion); + if(conflictDuringAnAssert){ + if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } + d_partialModel.revertAssignmentChanges(); + if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - //return - return; + return; + } } if(fullEffort(level)){ + if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } + Node possibleConflict = updateInconsistentVars(); if(possibleConflict != Node::null()){ - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } d_partialModel.revertAssignmentChanges(); - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - d_out->conflict(possibleConflict, true); - - Debug("arith_conflict") << "Found a conflict " - << possibleConflict << endl; + Debug("arith_conflict") <<"Found a conflict "<< possibleConflict << endl; }else{ - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - d_partialModel.commitAssignmentChanges(); - - if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } - - Debug("arith_conflict") << "No conflict found" << endl; } + if(debugTagIsOn("paranoid:check_tableau")){ checkTableau(); } } - // if(fullEffort(level)){ -// bool enqueuedCaseSplit = false; -// typedef context::CDList::const_iterator diseq_iterator; -// for(diseq_iterator i = d_diseq.begin(); i!= d_diseq.end(); ++i){ - -// Node assertion = *i; -// Debug("arith") << "splitting" << assertion << endl; -// TNode eq = assertion[0]; -// TNode x_i = eq[0]; -// TNode c_i = eq[1]; -// DeltaRational constant = c_i.getConst(); -// Debug("arith") << "broken apart" << endl; -// if(d_partialModel.getAssignment(x_i) == constant){ -// Debug("arith") << "here" << endl; -// enqueuedCaseSplit = true; -// Node lt = NodeManager::currentNM()->mkNode(LT,x_i,c_i); -// Node gt = NodeManager::currentNM()->mkNode(GT,x_i,c_i); -// Node caseSplit = NodeManager::currentNM()->mkNode(OR, eq, lt, gt); -// //d_out->enqueueCaseSplits(caseSplit); -// Debug("arith") << "finished" << caseSplit << endl; -// } -// Debug("arith") << "end of for loop" << endl; - -// } -// Debug("arith") << "finished" << endl; - -// if(enqueuedCaseSplit){ -// //d_out->caseSplit(); -// //Warning() << "Outstanding case split in theory arith" << endl; -// } -// } - Debug("arith") << "TheoryArith::check end" << std::endl; } @@ -760,7 +670,7 @@ void TheoryArith::checkTableau(){ Row* row_k = d_tableau.lookup(basic); DeltaRational sum; Debug("paranoid:check_tableau") << "starting row" << basic << endl; - for(std::set::iterator nonbasicIter = row_k->begin(); + for(Row::iterator nonbasicIter = row_k->begin(); nonbasicIter != row_k->end(); ++nonbasicIter){ TNode nonbasic = *nonbasicIter; @@ -770,7 +680,8 @@ void TheoryArith::checkTableau(){ sum = sum + (beta*coeff); } DeltaRational shouldBe = d_partialModel.getAssignment(basic); - Debug("paranoid:check_tableau") << "ending row" << sum << "," << shouldBe << endl; + Debug("paranoid:check_tableau") << "ending row" << sum + << "," << shouldBe << endl; Assert(sum == shouldBe); } diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index e54f273e9..ddd876f76 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -38,73 +38,189 @@ namespace CVC4 { namespace theory { namespace arith { + +/** + * Implementation of QF_LRA. + * Based upon: + * http://research.microsoft.com/en-us/um/people/leonardo/cav06.pdf + */ class TheoryArith : public Theory { private: + + /* TODO Everything in the chopping block needs to be killed. */ /* Chopping block begins */ std::vector d_splits; //This stores the eager splits sent out of the theory. - //TODO get rid of this. - - context::CDList d_preprocessed; - //TODO This is currently needed to save preprocessed nodes that may not - //currently have an outisde reference. Get rid of when preprocessing is occuring - //correctly. std::vector d_variables; - //TODO get rid of this. Currently forces every variable and skolem constant that + // Currently forces every variable and skolem constant that // can hit the tableau to stay alive forever! - //This needs to come before d_partialModel and d_tableau in the file + + /* Chopping block ends */ + /** + * Priority Queue of the basic variables that may be inconsistent. + * + * This is required to contain at least 1 instance of every inconsistent + * basic variable. This is only required to be a superset though so its + * contents must be checked to still be basic and inconsistent. + */ std::priority_queue d_possiblyInconsistent; - /* Chopping block ends */ + /** Stores system wide constants to avoid unnessecary reconstruction. */ ArithConstants d_constants; + + /** + * Manages information about the assignment and upper and lower bounds on + * variables. + */ ArithPartialModel d_partialModel; + + /** + * List of all of the inequalities asserted in the current context. + */ context::CDList d_diseq; - Tableau d_tableau; - ArithRewriter d_rewriter; + /** + * The tableau for all of the constraints seen thus far in the system. + */ + Tableau d_tableau; + /** + * The rewriter module for arithmetic. + */ + ArithRewriter d_rewriter; public: TheoryArith(context::Context* c, OutputChannel& out); ~TheoryArith(); + /** + * Rewrites a node to a unique normal form given in normal_form_notes.txt + */ Node rewrite(TNode n); + /** + * Does non-context dependent setup for a node connected to a theory. + */ void preRegisterTerm(TNode n); + + /** CD setup for a node. Currently does nothing. */ void registerTerm(TNode n); + void check(Effort e); void propagate(Effort e) { Unimplemented(); } void explain(TNode n, Effort e) { Unimplemented(); } private: + /** + * Assert*(n, orig) takes an bound n that is implied by orig. + * and asserts that as a new bound if it is tighter than the current bound + * and updates the value of a basic variable if needed. + * If this new bound is in conflict with the other bound, + * a conflict is created and asserted to the output channel. + * + * orig must be an atom in the SAT solver so that it can be used for + * conflict analysis. + * + * n is of the form (x =?= c) where x is a variable, + * c is a constant and =?= is either LT, LEQ, EQ, GEQ, or GT. + * + * returns true if a conflict was asserted. + */ bool AssertLower(TNode n, TNode orig); bool AssertUpper(TNode n, TNode orig); + + /** + * Updates the assignment of a nonbasic variable x_i to v. + * Also updates the assignment of basic variables accordingly. + */ void update(TNode x_i, DeltaRational& v); + + /** + * Updates the value of a basic variable x_i to v, + * and then pivots x_i with the nonbasic variable in its row x_j. + * Updates the assignment of the other basic variables accordingly. + */ void pivotAndUpdate(TNode x_i, TNode x_j, DeltaRational& v); + /** + * Tries to update the assignments of variables such that all of the + * assignments are consistent with their bounds. + * + * This is done by searching through the tableau. + * If all of the variables can be made consistent with their bounds + * Node::null() is returned. Otherwise a minimized conflict is returned. + * + * If a conflict is found, changes to the assignments need to be reverted. + * + * Tableau pivoting is performed so variables may switch from being basic to + * nonbasic and vice versa. + * + * Corresponds to the "check()" procedure in [Cav06]. + */ Node updateInconsistentVars(); - TNode selectSlackBelow(TNode x_i); - TNode selectSlackAbove(TNode x_i); + /** + * Given the basic variable x_i, + * this function finds the smallest nonbasic variable x_j in the row of x_i + * in the tableau that can "take up the slack" to let x_i satisfy its bounds. + * This returns TNode::null() if none exists. + * + * More formally one of the following conditions must be satisfied: + * - above && a_ij < 0 && assignment(x_j) < upperbound(x_j) + * - above && a_ij > 0 && assignment(x_j) > lowerbound(x_j) + * - !above && a_ij > 0 && assignment(x_j) < upperbound(x_j) + * - !above && a_ij < 0 && assignment(x_j) > lowerbound(x_j) + */ + template + TNode selectSlack(TNode x_i); + + TNode selectSlackBelow(TNode x_i) { return selectSlack(x_i); } + TNode selectSlackAbove(TNode x_i) { return selectSlack(x_i); } + + /** + * Returns the smallest basic variable whose assignment is not consistent + * with its upper and lower bounds. + */ TNode selectSmallestInconsistentVar(); + /** + * Given a non-basic variable that is know to not be updatable + * to a consistent value, construct and return a conflict. + * Follows section 4.2 in the CAV06 paper. + */ Node generateConflictAbove(TNode conflictVar); Node generateConflictBelow(TNode conflictVar); + + /** Initial (not context dependent) sets up for a variable.*/ void setupVariable(TNode x); + + /** Initial (not context dependent) sets up for a new slack variable.*/ void setupSlack(TNode left); + + /** Computes the value of a row in the tableau using the current assignment.*/ DeltaRational computeRowValueUsingAssignment(TNode x); + + /** Computes the value of a row in the tableau using the safe assignment.*/ DeltaRational computeRowValueUsingSavedAssignment(TNode x); + + /** Checks to make sure the assignment is consistent with the tableau. */ void checkTableau(); + /** Check to make sure all of the basic variables are within their bounds. */ void checkBasicVariable(TNode basic); + /** + * Handles the case splitting for check() for a new assertion. + * returns true if their is a conflict. + */ + bool assertionCases(TNode original, TNode assertion); //TODO get rid of this!