starting the slicer form scratch.
authorlianah <lianahady@gmail.com>
Fri, 25 Jan 2013 22:31:45 +0000 (17:31 -0500)
committerlianah <lianahady@gmail.com>
Fri, 25 Jan 2013 22:31:45 +0000 (17:31 -0500)
src/theory/bv/slicer.cpp
src/theory/bv/slicer.h

index 8eb7d6127a571365693d2206167cf2407a8c6c8e..0014274884df76f7f0fca6a8c0c73e63ff179124 100644 (file)
@@ -17,7 +17,6 @@
  **/
 
 #include "theory/bv/slicer.h"
-#include "util/integer.h"
 #include "util/utility.h"
 #include "theory/bv/theory_bv_utils.h"
 #include "theory/rewriter.h"
@@ -27,582 +26,257 @@ using namespace CVC4::theory;
 using namespace CVC4::theory::bv;
 using namespace std; 
 
-void Base::decomposeNode(TNode node, std::vector<Node>& decomp) const {
-  Debug("bv-slicer") << "Base::decomposeNode " << node << "\n with base" << debugPrint() << endl;
-  
-  Index low = 0;
-  Index high = utils::getSize(node) - 1; 
-  if (node.getKind() == kind::BITVECTOR_EXTRACT) {
-    low = utils::getExtractLow(node);
-    high = utils::getExtractHigh(node);
-    node = node[0];
-  }
-  Index index = low; 
-  for (Index i = low; i <= high; ++i) {
-    if (isCutPoint(i)) {
-      // make sure the extract is pushed down before concat
-      Node slice = Rewriter::rewrite(utils::mkExtract(node, i, index));
-      index = i + 1;
-      decomp.push_back(slice);
-      Debug("bv-slicer") << slice <<" "; 
-    }
-  }
-  Debug("bv-slicer") << endl; 
-}
-
-/** 
- * Adds a cutPoint at Index i and splits the corresponding Splinter
+/**
+ * Base
  * 
- * @param i the index where the cut point will be introduced
- * @param sp the splinter pointer corresponding to the splinter to be sliced
- * @param low_splinter the resulting bottom part of the splinter
- * @param top_splinter the resulting top part of the splinter
  */
