more slicer changes for incremental
authorlianah <lianahady@gmail.com>
Wed, 6 Mar 2013 21:35:38 +0000 (16:35 -0500)
committerlianah <lianahady@gmail.com>
Wed, 6 Mar 2013 21:35:38 +0000 (16:35 -0500)
src/expr/type_checker_template.cpp
src/theory/bv/bv_subtheory_core.cpp
src/theory/bv/bv_subtheory_core.h
src/theory/bv/slicer.cpp
src/theory/bv/slicer.h
src/theory/bv/theory_bv.cpp

index 16f9ba9171fb3817418d623dc3ee569277b0bf56..4d9cbc60d863ef8cf8458ea01c097336b4c2e888 100644 (file)
@@ -39,6 +39,9 @@ TypeNode TypeChecker::computeType(NodeManager* nodeManager, TNode n, bool check)
   case kind::BUILTIN:
     typeNode = nodeManager->builtinOperatorType();
     break;
+  case kind::BITVECTOR_EXTRACT_OP :
+    typeNode = nodeManager->builtinOperatorType();
+    break; 
 
 ${typerules}
 
index 3f2ede9e2998baad66a20f2293d55e7ab82e7b06..91cf29ee9997c0d1efc054b3a7a7cbc4ed6c1ddb 100644 (file)
@@ -100,21 +100,14 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
 }
 
 Node CoreSolver::getBaseDecomposition(TNode a) {
-  // if (d_normalFormCache.find(a) != d_normalFormCache.end()) {
-  //   return d_normalFormCache[a]; 
-  // }
-
-  // otherwise we must compute the normal form
   std::vector<Node> a_decomp;
   d_slicer->getBaseDecomposition(a, a_decomp);
   Node new_a = utils::mkConcat(a_decomp);
-  //  d_normalFormCache[a] = new_a;
   return new_a; 
 }
 
 bool CoreSolver::decomposeFact(TNode fact) {
   Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl;  
-  // FIXME: are this the right things to assert? 
   // assert decompositions since the equality engine does not know the semantics of
   // concat:
   //   a == a_1 concat ... concat a_k
@@ -123,6 +116,12 @@ bool CoreSolver::decomposeFact(TNode fact) {
 
   TNode a = eq[0];
   TNode b = eq[1];
+  // we need to get the old decomposition to keep track of the cuts we added
+  Base a_old_base = d_slicer->getTopLevelBase(a);
+  Base b_old_base = d_slicer->getTopLevelBase(b);
+
+  d_slicer->processEquality(eq); 
+  
   Node new_a = getBaseDecomposition(a);
   Node new_b = getBaseDecomposition(b); 
   
@@ -133,7 +132,15 @@ bool CoreSolver::decomposeFact(TNode fact) {
   Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a);
   Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b);
 
-  bool ok = true;
+  Base a_new_base = d_slicer->getTopLevelBase(a);
+  Base b_new_base = d_slicer->getTopLevelBase(b);
+
+  bool ok = true; 
+  ok = addNewSplits(a, a_old_base, a_new_base);
+  if (!ok) return false; 
+  ok = addNewSplits(b, b_old_base, b_new_base);
+  if (!ok) return false; 
+  
   ok = assertFact(a_eq_new_a, utils::mkTrue());
   if (!ok) return false; 
   ok = assertFact(b_eq_new_b, utils::mkTrue());
@@ -158,6 +165,56 @@ bool CoreSolver::decomposeFact(TNode fact) {
   return true; 
 }
 
+bool CoreSolver::addNewSplits(TNode n, Base& old_base, Base& new_base) {
+  if (n.getKind() == kind::BITVECTOR_EXTRACT) {
+    n = n[0]; 
+  }
+  Assert (old_base.getBitwidth() == new_base.getBitwidth() &&
+          utils::getSize(n) == old_base.getBitwidth()); 
+
+  Index high, low = 0;
+  std::vector<std::pair<Index, Index> > toSlice;
+  bool hasNewCut = false; 
+  // collect the intervals that need to be sliced
+  for (unsigned i = 0; i <= old_base.getBitwidth(); ++i) {
+    Assert (! old_base.isCutPoint(i) || new_base.isCutPoint(i));
+    if (new_base.isCutPoint(i) && !old_base.isCutPoint(i)) {
+      hasNewCut = true; 
+    }
+    if (new_base.isCutPoint(i) && old_base.isCutPoint(i)) {
+      high = i;
+      if (hasNewCut) {
+        toSlice.push_back(std::pair<Index, Index>(high, low));
+      }
+      low = i;
+      hasNewCut = false; 
+    }
+  }
+  // for each interval, assert the proper equality
+  for (unsigned i = 0; i < toSlice.size(); ++i) {
+    int high = toSlice[i].first;
+    int low = toSlice[i].second;
+    int prev = high;
+    std::vector<Node> extracts; 
+    for (int k = high -1; k >= low; --k) {
+      if (new_base.isCutPoint(k) && (!old_base.isCutPoint(k) || k == low)) {
+        // add a new extract
+        Node ex = utils::mkExtract(n, prev - 1, k);
+        prev = k;
+        extracts.push_back(ex); 
+      }
+    }
+    Node concat = utils::mkConcat(extracts);
+    Node current = utils::mkExtract(n, high - 1, low);
+    Node eq = utils::mkNode(kind::EQUAL, concat, current);
+    bool ok = assertFact(eq, utils::mkTrue());
+    if (!ok)
+      return false; 
+  }
+  return true; 
+}
+
+
 bool CoreSolver::addAssertions(const std::vector<TNode>& assertions, Theory::Effort e) {
   Trace("bitvector::core") << "CoreSolver::addAssertions \n";
   Assert (!d_bv->inConflict());
@@ -168,14 +225,13 @@ bool CoreSolver::addAssertions(const std::vector<TNode>& assertions, Theory::Eff
     TNode fact = assertions[i];
     
     // update whether we are in the core fragment
-    // FIXME: move isCoreTerm into CoreSolver
     if (d_isCoreTheory && !d_slicer->isCoreTerm(fact)) {
       d_isCoreTheory = false; 
     }
     
     // only reason about equalities
-    // FIXME: should we slice when we have the terms in inequalities?
     if (fact.getKind() == kind::EQUAL || (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL)) {
+      TNode eq = fact.getKind() == kind::EQUAL ? fact : fact[0];
       ok = decomposeFact(fact);
     } else {
       ok = assertFact(fact, fact); 
index 38676bfa6b8d2726da8b8a204edcf7f0f00dc501..1adf813ff0d76ac8d7959cfc95cc1e25f792de70 100644 (file)
@@ -25,7 +25,7 @@ namespace theory {
 namespace bv {
 
 class Slicer; 
-
+class Base; 
 /**
  * Bitvector equality solver
  */
@@ -75,7 +75,8 @@ class CoreSolver : public SubtheorySolver {
 
   bool assertFact(TNode fact, TNode reason);  
   bool decomposeFact(TNode fact);
-  Node getBaseDecomposition(TNode a); 
+  Node getBaseDecomposition(TNode a);
+  bool addNewSplits(TNode n, Base& old_base, Base& new_base); 
 public:
   bool isCoreTheory() {return d_isCoreTheory; }
   CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer);
index 3a6ca8a2f18933c7e4adc607775fc72e628364c3..2334ed2b0365ef378c42cda0265419a2d2b6ce29 100644 (file)
@@ -167,7 +167,7 @@ TermId UnionFind::addTerm(Index bitwidth) {
   ++(d_statistics.d_numNodes);
   
   TermId id = d_nodes.size() - 1; 
-  d_representatives.insert(id);
+  //  d_representatives.insert(id);
   ++(d_statistics.d_numRepresentatives); 
 
   Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl;
@@ -217,7 +217,7 @@ void UnionFind::merge(TermId t1, TermId t2) {
   Assert (! hasChildren(t1) && ! hasChildren(t2));
   setRepr(t1, t2); 
   recordOperation(UnionFind::MERGE, t1); 
-  d_representatives.erase(t1);
+  //d_representatives.erase(t1);
   d_statistics.d_numRepresentatives += -1; 
 }
 
@@ -254,7 +254,6 @@ void UnionFind::split(TermId id, Index i) {
     TermId top_id = addTerm(getBitwidth(id) - i);
     setChildren(id, top_id, bottom_id);
     recordOperation(UnionFind::SPLIT, id); 
-
   } else {
     Index cut = getCutPoint(id); 
     if (i < cut )
@@ -418,8 +417,10 @@ void UnionFind::ensureSlicing(const ExtractTerm& term) {
 }
 
 void UnionFind::backtrack() {
-  for (int i = d_undoStack.size() -1; i >= d_undoStackIndex; ++i) {
+  int size = d_undoStack.size(); 
+  for (int i = size; i > d_undoStackIndex.get(); --i) {
     Operation op = d_undoStack.back(); 
+    Assert (!d_undoStack.empty()); 
     d_undoStack.pop_back();
     if (op.op == UnionFind::MERGE) {
       undoMerge(op.id); 
@@ -431,23 +432,35 @@ void UnionFind::backtrack() {
 }
 
 void UnionFind::undoMerge(TermId id) {
-  Node& node = getNode(id);
-  Assert (getRepr(id) != id);
-  setRepr(id, id); 
+  TermId repr = getRepr(id);
+  Assert (repr != id);
+  setRepr(id, UndefinedId); 
 }
 
 void UnionFind::undoSplit(TermId id) {
-  Node& node = getNode(id);
-  Assert (hasChildren(node));
-  setChildren(id, UndefindId, UndefinedId); 
+  Assert (hasChildren(id));
+  setChildren(id, UndefinedId, UndefinedId); 
 }
 
 void UnionFind::recordOperation(OperationKind op, TermId term) {
-  ++d_undoStackIndex;
+  d_undoStackIndex.set(d_undoStackIndex.get() + 1);
   d_undoStack.push_back(Operation(op, term));
   Assert (d_undoStack.size() == d_undoStackIndex); 
 }
 
+void UnionFind::getBase(TermId id, Base& base, Index offset) {
+  id = find(id); 
+  if (!hasChildren(id))
+    return;
+  TermId id1 = find(getChild(id, 1));
+  TermId id0 = find(getChild(id, 0));
+  Index cut = getCutPoint(id);
+  base.sliceAt(cut + offset);
+  getBase(id1, base, cut + offset);
+  getBase(id0, base, offset); 
+}
+
+
 /**
  * Slicer
  * 
@@ -517,7 +530,6 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
     current_low += current_size;
     decomp.push_back(current); 
   }
-  // cache the result
 
   Debug("bv-slicer") << "as [";
   for (unsigned i = 0; i < decomp.size(); ++i) {
@@ -595,7 +607,28 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
     equalities.push_back(node);
   }
   d_numAddedEqualities += equalities.size() - 1; 
-} 
+}
+
+/** 
+ * Returns the base decomposition of the current term. 
+ * 
+ * @param id 
+ * 
+ * @return 
+ */
+Base Slicer::getTopLevelBase(TNode node) {
+  if (node.getKind() == kind::BITVECTOR_EXTRACT) {
+    node = node[0];
+  }
+  // if we haven't seen this node before it must not be sliced yet
+  if (d_nodeToId.find(node) == d_nodeToId.end()) {
+    return Base(utils::getSize(node)); 
+  }
+  TermId id = d_nodeToId[node];
+  Base base(d_unionFind.getBitwidth(id));
+  d_unionFind.getBase(id, base, 0);
+  return base; 
+}
 
 std::string UnionFind::debugPrint(TermId id) {
   ostringstream os; 
index 731141262ec235ad2df3175446c0fd63398790d3..0508c67c17b2d539f63bbe067c25e929ae543fb7 100644 (file)
@@ -56,7 +56,7 @@ class Base {
   Index d_size;
   std::vector<uint32_t> d_repr;
 public:
-  Base(Index size);
+  Base (Index size);
   void sliceAt(Index index); 
   void sliceWith(const Base& other);
   bool isCutPoint(Index index) const;
@@ -84,7 +84,7 @@ public:
  * UnionFind
  * 
  */
-typedef context::CDHashSet<uint32_t> CDTermSet;
+typedef context::CDHashSet<uint32_t, std::hash<uint32_t> > CDTermSet;
 typedef std::vector<TermId> Decomposition; 
 
 struct ExtractTerm {
@@ -151,7 +151,7 @@ class UnionFind : public context::ContextNotifyObj {
       d_repr = id;
     }
     void setChildren(TermId ch1, TermId ch0) {
-      Assert (d_repr == UndefinedId && !hasChildren());
+      // Assert (d_repr == UndefinedId && !hasChildren());
       d_ch1 = ch1;
       d_ch0 = ch0; 
     }
@@ -161,7 +161,7 @@ class UnionFind : public context::ContextNotifyObj {
   /// map from TermId to the nodes that represent them 
   std::vector<Node> d_nodes;
   /// a term is in this set if it is its own representative
-  CDTermSet d_representatives;
+  //CDTermSet d_representatives;
   
   void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
   void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
@@ -187,7 +187,8 @@ class UnionFind : public context::ContextNotifyObj {
     d_nodes[id].setRepr(new_repr); 
   }
   void setChildren(TermId id, TermId ch1, TermId ch0) {
-    Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0));
+    Assert ((ch1 == UndefinedId && ch0 == UndefinedId) ||
+            (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)));
     d_nodes[id].setChildren(ch1, ch0); 
   }
 
@@ -212,7 +213,7 @@ class UnionFind : public context::ContextNotifyObj {
   void undoMerge(TermId id);
   void undoSplit(TermId id); 
   void recordOperation(OperationKind op, TermId term);
-  
+  virtual ~UnionFind() throw(AssertionException) {}
   class Statistics {
   public:
     IntStat d_numNodes; 
@@ -228,8 +229,9 @@ class UnionFind : public context::ContextNotifyObj {
   Statistics d_statistics;
 public:
   UnionFind(context::Context* ctx)
-    : d_nodes(),
-      d_representatives(ctx),
+    : ContextNotifyObj(ctx), 
+      d_nodes(),
+      //      d_representatives(ctx),
       d_undoStack(),
       d_undoStackIndex(ctx),
       d_statistics()
@@ -248,6 +250,7 @@ public:
     Assert (id < d_nodes.size());
     return d_nodes[id].getBitwidth(); 
   }
+  void getBase(TermId id, Base& base, Index offset); 
   std::string debugPrint(TermId id);
 
   void contextNotifyPop() {
@@ -274,7 +277,7 @@ public:
   void getBaseDecomposition(TNode node, std::vector<Node>& decomp);
   void processEquality(TNode eq);
   bool isCoreTerm (TNode node);
-  
+  Base getTopLevelBase(TNode node); 
   static void splitEqualities(TNode node, std::vector<Node>& equalities);
   static unsigned d_numAddedEqualities; 
 }; 
index bb4b480d64ad6552de315dcbc06a1c821ec4361a..6248782bdf9dcb965b83e32a8e3b65d336db24fd 100644 (file)
@@ -40,7 +40,7 @@ TheoryBV::TheoryBV(context::Context* c, context::UserContext* u, OutputChannel&
     d_context(c),
     d_alreadyPropagatedSet(c),
     d_sharedTermsSet(c),
-    d_slicer(),
+    d_slicer(c),
     d_bitblastAssertionsQueue(c),
     d_bitblastSolver(c, this),
     d_coreSolver(c, this, &d_slicer),
@@ -74,6 +74,8 @@ TheoryBV::Statistics::~Statistics() {
   StatisticsRegistry::unregisterStat(&d_solveTimer);
 }
 
+
+
 void TheoryBV::preRegisterTerm(TNode node) {
   Debug("bitvector-preregister") << "TheoryBV::preRegister(" << node << ")" << std::endl;
 
@@ -81,10 +83,6 @@ void TheoryBV::preRegisterTerm(TNode node) {
     // don't use the equality engine in the eager bit-blasting
     return;
   }
-
-  if (node.getKind() == kind::EQUAL) {
-    d_slicer.processEquality(node); 
-  }
   
   d_bitblastSolver.preRegister(node);
   d_coreSolver.preRegister(node);