Updated the ArithStaticLearner to be user context dependent.
authorTim King <taking@cs.nyu.edu>
Wed, 24 Oct 2012 21:46:34 +0000 (21:46 +0000)
committerTim King <taking@cs.nyu.edu>
Wed, 24 Oct 2012 21:46:34 +0000 (21:46 +0000)
src/theory/arith/arith_static_learner.cpp
src/theory/arith/arith_static_learner.h
src/theory/arith/arith_utilities.h
src/theory/arith/matrix.h
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h

index a5d2b0a53ed4e8ddb09c35b83e637b148eed122b..af2f0c9bcc1ff3aa984d11946616451c0019f145 100644 (file)
@@ -35,10 +35,10 @@ namespace theory {
 namespace arith {
 
 
-ArithStaticLearner::ArithStaticLearner(SubstitutionMap& pbSubstitutions) :
-  d_miplibTrick(),
-  d_miplibTrickKeys(),
-  d_pbSubstitutions(pbSubstitutions),
+ArithStaticLearner::ArithStaticLearner(context::Context* userContext) :
+  d_miplibTrick(userContext),
+  d_minMap(userContext),
+  d_maxMap(userContext),
   d_statistics()
 {}
 
@@ -108,11 +108,7 @@ void ArithStaticLearner::staticLearning(TNode n, NodeBuilder<>& learned){
 }
 
 
-void ArithStaticLearner::clear(){
-  d_miplibTrick.clear();
-  d_miplibTrickKeys.clear();
-  // do not clear d_pbSubstitutions, as it is shared
-}
+
 
 
 void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue){
@@ -140,11 +136,9 @@ void ArithStaticLearner::process(TNode n, NodeBuilder<>& learned, const TNodeSet
       if(rewriteEqTo.getKind() == CONST_RATIONAL){
 
         TNode var = n[1][0];
-        if(d_miplibTrick.find(var)  == d_miplibTrick.end()){
-          d_miplibTrick.insert(make_pair(var, set<Node>()));
-          d_miplibTrickKeys.push_back(var);
-        }
-        d_miplibTrick[var].insert(n);
+        Node current = (d_miplibTrick.find(var)  == d_miplibTrick.end()) ?
+          mkBoolNode(false) : d_miplibTrick[var];
+        d_miplibTrick.insert(var, n.orNode(current));
         Debug("arith::miplib") << "insert " << var  << " const " << n << endl;
       }
     }
@@ -249,9 +243,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
   Debug("arith::static") << "iteConstant(" << n << ")" << endl;
 
   if (d_minMap.find(n[1]) != d_minMap.end() && d_minMap.find(n[2]) != d_minMap.end()) {
-    DeltaRational min = std::min(d_minMap[n[1]], d_minMap[n[2]]);
-    NodeToMinMaxMap::iterator minFind = d_minMap.find(n);
-    if (minFind == d_minMap.end() || minFind->second < min) {
+    const DeltaRational& first = d_minMap[n[1]];
+    const DeltaRational& second = d_minMap[n[2]];
+    DeltaRational min = std::min(first, second);
+    CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n);
+    if (minFind == d_minMap.end() || (*minFind).second < min) {
       d_minMap[n] = min;
       Node nGeqMin;
       if (min.getInfinitesimalPart() == 0) {
@@ -266,9 +262,11 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
   }
 
   if (d_maxMap.find(n[1]) != d_maxMap.end() && d_maxMap.find(n[2]) != d_maxMap.end()) {
-    DeltaRational max = std::max(d_maxMap[n[1]], d_maxMap[n[2]]);
-    NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n);
-    if (maxFind == d_maxMap.end() || maxFind->second > max) {
+    const DeltaRational& first = d_minMap[n[1]];
+    const DeltaRational& second = d_minMap[n[2]];
+    DeltaRational max = std::max(first, second);
+    CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n);
+    if (maxFind == d_maxMap.end() || (*maxFind).second > max) {
       d_maxMap[n] = max;
       Node nLeqMax;
       if (max.getInfinitesimalPart() == 0) {
@@ -283,14 +281,29 @@ void ArithStaticLearner::iteConstant(TNode n, NodeBuilder<>& learned){
   }
 }
 
+std::set<Node> listToSet(TNode l){
+  std::set<Node> ret;
+  while(l.getKind() == OR){
+    Assert(l.getNumChildren() == 2);
+    ret.insert(l[0]);
+    l = l[1];
+  }
+  return ret;
+}
 
 void ArithStaticLearner::postProcess(NodeBuilder<>& learned){
   // == 3-FINITE VALUE SET ==
-  list<TNode>::iterator keyIter = d_miplibTrickKeys.begin();
-  list<TNode>::iterator endKeys = d_miplibTrickKeys.end();
+  CDNodeToNodeListMap::const_iterator keyIter = d_miplibTrick.begin();
+  CDNodeToNodeListMap::const_iterator endKeys = d_miplibTrick.end();
   while(keyIter != endKeys) {
-    TNode var = *keyIter;
-    const set<Node>& imps = d_miplibTrick[var];
+    TNode var = (*keyIter).first;
+    Node list = (*keyIter).second;
+    const set<Node> imps = listToSet(list);
+
+    if(imps.empty()){
+      ++keyIter;
+      continue;
+    }
 
     Assert(!imps.empty());
     vector<Node> conditions;
@@ -325,20 +338,9 @@ void ArithStaticLearner::postProcess(NodeBuilder<>& learned){
     Result isTaut = PropositionalQuery::isTautology(possibleTaut);
     if(isTaut == Result(Result::VALID)){
       miplibTrick(var, values, learned);
-      d_miplibTrick.erase(var);
-      // also have to erase from keys list
-      if(keyIter == endKeys) {
-        // last element is special: exit loop
-        d_miplibTrickKeys.erase(keyIter);
-        break;
-      } else {
-        // non-last element: make sure iterator is incremented before erase
-        list<TNode>::iterator eraseIter = keyIter++;
-        d_miplibTrickKeys.erase(eraseIter);
-      }
-    } else {
-      ++keyIter;
+      d_miplibTrick.insert(var, mkBoolNode(false));
     }
+    ++keyIter;
   }
 }
 
@@ -384,8 +386,8 @@ void ArithStaticLearner::miplibTrick(TNode var, set<Rational>& values, NodeBuild
 
 void ArithStaticLearner::addBound(TNode n) {
 
-  NodeToMinMaxMap::iterator minFind = d_minMap.find(n[0]);
-  NodeToMinMaxMap::iterator maxFind = d_maxMap.find(n[0]);
+  CDNodeToMinMaxMap::const_iterator minFind = d_minMap.find(n[0]);
+  CDNodeToMinMaxMap::const_iterator maxFind = d_maxMap.find(n[0]);
 
   Rational constant = n[1].getConst<Rational>();
   DeltaRational bound = constant;
@@ -395,7 +397,7 @@ void ArithStaticLearner::addBound(TNode n) {
     bound = DeltaRational(constant, -1);
     /* fall through */
   case kind::LEQ:
-    if (maxFind == d_maxMap.end() || maxFind->second > bound) {
+    if (maxFind == d_maxMap.end() || (*maxFind).second > bound) {
       d_maxMap[n[0]] = bound;
       Debug("arith::static") << "adding bound " << n << endl;
     }
@@ -404,7 +406,7 @@ void ArithStaticLearner::addBound(TNode n) {
     bound = DeltaRational(constant, 1);
     /* fall through */
   case kind::GEQ:
-    if (minFind == d_minMap.end() || minFind->second < bound) {
+    if (minFind == d_minMap.end() || (*minFind).second < bound) {
       d_minMap[n[0]] = bound;
       Debug("arith::static") << "adding bound " << n << endl;
     }
index 622650f022a21a000ad59c057e79b14bdd848224..b047018e8eb5e29803655b199382f46b735466d3 100644 (file)
 #include "util/statistics_registry.h"
 #include "theory/arith/arith_utilities.h"
 #include "theory/substitutions.h"
+
+#include "context/context.h"
+#include "context/cdlist.h"
+#include "context/cdhashmap.h"
 #include <set>
-#include <list>
 
 namespace CVC4 {
 namespace theory {
@@ -33,44 +36,31 @@ namespace arith {
 
 class ArithStaticLearner {
 private:
-  typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet;
 
   /* Maps a variable, x, to the set of defTrue nodes of the form
    *  (=> _ (= x c))
    * where c is a constant.
    */
-  typedef __gnu_cxx::hash_map<Node, std::set<Node>, NodeHashFunction> VarToNodeSetMap;
-  VarToNodeSetMap d_miplibTrick;
-  std::list<TNode> d_miplibTrickKeys;
-
-  /**
-   * Some integer variables are eligible to be replaced by
-   * pseudoboolean variables.  This map collects those eligible
-   * substitutions.
-   *
-   * This is a reference to the substitution map in TheoryArith; as
-   * it's not "owned" by this static learner, it isn't cleared on
-   * clear().  This makes sense, as the static learner only
-   * accumulates information in the substitution map, it never uses it
-   * (i.e., it's write-only).
-   */
-  SubstitutionMap& d_pbSubstitutions;
+  //typedef __gnu_cxx::hash_map<Node, std::set<Node>, NodeHashFunction> VarToNodeSetMap;
+  typedef context::CDHashMap<Node, Node, NodeHashFunction> CDNodeToNodeListMap;
+  // The domain is an implicit list OR(x, OR(y, ..., FALSE ))
+  // or FALSE
+  CDNodeToNodeListMap d_miplibTrick;
 
   /**
    * Map from a node to it's minimum and maximum.
    */
-  typedef __gnu_cxx::hash_map<Node, DeltaRational, NodeHashFunction> NodeToMinMaxMap;
-  NodeToMinMaxMap d_minMap;
-  NodeToMinMaxMap d_maxMap;
+  //typedef __gnu_cxx::hash_map<Node, DeltaRational, NodeHashFunction> NodeToMinMaxMap;
+  typedef context::CDHashMap<Node, DeltaRational, NodeHashFunction> CDNodeToMinMaxMap;
+  CDNodeToMinMaxMap d_minMap;
+  CDNodeToMinMaxMap d_maxMap;
 
 public:
-  ArithStaticLearner(SubstitutionMap& pbSubstitutions);
+  ArithStaticLearner(context::Context* userContext);
   void staticLearning(TNode n, NodeBuilder<>& learned);
 
   void addBound(TNode n);
 
-  void clear();
-
 private:
   void process(TNode n, NodeBuilder<>& learned, const TNodeSet& defTrue);
 
index 3ae61006ded2b9590ec40598e432d0896a1ab867..c7f511a98f4c3abde7463dd32833c62d9f2dfa9d 100644 (file)
@@ -33,6 +33,7 @@ namespace arith {
 
 //Sets of Nodes
 typedef __gnu_cxx::hash_set<Node, NodeHashFunction> NodeSet;
+typedef __gnu_cxx::hash_set<TNode, TNodeHashFunction> TNodeSet;
 typedef context::CDHashSet<Node, NodeHashFunction> CDNodeSet;
 
 inline Node mkRationalNode(const Rational& q){
index e4646b7652d513f0a6771da2f803aca9101be390..ea6c389b954f87c81b5d31cbd657e95dfa490b0a 100644 (file)
@@ -20,7 +20,6 @@
 #pragma once
 
 #include "expr/node.h"
-#include "expr/attribute.h"
 
 #include "util/index.h"
 #include "util/dense_map.h"
index 6613cfaad2b803c48b73ad5125d2cd7687f0cf3c..65d9551acc7ca868b4dc231701bc7f6f5a254111 100644 (file)
@@ -60,7 +60,7 @@ TheoryArith::TheoryArith(context::Context* c, context::UserContext* u, OutputCha
   d_qflraStatus(Result::SAT_UNKNOWN),
   d_unknownsInARow(0),
   d_hasDoneWorkSinceCut(false),
-  d_learner(d_pbSubstitutions),
+  d_learner(u),
   d_setupLiteralCallback(this),
   d_assertionsThatDoNotMatchTheirLiterals(c),
   d_nextIntegerCheckVar(0),
@@ -72,7 +72,6 @@ TheoryArith::TheoryArith(context::Context* c, context::UserContext* u, OutputCha
   d_tableau(),
   d_linEq(d_partialModel, d_tableau, d_basicVarModelUpdateCallBack),
   d_diosolver(c),
-  d_pbSubstitutions(u),
   d_restartsCounter(0),
   d_tableauSizeHasBeenModified(false),
   d_tableauResetDensity(1.6),
@@ -633,36 +632,19 @@ void TheoryArith::addSharedTerm(TNode n){
 }
 
 Node TheoryArith::ppRewrite(TNode atom) {
-
-  if (!atom.getType().isBoolean()) {
-    return atom;
-  }
-
   Debug("arith::preprocess") << "arith::preprocess() : " << atom << endl;
 
-  Node a = d_pbSubstitutions.apply(atom);
-
-  if (a != atom) {
-    Debug("pb") << "arith::preprocess() : after pb substitutions: " << a << endl;
-    a = Rewriter::rewrite(a);
-    Debug("pb") << "arith::preprocess() : after pb substitutions and rewriting: "
-                << a << endl;
-    Debug("arith::preprocess") << "arith::preprocess() :"
-                               << "after pb substitutions and rewriting: "
-                               << a << endl;
-  }
 
-
-  if (a.getKind() == kind::EQUAL  && options::arithRewriteEq()) {
-    Node leq = NodeBuilder<2>(kind::LEQ) << a[0] << a[1];
-    Node geq = NodeBuilder<2>(kind::GEQ) << a[0] << a[1];
+  if (atom.getKind() == kind::EQUAL  && options::arithRewriteEq()) {
+    Node leq = NodeBuilder<2>(kind::LEQ) << atom[0] << atom[1];
+    Node geq = NodeBuilder<2>(kind::GEQ) << atom[0] << atom[1];
     Node rewritten = Rewriter::rewrite(leq.andNode(geq));
     Debug("arith::preprocess") << "arith::preprocess() : returning "
                                << rewritten << endl;
     return rewritten;
+  } else {
+    return atom;
   }
-
-  return a;
 }
 
 Theory::PPAssertStatus TheoryArith::ppAssert(TNode in, SubstitutionMap& outSubstitutions) {
@@ -2256,8 +2238,6 @@ void TheoryArith::presolve(){
   //     d_out->lemma(lem);
   //   }
   // }
-
-  d_learner.clear();
 }
 
 EqualityStatus TheoryArith::getEqualityStatus(TNode a, TNode b) {
index 3340519011ed56fa0f407397b6933c733f8aae3b..fd664e04a34590500fc0497a3cc5c496a2e33250 100644 (file)
@@ -228,18 +228,6 @@ private:
    */
   DioSolver d_diosolver;
 
-  /**
-   * Some integer variables can be replaced with pseudoboolean
-   * variables internally.  This map is built up at static learning
-   * time for top-level asserted expressions of the shape "x = 0 OR x
-   * = 1".  This substitution map is then applied in preprocess().
-   *
-   * Note that expressions of the shape "x >= 0 AND x <= 1" are
-   * already substituted for PB versions at solve() time and won't
-   * appear here.
-   */
-  SubstitutionMap d_pbSubstitutions;
-
   /** Counts the number of notifyRestart() calls to the theory. */
   uint32_t d_restartsCounter;