From 0ced5194e3072c8e466e0ed597ac71ae5acf7ea2 Mon Sep 17 00:00:00 2001 From: Tim King Date: Sun, 13 Feb 2011 21:19:20 +0000 Subject: [PATCH] 3 heuristics were added to arithmetic. A heuristic for detecting an encoding of min added to static learning in LRA. A heuristic added for when the true branch and false branch are both constants (also in static learning). A heuristic for checking whether any variables begin in conflict before pivoting. --- src/theory/arith/arith_utilities.h | 21 ++-- src/theory/arith/normal_form.cpp | 2 +- src/theory/arith/simplex.cpp | 75 ++++++++++++- src/theory/arith/simplex.h | 5 + src/theory/arith/theory_arith.cpp | 166 ++++++++++++++++++++++++++++- src/theory/arith/theory_arith.h | 3 + 6 files changed, 262 insertions(+), 10 deletions(-) diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h index 25aff4e75..452d54fae 100644 --- a/src/theory/arith/arith_utilities.h +++ b/src/theory/arith/arith_utilities.h @@ -120,7 +120,7 @@ inline bool isRelationOperator(Kind k){ } /** is k \in {LT, LEQ, EQ, GEQ, GT} */ -inline Kind negateRelationKind(Kind k){ +inline Kind reverseRelationKind(Kind k){ using namespace kind; switch(k){ @@ -134,6 +134,7 @@ inline Kind negateRelationKind(Kind k){ Unreachable(); } } + inline bool evaluateConstantPredicate(Kind k, const Rational& left, const Rational& right){ using namespace kind; @@ -211,20 +212,24 @@ inline int deltaCoeff(Kind k){ } /** - * Given a rewritten predicate to TheoryArith return a single kind to + * Given a literal to TheoryArith return a single kind to * to indicate its underlying structure. * The function returns the following in each case: - * - (K left right) -> K where is a wildcard for EQUAL, LEQ, or GEQ: + * - (K left right) -> K where is a wildcard for EQUAL, LT, GT, LEQ, or GEQ: * - (NOT (EQUAL left right)) -> DISTINCT * - (NOT (LEQ left right)) -> GT * - (NOT (GEQ left right)) -> LT + * - (NOT (LT left right)) -> GEQ + * - (NOT (GT left right)) -> LEQ * If none of these match, it returns UNDEFINED_KIND. */ inline Kind simplifiedKind(TNode assertion){ switch(assertion.getKind()){ + case kind::LT: + case kind::GT: case kind::LEQ: - case kind::GEQ: - case kind::EQUAL: + case kind::GEQ: + case kind::EQUAL: return assertion.getKind(); case kind::NOT: { @@ -232,8 +237,12 @@ inline int deltaCoeff(Kind k){ switch(atom.getKind()){ case kind::LEQ: //(not (LEQ x c)) <=> (GT x c) return kind::GT; - case kind::GEQ: //(not (GEQ x c) <=> (LT x c) + case kind::GEQ: //(not (GEQ x c)) <=> (LT x c) return kind::LT; + case kind::LT: //(not (LT x c)) <=> (GEQ x c) + return kind::GEQ; + case kind::GT: //(not (GT x c) <=> (LEQ x c) + return kind::LEQ; case kind::EQUAL: return kind::DISTINCT; default: diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp index 2a8c1077e..aea1a43d8 100644 --- a/src/theory/arith/normal_form.cpp +++ b/src/theory/arith/normal_form.cpp @@ -277,7 +277,7 @@ Comparison Comparison::addConstant(const Constant& constant) const { Comparison Comparison::multiplyConstant(const Constant& constant) const { Assert(!isBoolean()); - Kind newOper = (constant.getValue() < 0) ? negateRelationKind(oper) : oper; + Kind newOper = (constant.getValue() < 0) ? reverseRelationKind(oper) : oper; return mkComparison(newOper, left*Monomial(constant), right*constant); } diff --git a/src/theory/arith/simplex.cpp b/src/theory/arith/simplex.cpp index 153ccad98..2e9fb7352 100644 --- a/src/theory/arith/simplex.cpp +++ b/src/theory/arith/simplex.cpp @@ -18,7 +18,10 @@ SimplexDecisionProcedure::Statistics::Statistics(): d_statAssertLowerConflicts("theory::arith::AssertLowerConflicts", 0), d_statUpdateConflicts("theory::arith::UpdateConflicts", 0), d_statEjections("theory::arith::Ejections", 0), - d_statUnEjections("theory::arith::UnEjections", 0) + d_statUnEjections("theory::arith::UnEjections", 0), + d_statEarlyConflicts("theory::arith::EarlyConflicts", 0), + d_statEarlyConflictImprovements("theory::arith::EarlyConflictImprovements", 0), + d_selectInitialConflictTime("theory::arith::selectInitialConflictTime") { StatisticsRegistry::registerStat(&d_statPivots); StatisticsRegistry::registerStat(&d_statUpdates); @@ -27,6 +30,9 @@ SimplexDecisionProcedure::Statistics::Statistics(): StatisticsRegistry::registerStat(&d_statUpdateConflicts); StatisticsRegistry::registerStat(&d_statEjections); StatisticsRegistry::registerStat(&d_statUnEjections); + StatisticsRegistry::registerStat(&d_statEarlyConflicts); + StatisticsRegistry::registerStat(&d_statEarlyConflictImprovements); + StatisticsRegistry::registerStat(&d_selectInitialConflictTime); } SimplexDecisionProcedure::Statistics::~Statistics(){ @@ -37,6 +43,9 @@ SimplexDecisionProcedure::Statistics::~Statistics(){ StatisticsRegistry::unregisterStat(&d_statUpdateConflicts); StatisticsRegistry::unregisterStat(&d_statEjections); StatisticsRegistry::unregisterStat(&d_statUnEjections); + StatisticsRegistry::unregisterStat(&d_statEarlyConflicts); + StatisticsRegistry::unregisterStat(&d_statEarlyConflictImprovements); + StatisticsRegistry::unregisterStat(&d_selectInitialConflictTime); } @@ -370,8 +379,70 @@ ArithVar SimplexDecisionProcedure::selectSlack(ArithVar x_i){ return slack; } +Node betterConflict(TNode x, TNode y){ + if(x.isNull()) return y; + else if(y.isNull()) return x; + else if(x.getNumChildren() <= y.getNumChildren()) return x; + else return y; +} + +Node SimplexDecisionProcedure::selectInitialConflict() { + Node bestConflict = Node::null(); + + TimerStat::CodeTimer codeTimer(d_statistics.d_selectInitialConflictTime); + + vector init; + + while( !d_griggioRuleQueue.empty()){ + ArithVar var = d_griggioRuleQueue.top().first; + if(d_basicManager.isMember(var)){ + if(!d_partialModel.assignmentIsConsistent(var)){ + init.push_back( d_griggioRuleQueue.top()); + } + } + d_griggioRuleQueue.pop(); + } + + int conflictChanges = 0; + + for(vector::iterator i=init.begin(), end=init.end(); i != end; ++i){ + ArithVar x_i = (*i).first; + d_griggioRuleQueue.push(*i); + + DeltaRational beta_i = d_partialModel.getAssignment(x_i); + + if(d_partialModel.belowLowerBound(x_i, beta_i, true)){ + DeltaRational l_i = d_partialModel.getLowerBound(x_i); + ArithVar x_j = selectSlackBelow(x_i); + if(x_j == ARITHVAR_SENTINEL ){ + Node better = betterConflict(bestConflict, generateConflictBelow(x_i)); + if(better != bestConflict) ++conflictChanges; + bestConflict = better; + ++(d_statistics.d_statEarlyConflicts); + } + }else if(d_partialModel.aboveUpperBound(x_i, beta_i, true)){ + DeltaRational u_i = d_partialModel.getUpperBound(x_i); + ArithVar x_j = selectSlackAbove(x_i); + if(x_j == ARITHVAR_SENTINEL ){ + Node better = betterConflict(bestConflict, generateConflictAbove(x_i)); + if(better != bestConflict) ++conflictChanges; + bestConflict = better; + ++(d_statistics.d_statEarlyConflicts); + } + } + } + if(conflictChanges > 1) ++(d_statistics.d_statEarlyConflictImprovements); + return bestConflict; +} + Node SimplexDecisionProcedure::updateInconsistentVars(){ - Node possibleConflict = privateUpdateInconsistentVars(); + if(d_griggioRuleQueue.empty()) return Node::null(); + + Node possibleConflict = selectInitialConflict(); + if(possibleConflict.isNull()){ + possibleConflict = privateUpdateInconsistentVars(); + } + Assert(!possibleConflict.isNull() || d_griggioRuleQueue.empty()); Assert(!possibleConflict.isNull() || d_possiblyInconsistent.empty()); d_pivotStage = true; diff --git a/src/theory/arith/simplex.h b/src/theory/arith/simplex.h index 7514b6284..d8997af93 100644 --- a/src/theory/arith/simplex.h +++ b/src/theory/arith/simplex.h @@ -134,6 +134,8 @@ public: private: Node privateUpdateInconsistentVars(); + Node selectInitialConflict(); + private: /** * Given the basic variable x_i, @@ -197,6 +199,9 @@ private: IntStat d_statAssertLowerConflicts, d_statUpdateConflicts; IntStat d_statEjections, d_statUnEjections; + + IntStat d_statEarlyConflicts, d_statEarlyConflictImprovements; + TimerStat d_selectInitialConflictTime; Statistics(); ~Statistics(); }; diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index b9c983215..ff79c18e6 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -72,12 +72,14 @@ TheoryArith::Statistics::Statistics(): d_statUserVariables("theory::arith::UserVariables", 0), d_statSlackVariables("theory::arith::SlackVariables", 0), d_statDisequalitySplits("theory::arith::DisequalitySplits", 0), - d_statDisequalityConflicts("theory::arith::DisequalityConflicts", 0) + d_statDisequalityConflicts("theory::arith::DisequalityConflicts", 0), + d_staticLearningTimer("theory::arith::staticLearningTimer") { StatisticsRegistry::registerStat(&d_statUserVariables); StatisticsRegistry::registerStat(&d_statSlackVariables); StatisticsRegistry::registerStat(&d_statDisequalitySplits); StatisticsRegistry::registerStat(&d_statDisequalityConflicts); + StatisticsRegistry::registerStat(&d_staticLearningTimer); } TheoryArith::Statistics::~Statistics(){ @@ -85,8 +87,170 @@ TheoryArith::Statistics::~Statistics(){ StatisticsRegistry::unregisterStat(&d_statSlackVariables); StatisticsRegistry::unregisterStat(&d_statDisequalitySplits); StatisticsRegistry::unregisterStat(&d_statDisequalityConflicts); + StatisticsRegistry::unregisterStat(&d_staticLearningTimer); } +void TheoryArith::staticLearning(TNode n, NodeBuilder<>& learned) { + TimerStat::CodeTimer codeTimer(d_statistics.d_staticLearningTimer); + + vector workList; + workList.push_back(n); + __gnu_cxx::hash_set processed; + + while(!workList.empty()) { + n = workList.back(); + + bool unprocessedChildren = false; + for(TNode::iterator i = n.begin(), iend = n.end(); i != iend; ++i) { + if(processed.find(*i) == processed.end()) { + // unprocessed child + workList.push_back(*i); + unprocessedChildren = true; + } + } + + if(unprocessedChildren) { + continue; + } + + workList.pop_back(); + // has node n been processed in the meantime ? + if(processed.find(n) != processed.end()) { + continue; + } + processed.insert(n); + + // == MINS == + + Debug("mins") << "===================== looking at" << endl << n << endl; + if(n.getKind() == kind::ITE && n[0].getKind() != EQUAL && isRelationOperator(n[0].getKind()) ){ + TNode c = n[0]; + Kind k = simplifiedKind(c); + TNode t = n[1]; + TNode e = n[2]; + TNode cleft = (c.getKind() == NOT) ? c[0][0] : c[0]; + TNode cright = (c.getKind() == NOT) ? c[0][1] : c[1]; + + if((t == cright) && (e == cleft)){ + TNode tmp = t; + t = e; + e = tmp; + k = reverseRelationKind(k); + } + if(t == cleft && e == cright){ + // t == cleft && e == cright + Assert( t == cleft ); + Assert( e == cright ); + switch(k){ + case LT: // (ite (< x y) x y) + case LEQ: { // (ite (<= x y) x y) + Node nLeqX = NodeBuilder<2>(LEQ) << n << t; + Node nLeqY = NodeBuilder<2>(LEQ) << n << e; + Debug("arith::mins") << n << "is a min =>" << nLeqX << nLeqY << endl; + learned << nLeqX << nLeqY; + break; + } + case GT: // (ite (> x y) x y) + case GEQ: { // (ite (>= x y) x y) + Node nGeqX = NodeBuilder<2>(GEQ) << n << t; + Node nGeqY = NodeBuilder<2>(GEQ) << n << e; + Debug("arith::mins") << n << "is a max =>" << nGeqX << nGeqY << endl; + learned << nGeqX << nGeqY; + break; + } + default: Unreachable(); + } + } + } + // == 2-CONSTANTS == + + if(n.getKind() == ITE && + (n[1].getKind() == CONST_RATIONAL || n[1].getKind() == CONST_INTEGER) && + (n[2].getKind() == CONST_RATIONAL || n[2].getKind() == CONST_INTEGER)) { + Rational t = coerceToRational(n[1]); + Rational e = coerceToRational(n[2]); + TNode min = (t <= e) ? n[1] : n[2]; + TNode max = (t >= e) ? n[1] : n[2]; + + Node nGeqMin = NodeBuilder<2>(GEQ) << n << min; + Node nLeqMax = NodeBuilder<2>(LEQ) << n << max; + Debug("arith::mins") << n << " is a constant sandwich" << nGeqMin << nLeqMax << endl; + learned << nGeqMin << nLeqMax; + } + + // // binary OR of binary ANDs of EQUALities + // if(n.getKind() == kind::OR && n.getNumChildren() == 2 && + // n[0].getKind() == kind::AND && n[0].getNumChildren() == 2 && + // n[1].getKind() == kind::AND && n[1].getNumChildren() == 2 && + // (n[0][0].getKind() == kind::EQUAL || n[0][0].getKind() == kind::IFF) && + // (n[0][1].getKind() == kind::EQUAL || n[0][1].getKind() == kind::IFF) && + // (n[1][0].getKind() == kind::EQUAL || n[1][0].getKind() == kind::IFF) && + // (n[1][1].getKind() == kind::EQUAL || n[1][1].getKind() == kind::IFF)) { + // // now we have (a = b && c = d) || (e = f && g = h) + + // Debug("diamonds") << "has form of a diamond!" << endl; + + // TNode + // a = n[0][0][0], b = n[0][0][1], + // c = n[0][1][0], d = n[0][1][1], + // e = n[1][0][0], f = n[1][0][1], + // g = n[1][1][0], h = n[1][1][1]; + + // // test that one of {a, b} = one of {c, d}, and make "b" the + // // shared node (i.e. put in the form (a = b && b = d)) + // // note we don't actually care about the shared ones, so the + // // "swaps" below are one-sided, ignoring b and c + // if(a == c) { + // a = b; + // } else if(a == d) { + // a = b; + // d = c; + // } else if(b == c) { + // // nothing to do + // } else if(b == d) { + // d = c; + // } else { + // // condition not satisfied + // Debug("diamonds") << "+ A fails" << endl; + // continue; + // } + + // Debug("diamonds") << "+ A holds" << endl; + + // // same: one of {e, f} = one of {g, h}, and make "f" the + // // shared node (i.e. put in the form (e = f && f = h)) + // if(e == g) { + // e = f; + // } else if(e == h) { + // e = f; + // h = g; + // } else if(f == g) { + // // nothing to do + // } else if(f == h) { + // h = g; + // } else { + // // condition not satisfied + // Debug("diamonds") << "+ B fails" << endl; + // continue; + // } + + // Debug("diamonds") << "+ B holds" << endl; + + // // now we have (a = b && b = d) || (e = f && f = h) + // // test that {a, d} == {e, h} + // if( (a == e && d == h) || + // (a == h && d == e) ) { + // // learn: n implies a == d + // Debug("diamonds") << "+ C holds" << endl; + // Node newEquality = a.getType().isBoolean() ? a.iffNode(d) : a.eqNode(d); + // Debug("diamonds") << " ==> " << newEquality << endl; + // learned << n.impNode(newEquality); + // } else { + // Debug("diamonds") << "+ C fails" << endl; + // } + // } + } +} diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index c95ca6cc4..5d39f626c 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -125,6 +125,8 @@ public: check(FULL_EFFORT); } + void staticLearning(TNode in, NodeBuilder<>& learned); + std::string identify() const { return std::string("TheoryArith"); } private: @@ -167,6 +169,7 @@ private: IntStat d_statUserVariables, d_statSlackVariables; IntStat d_statDisequalitySplits; IntStat d_statDisequalityConflicts; + TimerStat d_staticLearningTimer; Statistics(); ~Statistics(); -- 2.30.2