From 2d091366f7d437c3839307b1ad732a6999333fe0 Mon Sep 17 00:00:00 2001 From: lianah Date: Wed, 27 Mar 2013 17:48:39 -0400 Subject: [PATCH] reverted the core solver to do static slicing, added option --bv-core-solver --- src/theory/bv/bv_subtheory_core.cpp | 124 +++---- src/theory/bv/bv_subtheory_core.h | 2 +- src/theory/bv/options | 4 + src/theory/bv/slicer.cpp | 507 +++++----------------------- src/theory/bv/slicer.h | 236 ++++--------- src/theory/bv/theory_bv.cpp | 12 +- 6 files changed, 207 insertions(+), 678 deletions(-) diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index 91fed8a67..f8c26c35a 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -20,6 +20,7 @@ #include "theory/bv/theory_bv_utils.h" #include "theory/bv/slicer.h" #include "theory/model.h" +#include "theory/bv/options.h" using namespace std; using namespace CVC4; @@ -32,7 +33,7 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv) : SubtheorySolver(c, bv), d_notify(*this), d_equalityEngine(d_notify, c, "theory::bv::TheoryBV"), - d_slicer(new Slicer(c, this)), + d_slicer(new Slicer()), d_isCoreTheory(c, true), d_reasons(c) { @@ -85,7 +86,9 @@ void CoreSolver::preRegister(TNode node) { if (node.getKind() == kind::EQUAL) { d_equalityEngine.addTriggerEquality(node); - // d_slicer->processEquality(node); + if (options::bitvectorCoreSolver()) { + d_slicer->processEquality(node); + } } else { d_equalityEngine.addTerm(node); } @@ -102,9 +105,9 @@ void CoreSolver::explain(TNode literal, std::vector& assumptions) { } } -Node CoreSolver::getBaseDecomposition(TNode a, std::vector& explanation) { +Node CoreSolver::getBaseDecomposition(TNode a) { std::vector a_decomp; - d_slicer->getBaseDecomposition(a, a_decomp, explanation); + d_slicer->getBaseDecomposition(a, a_decomp); Node new_a = utils::mkConcat(a_decomp); Debug("bv-slicer") << "CoreSolver::getBaseDecomposition " << a <<" => " << new_a << "\n"; return new_a; @@ -116,77 +119,49 @@ bool CoreSolver::decomposeFact(TNode fact) { // concat: // a == a_1 concat ... concat a_k // b == b_1 concat ... concat b_k + Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl; + // FIXME: are this the right things to assert? + // assert decompositions since the equality engine does not know the semantics of + // concat: + // a == a_1 concat ... concat a_k + // b == b_1 concat ... concat b_k + TNode eq = fact.getKind() == kind::NOT? fact[0] : fact; + TNode a = eq[0]; + TNode b = eq[1]; + Node new_a = getBaseDecomposition(a); + Node new_b = getBaseDecomposition(b); + + Assert (utils::getSize(new_a) == utils::getSize(new_b) && + utils::getSize(new_a) == utils::getSize(a)); + + NodeManager* nm = NodeManager::currentNM(); + Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a); + Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b); + + bool ok = true; + ok = assertFactToEqualityEngine(a_eq_new_a, utils::mkTrue()); + if (!ok) return false; + ok = assertFactToEqualityEngine(b_eq_new_b, utils::mkTrue()); + if (!ok) return false; + ok = assertFactToEqualityEngine(fact, fact); + if (!ok) return false; + if (fact.getKind() == kind::EQUAL) { - TNode a = fact[0]; - TNode b = fact[1]; - - d_slicer->processEquality(fact); - std::vector explanation_a; - Node new_a = getBaseDecomposition(a, explanation_a); - Node reason_a = mkAnd(explanation_a); - d_reasons.insert(reason_a); - - std::vector explanation_b; - Node new_b = getBaseDecomposition(b, explanation_b); - Node reason_b = mkAnd(explanation_b); - d_reasons.insert(reason_b); - - std::vector explanation; - explanation.push_back(fact); - explanation.insert(explanation.end(), explanation_a.begin(), explanation_a.end()); - explanation.insert(explanation.end(), explanation_b.begin(), explanation_b.end()); - - Node reason = utils::mkAnd(explanation); - d_reasons.insert(reason); - - Assert (utils::getSize(new_a) == utils::getSize(new_b) && - utils::getSize(new_a) == utils::getSize(a)); - - NodeManager* nm = NodeManager::currentNM(); - Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a); - Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b); - - bool ok = true; - ok = assertFactToEqualityEngine(a_eq_new_a, reason_a); - if (!ok) return false; - ok = assertFactToEqualityEngine(b_eq_new_b, reason_a); - if (!ok) return false; // assert the individual equalities as well // a_i == b_i if (new_a.getKind() == kind::BITVECTOR_CONCAT && new_b.getKind() == kind::BITVECTOR_CONCAT) { + Assert (new_a.getNumChildren() == new_b.getNumChildren()); for (unsigned i = 0; i < new_a.getNumChildren(); ++i) { Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]); - // this reason is not very precise!! - ok = assertFactToEqualityEngine(eq_i, reason); - d_reasons.insert(eq_i); + ok = assertFactToEqualityEngine(eq_i, fact); if (!ok) return false; } } - // merge the two terms in the slicer as well - d_slicer->assertEquality(fact); - } else { - // still need to register the terms - d_slicer->processEquality(fact[0]); - TNode a = fact[0][0]; - TNode b = fact[0][1]; - std::vector explanation_a; - Node new_a = getBaseDecomposition(a, explanation_a); - Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a); - assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a); - - std::vector explanation_b; - Node new_b = getBaseDecomposition(b, explanation_b); - Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b); - assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b); - - d_reasons.insert(reason_a); - d_reasons.insert(reason_b); } - // finally assert the actual fact to the equality engine - return assertFactToEqualityEngine(fact, fact); + return true; } bool CoreSolver::check(Theory::Effort e) { @@ -205,8 +180,11 @@ bool CoreSolver::check(Theory::Effort e) { // only reason about equalities if (fact.getKind() == kind::EQUAL || (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL)) { - // ok = decomposeFact(fact); - ok = assertFactToEqualityEngine(fact, fact); + if (options::bitvectorCoreSolver()) { + ok = decomposeFact(fact); + } else { + ok = assertFactToEqualityEngine(fact, fact); + } } else { ok = assertFactToEqualityEngine(fact, fact); } @@ -214,16 +192,6 @@ bool CoreSolver::check(Theory::Effort e) { return false; } - // make sure to assert the new splits - // std::vector new_splits; - // d_slicer->getNewSplits(new_splits); - // for (unsigned i = 0; i < new_splits.size(); ++i) { - // ok = assertFactToEqualityEngine(new_splits[i], utils::mkTrue()); - // if (!ok) - // return false; - // } - - // if we are sat and in full check attempt to construct a model if (Theory::fullEffort(e) && isComplete()) { buildModel(); } @@ -232,6 +200,10 @@ bool CoreSolver::check(Theory::Effort e) { } void CoreSolver::buildModel() { + if (options::bitvectorCoreSolver()) { + // FIXME + return; + } Debug("bv-core") << "CoreSolver::buildModel() \n"; d_modelValues.clear(); TNodeSet constants; @@ -381,6 +353,10 @@ void CoreSolver::conflict(TNode a, TNode b) { } void CoreSolver::collectModelInfo(TheoryModel* m) { + if (options::bitvectorCoreSolver()) { + Unreachable(); + return; + } if (Debug.isOn("bitvector-model")) { context::CDQueue::const_iterator it = d_assertionQueue.begin(); for (; it!= d_assertionQueue.end(); ++it) { diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index d04dc164f..d314b2fbf 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -77,7 +77,7 @@ class CoreSolver : public SubtheorySolver { void buildModel(); bool assertFactToEqualityEngine(TNode fact, TNode reason); bool decomposeFact(TNode fact); - Node getBaseDecomposition(TNode a, std::vector& explanation); + Node getBaseDecomposition(TNode a); Statistics d_statistics; public: CoreSolver(context::Context* c, TheoryBV* bv); diff --git a/src/theory/bv/options b/src/theory/bv/options index 8e01c6572..cdc02c9ad 100644 --- a/src/theory/bv/options +++ b/src/theory/bv/options @@ -16,4 +16,8 @@ option bitvectorEagerFullcheck --bitblast-eager-fullcheck bool option bitvectorInequalitySolver --bv-inequality-solver bool turn on the inequality solver for the bit-vector theory + +option bitvectorCoreSolver --bv-core-solver bool + turn on the core solver for the bit-vector theory + endmodule diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index 121802b65..2837b075f 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -19,7 +19,7 @@ #include "theory/bv/slicer.h" #include "theory/bv/theory_bv_utils.h" #include "theory/rewriter.h" -#include "theory/bv/bv_subtheory_core.h" +#include "theory/bv/options.h" using namespace CVC4; using namespace CVC4::theory; using namespace CVC4::theory::bv; @@ -41,24 +41,15 @@ Base::Base(uint32_t size) void Base::sliceAt(Index index) { - if (index == d_size) - return; - Assert(index < d_size); Index vector_index = index / 32; - Assert (vector_index < d_repr.size()); + if (vector_index == d_repr.size()) + return; + Index int_index = index % 32; uint32_t bit_mask = utils::pow2(int_index); d_repr[vector_index] = d_repr[vector_index] | bit_mask; } -void Base::undoSliceAt(Index index) { - Index vector_index = index / 32; - Assert (vector_index < d_size); - Index int_index = index % 32; - uint32_t bit_mask = utils::pow2(int_index); - d_repr[vector_index] = d_repr[vector_index] ^ bit_mask; -} - void Base::sliceWith(const Base& other) { Assert (d_size == other.d_size); for (unsigned i = 0; i < d_repr.size(); ++i) { @@ -156,13 +147,13 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const { return os.str(); } /** - * UnionFind::EqualityNode + * UnionFind::Node * */ -std::string UnionFind::EqualityNode::debugPrint() const { +std::string UnionFind::Node::debugPrint() const { ostringstream os; - os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] "; + os << "Repr " << d_repr << " ["<< d_bitwidth << "] "; os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; return os.str(); } @@ -172,94 +163,27 @@ std::string UnionFind::EqualityNode::debugPrint() const { * UnionFind * */ - -TermId UnionFind::registerTopLevelTerm(Index bitwidth) { - TermId id = mkEqualityNode(bitwidth); - d_topLevelIds.insert(id); - return id; -} - -TermId UnionFind::mkEqualityNode(Index bitwidth) { - Assert (bitwidth > 0); - EqualityNode node(bitwidth); - d_equalityNodes.push_back(node); - +TermId UnionFind::addTerm(Index bitwidth) { + Node node(bitwidth); + d_nodes.push_back(node); ++(d_statistics.d_numNodes); - TermId id = d_equalityNodes.size() - 1; - // d_representatives.insert(id); + TermId id = d_nodes.size() - 1; + d_representatives.insert(id); ++(d_statistics.d_numRepresentatives); + Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl; return id; } -/** - * Create an extract term making sure there are no nested extracts. - * - * @param id - * @param high - * @param low - * - * @return - */ -ExtractTerm UnionFind::mkExtractTerm(TermId id, Index high, Index low) { - if (d_topLevelIds.find(id) != d_topLevelIds.end()) { - return ExtractTerm(id, high, low); - } - Assert (isExtractTerm(id)); - ExtractTerm top = getExtractTerm(id); - Assert (d_topLevelIds.find(top.id) != d_topLevelIds.end()); - - Index top_low = top.low; - Assert (top.high - top_low + 1 > high); - high += top_low; - low += top_low; - id = top.id; - return ExtractTerm(id, high, low); -} - -/** - * Associate the given extract term with the given id. - * - * @param id - * @param extract - */ -void UnionFind::storeExtractTerm(TermId id, const ExtractTerm& extract) { - if (d_extractToId.find(extract) != d_extractToId.end()) { - Assert (d_extractToId[extract] == id); - return; - } - Debug("bv-slicer") << "UnionFind::storeExtract " << extract.debugPrint() << " => id" << id << "\n"; - d_idToExtract[id] = extract; - d_extractToId[extract] = id; - } - -TermId UnionFind::addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low) { - ExtractTerm extract(id, high, low); - if (d_extractToId.find(extract) != d_extractToId.end()) { - // if the extract already exists we don't need to make a new node - TermId extract_id = d_extractToId[extract]; - Assert (extract_id < d_equalityNodes.size()); - return extract_id; - } - // otherwise make an equality node for it and store the extract - TermId node_id = mkEqualityNode(bitwidth); - storeExtractTerm(node_id, extract); - return node_id; -} - /** * At this point we assume the slicings of the two terms are properly aligned. * * @param t1 * @param t2 */ -void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) { - const ExtractTerm& t1 = getExtractTerm(id1); - const ExtractTerm& t2 = getExtractTerm(id2); - +void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) { Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n" - << " " << t2.debugPrint() << "\n" - << " with reason " << reason << endl; + << " " << t2.debugPrint() << endl; Assert (t1.getBitwidth() == t2.getBitwidth()); NormalForm nf1(t1.getBitwidth()); @@ -272,11 +196,10 @@ void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) { Assert (nf1.base == nf2.base); for (unsigned i = 0; i < nf1.decomp.size(); ++i) { - merge (nf1.decomp[i], nf2.decomp[i], reason); + merge (nf1.decomp[i], nf2.decomp[i]); } } - /** * Merge the two terms in the union find. Both t1 and t2 * should be root terms. @@ -284,7 +207,7 @@ void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) { * @param t1 * @param t2 */ -void UnionFind::merge(TermId t1, TermId t2, TermId reason) { +void UnionFind::merge(TermId t1, TermId t2) { Debug("bv-slicer-uf") << "UnionFind::merge (" << t1 <<", " << t2 << ")" << endl; ++(d_statistics.d_numMerges); t1 = find(t1); @@ -294,9 +217,8 @@ void UnionFind::merge(TermId t1, TermId t2, TermId reason) { return; Assert (! hasChildren(t1) && ! hasChildren(t2)); - setRepr(t1, t2, reason); - recordOperation(UnionFind::MERGE, t1); - //d_representatives.erase(t1); + setRepr(t1, t2); + d_representatives.erase(t1); d_statistics.d_numRepresentatives += -1; } @@ -304,26 +226,11 @@ TermId UnionFind::find(TermId id) { TermId repr = getRepr(id); if (repr != UndefinedId) { TermId find_id = find(repr); + setRepr(id, find_id); return find_id; } return id; } - -TermId UnionFind::findWithExplanation(TermId id, std::vector& explanation) { - TermId repr = getRepr(id); - - if (repr != UndefinedId) { - TermId reason = getReason(id); - Assert (reason != UndefinedId); - explanation.push_back(reason); - - TermId find_id = findWithExplanation(repr, explanation); - return find_id; - } - return id; -} - - /** * Splits the representative of the term between i-1 and i * @@ -335,30 +242,19 @@ TermId UnionFind::findWithExplanation(TermId id, std::vector& exp void UnionFind::split(TermId id, Index i) { Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl; id = find(id); - Debug("bv-slicer-uf") << " node: " << d_equalityNodes[id].debugPrint() << endl; + Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl; if (i == 0 || i == getBitwidth(id)) { // nothing to do - return; + return; } - Assert (i < getBitwidth(id)); if (!hasChildren(id)) { - // first time we split this term - ExtractTerm bottom_extract = mkExtractTerm(id, i-1, 0); - ExtractTerm top_extract = mkExtractTerm(id, getBitwidth(id) - 1, i); - - TermId bottom_id = extractHasId(bottom_extract)? getExtractId(bottom_extract) : mkEqualityNode(i); - TermId top_id = extractHasId(top_extract)? getExtractId(top_extract) : mkEqualityNode(getBitwidth(id) - i); - storeExtractTerm(bottom_id, bottom_extract); - storeExtractTerm(top_id, top_extract); - + // first time we split this term + TermId bottom_id = addTerm(i); + TermId top_id = addTerm(getBitwidth(id) - i); setChildren(id, top_id, bottom_id); - recordOperation(UnionFind::SPLIT, id); - - if (d_slicer->termInEqualityEngine(id)) { - d_slicer->enqueueSplit(id, i, top_id, bottom_id); - } + } else { Index cut = getCutPoint(id); if (i < cut ) @@ -369,14 +265,6 @@ void UnionFind::split(TermId id, Index i) { ++(d_statistics.d_numSplits); } -// TermId UnionFind::getTopLevel(TermId id) const { -// __gnu_cxx::hash_map >::const_iterator it = d_idToExtract.find(id); -// if (it != d_idToExtract.end()) { -// return (*it).second.id; -// } -// return id; -// } - void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) { nf.clear(); getDecomposition(term, nf.decomp); @@ -423,56 +311,6 @@ void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) getDecomposition(high_child, decomp); } } - -void UnionFind::getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, - std::vector& explanation) { - nf.clear(); - getDecompositionWithExplanation(term, nf.decomp, explanation); - // update nf base - Index count = 0; - for (unsigned i = 0; i < nf.decomp.size(); ++i) { - count += getBitwidth(nf.decomp[i]); - nf.base.sliceAt(count); - } - Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl; - Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl; -} - -void UnionFind::getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, - std::vector& explanation) { - // making sure the term is aligned - TermId id = findWithExplanation(term.id, explanation); - - Assert (term.high < getBitwidth(id)); - // because we split the node, this must be the whole extract - if (!hasChildren(id)) { - Assert (term.high == getBitwidth(id) - 1 && - term.low == 0); - decomp.push_back(id); - return; - } - - Index cut = getCutPoint(id); - - if (term.low < cut && term.high < cut) { - // the extract falls entirely on the low child - ExtractTerm child_ex(getChild(id, 0), term.high, term.low); - getDecompositionWithExplanation(child_ex, decomp, explanation); - } - else if (term.low >= cut && term.high >= cut){ - // the extract falls entirely on the high child - ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut); - getDecompositionWithExplanation(child_ex, decomp, explanation); - } - else { - // the extract is split over the two children - ExtractTerm low_child(getChild(id, 0), cut - 1, term.low); - getDecompositionWithExplanation(low_child, decomp, explanation); - ExtractTerm high_child(getChild(id, 1), term.high - cut, 0); - getDecompositionWithExplanation(high_child, decomp, explanation); - } -} - /** * May cause reslicings of the decompositions. Must not assume the decompositons * are the current normal form. @@ -518,10 +356,7 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit } -void UnionFind::alignSlicings(TermId id1, TermId id2) { - const ExtractTerm& term1 = getExtractTerm(id1); - const ExtractTerm& term2 = getExtractTerm(id2); - +void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) { Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl; Debug("bv-slicer") << " " << term2.debugPrint() << endl; NormalForm nf1(term1.getBitwidth()); @@ -569,169 +404,63 @@ void UnionFind::alignSlicings(TermId id1, TermId id2) { } } while (changed); } - - /** * Given an extract term a[i:j] makes sure a is sliced * at indices i and j. * * @param term */ -void UnionFind::ensureSlicing(TermId t) { - ExtractTerm term = getExtractTerm(t); +void UnionFind::ensureSlicing(const ExtractTerm& term) { //Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl; - TermId id = term.id; + TermId id = find(term.id); split(id, term.high + 1); split(id, term.low); } -void UnionFind::backtrack() { - int size = d_undoStack.size(); - for (int i = size; i > (int)d_undoStackIndex.get(); --i) { - Assert (!d_undoStack.empty()); - Operation op = d_undoStack.back(); - d_undoStack.pop_back(); - if (op.op == UnionFind::MERGE) { - undoMerge(op.id); - } else { - Assert (op.op == UnionFind::SPLIT); - undoSplit(op.id); - } - } -} - -void UnionFind::undoMerge(TermId id) { - Assert (getRepr(id) != UndefinedId); - setRepr(id, UndefinedId, UndefinedId); -} - -void UnionFind::undoSplit(TermId id) { - Assert (hasChildren(id)); - setChildren(id, UndefinedId, UndefinedId); -} - -void UnionFind::recordOperation(OperationKind op, TermId term) { - d_undoStackIndex.set(d_undoStackIndex.get() + 1); - d_undoStack.push_back(Operation(op, term)); - Assert (d_undoStack.size() == d_undoStackIndex); -} - -void UnionFind::getBase(TermId id, Base& base, Index offset) { - id = find(id); - if (!hasChildren(id)) - return; - TermId id1 = find(getChild(id, 1)); - TermId id0 = find(getChild(id, 0)); - Index cut = getCutPoint(id); - base.sliceAt(cut + offset); - getBase(id1, base, cut + offset); - getBase(id0, base, offset); -} - -/// getter methods for the internal nodes -TermId UnionFind::getRepr(TermId id) const { - Assert (id < d_equalityNodes.size()); - return d_equalityNodes[id].getRepr(); -} -ExplanationId UnionFind::getReason(TermId id) const { - Assert (id < d_equalityNodes.size()); - return d_equalityNodes[id].getReason(); -} -TermId UnionFind::getChild(TermId id, Index i) const { - Assert (id < d_equalityNodes.size()); - return d_equalityNodes[id].getChild(i); -} -Index UnionFind::getCutPoint(TermId id) const { - return getBitwidth(getChild(id, 0)); -} -bool UnionFind::hasChildren(TermId id) const { - Assert (id < d_equalityNodes.size()); - return d_equalityNodes[id].hasChildren(); -} - -/// setter methods for the internal nodes -void UnionFind::setRepr(TermId id, TermId new_repr, ExplanationId reason) { - Assert (id < d_equalityNodes.size()); - d_equalityNodes[id].setRepr(new_repr, reason); -} -void UnionFind::setChildren(TermId id, TermId ch1, TermId ch0) { - Assert ((ch1 == UndefinedId && ch0 == UndefinedId) || - (id < d_equalityNodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0))); - d_equalityNodes[id].setChildren(ch1, ch0); -} - - /** * Slicer * */ -TermId Slicer::registerTerm(TNode node) { +ExtractTerm Slicer::registerTerm(TNode node) { + Index low = 0, high = utils::getSize(node) - 1; + TNode n = node; if (node.getKind() == kind::BITVECTOR_EXTRACT) { - TNode n = node[0]; - TermId top_id = registerTopLevelTerm(n); - Index high = utils::getExtractHigh(node); - Index low = utils::getExtractLow(node); - TermId id = d_unionFind.addEqualityNode(utils::getSize(node), top_id, high, low); - return id; + n = node[0]; + high = utils::getExtractHigh(node); + low = utils::getExtractLow(node); } - TermId id = registerTopLevelTerm(node); - return id; -} - -TermId Slicer::registerTopLevelTerm(TNode node) { - Assert (node.getKind() != kind::BITVECTOR_EXTRACT || - node.getKind() != kind::BITVECTOR_CONCAT); - - if (d_nodeToId.find(node) == d_nodeToId.end()) { - TermId id = d_unionFind.registerTopLevelTerm(utils::getSize(node)); - d_idToNode[id] = node; - d_nodeToId[node] = id; - Debug("bv-slicer") << "Slicer::registerTopLevelTerm " << node << " => id" << id << endl; - return id; + if (d_nodeToId.find(n) == d_nodeToId.end()) { + TermId id = d_unionFind.addTerm(utils::getSize(n)); + d_nodeToId[n] = id; + d_idToNode[id] = n; } - return d_nodeToId[node]; + TermId id = d_nodeToId[n]; + ExtractTerm res(id, high, low); + Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl; + return res; } void Slicer::processEquality(TNode eq) { Debug("bv-slicer") << "Slicer::processEquality: " << eq << endl; - - registerEquality(eq); + Assert (eq.getKind() == kind::EQUAL); TNode a = eq[0]; TNode b = eq[1]; - TermId a_id = registerTerm(a); - TermId b_id = registerTerm(b); + ExtractTerm a_ex= registerTerm(a); + ExtractTerm b_ex= registerTerm(b); - d_unionFind.ensureSlicing(a_id); - d_unionFind.ensureSlicing(b_id); + d_unionFind.ensureSlicing(a_ex); + d_unionFind.ensureSlicing(b_ex); - d_unionFind.alignSlicings(a_id, b_id); - - // Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; - // Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; - // Debug("bv-slicer") << "Slicer::processEquality done. " << endl; + d_unionFind.alignSlicings(a_ex, b_ex); + d_unionFind.unionTerms(a_ex, b_ex); + Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; + Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; + Debug("bv-slicer") << "Slicer::processEquality done. " << endl; } -void Slicer::assertEquality(TNode eq) { - Assert (eq.getKind() == kind::EQUAL); - TermId a = registerTerm(eq[0]); - TermId b = registerTerm(eq[1]); - ExplanationId reason = getExplanationId(eq); - d_unionFind.unionTerms(a, b, reason); -} - - -void Slicer::registerEquality(TNode eq) { - if (d_explanationToId.find(eq) == d_explanationToId.end()) { - ExplanationId id = d_explanations.size(); - d_explanations.push_back(eq); - d_explanationToId[eq] = id; - Debug("bv-slicer-explanation") << "Slicer::registerEquality " << eq << " => id"<< id << "\n"; - } -} - -void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation) { +void Slicer::getBaseDecomposition(TNode node, std::vector& decomp) { Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; Index high = utils::getSize(node) - 1; @@ -742,47 +471,45 @@ void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::ve low = utils::getExtractLow(node); top = node[0]; } - Assert (d_nodeToId.find(top) != d_nodeToId.end()); TermId id = d_nodeToId[top]; - NormalForm nf(high-low+1); - std::vector explanation_ids; - d_unionFind.getNormalFormWithExplanation(ExtractTerm(id, high, low), nf, explanation_ids); - - for (unsigned i = 0; i < explanation_ids.size(); ++i) { - Assert (hasExplanation(explanation_ids[i])); - TNode exp = getExplanation(explanation_ids[i]); - explanation.push_back(exp); - } + NormalForm nf(high-low+1); + d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf); - for (int i = nf.decomp.size() - 1; i>=0 ; --i) { - Node current = getNode(nf.decomp[i]); + // construct actual extract nodes + unsigned size = utils::getSize(node); + Index current_low = size; + Index current_high = size; + for (int i = nf.decomp.size() - 1; i >= 0; --i) { + Index current_size = d_unionFind.getBitwidth(nf.decomp[i]); + current_low -= current_size; + Node current = Rewriter::rewrite(utils::mkExtract(node, current_high - 1, current_low)); + current_high = current_low; decomp.push_back(current); } - if (Debug.isOn("bv-slicer-explanation")) { - Debug("bv-slicer-explanation") << "Slicer::getBaseDecomposition for " << node << "\n" - << "as "; - for (unsigned i = 0; i < decomp.size(); ++i) { - Debug("bv-slicer-explanation") << decomp[i] <<" " ; - } - Debug("bv-slicer-explanation") << "\n Explanation : \n"; - for (unsigned i = 0; i < explanation.size(); ++i) { - Debug("bv-slicer-explanation") << " " << explanation[i] << "\n"; - } - + + Debug("bv-slicer") << "as ["; + for (unsigned i = 0; i < decomp.size(); ++i) { + Debug("bv-slicer") << decomp[i] <<" "; } + Debug("bv-slicer") << "]" << endl; } bool Slicer::isCoreTerm(TNode node) { if (d_coreTermCache.find(node) == d_coreTermCache.end()) { - Kind kind = node.getKind(); - if (//kind != kind::BITVECTOR_EXTRACT && - //kind != kind::BITVECTOR_CONCAT && + Kind kind = node.getKind(); + bool not_core; + if (options::bitvectorCoreSolver()) { + not_core = (kind != kind::BITVECTOR_EXTRACT && kind != kind::BITVECTOR_CONCAT); + } else { + not_core = true; + } + if (not_core && kind != kind::EQUAL && + kind != kind::NOT && kind != kind::STORE && kind != kind::SELECT && - kind != kind::NOT && node.getMetaKind() != kind::metakind::VARIABLE && kind != kind::CONST_BITVECTOR) { d_coreTermCache[node] = false; @@ -845,81 +572,7 @@ void Slicer::splitEqualities(TNode node, std::vector& equalities) { equalities.push_back(node); } d_numAddedEqualities += equalities.size() - 1; -} - - -ExtractTerm UnionFind::getExtractTerm(TermId id) const { - if (d_topLevelIds.find(id) != d_topLevelIds.end()) { - // if it's a top level term so we don't have an extract stored for it - return ExtractTerm(id, getBitwidth(id) - 1, 0); - } - Assert (isExtractTerm(id)); - - return (d_idToExtract.find(id))->second; -} - -bool UnionFind::isExtractTerm(TermId id) const { - return d_idToExtract.find(id) != d_idToExtract.end(); -} - -bool Slicer::isTopLevelNode(TermId id) const { - return d_idToNode.find(id) != d_idToNode.end(); -} - -Node Slicer::getNode(TermId id) const { - if (isTopLevelNode(id)) { - return d_idToNode.find(id)->second; - } - Assert (d_unionFind.isExtractTerm(id)); - const ExtractTerm& extract = d_unionFind.getExtractTerm(id); - Assert (isTopLevelNode(extract.id)); - TNode node = d_idToNode.find(extract.id)->second; - if (extract.high == utils::getSize(node) -1 && extract.low == 0) { - return node; - } - Node ex = utils::mkExtract(node, extract.high, extract.low); - return ex; -} - -bool Slicer::termInEqualityEngine(TermId id) { - Node node = getNode(id); - return d_coreSolver->hasTerm(node); -} - -void Slicer::enqueueSplit(TermId id, Index i, TermId top_id, TermId bottom_id) { - Node node = getNode(id); - Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0)); - Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i)); - // must add terms to equality engine so we get notified when they get split more - d_coreSolver->addTermToEqualityEngine(bottom); - d_coreSolver->addTermToEqualityEngine(top); - - Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom)); - d_newSplits.push_back(eq); - Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl; - Debug("bv-slicer") << " " << id << "=" << top_id << " " << bottom_id << endl; -} - -void Slicer::getNewSplits(std::vector& splits) { - for (unsigned i = d_newSplitsIndex; i < d_newSplits.size(); ++i) { - splits.push_back(d_newSplits[i]); - } - d_newSplitsIndex = d_newSplits.size(); -} - -bool Slicer::hasExplanation(ExplanationId id) const { - return id < d_explanations.size(); -} - -TNode Slicer::getExplanation(ExplanationId id) const { - Assert(hasExplanation(id)); - return d_explanations[id]; -} - -ExplanationId Slicer::getExplanationId(TNode reason) const { - Assert (d_explanationToId.find(reason) != d_explanationToId.end()); - return d_explanationToId.find(reason)->second; -} +} std::string UnionFind::debugPrint(TermId id) { ostringstream os; diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index c46ef99ed..88254b983 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -29,10 +29,6 @@ #include "util/index.h" #include "expr/node.h" #include "theory/bv/theory_bv_utils.h" -#include "context/context.h" -#include "context/cdhashset.h" -#include "context/cdo.h" -#include "context/cdqueue.h" #ifndef __CVC4__THEORY__BV__SLICER_BV_H #define __CVC4__THEORY__BV__SLICER_BV_H @@ -45,10 +41,8 @@ namespace bv { typedef Index TermId; -typedef TermId ExplanationId; extern const TermId UndefinedId; -class CDBase; /** * Base @@ -57,11 +51,9 @@ class CDBase; class Base { Index d_size; std::vector d_repr; - void undoSliceAt(Index index); public: - Base (Index size); - void sliceAt(Index index); - + Base(Index size); + void sliceAt(Index index); void sliceWith(const Base& other); bool isCutPoint(Index index) const; void diffCutPoints(const Base& other, Base& res) const; @@ -84,23 +76,17 @@ public: } }; - /** * UnionFind * */ -typedef context::CDHashSet > CDTermSet; +typedef __gnu_cxx::hash_set TermSet; typedef std::vector Decomposition; struct ExtractTerm { TermId id; Index high; Index low; - ExtractTerm() - : id (UndefinedId), - high(UndefinedId), - low(UndefinedId) - {} ExtractTerm(TermId i, Index h, Index l) : id (i), high(h), @@ -108,24 +94,10 @@ struct ExtractTerm { { Assert (h >= l && id != UndefinedId); } - bool operator== (const ExtractTerm& other) const { - return id == other.id && high == other.high && low == other.low; - } Index getBitwidth() const { return high - low + 1; } std::string debugPrint() const; - friend class ExtractTermHashFunction; }; -struct ExtractTermHashFunction { - ::std::size_t operator() (const ExtractTerm& t) const { - __gnu_cxx::hash h; - unsigned id = t.id; - unsigned high = t.high; - unsigned low = t.low; - return (h(id) * 7919 + h(high))* 4391 + h(low); - } -}; - class UnionFind; struct NormalForm { @@ -148,34 +120,21 @@ struct NormalForm { void clear() { base.clear(); decomp.clear(); } }; -class Slicer; - -class UnionFind : public context::ContextNotifyObj { - struct ReprEdge { - TermId repr; - ExplanationId reason; - ReprEdge() - : repr(UndefinedId), - reason(UndefinedId) - {} - }; - - class EqualityNode { - Index d_bitwidth; - TermId d_ch1, d_ch0; // the ids of the two children if they exist - ReprEdge d_edge; // points to the representative and stores the explanation - +class UnionFind { + class Node { + Index d_bitwidth; + TermId d_ch1, d_ch0; + TermId d_repr; public: - EqualityNode(Index b) + Node(Index b) : d_bitwidth(b), d_ch1(UndefinedId), d_ch0(UndefinedId), - d_edge() + d_repr(UndefinedId) {} - - TermId getRepr() const { return d_edge.repr; } - ExplanationId getReason() const { return d_edge.reason; } + + TermId getRepr() const { return d_repr; } Index getBitwidth() const { return d_bitwidth; } bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; } @@ -183,64 +142,51 @@ class UnionFind : public context::ContextNotifyObj { Assert (i < 2); return i == 0? d_ch0 : d_ch1; } - void setRepr(TermId repr, ExplanationId reason) { + void setRepr(TermId id) { Assert (! hasChildren()); - d_edge.repr = repr; - d_edge.reason = reason; + d_repr = id; } void setChildren(TermId ch1, TermId ch0) { + Assert (d_repr == UndefinedId && !hasChildren()); d_ch1 = ch1; d_ch0 = ch0; } std::string debugPrint() const; }; - - // the equality nodes in the union find - std::vector d_equalityNodes; - - /// getter methods for the internal nodes - TermId getRepr(TermId id) const; - ExplanationId getReason(TermId id) const; - TermId getChild(TermId id, Index i) const; - Index getCutPoint(TermId id) const; - bool hasChildren(TermId id) const; - /// setter methods for the internal nodes - void setRepr(TermId id, TermId new_repr, ExplanationId reason); - void setChildren(TermId id, TermId ch1, TermId ch0); - - // the mappings between ExtractTerms and ids - __gnu_cxx::hash_map > d_idToExtract; - __gnu_cxx::hash_map d_extractToId; - - __gnu_cxx::hash_set d_topLevelIds; + /// map from TermId to the nodes that represent them + std::vector d_nodes; + /// a term is in this set if it is its own representative + TermSet d_representatives; void getDecomposition(const ExtractTerm& term, Decomposition& decomp); - void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector& explanation); void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); - - /* Backtracking mechanisms */ - - enum OperationKind { - MERGE, - SPLIT - }; - - struct Operation { - OperationKind op; - TermId id; - Operation(OperationKind o, TermId i) - : op(o), id(i) {} - }; - - std::vector d_undoStack; - context::CDO d_undoStackIndex; + /// getter methods for the internal nodes + TermId getRepr(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getRepr(); + } + TermId getChild(TermId id, Index i) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getChild(i); + } + Index getCutPoint(TermId id) const { + return getBitwidth(getChild(id, 0)); + } + bool hasChildren(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].hasChildren(); + } + /// setter methods for the internal nodes + void setRepr(TermId id, TermId new_repr) { + Assert (id < d_nodes.size()); + d_nodes[id].setRepr(new_repr); + } + void setChildren(TermId id, TermId ch1, TermId ch0) { + Assert (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0)); + d_nodes[id].setChildren(ch1, ch0); + } - void backtrack(); - void undoMerge(TermId id); - void undoSplit(TermId id); - void recordOperation(OperationKind op, TermId term); - virtual ~UnionFind() throw(AssertionException) {} class Statistics { public: IntStat d_numNodes; @@ -249,106 +195,56 @@ class UnionFind : public context::ContextNotifyObj { IntStat d_numMerges; AverageStat d_avgFindDepth; ReferenceStat d_numAddedEqualities; + //IntStat d_numAddedEqualities; Statistics(); ~Statistics(); }; - Statistics d_statistics; - Slicer* d_slicer; - TermId d_termIdCount; - TermId mkEqualityNode(Index bitwidth); - ExtractTerm mkExtractTerm(TermId id, Index high, Index low); - void storeExtractTerm(Index id, const ExtractTerm& term); - ExtractTerm getExtractTerm(TermId id) const; - bool extractHasId(const ExtractTerm& ex) const { return d_extractToId.find(ex) != d_extractToId.end(); } - TermId getExtractId(const ExtractTerm& ex) const {Assert (extractHasId(ex)); return d_extractToId.find(ex)->second; } - bool isExtractTerm(TermId id) const; + Statistics d_statistics +; + public: - UnionFind(context::Context* ctx, Slicer* slicer) - : ContextNotifyObj(ctx), - d_equalityNodes(), - d_idToExtract(), - d_extractToId(), - d_topLevelIds(), - d_undoStack(), - d_undoStackIndex(ctx), - d_statistics(), - d_slicer(slicer), - d_termIdCount(0) + UnionFind() + : d_nodes(), + d_representatives() {} - TermId addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low); - TermId registerTopLevelTerm(Index bitwidth); - void unionTerms(TermId id1, TermId id2, TermId reason); - void merge(TermId t1, TermId t2, TermId reason); - TermId find(TermId t1); - TermId findWithExplanation(TermId id, std::vector& explanation); + TermId addTerm(Index bitwidth); + void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2); + void merge(TermId t1, TermId t2); + TermId find(TermId t1); void split(TermId term, Index i); + void getNormalForm(const ExtractTerm& term, NormalForm& nf); - void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector& explanation); - void alignSlicings(TermId id1, TermId id2); - void ensureSlicing(TermId id); + void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); + void ensureSlicing(const ExtractTerm& term); Index getBitwidth(TermId id) const { - Assert (id < d_equalityNodes.size()); - return d_equalityNodes[id].getBitwidth(); + Assert (id < d_nodes.size()); + return d_nodes[id].getBitwidth(); } - void getBase(TermId id, Base& base, Index offset); std::string debugPrint(TermId id); - - void contextNotifyPop() { - backtrack(); - } friend class Slicer; }; -class CoreSolver; - class Slicer { - __gnu_cxx::hash_map d_idToNode; + __gnu_cxx::hash_map d_idToNode; __gnu_cxx::hash_map d_nodeToId; __gnu_cxx::hash_map d_coreTermCache; - __gnu_cxx::hash_map d_explanationToId; - std::vector d_explanations; UnionFind d_unionFind; - - context::CDQueue d_newSplits; - context::CDO d_newSplitsIndex; - CoreSolver* d_coreSolver; - TermId registerTopLevelTerm(TNode node); - bool isTopLevelNode(TermId id) const; - TermId registerTerm(TNode node); + ExtractTerm registerTerm(TNode node); public: - Slicer(context::Context* ctx, CoreSolver* coreSolver) + Slicer() : d_idToNode(), d_nodeToId(), d_coreTermCache(), - d_explanationToId(), - d_explanations(), - d_unionFind(ctx, this), - d_newSplits(ctx), - d_newSplitsIndex(ctx, 0), - d_coreSolver(coreSolver) + d_unionFind() {} - void getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation); - void registerEquality(TNode eq); - + void getBaseDecomposition(TNode node, std::vector& decomp); void processEquality(TNode eq); - void assertEquality(TNode eq); bool isCoreTerm (TNode node); - - bool hasNode(TermId id) const; - Node getNode(TermId id) const; - - bool hasExplanation(ExplanationId id) const; - TNode getExplanation(ExplanationId id) const; - ExplanationId getExplanationId(TNode reason) const; - - bool termInEqualityEngine(TermId id); - void enqueueSplit(TermId id, Index i, TermId top, TermId bottom); - void getNewSplits(std::vector& splits); static void splitEqualities(TNode node, std::vector& equalities); - static unsigned d_numAddedEqualities; + static unsigned d_numAddedEqualities; }; diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 502d49f58..b202b7eca 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -239,12 +239,12 @@ Node TheoryBV::ppRewrite(TNode t) Node result = RewriteRule::run(t); return Rewriter::rewrite(result); } - - // if (t.getKind() == kind::EQUAL) { - // std::vector equalities; - // Slicer::splitEqualities(t, equalities); - // return utils::mkAnd(equalities); - // } + + if (options::bitvectorCoreSolver() && t.getKind() == kind::EQUAL) { + std::vector equalities; + Slicer::splitEqualities(t, equalities); + return utils::mkAnd(equalities); + } return t; } -- 2.30.2