fixed some more bugs
authorlianah <lianahady@gmail.com>
Thu, 31 Jan 2013 01:02:47 +0000 (20:02 -0500)
committerlianah <lianahady@gmail.com>
Thu, 31 Jan 2013 01:02:47 +0000 (20:02 -0500)
src/theory/bv/slicer.cpp
src/theory/bv/slicer.h

index 80a52525d73330d6a44c5700d11b80cf0ea8005a..79f3f5b68ad89311c311a2a9a98fa29134de1f35 100644 (file)
@@ -34,7 +34,7 @@ const TermId CVC4::theory::bv::UndefinedId = -1;
  */
 Base::Base(uint32_t size) 
   : d_size(size),
-    d_repr((size-1)/32 + ((size-1) % 32 == 0? 0 : 1), 0)
+    d_repr(size/32 + (size % 32 == 0? 0 : 1), 0)
 {
   Assert (d_size > 0); 
 }
@@ -42,7 +42,7 @@ Base::Base(uint32_t size)
   
 void Base::sliceAt(Index index) {
   Index vector_index = index / 32;
-  Assert (vector_index < d_size - 1); 
+  Assert (vector_index < d_size); 
   Index int_index = index % 32;
   uint32_t bit_mask = utils::pow2(int_index); 
   d_repr[vector_index] = d_repr[vector_index] | bit_mask; 
@@ -56,12 +56,12 @@ void Base::sliceWith(const Base& other) {
 }
 
 bool Base::isCutPoint (Index index) const {
-  // there is an implicit cut point at the end of the bv
-  if (index == d_size - 1)
+  // there is an implicit cut point at the end and begining of the bv
+  if (index == d_size || index == 0)
     return true;
     
   Index vector_index = index / 32;
-  Assert (vector_index < d_size - 1); 
+  Assert (vector_index < d_size); 
   Index int_index = index % 32;
   uint32_t bit_mask = utils::pow2(int_index); 
 
@@ -88,7 +88,7 @@ std::string Base::debugPrint() const {
   std::ostringstream os;
   os << "[";
   bool first = true; 
-  for (unsigned i = 0; i < d_size - 1; ++i) {
+  for (int i = d_size - 1; i >= 0; --i) {
     if (isCutPoint(i)) {
       if (first)
         first = false;
@@ -118,26 +118,28 @@ std::string ExtractTerm::debugPrint() const {
  *
  */
 
-TermId NormalForm::getTerm(Index i, const UnionFind& uf) const {
-  Assert (i < base.getBitwidth()); 
+std::pair<TermId, Index> NormalForm::getTerm(Index index, const UnionFind& uf) const {
+  Assert (index < base.getBitwidth()); 
   Index count = 0;
   for (unsigned i = 0; i < decomp.size(); ++i) {
     Index size = uf.getBitwidth(decomp[i]); 
-    if ( count + size <= i && count >= i) {
-      return decomp[i]
+    if ( count + size > index && index >= count) {
+      return pair<TermId, Index>(decomp[i], count)
     }
     count += size; 
   }
   Unreachable(); 
 }
 
+
+
 std::string NormalForm::debugPrint(const UnionFind& uf) const {
   ostringstream os;
   os << "NF " << base.debugPrint() << endl;
   os << "("; 
-  for (unsigned i = 0; i < decomp.size(); ++i) {
+  for (int i = decomp.size() - 1; i>= 0; --i) {
     os << decomp[i] << "[" << uf.getBitwidth(decomp[i]) <<"]";
-    os << (i < decomp.size() - 1? ", " : "");  
+    os << (i != 0? ", " : "");  
   }
   os << ") \n"; 
   return os.str(); 
@@ -150,7 +152,7 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const {
 std::string UnionFind::Node::debugPrint() const {
   ostringstream os;
   os << "Repr " << d_repr << " ["<< d_bitwidth << "] ";
-  os << "( " << d_ch1 <<", " << d_ch2 << ")" << endl; 
+  os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; 
   return os.str(); 
 }
 
@@ -213,8 +215,9 @@ void UnionFind::merge(TermId t1, TermId t2) {
 }
 
 TermId UnionFind::find(TermId id) const {
-  if (getRepr(id) != UndefinedId)
-    return find(getRepr(id));
+  TermId repr = getRepr(id); 
+  if (repr != UndefinedId)
+    return find(repr);
   return id; 
 }
 /** 
@@ -244,13 +247,14 @@ void UnionFind::split(TermId id, Index i) {
   } else {
     Index cut = getCutPoint(id); 
     if (i < cut )
-      split(getChild(id, 1), i);
+      split(getChild(id, 0), i);
     else
-      split(getChild(id, 0), i - cut); 
+      split(getChild(id, 1), i - cut); 
   }
 }
 
 void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) {
+  nf.clear(); 
   getDecomposition(term, nf.decomp);
   // update nf base
   Index count = 0; 
@@ -271,7 +275,8 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp)
   if (!hasChildren(id)) {
     Assert (term.high == getBitwidth(id) - 1 &&
             term.low == 0);
-    decomp.push_back(id); 
+    decomp.push_back(id);
+    return; 
   }
     
   Index cut = getCutPoint(id);
@@ -290,7 +295,7 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp)
     // the extract is split over the two children
     ExtractTerm low_child(getChild(id, 0), cut - 1, term.low);
     getDecomposition(low_child, decomp);
-    ExtractTerm high_child(getChild(id, 1), term.high, cut);
+    ExtractTerm high_child(getChild(id, 1), term.high - cut, 0);
     getDecomposition(high_child, decomp); 
   }
 }
@@ -322,11 +327,11 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit
   start1 = start1 > start2 ? start2 : start1;
   start2 = start1 > start2 ? start1 : start2; 
 
-  if (start1 + common_size <= start2) {
+  if (start2 - start1 < common_size) {
     Index overlap = start1 + common_size - start2;
     Assert (overlap > 0);
-    Index diff = start2 - overlap;
-    Assert (diff > 0);
+    Index diff = common_size - overlap;
+    Assert (diff >= 0);
     Index granularity = utils::gcd(diff, overlap);
     // split the common part 
     for (unsigned i = 0; i < common_size; i+= granularity) {
@@ -362,13 +367,14 @@ void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2
   // align the cuts points of the two slicings
   // FIXME: this can be done more efficiently
   Base& cuts = nf1.base;
+  cuts.debugPrint(); 
   cuts.sliceWith(nf2.base);
   for (unsigned i = 0; i < cuts.getBitwidth(); ++i) {
     if (cuts.isCutPoint(i)) {
-      TermId t1 = nf1.getTerm(i, *this);
-      split(t1, i); 
-      TermId t2 = nf2.getTerm(i, *this);
-      split(t2, i); 
+      pair<TermId, Index> pair1 = nf1.getTerm(i, *this);
+      split(pair1.first, i - pair1.second); 
+      pair<TermId, Index> pair2 = nf2.getTerm(i, *this);
+      split(pair2.first, i - pair2.second); 
     }
   }
 }
@@ -423,23 +429,24 @@ void Slicer::processEquality(TNode eq) {
   
   d_unionFind.alignSlicings(a_ex, b_ex);
   d_unionFind.unionTerms(a_ex, b_ex);
-  
+
   Debug("bv-slicer") << "Slicer::processEquality done. " << endl;
 }
 
 void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
   Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl;
   
-  Index high = utils::getSize(node);
-  Index low = 0; 
+  Index high = utils::getSize(node) - 1;
+  Index low = 0;
+  TNode top = node; 
   if (node.getKind() == kind::BITVECTOR_EXTRACT) {
     high = utils::getExtractHigh(node);
     low = utils::getExtractLow(node);
-    node = node[0]; 
+    top = node[0]; 
   }
-  Assert (d_nodeToId.find(node) != d_nodeToId.end()); 
-  TermId id = d_nodeToId[node];
-  NormalForm nf(utils::getSize(node)); 
+  Assert (d_nodeToId.find(top) != d_nodeToId.end()); 
+  TermId id = d_nodeToId[top];
+  NormalForm nf(high-low+1); 
   d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf);
   
   // construct actual extract nodes
@@ -448,7 +455,7 @@ void Slicer::getBaseDecomposition(TNode node, std::vector<Node>& decomp) {
   for (unsigned i = 0; i < nf.decomp.size(); ++i) {
     Index current_size = d_unionFind.getBitwidth(nf.decomp[i]); 
     current_high += current_size; 
-    Node current = utils::mkExtract(node, current_high - 1, current_low);
+    Node current = Rewriter::rewrite(utils::mkExtract(node, current_high - 1, current_low));
     current_low += current_size;
     decomp.push_back(current); 
   }
@@ -528,3 +535,20 @@ void Slicer::splitEqualities(TNode node, std::vector<Node>& equalities) {
     equalities.push_back(node);
   }
 } 
+
+std::string UnionFind::debugPrint(TermId id) {
+  ostringstream os; 
+  if (hasChildren(id)) {
+    TermId id1 = find(getChild(id, 1));
+    TermId id0 = find(getChild(id, 0));
+    os << debugPrint(id1) <<" ";
+    os << debugPrint(id0) <<" "; 
+  } else {
+    if (getRepr(id) == UndefinedId) {
+      os << id <<"[" << getBitwidth(id) <<"] "; 
+    } else {
+      os << debugPrint(find(id)) << " ";
+    }
+  }
+  return os.str(); 
+}
index b27b85e65dbdc30bc576f3fd09b5313edf67289f..c7451c28884684ce7ccdfdab3c984f4dc15cf636 100644 (file)
@@ -60,6 +60,11 @@ public:
   bool isEmpty() const;
   std::string debugPrint() const;
   Index getBitwidth() const { return d_size; }
+  void clear() {
+    for (unsigned i = 0; i < d_repr.size(); ++i) {
+      d_repr[i] = 0; 
+    }
+  }
   bool operator==(const Base& other) const {
     if (other.getBitwidth() != getBitwidth())
       return false;
@@ -110,40 +115,41 @@ struct NormalForm {
    * 
    * @return 
    */
-  TermId getTerm(Index i, const UnionFind& uf) const;
-  std::string debugPrint(const UnionFind& uf) const; 
+  std::pair<TermId, Index> getTerm(Index i, const UnionFind& uf) const;
+  std::string debugPrint(const UnionFind& uf) const;
+  void clear() { base.clear(); decomp.clear(); }
 };
 
 
 class UnionFind {
   class Node {
     Index d_bitwidth;
-    TermId d_ch1, d_ch2;
+    TermId d_ch1, d_ch0;
     TermId d_repr;
   public:
     Node(Index b)
   : d_bitwidth(b),
     d_ch1(UndefinedId),
-    d_ch2(UndefinedId), 
+    d_ch0(UndefinedId), 
     d_repr(UndefinedId)
     {}
     
     TermId getRepr() const { return d_repr; }
     Index getBitwidth() const { return d_bitwidth; }
-    bool hasChildren() const { return d_ch1 != UndefinedId && d_ch2 != UndefinedId; }
+    bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; }
 
     TermId getChild(Index i) const {
       Assert (i < 2);
-      return i == 0? d_ch1 : d_ch2;
+      return i == 0? d_ch0 : d_ch1;
     }
     void setRepr(TermId id) {
       Assert (! hasChildren());
       d_repr = id;
     }
-    void setChildren(TermId ch1, TermId ch2) {
+    void setChildren(TermId ch1, TermId ch0) {
       Assert (d_repr == UndefinedId && !hasChildren());
       d_ch1 = ch1;
-      d_ch2 = ch2
+      d_ch0 = ch0
     }
     std::string debugPrint() const;
   };
@@ -165,7 +171,7 @@ class UnionFind {
     return d_nodes[id].getChild(i); 
   }
   Index getCutPoint(TermId id) const {
-    return getBitwidth(getChild(id, 1)); 
+    return getBitwidth(getChild(id, 0)); 
   }
   bool hasChildren(TermId id) const {
     Assert (id < d_nodes.size());
@@ -176,9 +182,9 @@ class UnionFind {
     Assert (id < d_nodes.size());
     d_nodes[id].setRepr(new_repr); 
   }
-  void setChildren(TermId id, TermId ch1, TermId ch2) {
-    Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch2));
-    d_nodes[id].setChildren(ch1, ch2); 
+  void setChildren(TermId id, TermId ch1, TermId ch0) {
+    Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0));
+    d_nodes[id].setChildren(ch1, ch0); 
   }
 
   
@@ -201,6 +207,7 @@ public:
     Assert (id < d_nodes.size());
     return d_nodes[id].getBitwidth(); 
   }
+  std::string debugPrint(TermId id); 
 
 };