-void Slice::split(Index i, SplinterPointer& sp, Splinter*& low_splinter, Splinter*& top_splinter) {
-  Debug("bv-slicer") << "Slice::split " << this->debugPrint() << "\n";
-  Assert (!d_base.isCutPoint(i));
-  d_base.sliceAt(i);
-  
-  Splinter* s = NULL;
-  Slice::const_iterator it = begin();
-  bool lt, gt; 
-  for (; it != end(); ++it) {
-    lt = (it->second)->getHigh() >= i;
-    gt = (it->second)->getLow() <= i;
-    if (gt && lt) {
-      s = it->second;
-      break; 
-    }
-  }
-  
-  Assert (s != NULL);
+Base::Base(uint32_t size) 
+  : d_size(size),
+    d_repr((size-1)/32 + ((size-1) % 32 == 0? 0 : 1), 0)
+{
+  Assert (d_size > 0); 
+}
 
-  sp = s->getPointer();
-  Index low = s->getLow();
-  Index high = s->getHigh();
-  // creating the two splinter fragments 
-  low_splinter = new Splinter(i, low);
-  top_splinter = new Splinter(high, i+1);
   
-  addSplinter(low, low_splinter);
-  addSplinter(i+1, top_splinter);
-  Debug("bv-slicer") << "          to " << this->debugPrint() << "\n"
-} 
-void Slice::addSplinter(Index i, Splinter* sp) {
-  Assert (i == sp->getLow() && sp->getHigh() < d_bitwidth);
+void Base::sliceAt(Index index) {
+  Index vector_index = index / 32;
+  Assert (vector_index < d_size - 1)
+  Index int_index = index % 32;
+  uint32_t bit_mask = utils::pow2(int_index); 
+  d_repr[vector_index] = d_repr[vector_index] | bit_mask; 
+}
 
-  if (i != 0) {
-    d_base.sliceAt(i - 1);
+void Base::sliceWith(const Base& other) {
+  Assert (d_size == other.d_size);
+  for (unsigned i = 0; i < d_repr.size(); ++i) {
+    d_repr[i] = d_repr[i] | other.d_repr[i]; 
   }
-  // d_base.debugPrint(); 
-  // free the memory associated with the previous splinter
-  if (d_splinters.find(i) != d_splinters.end()) {
-    delete d_splinters[i];
-  }
-  d_splinters[i] = sp;
 }
 
-bool Slice::isConsistent() {
-  // check that base is consistent with slicings
-  // and that the slicings are continous
-  std::map<Index, Splinter*>::const_iterator it = d_splinters.begin();
-  Index prev = -1;
-  for (; it != d_splinters.end(); ++it) {
-    Index index = (*it).first;  
-    Splinter* splinter = (*it).second;
-    if (index != 0 && !d_base.isCutPoint(index-1))
-      return false; 
-    if (index != splinter->getLow())
-      return false;
-    if (prev + 1 != index)
-      return false;
-    prev = splinter->getHigh(); 
-  }
-  if (prev != d_bitwidth - 1)
-    return false; 
+bool Base::isCutPoint (Index index) const {
+  // there is an implicit cut point at the end of the bv
+  if (index == d_size - 1)
+    return true;
+    
+  Index vector_index = index / 32;
+  Assert (vector_index < d_size - 1); 
+  Index int_index = index % 32;
+  uint32_t bit_mask = utils::pow2(int_index); 
 
-  for (unsigned i = 0; i < d_bitwidth - 1; ++i) {
-    if (d_base.isCutPoint(i) && d_splinters.find(i+1) == d_splinters.end())
-      return false; 
-  }
-  return true; 
+  return (bit_mask & d_repr[vector_index]) != 0; 
 }
 
-std::string Slice::debugPrint() {
-  std::ostringstream os;
-  os << d_base.debugPrint(); 
-  os << "{ ";
-  for (Slice::const_iterator it = begin(); it != end(); ++it) {
-    Splinter* s = (*it).second;
-    os << "[" << s->getLow() << ":" << s->getHigh() <<"]"; 
-    Assert ((*it).first == s->getLow());
-    Assert (s->getLow() == 0 || d_base.isCutPoint(s->getLow() - 1)); 
-    SplinterPointer sp = s->getPointer(); 
-    if (s->getPointer() != Undefined) {
-      os << "->" << sp.debugPrint(); 
-    }
-    os << " "; 
+void Base::diffCutPoints(const Base& other, Base& res) const {
+  Assert (d_size == other.d_size && res.d_size == d_size);
+  for (unsigned i = 0; i < d_repr.size(); ++i) {
+    Assert (res.d_repr[i] == 0); 
+    res.d_repr[i] = d_repr[i] ^ other.d_repr[i]; 
   }
-  os << "}";
-  return os.str(); 
 }
 
-void SliceBlock::computeBlockBase(BlockIdSet& changedSet)  {
-  Debug("bv-slicer") << "SliceBlock::computeBlockBase for block" << d_rootId << endl;
-  Debug("bv-slicer") << this->debugPrint() << endl; 
-
-  ++(d_slicer->d_statistics.d_numBlockBaseComputations); 
-  
-  Base new_cut_points(d_bitwidth);
-  
-  // at this point d_base has all the cut points in the individual slices
-  for (unsigned row = 0; row < d_block.size(); ++row) {
-    Slice* slice = d_block[row];
-    new_cut_points.reset(); 
-    slice->getBase().diffCutPoints(d_base, new_cut_points);
-
-    if (! new_cut_points.isEmpty()) {
-      // use the cut points from the base to split the current slice
-      for (unsigned i = 0; i < d_bitwidth; ++i) {
-        const Base& base = slice->getBase(); // the base may have changed if splinters of the same slice are equal
-
-        if (new_cut_points.isCutPoint(i) && i != d_bitwidth - 1 && ! base.isCutPoint(i) ) {
-
-          Debug("bv-slicer") << "    adding cut point at " << i << " for row " << row << endl; 
-          // split this slice (this updates the slice's base)
-          Splinter* bottom, *top = NULL;
-          SplinterPointer sp;
-
-          ++(d_slicer->d_statistics.d_numSplits); 
-          slice->split(i, sp, bottom, top);
-          Assert (bottom != NULL && top != NULL); 
-          Assert (i >= bottom->getLow());
-          
-          if (sp != Undefined) {
-            unsigned delta = i - bottom->getLow();   
-            // if we do need to split something else split it now
-            Debug("bv-slicer") <<"    must split " << sp.debugPrint(); 
-            Slice* other_slice = d_slicer->getSlice(sp);
-            Splinter* s = d_slicer->getSplinter(sp);
-            Index cutPoint = s->getLow() + delta; 
-            Splinter* new_bottom = new Splinter(cutPoint, s->getLow());
-            Splinter* new_top = new Splinter(s->getHigh(), cutPoint + 1);
-            new_bottom->setPointer(SplinterPointer(d_rootId, row, bottom->getLow()));
-            new_top->setPointer(SplinterPointer(d_rootId, row, top->getLow()));
-            // note that this could modify the current splinter 
-            other_slice->addSplinter(new_bottom->getLow(), new_bottom);
-            other_slice->addSplinter(new_top->getLow(), new_top); 
-          
-            bottom->setPointer(SplinterPointer(sp.term, sp.row, new_bottom->getLow()));
-            top->setPointer(SplinterPointer(sp.term, sp.row, new_top->getLow()));
-            // update base for block
-            d_slicer->getSliceBlock(sp)->sliceBaseAt(cutPoint);
-            // add to queue of blocks that have changed base
-            Debug("bv-slicer") << "    adding block to queue: " << sp.term << endl; 
-            changedSet.insert(sp.term); 
-          }
-        }
-      }
-    }
+bool Base::isEmpty() const {
+  for (unsigned i = 0; i< d_repr.size(); ++i) {
+    if (d_repr[i] != 0)
+      return false;
   }
-
-  Debug("bv-slicer") << "base computed: " << d_rootId << endl;
-  Debug("bv-slicer") << this->debugPrint() << endl;
-  Debug("bv-slicer") << "SliceBlock::computeBlockBase done. \n"; 
-
+  return true;
 }
 
-std::string SliceBlock::debugPrint() {
+std::string Base::debugPrint() const {
   std::ostringstream os;
-  os << "Width " << d_bitwidth << endl; 
-  os << "Base " << d_base.debugPrint() << endl;
-  for (SliceBlock::const_iterator it = begin(); it!= end(); ++it) {
-    os << (*it)->debugPrint() << endl;
+  os << "[";
+  bool first = true; 
+  for (unsigned i = 0; i < d_size - 1; ++i) {
+    if (isCutPoint(i)) {
+      if (first)
+        first = false;
+      else
+        os <<"| "; 
+        
+      os << i ; 
+    }
   }
+  os << "]"; 
   return os.str(); 
 }
-
-Slicer::Slicer()
-  : d_simpleEqualities(),
-    d_roots(),
-    d_numRoots(0),
-    d_nodeRootMap(),
-    d_rootBlocks(),
-    d_coreTermCache(),
-    d_statistics()
-{}
-
-Slicer::Statistics::Statistics() :
-  d_numBlocks("TheoryBV::Slicer::NumBlocks", 0),
-  d_avgBlockSize("TheoryBV::Slicer::AvgBlockSize"),
-  d_avgBlockBitwitdh("TheoryBV::Slicer::AvgBlockBitwidth"),
-  d_numBlockBaseComputations("TheoryBV::Slicer::NumBlockBaseComputations", 0), 
-  d_numSplits("TheoryBV::Slicer::NumSplits", 0),
-  d_numSimpleEqualities("TheoryBV::Slicer::NumSimpleEqualities", 0),
-  d_numSlices("TheoryBV::Slicer::NumSlices", 0)
-{
-  StatisticsRegistry::registerStat(&d_numBlocks);
-  StatisticsRegistry::registerStat(&d_avgBlockSize);
-  StatisticsRegistry::registerStat(&d_avgBlockBitwitdh);
-  StatisticsRegistry::registerStat(&d_numBlockBaseComputations);
-  StatisticsRegistry::registerStat(&d_numSplits);
-  StatisticsRegistry::registerStat(&d_numSimpleEqualities);
-  StatisticsRegistry::registerStat(&d_numSlices);
-}
-
-Slicer::Statistics::~Statistics() {
-  StatisticsRegistry::unregisterStat(&d_numBlocks);
-  StatisticsRegistry::unregisterStat(&d_avgBlockSize);
-  StatisticsRegistry::unregisterStat(&d_avgBlockBitwitdh);
-  StatisticsRegistry::unregisterStat(&d_numBlockBaseComputations);
-  StatisticsRegistry::unregisterStat(&d_numSplits);
-  StatisticsRegistry::unregisterStat(&d_numSimpleEqualities);
-  StatisticsRegistry::unregisterStat(&d_numSlices);
+/**
+ * UnionFind
+ * 
+ */
+TermId UnionFind::addTerm(Index bitwidth) {
+  Node node(bitwidth);
+  d_nodes.push_back(node);
+  TermId id = d_nodes.size() - 1; 
+  d_representatives.insert(id);
+  d_topLevelTerms.insert(id); 
 }
 
-
-RootId Slicer::makeRoot(TNode n)  {
-  Assert (n.getKind() != kind::BITVECTOR_EXTRACT && n.getKind() != kind::BITVECTOR_CONCAT);
-  if (d_nodeRootMap.find(n) != d_nodeRootMap.end()) {
-    return d_nodeRootMap[n];
-  }
-  RootId id = d_roots.size();
-  d_nodeRootMap[n] = id; 
-  d_roots.push_back(n); 
-  d_rootBlocks.push_back(new SliceBlock(id, utils::getSize(n), this));
-  Assert (d_roots.size() == d_rootBlocks.size());
+void UnionFind::merge(TermId t1, TermId t2) {
   
-  Debug("bv-slicer") << "Slicer::makeRoot " << n << " -> " << id << endl;
+}
+TermId UnionFind::find(TermId id) {
+  Node node = getNode(id); 
+  if (node.getRepr() != -1)
+    return find(node.getRepr());
   return id; 
 }
+/** 
+ * Splits the representative of the term between i-1 and i
+ * 
+ * @param id the id of the term
+ * @param i the index we are splitting at
+ * 
+ * @return 
+ */
+void UnionFind::split(TermId id, Index i) {
+  id = find(id); 
+  Node node = getNode(id); 
+  Assert (i < node.getBitwidth());
 
-void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
-  Assert (node.getKind() == kind::EQUAL);
-  TNode t1 = node[0];
-  TNode t2 = node[1];
-
-  uint32_t width = utils::getSize(t1); 
-  
-  Base base1(width); 
-  if (t1.getKind() == kind::BITVECTOR_CONCAT) {
-    int size = -1;
-    // no need to count the last child since the end cut point is implicit 
-    for (int i = t1.getNumChildren() - 1; i >= 1 ; --i) {
-      size = size + utils::getSize(t1[i]);
-      base1.sliceAt(size); 
-    }
+  if (i == 0 || i == node.getBitwidth()) {
+    // nothing to do 
+    return;
   }
 
-  Base base2(width); 
-  if (t2.getKind() == kind::BITVECTOR_CONCAT) {
-    unsigned size = -1; 
-    for (int i = t2.getNumChildren() - 1; i >= 1; --i) {
-      size = size + utils::getSize(t2[i]);
-      base2.sliceAt(size); 
-    }
-  }
+  if (!node.hasChildren()) {
+    // first time we split this term 
+    TermId bottom_id = addTerm(i);
+    TermId top_id = addTerm(node.getBitwidth() - i);
+    node.addChildren(top_id, bottom_id);
 
-  base1.sliceWith(base2); 
-  if (!base1.isEmpty()) {
-    // we split the equalities according to the base
-    int last = 0; 
-    for (unsigned i = 0; i < utils::getSize(t1); ++i) {
-      if (base1.isCutPoint(i)) {
-        Node extract1 = Rewriter::rewrite(utils::mkExtract(t1, i, last));
-        Node extract2 = Rewriter::rewrite(utils::mkExtract(t2, i, last));
-        last = i + 1;
-        Assert (utils::getSize(extract1) == utils::getSize(extract2)); 
-        equalities.push_back(utils::mkNode(kind::EQUAL, extract1, extract2)); 
-      }
-    }
   } else {
-    // just return same equality
-    equalities.push_back(node);
+    Index cut = node.getCutPoint(); 
+    if (i < cut )
+      split(child1, i);
+    else
+      split(node.getChild(1), i - cut); 
   }
-} 
-void Slicer::processEquality(TNode node) {
-  Assert (node.getKind() == kind::EQUAL);
-  Debug("bv-slicer") << "theory::bv::Slicer::processEquality " << node << endl; 
-  // std::vector<Node> equalities;
-  // splitEqualities(node, equalities); 
-  // for (unsigned i = 0; i < equalities.size(); ++i) {
-  //   Debug("bv-slicer") << "    splitEqualities " << node << endl;
-  registerSimpleEquality(node); 
-  d_simpleEqualities.push_back(node);
-  //  }
 }
 
-void Slicer::registerSimpleEquality(TNode eq) {
-  Assert (eq.getKind() == kind::EQUAL);
-  ++(d_statistics.d_numSimpleEqualities);
-  
-  Debug("bv-slicer-eq") << "theory::bv::Slicer::registerSimpleEquality " << eq << endl;  
-  TNode a = eq[0];
-  TNode b = eq[1];
-
-  if (a == b)
-    return;
-  
-  RootId id_a = registerTerm(a);
-  RootId id_b = registerTerm(b);
-  
-  unsigned low_a = 0; 
-  unsigned low_b = 0; 
-
-  if (a.getKind() == kind::BITVECTOR_EXTRACT) {
-    low_a  = utils::getExtractLow(a);
-  }
-  
-  if (b.getKind() == kind::BITVECTOR_EXTRACT) {
-    low_b  = utils::getExtractLow(b);
+void UnionFind::getNormalForm(ExtractTerm term, NormalForm& nf) {
+  TermId id = find(term.id);
+  getDecomposition(term, nf.decomp);
+  // update nf base
+  Index count = 0; 
+  for (unsigned i = 0; i < nf.decomp.size(); ++i) {
+    count += getBitwidth(nf.decomp[i]);
+    nf.base.sliceAt(count); 
   }
+}
 
-  if (id_a == id_b ) {
-    // we are in the special case a[i0:j0] = a[i1:j1]
-    Index high_a = utils::getExtractHigh(a);
-    Index high_b = utils::getExtractHigh(b);
-    
-    unsigned intersection_low = std::max(low_a, low_b);
-    unsigned intersection_high = std::min(high_a, high_b);
-    if (intersection_low <= intersection_high) {
-      // if the two extracts intersect 
-      unsigned intersection_size = intersection_high - intersection_low + 1;
-      // gcd between overlapping area and difference
-      unsigned diff = low_a > low_b ? low_a - low_b  : low_b - low_a; 
-      unsigned granularity = gcd(intersection_size, diff);
-      SliceBlock* block_a = d_rootBlocks[id_a];
-      Assert (a.getKind() == kind::BITVECTOR_EXTRACT);
-      unsigned size = utils::getSize(a[0]);
-      
-      Slice* slice = new Slice(size);
-      unsigned low = low_a > low_b ? low_b : low_a;
-      unsigned high = high_a > high_b ? high_a : high_b;
-      Splinter* prev_splinter = NULL;
-      // the row the new slice will be in 
-      unsigned block_row = block_a->getSize(); 
-      for (unsigned i = low; i <= high; i+=granularity) {
-        Splinter* s = new Splinter(i+ granularity-1, i);
-        slice->addSplinter(i, s);
-        // update splinter pointers to reflect entailed equalities 
-        if (prev_splinter!= NULL) {
-          // the previous splinter will be equal to the current 
-          prev_splinter->setPointer(SplinterPointer(id_a, block_row, i));
-        }
-        prev_splinter = s; 
-      }
-      // make sure to splinters for the extremities
-      if (low!= 0) {
-        Splinter* s = new Splinter(low -1 , 0);
-        slice->addSplinter(0, s); 
-      }
-      if (high != size - 1) {
-        Splinter* s = new Splinter(size - 1, high + 1);
-        slice->addSplinter(high+1, s); 
-      }
-      block_a->addSlice(slice);
-      d_rootBlocks[id_a] = block_a; 
-      Debug("bv-slicer") << "     updated block" << id_a << " to " << endl;
-      Debug("bv-slicer") << block_a->debugPrint() << endl;
-      return; 
-    }
+void UnionFind::getDecomposition(ExtractTerm term, Decomposition& decomp) {
+  // making sure the term is aligned
+  TermId id = find(term.id); 
+
+  Node node = getNode(id);
+  Assert (term.high < node.getBitwidth());
+  // because we split the node, this must be the whole extract
+  if (!node.hasChildren()) {
+    Assert (term.high == node.getBitwidth() - 1 &&
+            term.low == 0);
+    decomp.push_back(id); 
   }
+    
+  Index cut = node.getCutPoint();
   
-  Slice* slice_a = makeSlice(a);
-  Slice* slice_b = makeSlice(b); 
-
-  SliceBlock* block_a = d_rootBlocks[id_a];
-  SliceBlock* block_b = d_rootBlocks[id_b];
-
-  uint32_t row_a = block_a->addSlice(slice_a);
-  uint32_t row_b = block_b->addSlice(slice_b); 
-
-  SplinterPointer sp_a = SplinterPointer(id_a, row_a, low_a);
-  SplinterPointer sp_b = SplinterPointer(id_b, row_b, low_b); 
-
-  slice_a->getSplinter(low_a)->setPointer(sp_b);
-  slice_b->getSplinter(low_b)->setPointer(sp_a);
-  Debug("bv-slicer") << "     updated block" << id_a << " to " << endl;
-  Debug("bv-slicer") << block_a->debugPrint() << endl;
-  Debug("bv-slicer") << "     updated block" <<id_b << " to " << endl;
-  Debug("bv-slicer") << block_b->debugPrint() << endl;
-}
-
-Slice* Slicer::makeSlice(TNode node) {
-  //Assert (d_sliceSet.find(node) == d_sliceSet.end());
-  ++(d_statistics.d_numSlices); 
-  Index bitwidth = utils::getSize(node); 
-  Index low = 0;
-  Index high = bitwidth -1;
-  if (node.getKind() == kind::BITVECTOR_EXTRACT) {
-    low  = utils::getExtractLow(node);
-    high = utils::getExtractHigh(node);
-    bitwidth = utils::getSize(node[0]); 
+  if (low < cut && high < cut) {
+    // the extract falls entirely on the low child
+    ExtractTerm child_ex(node.getChild(0), high, low); 
+    getDecomposition(child_ex, decomp); 
   }
-  Splinter* splinter = new Splinter(high, low);
-  Slice* slice = new Slice(bitwidth);
-  slice->addSplinter(low, splinter);
-  if (low != 0) {
-    Splinter* bottom_splinter = new Splinter(low-1, 0);
-    slice->addSplinter(0, bottom_splinter); 
+  else if (low >= cut && high >= cut){
+    // the extract falls entirely on the high child
+    ExtractTerm child_ex(node.getChild(1), high - cut, low - cut);
+    getDecomposition(child_ex, decomp); 
   }
-  if (high != bitwidth - 1) {
-    Splinter* top_splinter = new Splinter(bitwidth - 1, high + 1);
-    slice->addSplinter(high+1, top_splinter); 
+  else {
+    // the extract is split over the two children
+    ExtractTerm low_child(node.getChild(0), cut - 1, low);
+    getDecomposition(low_child, decomp);
+    ExtractTerm high_child(node.getChild(1), high, cut);
+    getDecomposition(high_child, decomp); 
   }
-  return slice; 
 }
 
-
-RootId Slicer::registerTerm(TNode node) {
-  if (node.getKind() == kind::BITVECTOR_EXTRACT ) {
-    node = node[0];
-    Assert (isRootTerm(node)); 
+void UnionFind::alignSlicings(NormalForm& nf1, NormalForm& nf2) {
+  Assert (nf1.base.getBitwidth() == nf2.base.getBitwidth());
+  // check if the two have
+  std::vector<TermId> intersection; 
+  intersection(nf1.decomp, nf2.decomp, intersection); 
+  for (unsigned i = 0; i < intersection.size(); ++i) {
+    TermId overlap = intersection[i];
+    Index start1 = 0;
+    Decomposition& decomp1 = nf1.decomp; 
+    for (unsigned j = 0; j < decomp1.size(); ++j) {
+      if (decomp1[j] == overlap)
+        break;
+      start1 += getSize(decomp1[j]); 
+    }
   }
-  // setting up the data-structures for the root term
-  RootId id = makeRoot(node);
-  return id; 
-}
-
-bool Slicer::isCoreTerm(TNode node) {
-  if (d_coreTermCache.find(node) == d_coreTermCache.end()) {
-    Kind kind = node.getKind(); 
-    if (kind != kind::BITVECTOR_EXTRACT &&
-        kind != kind::BITVECTOR_CONCAT &&
-        kind != kind::EQUAL && kind != kind::NOT &&
-        node.getMetaKind() != kind::metakind::VARIABLE &&
-        kind != kind::CONST_BITVECTOR) {
-      d_coreTermCache[node] = false;
-      return false; 
-    } else {
-      // we need to recursively check whether the term is a root term or not
-      bool isCore = true;
-      for (unsigned i = 0; i < node.getNumChildren(); ++i) {
-        isCore = isCore && isCoreTerm(node[i]); 
-      }
-      d_coreTermCache[node] = isCore;
-      return isCore; 
+  
+  Base new_cuts1 = nf1.base.diffCutPoints(nf2.base);
+  Base new_cuts2 = nf2.base.diffCutPoints(nf1.base); 
+  for (unsigned i = 0; i < new_cuts.base.getBitwidth(); ++i) {
+    if (new_cuts1.isCutPoint(i)) {
+      
     }
   }
-  return d_coreTermCache[node]; 
+  
 }
 
-bool Slicer::isRootTerm(TNode node) {
-  Kind kind = node.getKind();
-  return kind != kind::BITVECTOR_EXTRACT && kind != kind::BITVECTOR_CONCAT;
+void UnionFind::ensureSlicing(ExtractTerm& term) {
+  TermId id = find(term.id);
+  split(id, term.high);
+  split(id, term.low);
+
+  
 }
 
-// Base Slicer::getBase(TNode node) {
-//   Assert (d_bases.find(node) != d_bases.end());
-//   return d_bases[node]; 
-// }
+/**
+ * Slicer
+ * 
+ */
 
-// void Slicer::updateBase(TNode node, const Base& base) {
-//   Assert (d_bases.find(node) != d_bases.end());
-//   d_bases[node] = d_bases[node].bitwiseOr(base); 
-// }
 
 
-void Slicer::computeCoarsestBase() {
-  Debug("bv-slicer") << "theory::bv::Slicer::computeCoarsestBase " << endl; 
-  BlockIdSet changed_set;
-  for (unsigned i = 0; i < d_rootBlocks.size(); ++i) {
-    SliceBlock* block = d_rootBlocks[i];
-    block->computeBlockBase(changed_set);
-    ++(d_statistics.d_numBlocks);
-    d_statistics.d_avgBlockSize.addEntry(block->getSize());
-    d_statistics.d_avgBlockBitwitdh.addEntry(block->getBitwidth()); 
+void Slicer::registerTerm(TNode node) {
+  Index low = 0, high = utils::getSize(node); 
+  TNode n = node; 
+  if (node.getKind() == kind::BITVECTOR_EXTRACT) {
+    TNode n = node[0];
+    high = utils::getExtractHigh(node);
+    low = utils::getExtractLow(node); 
   }
-
-  Debug("bv-slicer") << " processing changeSet of size " << changed_set.size() << endl; 
-  while (!changed_set.empty()) {
-    // process split candidate
-    RootId current = *(changed_set.begin()); 
-    changed_set.erase(current); 
-    SliceBlock* block = d_rootBlocks[current];
-    block->computeBlockBase(changed_set); 
+  if (d_nodeToId.find(n) == d_nodeToId.end()) {
+    id = d_uf.addTerm(utils::getSize(n)); 
+    d_nodeToId[n] = id;
+    d_idToNode[id] = n; 
   }
-  Assert(debugCheckBase());
-  std::cout << "Done computing coarsest base \n"; 
-}
-
+  TermId id = d_nodeToId[n];
 
-void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
-  Assert (node.getKind() != kind::BITVECTOR_CONCAT); 
-  TNode root = node.getKind() == kind::BITVECTOR_EXTRACT ? node[0] : node; 
-  Assert (isRootTerm(root)); 
-  const Base& base = getSliceBlock(getRootId(root))->getBase();
-  base.decomposeNode(node, decomp);
+  return ExtractTerm(id, high, low); 
 }
 
-bool Slicer::debugCheckBase() {
-  // check that all terms involved in equalities are properly sliced w.r.t.
-  // these equalities 
-  for (unsigned i = 0; i < d_simpleEqualities.size(); ++i) {
-    TNode a = d_simpleEqualities[i][0];
-    TNode b = d_simpleEqualities[i][1];
-    std::vector<Node> a_decomp;
-    std::vector<Node> b_decomp;
+void Slicer::processSimpleEquality(TNode eq) {
+  Assert (eq.getKind() == kind::EQUAL);
+  TNode a = eq[0];
+  TNode b = eq[1];
+  ExtractTerm a_ex= registerTerm(a);
+  ExtractTerm b_ex= registerTerm(b);
+  
+  NormalForm a_nf, b_nf;
+  d_uf.ensureSlicing(a_ex);
+  d_uf.ensureSlicing(b_ex);
+  
+  d_uf.getNormalForm(a_ex, a_nf);
+  d_uf.getNormalForm(b_ex, b_nf);
 
-    const Base& base_a = getSliceBlock(getRootId(a.getKind() == kind::BITVECTOR_EXTRACT ? a[0] : a))->getBase();
-    const Base& base_b = getSliceBlock(getRootId(b.getKind() == kind::BITVECTOR_EXTRACT ? b[0] : b))->getBase();
-    base_a.decomposeNode(a, a_decomp);
-    base_b.decomposeNode(b, b_decomp);
-    if (a_decomp.size() != b_decomp.size()) {
-      Debug("bv-slicer-check") << "Slicer::debugCheckBase different decomposition sizes for \n"
-                               << a <<" and \n"
-                               << b <<" \n"; 
-      return false;
-    }
-    for (unsigned j = 0; j < a_decomp.size(); ++j) {
-      if (utils::getSize(a_decomp[j]) != utils::getSize(b_decomp[j])) {
-        Debug("bv-slicer-check") << "Slicer::debugCheckBase inconsistent decompositions  \n"; 
-        return false;
-      }
-    }
-  }
-  // iterate through blocks and check that the block base is the same as each slice base
-  for (unsigned i = 0; i < d_rootBlocks.size(); ++i) {
-    SliceBlock* block = d_rootBlocks[i];
-    const Base& block_base = block->getBase();
-    Base diff_points(block->getBitwidth());
-    SliceBlock::const_iterator it = block->begin();
-    for (; it != block->end(); ++it) {
-      Slice* slice = *it;
-      if (!slice->isConsistent()) {
-        Debug("bv-slicer-check") << "Slicer::debugCheckBase inconsistent slice:  \n"
-                                 << slice->debugPrint() << "\n"; 
-        return false;
-      }
-      
-      diff_points.reset(); 
-      slice->getBase().diffCutPoints(block_base, diff_points);
-      if (!diff_points.isEmpty()) {
-        Debug("bv-slicer-check") << "Slicer::debugCheckBase slice missing cut points:  \n"
-                                 << slice->debugPrint()
-                                 << "Block base: " << block->getBase().debugPrint() << endl; 
-        return false;
-      }
-      Slice::const_iterator slice_it = slice->begin();
-      for (; slice_it!= slice->end(); ++slice_it) {
-        Splinter* splinter = (*slice_it).second;
-        const SplinterPointer& sp = splinter->getPointer();
-        if (sp != Undefined) {
-          Splinter* other = getSplinter(sp);
-          if (splinter->getBitwidth() != other->getBitwidth()) {
-            Debug("bv-slicer-check") << "Slicer::debugCheckBase inconsistent splinter pointer  \n"; 
-            return false;
-          }
-        }
-      }
-    }
-  }
-  return true; 
+  d_uf.alignSlicings(a_nf, b_nf); 
 }
 
-
-
-
+void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) const {
+}
 
 
index 16820801e4af45953689cc0f4dc3937eacf058e0..288b72bacc30ae4a4725805797ce63bfaaae2a3b 100644 (file)
@@ -38,388 +38,157 @@ namespace CVC4 {
 namespace theory {
 namespace bv {
 
-typedef uint32_t RootId;
-typedef uint32_t SplinterId;
+typedef uint32_t TermId;
 typedef uint32_t Index;
 
+
+
+/** 
+ * Base
+ * 
+ */
 class Base {
-  uint32_t d_size;
+  Index d_size;
   std::vector<uint32_t> d_repr;
 public:
-  Base(uint32_t size) 
-    : d_size(size),
-      d_repr((size-1)/32 + ((size-1) % 32 == 0? 0 : 1), 0)
-  {
-    Assert (d_size > 0); 
-  }
+  Base(uint32_t size);
+  void sliceAt(Index index); 
+  void sliceWith(const Base& other);
+  bool isCutPoint(Index index) const;
+  void diffCutPoints(const Base& other, Base& res) const;
+  bool isEmpty() const;
+  std::string debugPrint() const;
+}; 
 
-  
-  /** 
-   * Marks the base by adding a cut between index and index + 1
-   * 
-   * @param index 
-   */
-  void sliceAt(Index index) {
-    Index vector_index = index / 32;
-    Assert (vector_index < d_size - 1); 
-    Index int_index = index % 32;
-    uint32_t bit_mask = utils::pow2(int_index); 
-    d_repr[vector_index] = d_repr[vector_index] | bit_mask; 
+/**
+ * UnionFind
+ * 
+ */
+typedef __gnu_cxx::hash_set<TermId> TermSet;
+typedef std::vector<TermId> Decomposition; 
+
+struct ExtractTerm {
+  TermId id;
+  Index high;
+  Index low;
+  ExtractTerm(TermId i, Index h, Index l)
+    : id (i)
+      high(h)
+      low(l)
+  {
+    Assert (h >= l && id != -1); 
   }
+};
 
-  void sliceWith(const Base& other) {
-    Assert (d_size == other.d_size);
-    for (unsigned i = 0; i < d_repr.size(); ++i) {
-      d_repr[i] = d_repr[i] | other.d_repr[i]; 
-    }
-  }
+struct NormalForm {
+  Base base;
+  Decomposition decomp;
+  NormalForm(Index bitwidth)
+    : base(bitwidth),
+      decomp()
+  {}
+};
 
-  void decomposeNode(TNode node, std::vector<Node>& decomp) const;
 
-  bool isCutPoint (Index index) const {
-    // there is an implicit cut point at the end of the bv
-    if (index == d_size - 1)
-      return true;
+class UnionFind {
+  class Node {
+    TermId d_repr;
+    TermId d_ch1, d_ch2;
+    Index d_bitwidth;
+  public:
+    Node(Index b)
+  : d_bitwidth(b),
+    d_ch1(-1),
+    d_ch2(-1), 
+    d_repr(-1)
+    {}
     
-    Index vector_index = index / 32;
-    Assert (vector_index < d_size - 1); 
-    Index int_index = index % 32;
-    uint32_t bit_mask = utils::pow2(int_index); 
-
-    return (bit_mask & d_repr[vector_index]) != 0; 
-  }
+    TermId getRepr() const { return d_repr; }
+    Index getBitwidth() const { return d_bitwidth; }
+    bool hasChildren() const { return d_ch1 != -1 && d_ch2 != -1; }
 
-  void diffCutPoints(const Base& other, Base& res) const {
-    Assert (d_size == other.d_size && res.d_size == d_size);
-    for (unsigned i = 0; i < d_repr.size(); ++i) {
-      Assert (res.d_repr[i] == 0); 
-      res.d_repr[i] = d_repr[i] ^ other.d_repr[i]; 
+    TermId getChild(Index i) const {
+      Assert (i < 2);
+      return i == 0? ch1 : ch2;
     }
-  }
-
-  bool isEmpty() const {
-    for (unsigned i = 0; i< d_repr.size(); ++i) {
-      if (d_repr[i] != 0)
-        return false;
+    Index getCutPoint() const {
+      Assert (d_ch1 != -1 && d_ch2 != -1);
+      return getNode(d_ch1).getBitwidth(); 
     }
-    return true;
-  }
-
-  void reset() {
-    for (unsigned i = 0; i< d_repr.size(); ++i) {
-      d_repr[i] = 0;
+    void setRepr(TermId id) {
+      Assert (d_children.empty());
+      d_repr = id;
     }
-  }
-
-  // bool operator==(const Base& other) const {
-  //   Assert (d_size == other.d_size);
-  //   for (unsigned i = 0; i < d_repr.size(); ++i) {
-  //     if (d_repr[i] != other.d_repr[i])
-  //       return false; 
-  //   }
-  //   return true; 
-  // }
-  // bool operator!=(const Base& other) const {
-  //   return !(*this == other); 
-  // }
 
-  std::string debugPrint() const {
-    std::ostringstream os;
-    os << "[";
-    bool first = true; 
-    for (unsigned i = 0; i < d_size - 1; ++i) {
-      if (isCutPoint(i)) {
-        if (first)
-          first = false;
-        else
-          os <<"| "; 
-        
-        os << i ; 
-      }
+    void setChildren(TermId ch1, TermId ch2) {
+      Assert (d_repr == -1 && d_children.empty());
+      markAsNotTopLevel(ch1);
+      markAsNotTopLevel(ch2); 
+      d_children.push_back(ch1);
+      d_children.push_back(ch2); 
     }
-    os << "]"; 
-    return os.str(); 
-  }
-  
-}; 
-
-
-struct SplinterPointer {
-  RootId term;
-  uint32_t row; 
-  Index index;
-
-  SplinterPointer()
-    : term(-1),
-      row(-1),
-      index(-1)
-  {}
-
-  SplinterPointer(RootId t, uint32_t r,  Index i)
-    : term(t),
-      row(r),
-      index(i)
-  {}
-  
-  bool operator==(const SplinterPointer& other) const {
-    return term == other.term && index == other.index && row == other.row; 
-  }
-  bool operator!=(const SplinterPointer& other) const {
-    return !(*this == other); 
-  }
-
-  std::string debugPrint() {
-    std::ostringstream os;
-    os << "(id" << term << ", row" << row <<", i" << index << ")";
-    return os.str();
-  }
-  
-};
-
-static const SplinterPointer Undefined = SplinterPointer(-1, -1, -1); 
-
-class Splinter {
-  // start and end indices in slice
-  Index d_low;
-  Index d_high;
-
-  // keeps track of splinter this splinter is equal to
-  // equal to Undefined if there is none
-  SplinterPointer d_pointer;
-  
-public:
-  Splinter(uint32_t high, uint32_t low) :
-    d_low(low),
-    d_high(high),
-    d_pointer(Undefined)
-  {
-    Assert (high >= low); 
-  }
     
-  void setPointer(const SplinterPointer& pointer) {
-    Assert (d_pointer == Undefined);
-    d_pointer = pointer; 
-  }
-
-  const SplinterPointer& getPointer() const {
-    return d_pointer; 
-  }
+    // void setChildren(TermId ch1, TermId ch2, TermId ch3) {
+    //   Assert (d_repr == -1 && d_children.empty());
+    //   d_children.push_back(ch1);
+    //   d_children.push_back(ch2);
+    //   d_children.push_back(ch3); 
+    // }
+    
+  };
 
-  Index getLow() const { return d_low; }
-  Index getHigh() const {return d_high; }
-  uint32_t getBitwidth() const { return d_high - d_low; }
-};
+  std::vector<Node> d_nodes;
 
-class Slice {
-  uint32_t d_bitwidth; 
-  // map from the beginning of a splinter to the actual splinter id
-  std::map<Index, Splinter*> d_splinters;
-  Base d_base;
-  
-public:
-  Slice(uint32_t bitwidth)
-    : d_bitwidth(bitwidth),
-      d_splinters(),
-      d_base(bitwidth)
-  {}
-  /** 
-   * Split the slice by adding a cut point between indices i and i+1
-   * 
-   * @param i index where to cut
-   * @param id the id of the root term this slice belongs to
-   * @param row the row of the SliceBlock this Slice belongs to
-   */
-  void split(Index i, SplinterPointer& sp, Splinter*& low_splinter, Splinter*& top_splinter);
-  /** 
-   * Add splinter sp at Index i. If a splinter already exists there
-   * replace it and free the memory it occupied. 
-   * 
-   * @param i index where splinter starts
-   * @param sp new splinter
-   */
-  void addSplinter(Index i, Splinter* sp); 
-  /** 
-   * Return the splinter starting at Index start.
-   * 
-   * @param start 
-   * 
-   * @return 
-   */
-  Splinter* getSplinter (Index start) {
-    Assert (d_splinters.find(start) != d_splinters.end()); 
-    return d_splinters[start]; 
+  TermSet d_representatives;
+  TermSet d_topLevelTerms;
+  void markAsNotTopLevel(TermId id) {
+    if (d_topLevelTerms.find(id) != d_topLevelTerms.end())
+      d_topLevelTerms.erase(id); 
   }
-  const Base& getBase() const { return d_base; }
 
-  typedef std::map<Index, Splinter*>::const_iterator const_iterator; 
-  std::map<Index, Splinter*>::const_iterator begin() {
-    return d_splinters.begin(); 
+  bool isTopLevel(TermId id) {
+    return d_topLevelTerms.find(id) != d_topLevelTerms.end(); 
   }
-  std::map<Index, Splinter*>::const_iterator end() {
-    return d_splinters.end(); 
+  
+  Index getBitwidth(TermId id) {
+    Assert (id < d_nodes.size());
+    return d_nodes[id].getBitwidth(); 
   }
-  std::string debugPrint();
-  bool isConsistent();
-};
-
-class Slicer; 
-
-typedef __gnu_cxx::hash_set<RootId> BlockIdSet; 
-
-class SliceBlock {
-  uint32_t d_bitwidth; 
-  RootId d_rootId;                /**< the id of the root term this block corresponds to */
-  std::vector<Slice*> d_block;    /**< the slices in the block */
-  Base d_base;                    /**<  the base corresponding to this block containing all the cut points.
-                                   Invariant: the base should contain all the cut-points in the slices*/
-  Slicer* d_slicer; // FIXME: more elegant way to do this
   
 public:
-  SliceBlock(RootId rootId, uint32_t bitwidth, Slicer* slicer)
-    : d_bitwidth(bitwidth),
-      d_rootId(rootId),
-      d_block(),
-      d_base(bitwidth),
-      d_slicer(slicer)
+  UnionFind()
+    : d_nodes(),
+      d_representatives()
   {}
 
-  uint32_t addSlice(Slice* slice) {
-    // update the base with the cut-points in the slice
-    Debug("bv-slice") << "SliceBlock::addSlice Block"<< d_rootId << " adding slice " << slice->debugPrint() << std::endl; 
-    d_base.sliceWith(slice->getBase()); 
-    d_block.push_back(slice);
-    return d_block.size() - 1; 
-  }
-
-  Slice* getSlice(uint32_t row) const {
-    Assert (row < d_block.size()); 
-    return d_block[row]; 
-  }
-  /** 
-   * Propagate all the cut points in the Base to all the Slices. If one of the
-   * splinters that needs to get cut has a pointer to a splinter in a different
-   * block that splinter will also be split. 
-   * 
-   * @param queue other blocks that changed their base. 
-   */
-  void computeBlockBase(BlockIdSet& recompute);
-
-  void sliceBaseAt(Index i) {
-    d_base.sliceAt(i); 
-  }
-  typedef std::vector<Slice*>::const_iterator const_iterator; 
-  std::vector<Slice*>::const_iterator begin() {
-    return d_block.begin(); 
-  }
-  std::vector<Slice*>::const_iterator end() {
-    return d_block.end(); 
-  }
+  TermId addTerm(Index bitwidth);
+  void merge(TermId t1, TermId t2);
+  TermId find(TermId t1); 
+  TermId split(TermId term, Index i);
 
-  uint32_t getBitwidth() const {
-    return d_bitwidth; 
-  }
-  Base& getBase() {
-    return d_base; 
-  }
+  void getNormalForm(ExtractTerm term, NormalForm& nf);
+  void alignSlicings(NormalForm& nf1, NormalForm& nf2);
 
-  unsigned getSize() const {
-    return d_block.size(); 
+  Node getNode(TermId id) {
+    Assert (id < d_nodes.size());
+    return d_nodes[id]; 
   }
-  std::string debugPrint(); 
 };
 
-typedef __gnu_cxx::hash_map<TNode, bool, TNodeHashFunction> RootTermCache;
-
-typedef __gnu_cxx::hash_map<TNode, RootId, TNodeHashFunction> NodeRootIdMap;
-typedef std::vector<TNode> Roots; 
-
-typedef __gnu_cxx::hash_map<TNode, SplinterId, TNodeHashFunction> NodeSplinterIdMap;
-typedef std::vector<Splinter*> Splinters;
-
-typedef std::vector<SliceBlock*> SliceBlocks;
-
 class Slicer {
-  std::vector<Node> d_simpleEqualities; /**< equalities of the form a[i0:j0] = b[i1:j1] */
-  Roots d_roots;
-  uint32_t d_numRoots; 
-  NodeRootIdMap d_nodeRootMap;
-  /* Indexed by Root Id */
-  SliceBlocks d_rootBlocks;
-  RootTermCache d_coreTermCache;
-
-
-public:
-  Slicer();
-  void computeCoarsestBase();
-  /** 
-   * Takes an equality that is of the following form:
-   *          a1 a2 ... an = b1 b2 ... bk
-   * where ai, and bi are either variables or extracts over variables,
-   * and consecutive extracts have been merged. 
-   * 
-   * @param node 
-   */
-  void processEquality(TNode node);
-  bool isCoreTerm(TNode node); 
-  static void splitEqualities(TNode node, std::vector<Node>& equalities);
-private:
-  void registerSimpleEquality(TNode node);
-
-  TNode addSimpleTerm(TNode t);
-  bool isRootTerm(TNode node);
-
-  TNode getRoot(RootId id) {return d_roots[id]; }
-
-  RootId getRootId(TNode node) {
-    Assert (d_nodeRootMap.find(node) != d_nodeRootMap.end());
-    return d_nodeRootMap[node]; 
-  }
-
-  RootId registerTerm(TNode node); 
-  RootId makeRoot(TNode n);
-  Slice* makeSlice(TNode node);
-
-  bool debugCheckBase();
-
-  class Statistics {
-  public:
-    IntStat d_numBlocks;
-    AverageStat d_avgBlockSize;
-    AverageStat d_avgBlockBitwitdh;
-    IntStat d_numBlockBaseComputations; 
-    IntStat d_numSplits;
-    IntStat d_numSimpleEqualities;
-    IntStat d_numSlices; 
-    Statistics();
-    ~Statistics(); 
-  };
-
+  __gnu_cxx::hash_map<TermId, TNode> d_idToNode;
+  __gnu_cxx::hash_map<TNode, TermId> d_nodeToId;
+  UnionFind d_unionFind();
 public:
-  Statistics d_statistics;
-  
-  Slice* getSlice(const SplinterPointer& sp) {
-    Assert (sp != Undefined); 
-    SliceBlock* sb = d_rootBlocks[sp.term];
-    return sb->getSlice(sp.row); 
-  }
+  Slicer()
+    : d_topLevelTerms(),
+      d_unionFind()
+  {}
   
-  Splinter* getSplinter(const SplinterPointer& sp) {
-    Slice* slice = getSlice(sp);
-    return slice->getSplinter(sp.index); 
-  }
-
-  SliceBlock* getSliceBlock(RootId id) {
-    Assert (id < d_rootBlocks.size());
-    return d_rootBlocks[id]; 
-  }
-
-  SliceBlock* getSliceBlock(const SplinterPointer& sp) {
-    return getSliceBlock(sp.term); 
-  }
-
-  void getBaseDecomposition(TNode node, std::vector<Node>& decomp); 
-
-}; /* Slicer class */
+  void getBaseDecomposition(TNode node, std::vector<Node>& decomp) const;
+  void processEquality(TNode eq);
+}; 
 
 }/* CVC4::theory::bv namespace */
 }/* CVC4::theory namespace */