fixing slicer bugs.
authorLiana Hadarean <lianahady@gmail.com>
Wed, 30 Jan 2013 04:09:03 +0000 (23:09 -0500)
committerLiana Hadarean <lianahady@gmail.com>
Wed, 30 Jan 2013 04:09:03 +0000 (23:09 -0500)
src/theory/bv/slicer.cpp
src/theory/bv/slicer.h

index c624b9c5e6b6b4e76f3ec72d5680724e2511dd8f..80a52525d73330d6a44c5700d11b80cf0ea8005a 100644 (file)
@@ -207,17 +207,14 @@ void UnionFind::merge(TermId t1, TermId t2) {
   if (t1 == t2)
     return;
 
-  Node n1 = getNode(t1); 
-  Node n2 = getNode(t2);
-  Assert (! n1.hasChildren() && ! n2.hasChildren());
-  n1.setRepr(t2); 
+  Assert (! hasChildren(t1) && ! hasChildren(t2));
+  setRepr(t1, t2); 
   d_representatives.erase(t1); 
 }
 
 TermId UnionFind::find(TermId id) const {
-  Node node = getNode(id); 
-  if (node.getRepr() != UndefinedId)
-    return find(node.getRepr());
+  if (getRepr(id) != UndefinedId)
+    return find(getRepr(id));
   return id; 
 }
 /** 
@@ -231,27 +228,25 @@ TermId UnionFind::find(TermId id) const {
 void UnionFind::split(TermId id, Index i) {
   Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl;
   id = find(id); 
-  Node node = getNode(id);
-  Debug("bv-slicer-uf") << "   node: " << node.debugPrint() << endl;
-  Assert (i < node.getBitwidth());
+  Debug("bv-slicer-uf") << "   node: " << d_nodes[id].debugPrint() << endl;
 
-  if (i == 0 || i == node.getBitwidth()) {
+  if (i == 0 || i == getBitwidth(id)) {
     // nothing to do 
     return;
   }
-
-  if (!node.hasChildren()) {
+  Assert (i < getBitwidth(id));
+  if (!hasChildren(id)) {
     // first time we split this term 
     TermId bottom_id = addTerm(i);
-    TermId top_id = addTerm(node.getBitwidth() - i);
-    node.setChildren(top_id, bottom_id);
+    TermId top_id = addTerm(getBitwidth(id) - i);
+    setChildren(id, top_id, bottom_id);
 
   } else {
-    Index cut = node.getCutPoint(*this); 
+    Index cut = getCutPoint(id); 
     if (i < cut )
-      split(node.getChild(0), i);
+      split(getChild(id, 1), i);
     else
-      split(node.getChild(1), i - cut); 
+      split(getChild(id, 0), i - cut); 
   }
 }
 
@@ -271,32 +266,31 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp)
   // making sure the term is aligned
   TermId id = find(term.id); 
 
-  Node node = getNode(id);
-  Assert (term.high < node.getBitwidth());
+  Assert (term.high < getBitwidth(id));
   // because we split the node, this must be the whole extract
-  if (!node.hasChildren()) {
-    Assert (term.high == node.getBitwidth() - 1 &&
+  if (!hasChildren(id)) {
+    Assert (term.high == getBitwidth(id) - 1 &&
             term.low == 0);
     decomp.push_back(id); 
   }
     
-  Index cut = node.getCutPoint(*this);
+  Index cut = getCutPoint(id);
   
   if (term.low < cut && term.high < cut) {
     // the extract falls entirely on the low child
-    ExtractTerm child_ex(node.getChild(0), term.high, term.low); 
+    ExtractTerm child_ex(getChild(id, 0), term.high, term.low); 
     getDecomposition(child_ex, decomp); 
   }
   else if (term.low >= cut && term.high >= cut){
     // the extract falls entirely on the high child
-    ExtractTerm child_ex(node.getChild(1), term.high - cut, term.low - cut);
+    ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut);
     getDecomposition(child_ex, decomp); 
   }
   else {
     // the extract is split over the two children
-    ExtractTerm low_child(node.getChild(0), cut - 1, term.low);
+    ExtractTerm low_child(getChild(id, 0), cut - 1, term.low);
     getDecomposition(low_child, decomp);
-    ExtractTerm high_child(node.getChild(1), term.high, cut);
+    ExtractTerm high_child(getChild(id, 1), term.high, cut);
     getDecomposition(high_child, decomp); 
   }
 }
@@ -397,7 +391,7 @@ void UnionFind::ensureSlicing(const ExtractTerm& term) {
  */
 
 ExtractTerm Slicer::registerTerm(TNode node) {
-  Index low = 0, high = utils::getSize(node); 
+  Index low = 0, high = utils::getSize(node) - 1
   TNode n = node; 
   if (node.getKind() == kind::BITVECTOR_EXTRACT) {
     n = node[0];
index c4b3b06a1f160011e83c9912a8bcbd4b200a8d01..b27b85e65dbdc30bc576f3fd09b5313edf67289f 100644 (file)
@@ -119,7 +119,7 @@ class UnionFind {
   class Node {
     Index d_bitwidth;
     TermId d_ch1, d_ch2;
-    TermId d_repr;    
+    TermId d_repr;
   public:
     Node(Index b)
   : d_bitwidth(b),
@@ -136,23 +136,18 @@ class UnionFind {
       Assert (i < 2);
       return i == 0? d_ch1 : d_ch2;
     }
-    Index getCutPoint(const UnionFind& uf) const {
-      Assert (d_ch1 != UndefinedId && d_ch2 != UndefinedId);
-      return uf.getNode(d_ch1).getBitwidth(); 
-    }
     void setRepr(TermId id) {
       Assert (! hasChildren());
       d_repr = id;
     }
-
     void setChildren(TermId ch1, TermId ch2) {
       Assert (d_repr == UndefinedId && !hasChildren());
       d_ch1 = ch1;
       d_ch2 = ch2; 
     }
-    std::string debugPrint() const; 
+    std::string debugPrint() const;
   };
-
+  
   /// map from TermId to the nodes that represent them 
   std::vector<Node> d_nodes;
   /// a term is in this set if it is its own representative
@@ -160,6 +155,32 @@ class UnionFind {
   
   void getDecomposition(const ExtractTerm& term, Decomposition& decomp);
   void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common);
+  /// getter methods for the internal nodes
+  TermId getRepr(TermId id)  const {
+    Assert (id < d_nodes.size());
+    return d_nodes[id].getRepr(); 
+  }
+  TermId getChild(TermId id, Index i) const {
+    Assert (id < d_nodes.size());
+    return d_nodes[id].getChild(i); 
+  }
+  Index getCutPoint(TermId id) const {
+    return getBitwidth(getChild(id, 1)); 
+  }
+  bool hasChildren(TermId id) const {
+    Assert (id < d_nodes.size());
+    return d_nodes[id].hasChildren(); 
+  }
+  /// setter methods for the internal nodes
+  void setRepr(TermId id, TermId new_repr) {
+    Assert (id < d_nodes.size());
+    d_nodes[id].setRepr(new_repr); 
+  }
+  void setChildren(TermId id, TermId ch1, TermId ch2) {
+    Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch2));
+    d_nodes[id].setChildren(ch1, ch2); 
+  }
+
   
 public:
   UnionFind()
@@ -176,11 +197,6 @@ public:
   void getNormalForm(const ExtractTerm& term, NormalForm& nf);
   void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2);
   void ensureSlicing(const ExtractTerm& term);
-  
-  Node getNode(TermId id) const {
-    Assert (id < d_nodes.size());
-    return d_nodes[id]; 
-  }
   Index getBitwidth(TermId id) const {
     Assert (id < d_nodes.size());
     return d_nodes[id].getBitwidth(); 
@@ -208,6 +224,7 @@ public:
   static void splitEqualities(TNode node, std::vector<Node>& equalities); 
 }; 
 
+
 }/* CVC4::theory::bv namespace */
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */