From 6875e78dc08bd345061e38c0fabb0efd2ceff41d Mon Sep 17 00:00:00 2001 From: lianah Date: Wed, 30 Jan 2013 20:02:47 -0500 Subject: [PATCH] fixed some more bugs --- src/theory/bv/slicer.cpp | 92 +++++++++++++++++++++++++--------------- src/theory/bv/slicer.h | 31 ++++++++------ 2 files changed, 77 insertions(+), 46 deletions(-) diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index 80a52525d..79f3f5b68 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -34,7 +34,7 @@ const TermId CVC4::theory::bv::UndefinedId = -1; */ Base::Base(uint32_t size) : d_size(size), - d_repr((size-1)/32 + ((size-1) % 32 == 0? 0 : 1), 0) + d_repr(size/32 + (size % 32 == 0? 0 : 1), 0) { Assert (d_size > 0); } @@ -42,7 +42,7 @@ Base::Base(uint32_t size) void Base::sliceAt(Index index) { Index vector_index = index / 32; - Assert (vector_index < d_size - 1); + Assert (vector_index < d_size); Index int_index = index % 32; uint32_t bit_mask = utils::pow2(int_index); d_repr[vector_index] = d_repr[vector_index] | bit_mask; @@ -56,12 +56,12 @@ void Base::sliceWith(const Base& other) { } bool Base::isCutPoint (Index index) const { - // there is an implicit cut point at the end of the bv - if (index == d_size - 1) + // there is an implicit cut point at the end and begining of the bv + if (index == d_size || index == 0) return true; Index vector_index = index / 32; - Assert (vector_index < d_size - 1); + Assert (vector_index < d_size); Index int_index = index % 32; uint32_t bit_mask = utils::pow2(int_index); @@ -88,7 +88,7 @@ std::string Base::debugPrint() const { std::ostringstream os; os << "["; bool first = true; - for (unsigned i = 0; i < d_size - 1; ++i) { + for (int i = d_size - 1; i >= 0; --i) { if (isCutPoint(i)) { if (first) first = false; @@ -118,26 +118,28 @@ std::string ExtractTerm::debugPrint() const { * */ -TermId NormalForm::getTerm(Index i, const UnionFind& uf) const { - Assert (i < base.getBitwidth()); +std::pair NormalForm::getTerm(Index index, const UnionFind& uf) const { + Assert (index < base.getBitwidth()); Index count = 0; for (unsigned i = 0; i < decomp.size(); ++i) { Index size = uf.getBitwidth(decomp[i]); - if ( count + size <= i && count >= i) { - return decomp[i]; + if ( count + size > index && index >= count) { + return pair(decomp[i], count); } count += size; } Unreachable(); } + + std::string NormalForm::debugPrint(const UnionFind& uf) const { ostringstream os; os << "NF " << base.debugPrint() << endl; os << "("; - for (unsigned i = 0; i < decomp.size(); ++i) { + for (int i = decomp.size() - 1; i>= 0; --i) { os << decomp[i] << "[" << uf.getBitwidth(decomp[i]) <<"]"; - os << (i < decomp.size() - 1? ", " : ""); + os << (i != 0? ", " : ""); } os << ") \n"; return os.str(); @@ -150,7 +152,7 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const { std::string UnionFind::Node::debugPrint() const { ostringstream os; os << "Repr " << d_repr << " ["<< d_bitwidth << "] "; - os << "( " << d_ch1 <<", " << d_ch2 << ")" << endl; + os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; return os.str(); } @@ -213,8 +215,9 @@ void UnionFind::merge(TermId t1, TermId t2) { } TermId UnionFind::find(TermId id) const { - if (getRepr(id) != UndefinedId) - return find(getRepr(id)); + TermId repr = getRepr(id); + if (repr != UndefinedId) + return find(repr); return id; } /** @@ -244,13 +247,14 @@ void UnionFind::split(TermId id, Index i) { } else { Index cut = getCutPoint(id); if (i < cut ) - split(getChild(id, 1), i); + split(getChild(id, 0), i); else - split(getChild(id, 0), i - cut); + split(getChild(id, 1), i - cut); } } void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) { + nf.clear(); getDecomposition(term, nf.decomp); // update nf base Index count = 0; @@ -271,7 +275,8 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) if (!hasChildren(id)) { Assert (term.high == getBitwidth(id) - 1 && term.low == 0); - decomp.push_back(id); + decomp.push_back(id); + return; } Index cut = getCutPoint(id); @@ -290,7 +295,7 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) // the extract is split over the two children ExtractTerm low_child(getChild(id, 0), cut - 1, term.low); getDecomposition(low_child, decomp); - ExtractTerm high_child(getChild(id, 1), term.high, cut); + ExtractTerm high_child(getChild(id, 1), term.high - cut, 0); getDecomposition(high_child, decomp); } } @@ -322,11 +327,11 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit start1 = start1 > start2 ? start2 : start1; start2 = start1 > start2 ? start1 : start2; - if (start1 + common_size <= start2) { + if (start2 - start1 < common_size) { Index overlap = start1 + common_size - start2; Assert (overlap > 0); - Index diff = start2 - overlap; - Assert (diff > 0); + Index diff = common_size - overlap; + Assert (diff >= 0); Index granularity = utils::gcd(diff, overlap); // split the common part for (unsigned i = 0; i < common_size; i+= granularity) { @@ -362,13 +367,14 @@ void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2 // align the cuts points of the two slicings // FIXME: this can be done more efficiently Base& cuts = nf1.base; + cuts.debugPrint(); cuts.sliceWith(nf2.base); for (unsigned i = 0; i < cuts.getBitwidth(); ++i) { if (cuts.isCutPoint(i)) { - TermId t1 = nf1.getTerm(i, *this); - split(t1, i); - TermId t2 = nf2.getTerm(i, *this); - split(t2, i); + pair pair1 = nf1.getTerm(i, *this); + split(pair1.first, i - pair1.second); + pair pair2 = nf2.getTerm(i, *this); + split(pair2.first, i - pair2.second); } } } @@ -423,23 +429,24 @@ void Slicer::processEquality(TNode eq) { d_unionFind.alignSlicings(a_ex, b_ex); d_unionFind.unionTerms(a_ex, b_ex); - + Debug("bv-slicer") << "Slicer::processEquality done. " << endl; } void Slicer::getBaseDecomposition(TNode node, std::vector& decomp) { Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; - Index high = utils::getSize(node); - Index low = 0; + Index high = utils::getSize(node) - 1; + Index low = 0; + TNode top = node; if (node.getKind() == kind::BITVECTOR_EXTRACT) { high = utils::getExtractHigh(node); low = utils::getExtractLow(node); - node = node[0]; + top = node[0]; } - Assert (d_nodeToId.find(node) != d_nodeToId.end()); - TermId id = d_nodeToId[node]; - NormalForm nf(utils::getSize(node)); + Assert (d_nodeToId.find(top) != d_nodeToId.end()); + TermId id = d_nodeToId[top]; + NormalForm nf(high-low+1); d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf); // construct actual extract nodes @@ -448,7 +455,7 @@ void Slicer::getBaseDecomposition(TNode node, std::vector& decomp) { for (unsigned i = 0; i < nf.decomp.size(); ++i) { Index current_size = d_unionFind.getBitwidth(nf.decomp[i]); current_high += current_size; - Node current = utils::mkExtract(node, current_high - 1, current_low); + Node current = Rewriter::rewrite(utils::mkExtract(node, current_high - 1, current_low)); current_low += current_size; decomp.push_back(current); } @@ -528,3 +535,20 @@ void Slicer::splitEqualities(TNode node, std::vector& equalities) { equalities.push_back(node); } } + +std::string UnionFind::debugPrint(TermId id) { + ostringstream os; + if (hasChildren(id)) { + TermId id1 = find(getChild(id, 1)); + TermId id0 = find(getChild(id, 0)); + os << debugPrint(id1) <<" "; + os << debugPrint(id0) <<" "; + } else { + if (getRepr(id) == UndefinedId) { + os << id <<"[" << getBitwidth(id) <<"] "; + } else { + os << debugPrint(find(id)) << " "; + } + } + return os.str(); +} diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index b27b85e65..c7451c288 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -60,6 +60,11 @@ public: bool isEmpty() const; std::string debugPrint() const; Index getBitwidth() const { return d_size; } + void clear() { + for (unsigned i = 0; i < d_repr.size(); ++i) { + d_repr[i] = 0; + } + } bool operator==(const Base& other) const { if (other.getBitwidth() != getBitwidth()) return false; @@ -110,40 +115,41 @@ struct NormalForm { * * @return */ - TermId getTerm(Index i, const UnionFind& uf) const; - std::string debugPrint(const UnionFind& uf) const; + std::pair getTerm(Index i, const UnionFind& uf) const; + std::string debugPrint(const UnionFind& uf) const; + void clear() { base.clear(); decomp.clear(); } }; class UnionFind { class Node { Index d_bitwidth; - TermId d_ch1, d_ch2; + TermId d_ch1, d_ch0; TermId d_repr; public: Node(Index b) : d_bitwidth(b), d_ch1(UndefinedId), - d_ch2(UndefinedId), + d_ch0(UndefinedId), d_repr(UndefinedId) {} TermId getRepr() const { return d_repr; } Index getBitwidth() const { return d_bitwidth; } - bool hasChildren() const { return d_ch1 != UndefinedId && d_ch2 != UndefinedId; } + bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; } TermId getChild(Index i) const { Assert (i < 2); - return i == 0? d_ch1 : d_ch2; + return i == 0? d_ch0 : d_ch1; } void setRepr(TermId id) { Assert (! hasChildren()); d_repr = id; } - void setChildren(TermId ch1, TermId ch2) { + void setChildren(TermId ch1, TermId ch0) { Assert (d_repr == UndefinedId && !hasChildren()); d_ch1 = ch1; - d_ch2 = ch2; + d_ch0 = ch0; } std::string debugPrint() const; }; @@ -165,7 +171,7 @@ class UnionFind { return d_nodes[id].getChild(i); } Index getCutPoint(TermId id) const { - return getBitwidth(getChild(id, 1)); + return getBitwidth(getChild(id, 0)); } bool hasChildren(TermId id) const { Assert (id < d_nodes.size()); @@ -176,9 +182,9 @@ class UnionFind { Assert (id < d_nodes.size()); d_nodes[id].setRepr(new_repr); } - void setChildren(TermId id, TermId ch1, TermId ch2) { - Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch2)); - d_nodes[id].setChildren(ch1, ch2); + void setChildren(TermId id, TermId ch1, TermId ch0) { + Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)); + d_nodes[id].setChildren(ch1, ch0); } @@ -201,6 +207,7 @@ public: Assert (id < d_nodes.size()); return d_nodes[id].getBitwidth(); } + std::string debugPrint(TermId id); }; -- 2.30.2