From 76a7010156d3dc696b25c32483467ec39b92f2ed Mon Sep 17 00:00:00 2001 From: Liana Hadarean Date: Tue, 29 Jan 2013 23:09:03 -0500 Subject: [PATCH] fixing slicer bugs. --- src/theory/bv/slicer.cpp | 50 ++++++++++++++++++---------------------- src/theory/bv/slicer.h | 43 +++++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index c624b9c5e..80a52525d 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -207,17 +207,14 @@ void UnionFind::merge(TermId t1, TermId t2) { if (t1 == t2) return; - Node n1 = getNode(t1); - Node n2 = getNode(t2); - Assert (! n1.hasChildren() && ! n2.hasChildren()); - n1.setRepr(t2); + Assert (! hasChildren(t1) && ! hasChildren(t2)); + setRepr(t1, t2); d_representatives.erase(t1); } TermId UnionFind::find(TermId id) const { - Node node = getNode(id); - if (node.getRepr() != UndefinedId) - return find(node.getRepr()); + if (getRepr(id) != UndefinedId) + return find(getRepr(id)); return id; } /** @@ -231,27 +228,25 @@ TermId UnionFind::find(TermId id) const { void UnionFind::split(TermId id, Index i) { Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl; id = find(id); - Node node = getNode(id); - Debug("bv-slicer-uf") << " node: " << node.debugPrint() << endl; - Assert (i < node.getBitwidth()); + Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl; - if (i == 0 || i == node.getBitwidth()) { + if (i == 0 || i == getBitwidth(id)) { // nothing to do return; } - - if (!node.hasChildren()) { + Assert (i < getBitwidth(id)); + if (!hasChildren(id)) { // first time we split this term TermId bottom_id = addTerm(i); - TermId top_id = addTerm(node.getBitwidth() - i); - node.setChildren(top_id, bottom_id); + TermId top_id = addTerm(getBitwidth(id) - i); + setChildren(id, top_id, bottom_id); } else { - Index cut = node.getCutPoint(*this); + Index cut = getCutPoint(id); if (i < cut ) - split(node.getChild(0), i); + split(getChild(id, 1), i); else - split(node.getChild(1), i - cut); + split(getChild(id, 0), i - cut); } } @@ -271,32 +266,31 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) // making sure the term is aligned TermId id = find(term.id); - Node node = getNode(id); - Assert (term.high < node.getBitwidth()); + Assert (term.high < getBitwidth(id)); // because we split the node, this must be the whole extract - if (!node.hasChildren()) { - Assert (term.high == node.getBitwidth() - 1 && + if (!hasChildren(id)) { + Assert (term.high == getBitwidth(id) - 1 && term.low == 0); decomp.push_back(id); } - Index cut = node.getCutPoint(*this); + Index cut = getCutPoint(id); if (term.low < cut && term.high < cut) { // the extract falls entirely on the low child - ExtractTerm child_ex(node.getChild(0), term.high, term.low); + ExtractTerm child_ex(getChild(id, 0), term.high, term.low); getDecomposition(child_ex, decomp); } else if (term.low >= cut && term.high >= cut){ // the extract falls entirely on the high child - ExtractTerm child_ex(node.getChild(1), term.high - cut, term.low - cut); + ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut); getDecomposition(child_ex, decomp); } else { // the extract is split over the two children - ExtractTerm low_child(node.getChild(0), cut - 1, term.low); + ExtractTerm low_child(getChild(id, 0), cut - 1, term.low); getDecomposition(low_child, decomp); - ExtractTerm high_child(node.getChild(1), term.high, cut); + ExtractTerm high_child(getChild(id, 1), term.high, cut); getDecomposition(high_child, decomp); } } @@ -397,7 +391,7 @@ void UnionFind::ensureSlicing(const ExtractTerm& term) { */ ExtractTerm Slicer::registerTerm(TNode node) { - Index low = 0, high = utils::getSize(node); + Index low = 0, high = utils::getSize(node) - 1; TNode n = node; if (node.getKind() == kind::BITVECTOR_EXTRACT) { n = node[0]; diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index c4b3b06a1..b27b85e65 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -119,7 +119,7 @@ class UnionFind { class Node { Index d_bitwidth; TermId d_ch1, d_ch2; - TermId d_repr; + TermId d_repr; public: Node(Index b) : d_bitwidth(b), @@ -136,23 +136,18 @@ class UnionFind { Assert (i < 2); return i == 0? d_ch1 : d_ch2; } - Index getCutPoint(const UnionFind& uf) const { - Assert (d_ch1 != UndefinedId && d_ch2 != UndefinedId); - return uf.getNode(d_ch1).getBitwidth(); - } void setRepr(TermId id) { Assert (! hasChildren()); d_repr = id; } - void setChildren(TermId ch1, TermId ch2) { Assert (d_repr == UndefinedId && !hasChildren()); d_ch1 = ch1; d_ch2 = ch2; } - std::string debugPrint() const; + std::string debugPrint() const; }; - + /// map from TermId to the nodes that represent them std::vector d_nodes; /// a term is in this set if it is its own representative @@ -160,6 +155,32 @@ class UnionFind { void getDecomposition(const ExtractTerm& term, Decomposition& decomp); void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); + /// getter methods for the internal nodes + TermId getRepr(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getRepr(); + } + TermId getChild(TermId id, Index i) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getChild(i); + } + Index getCutPoint(TermId id) const { + return getBitwidth(getChild(id, 1)); + } + bool hasChildren(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].hasChildren(); + } + /// setter methods for the internal nodes + void setRepr(TermId id, TermId new_repr) { + Assert (id < d_nodes.size()); + d_nodes[id].setRepr(new_repr); + } + void setChildren(TermId id, TermId ch1, TermId ch2) { + Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch2)); + d_nodes[id].setChildren(ch1, ch2); + } + public: UnionFind() @@ -176,11 +197,6 @@ public: void getNormalForm(const ExtractTerm& term, NormalForm& nf); void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); void ensureSlicing(const ExtractTerm& term); - - Node getNode(TermId id) const { - Assert (id < d_nodes.size()); - return d_nodes[id]; - } Index getBitwidth(TermId id) const { Assert (id < d_nodes.size()); return d_nodes[id].getBitwidth(); @@ -208,6 +224,7 @@ public: static void splitEqualities(TNode node, std::vector& equalities); }; + }/* CVC4::theory::bv namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ -- 2.30.2