From: Tim King Date: Wed, 24 Oct 2012 21:46:34 +0000 (+0000) Subject: Updated the ArithStaticLearner to be user context dependent. X-Git-Tag: cvc5-1.0.0~7676 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=a6ac7fefed613c4d83e577361f98c28a8e18f3a9;p=cvc5.git Updated the ArithStaticLearner to be user context dependent. --- diff --git a/src/theory/arith/arith_static_learner.cpp b/src/theory/arith/arith_static_learner.cpp index a5d2b0a53..af2f0c9bc 100644 --- a/src/theory/arith/arith_static_learner.cpp +++ b/src/theory/arith/arith_static_learner.cpp @@ -35,10 +35,10 @@ namespace theory { namespace arith { -ArithStaticLearner::ArithStaticLearner(SubstitutionMap& pbSubstitutions) : - d_miplibTrick(), - d_miplibTrickKeys(), - d_pbSubstitutions(pbSubstitutions), +ArithStaticLearner::ArithStaticLearner(context::Context* userContext) : + d_miplibTrick(userContext), + d_minMap(userContext), + d_maxMap(userContext), d_statistics() {} @@ -108,11 +108,7 @@ void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){ } -void ArithStaticLearner::clear(){ - d_miplibTrick.clear(); - d_miplibTrickKeys.clear(); - // do not clear d_pbSubstitutions, as it is shared -} + void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){ @@ -140,11 +136,9 @@ void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet if(rewriteEqTo.getKind() == CONST_RATIONAL){ TNode var = n[1][0]; - if(d_miplibTrick.find(var) == d_miplibTrick.end()){ - d_miplibTrick.insert(make_pair(var, set())); - d_miplibTrickKeys.push_back(var); - } - d_miplibTrick[var].insert(n); + Node current = (d_miplibTrick.find(var) == d_miplibTrick.end()) ? + mkBoolNode(false) : d_miplibTrick[var]; + d_miplibTrick.insert(var, n.orNode(current)); Debug("arith::miplib") << "insert " << var << " const " << n << endl; } } @@ -249,9 +243,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ Debug("arith::static") << "iteConstant(" << n << ")" << endl; if (d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) { - DeltaRational min = std::min(d_minMap[n[1]], d_minMap[n[2]]); - NodeToMinMaxMap::iterator minFind = d_minMap.find(n); - if (minFind == d_minMap.end() || minFind->second < min) { + const DeltaRational& first = d_minMap[n[1]]; + const DeltaRational& second = d_minMap[n[2]]; + DeltaRational min = std::min(first, second); + CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n); + if (minFind == d_minMap.end() || (*minFind).second < min) { d_minMap[n] = min; Node nGeqMin; if (min.getInfinitesimalPart() == 0) { @@ -266,9 +262,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ } if (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end()) { - DeltaRational max = std::max(d_maxMap[n[1]], d_maxMap[n[2]]); - NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n); - if (maxFind == d_maxMap.end() || maxFind->second > max) { + const DeltaRational& first = d_minMap[n[1]]; + const DeltaRational& second = d_minMap[n[2]]; + DeltaRational max = std::max(first, second); + CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n); + if (maxFind == d_maxMap.end() || (*maxFind).second > max) { d_maxMap[n] = max; Node nLeqMax; if (max.getInfinitesimalPart() == 0) { @@ -283,14 +281,29 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){ } } +std::set listToSet(TNode l){ + std::set ret; + while(l.getKind() == OR){ + Assert(l.getNumChildren() == 2); + ret.insert(l[0]); + l = l[1]; + } + return ret; +} void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ // == 3-FINITE VALUE SET == - list::iterator keyIter = d_miplibTrickKeys.begin(); - list::iterator endKeys = d_miplibTrickKeys.end(); + CDNodeToNodeListMap::const_iterator keyIter = d_miplibTrick.begin(); + CDNodeToNodeListMap::const_iterator endKeys = d_miplibTrick.end(); while(keyIter != endKeys) { - TNode var = *keyIter; - const set& imps = d_miplibTrick[var]; + TNode var = (*keyIter).first; + Node list = (*keyIter).second; + const set imps = listToSet(list); + + if(imps.empty()){ + ++keyIter; + continue; + } Assert(!imps.empty()); vector conditions; @@ -325,20 +338,9 @@ void ArithStaticLearner::postProcess(NodeBuilder<>& learned){ Result isTaut = PropositionalQuery::isTautology(possibleTaut); if(isTaut == Result(Result::VALID)){ miplibTrick(var, values, learned); - d_miplibTrick.erase(var); - // also have to erase from keys list - if(keyIter == endKeys) { - // last element is special: exit loop - d_miplibTrickKeys.erase(keyIter); - break; - } else { - // non-last element: make sure iterator is incremented before erase - list::iterator eraseIter = keyIter++; - d_miplibTrickKeys.erase(eraseIter); - } - } else { - ++keyIter; + d_miplibTrick.insert(var, mkBoolNode(false)); } + ++keyIter; } } @@ -384,8 +386,8 @@ void ArithStaticLearner::miplibTrick(TNode var, set& values, NodeBuild void ArithStaticLearner::addBound(TNode n) { - NodeToMinMaxMap::iterator minFind = d_minMap.find(n[0]); - NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n[0]); + CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n[0]); + CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n[0]); Rational constant = n[1].getConst(); DeltaRational bound = constant; @@ -395,7 +397,7 @@ void ArithStaticLearner::addBound(TNode n) { bound = DeltaRational(constant, -1); /* fall through */ case kind::LEQ: - if (maxFind == d_maxMap.end() || maxFind->second > bound) { + if (maxFind == d_maxMap.end() || (*maxFind).second > bound) { d_maxMap[n[0]] = bound; Debug("arith::static") << "adding bound " << n << endl; } @@ -404,7 +406,7 @@ void ArithStaticLearner::addBound(TNode n) { bound = DeltaRational(constant, 1); /* fall through */ case kind::GEQ: - if (minFind == d_minMap.end() || minFind->second < bound) { + if (minFind == d_minMap.end() || (*minFind).second < bound) { d_minMap[n[0]] = bound; Debug("arith::static") << "adding bound " << n << endl; } diff --git a/src/theory/arith/arith_static_learner.h b/src/theory/arith/arith_static_learner.h index 622650f02..b047018e8 100644 --- a/src/theory/arith/arith_static_learner.h +++ b/src/theory/arith/arith_static_learner.h @@ -24,8 +24,11 @@ #include "util/statistics_registry.h" #include "theory/arith/arith_utilities.h" #include "theory/substitutions.h" + +#include "context/context.h" +#include "context/cdlist.h" +#include "context/cdhashmap.h" #include -#include namespace CVC4 { namespace theory { @@ -33,44 +36,31 @@ namespace arith { class ArithStaticLearner { private: - typedef __gnu_cxx::hash_set TNodeSet; /* Maps a variable, x, to the set of defTrue nodes of the form * (=> _ (= x c)) * where c is a constant. */ - typedef __gnu_cxx::hash_map, NodeHashFunction> VarToNodeSetMap; - VarToNodeSetMap d_miplibTrick; - std::list d_miplibTrickKeys; - - /** - * Some integer variables are eligible to be replaced by - * pseudoboolean variables. This map collects those eligible - * substitutions. - * - * This is a reference to the substitution map in TheoryArith; as - * it's not "owned" by this static learner, it isn't cleared on - * clear(). This makes sense, as the static learner only - * accumulates information in the substitution map, it never uses it - * (i.e., it's write-only). - */ - SubstitutionMap& d_pbSubstitutions; + //typedef __gnu_cxx::hash_map, NodeHashFunction> VarToNodeSetMap; + typedef context::CDHashMap CDNodeToNodeListMap; + // The domain is an implicit list OR(x, OR(y, ..., FALSE )) + // or FALSE + CDNodeToNodeListMap d_miplibTrick; /** * Map from a node to it's minimum and maximum. */ - typedef __gnu_cxx::hash_map NodeToMinMaxMap; - NodeToMinMaxMap d_minMap; - NodeToMinMaxMap d_maxMap; + //typedef __gnu_cxx::hash_map NodeToMinMaxMap; + typedef context::CDHashMap CDNodeToMinMaxMap; + CDNodeToMinMaxMap d_minMap; + CDNodeToMinMaxMap d_maxMap; public: - ArithStaticLearner(SubstitutionMap& pbSubstitutions); + ArithStaticLearner(context::Context* userContext); void staticLearning(TNode n, NodeBuilder<>& learned); void addBound(TNode n); - void clear(); - private: void process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue); diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h index 3ae61006d..c7f511a98 100644 --- a/src/theory/arith/arith_utilities.h +++ b/src/theory/arith/arith_utilities.h @@ -33,6 +33,7 @@ namespace arith { //Sets of Nodes typedef __gnu_cxx::hash_set NodeSet; +typedef __gnu_cxx::hash_set TNodeSet; typedef context::CDHashSet CDNodeSet; inline Node mkRationalNode(const Rational& q){ diff --git a/src/theory/arith/matrix.h b/src/theory/arith/matrix.h index e4646b765..ea6c389b9 100644 --- a/src/theory/arith/matrix.h +++ b/src/theory/arith/matrix.h @@ -20,7 +20,6 @@ #pragma once #include "expr/node.h" -#include "expr/attribute.h" #include "util/index.h" #include "util/dense_map.h" diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index 6613cfaad..65d9551ac 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -60,7 +60,7 @@ TheoryArith::TheoryArith(context::Context* c, context::UserContext* u, OutputCha d_qflraStatus(Result::SAT_UNKNOWN), d_unknownsInARow(0), d_hasDoneWorkSinceCut(false), - d_learner(d_pbSubstitutions), + d_learner(u), d_setupLiteralCallback(this), d_assertionsThatDoNotMatchTheirLiterals(c), d_nextIntegerCheckVar(0), @@ -72,7 +72,6 @@ TheoryArith::TheoryArith(context::Context* c, context::UserContext* u, OutputCha d_tableau(), d_linEq(d_partialModel, d_tableau, d_basicVarModelUpdateCallBack), d_diosolver(c), - d_pbSubstitutions(u), d_restartsCounter(0), d_tableauSizeHasBeenModified(false), d_tableauResetDensity(1.6), @@ -633,36 +632,19 @@ void TheoryArith::addSharedTerm(TNode n){ } Node TheoryArith::ppRewrite(TNode atom) { - - if (!atom.getType().isBoolean()) { - return atom; - } - Debug("arith::preprocess") << "arith::preprocess() : " << atom << endl; - Node a = d_pbSubstitutions.apply(atom); - - if (a != atom) { - Debug("pb") << "arith::preprocess() : after pb substitutions: " << a << endl; - a = Rewriter::rewrite(a); - Debug("pb") << "arith::preprocess() : after pb substitutions and rewriting: " - << a << endl; - Debug("arith::preprocess") << "arith::preprocess() :" - << "after pb substitutions and rewriting: " - << a << endl; - } - - if (a.getKind() == kind::EQUAL && options::arithRewriteEq()) { - Node leq = NodeBuilder<2>(kind::LEQ) << a[0] << a[1]; - Node geq = NodeBuilder<2>(kind::GEQ) << a[0] << a[1]; + if (atom.getKind() == kind::EQUAL && options::arithRewriteEq()) { + Node leq = NodeBuilder<2>(kind::LEQ) << atom[0] << atom[1]; + Node geq = NodeBuilder<2>(kind::GEQ) << atom[0] << atom[1]; Node rewritten = Rewriter::rewrite(leq.andNode(geq)); Debug("arith::preprocess") << "arith::preprocess() : returning " << rewritten << endl; return rewritten; + } else { + return atom; } - - return a; } Theory::PPAssertStatus TheoryArith::ppAssert(TNode in, SubstitutionMap& outSubstitutions) { @@ -2256,8 +2238,6 @@ void TheoryArith::presolve(){ // d_out->lemma(lem); // } // } - - d_learner.clear(); } EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) { diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 334051901..fd664e04a 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -228,18 +228,6 @@ private: */ DioSolver d_diosolver; - /** - * Some integer variables can be replaced with pseudoboolean - * variables internally. This map is built up at static learning - * time for top-level asserted expressions of the shape "x = 0 OR x - * = 1". This substitution map is then applied in preprocess(). - * - * Note that expressions of the shape "x >= 0 AND x <= 1" are - * already substituted for PB versions at solve() time and won't - * appear here. - */ - SubstitutionMap d_pbSubstitutions; - /** Counts the number of notifyRestart() calls to the theory. */ uint32_t d_restartsCounter;