From 8882aef2dd4f1f629b0de99fc3a7f390fab2f83e Mon Sep 17 00:00:00 2001 From: lianah Date: Sat, 23 Mar 2013 13:40:29 -0400 Subject: [PATCH] fixed some explanation problems for the core theory; still slow --- src/theory/bv/bv_subtheory_core.cpp | 40 +++-- src/theory/bv/bv_subtheory_core.h | 2 +- src/theory/bv/slicer.cpp | 250 +++++++++++++++++++--------- src/theory/bv/slicer.h | 111 ++++++------ src/theory/bv/theory_bv_utils.h | 128 ++++++++------ 5 files changed, 330 insertions(+), 201 deletions(-) diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index 2af0e47b8..6f5fd4119 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -102,7 +102,7 @@ void CoreSolver::explain(TNode literal, std::vector& assumptions) { } } -Node CoreSolver::getBaseDecomposition(TNode a, std::vector& explanation) { +Node CoreSolver::getBaseDecomposition(TNode a, std::vector& explanation) { std::vector a_decomp; d_slicer->getBaseDecomposition(a, a_decomp, explanation); Node new_a = utils::mkConcat(a_decomp); @@ -122,28 +122,35 @@ bool CoreSolver::decomposeFact(TNode fact) { TNode b = fact[1]; d_slicer->processEquality(fact); - std::vector explanation; - Node new_a = getBaseDecomposition(a, explanation); - Node new_b = getBaseDecomposition(b, explanation); + std::vector explanation_a; + Node new_a = getBaseDecomposition(a, explanation_a); + Node reason_a = mkAnd(explanation_a); + d_reasons.insert(reason_a); + + std::vector explanation_b; + Node new_b = getBaseDecomposition(b, explanation_b); + Node reason_b = mkAnd(explanation_b); + d_reasons.insert(reason_b); + std::vector explanation; explanation.push_back(fact); + explanation.insert(explanation.end(), explanation_a.begin(), explanation_a.end()); + explanation.insert(explanation.end(), explanation_b.begin(), explanation_b.end()); + Node reason = utils::mkAnd(explanation); d_reasons.insert(reason); Assert (utils::getSize(new_a) == utils::getSize(new_b) && utils::getSize(new_a) == utils::getSize(a)); - // FIXME: do we still need to assert these? + NodeManager* nm = NodeManager::currentNM(); Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a); Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b); - d_reasons.insert(a_eq_new_a); - d_reasons.insert(b_eq_new_b); - bool ok = true; - ok = assertFactToEqualityEngine(a_eq_new_a, utils::mkTrue()); + ok = assertFactToEqualityEngine(a_eq_new_a, reason_a); if (!ok) return false; - ok = assertFactToEqualityEngine(b_eq_new_b, utils::mkTrue()); + ok = assertFactToEqualityEngine(b_eq_new_b, reason_a); if (!ok) return false; // assert the individual equalities as well // a_i == b_i @@ -152,6 +159,7 @@ bool CoreSolver::decomposeFact(TNode fact) { Assert (new_a.getNumChildren() == new_b.getNumChildren()); for (unsigned i = 0; i < new_a.getNumChildren(); ++i) { Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]); + // this reason is not very precise!! ok = assertFactToEqualityEngine(eq_i, reason); d_reasons.insert(eq_i); if (!ok) return false; @@ -164,15 +172,16 @@ bool CoreSolver::decomposeFact(TNode fact) { d_slicer->processEquality(fact[0]); TNode a = fact[0][0]; TNode b = fact[0][1]; - std::vector explanation_a; + std::vector explanation_a; Node new_a = getBaseDecomposition(a, explanation_a); Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a); assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a); - std::vector explanation_b; + std::vector explanation_b; Node new_b = getBaseDecomposition(b, explanation_b); Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b); assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b); + d_reasons.insert(reason_a); d_reasons.insert(reason_b); } @@ -279,13 +288,16 @@ void CoreSolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) { bool CoreSolver::storePropagation(TNode literal) { return d_bv->storePropagation(literal, SUB_CORE); } - + void CoreSolver::conflict(TNode a, TNode b) { std::vector assumptions; d_equalityEngine.explainEquality(a, b, true, assumptions); - d_bv->setConflict(mkAnd(assumptions)); + Node conflict = flattenAnd(assumptions); + d_bv->setConflict(conflict); } + + void CoreSolver::collectModelInfo(TheoryModel* m) { if (Debug.isOn("bitvector-model")) { context::CDQueue::const_iterator it = d_assertionQueue.begin(); diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index 4f2d7a279..868f3754f 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -67,7 +67,7 @@ class CoreSolver : public SubtheorySolver { context::CDHashSet d_reasons; bool assertFactToEqualityEngine(TNode fact, TNode reason); bool decomposeFact(TNode fact); - Node getBaseDecomposition(TNode a, std::vector& explanation); + Node getBaseDecomposition(TNode a, std::vector& explanation); public: CoreSolver(context::Context* c, TheoryBV* bv); ~CoreSolver(); diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index 5d376ea50..b24702635 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -156,11 +156,11 @@ std::string NormalForm::debugPrint(const UnionFind& uf) const { return os.str(); } /** - * UnionFind::Node + * UnionFind::EqualityNode * */ -std::string UnionFind::Node::debugPrint() const { +std::string UnionFind::EqualityNode::debugPrint() const { ostringstream os; os << "Repr " << d_edge.repr << " ["<< d_bitwidth << "] "; os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; @@ -172,41 +172,80 @@ std::string UnionFind::Node::debugPrint() const { * UnionFind * */ -TermId UnionFind::addNode(Index bitwidth) { + +TermId UnionFind::registerTopLevelTerm(Index bitwidth) { + TermId id = mkEqualityNode(bitwidth); + d_topLevelIds.insert(id); + return id; +} + +TermId UnionFind::mkEqualityNode(Index bitwidth) { Assert (bitwidth > 0); - Node node(bitwidth); - d_nodes.push_back(node); + EqualityNode node(bitwidth); + d_equalityNodes.push_back(node); ++(d_statistics.d_numNodes); - TermId id = d_nodes.size() - 1; + TermId id = d_equalityNodes.size() - 1; // d_representatives.insert(id); ++(d_statistics.d_numRepresentatives); Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl; return id; } - -TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) { - if (isExtractTerm(topLevel)) { - ExtractTerm top = getExtractTerm(topLevel); - Index top_high = top.high; - Index top_low = top.low; - Assert (top_high - top_low + 1 > high); - high += top_low; - low += top_low; - topLevel = top.id; +/** + * Create an extract term making sure there are no nested extracts. + * + * @param id + * @param high + * @param low + * + * @return + */ +ExtractTerm UnionFind::mkExtractTerm(TermId id, Index high, Index low) { + if (d_topLevelIds.find(id) != d_topLevelIds.end()) { + return ExtractTerm(id, high, low); } - ExtractTerm extract(topLevel, high, low); + Assert (isExtractTerm(id)); + ExtractTerm top = getExtractTerm(id); + Assert (d_topLevelIds.find(top.id) != d_topLevelIds.end()); + + Index top_high = top.high; + Index top_low = top.low; + Assert (top_high - top_low + 1 > high); + high += top_low; + low += top_low; + id = top.id; + return ExtractTerm(id, high, low); +} + +/** + * Associate the given extract term with the given id. + * + * @param id + * @param extract + */ +void UnionFind::storeExtractTerm(TermId id, const ExtractTerm& extract) { if (d_extractToId.find(extract) != d_extractToId.end()) { - return d_extractToId[extract]; + Assert (d_extractToId[extract] == id); + return; } - - Assert (high >= low); - - TermId id = addNode(high - low + 1); + Debug("bv-slicer") << "UnionFind::storeExtract " << extract.debugPrint() << " => id" << id << "\n"; d_idToExtract[id] = extract; d_extractToId[extract] = id; - return id; + } + +TermId UnionFind::addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low) { + ExtractTerm extract(id, high, low); + if (d_extractToId.find(extract) != d_extractToId.end()) { + // if the extract already exists we don't need to make a new node + TermId extract_id = d_extractToId[extract]; + Assert (extract_id < d_equalityNodes.size()); + return extract_id; + } + // otherwise make an equality node for it and store the extract + TermId node_id = mkEqualityNode(bitwidth); + storeExtractTerm(node_id, extract); + return node_id; } /** @@ -215,7 +254,10 @@ TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) { * @param t1 * @param t2 */ -void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason) { +void UnionFind::unionTerms(TermId id1, TermId id2, TermId reason) { + const ExtractTerm& t1 = getExtractTerm(id1); + const ExtractTerm& t2 = getExtractTerm(id2); + Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n" << " " << t2.debugPrint() << "\n" << " with reason " << reason << endl; @@ -294,7 +336,7 @@ TermId UnionFind::findWithExplanation(TermId id, std::vector& exp void UnionFind::split(TermId id, Index i) { Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl; id = find(id); - Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl; + Debug("bv-slicer-uf") << " node: " << d_equalityNodes[id].debugPrint() << endl; if (i == 0 || i == getBitwidth(id)) { // nothing to do @@ -303,9 +345,15 @@ void UnionFind::split(TermId id, Index i) { Assert (i < getBitwidth(id)); if (!hasChildren(id)) { - // first time we split this term - TermId bottom_id = addExtract(id, i - 1, 0); - TermId top_id = addExtract(id, getBitwidth(id) - 1, i); + // first time we split this term + ExtractTerm bottom_extract = mkExtractTerm(id, i-1, 0); + ExtractTerm top_extract = mkExtractTerm(id, getBitwidth(id) - 1, i); + + TermId bottom_id = extractHasId(bottom_extract)? getExtractId(bottom_extract) : mkEqualityNode(i); + TermId top_id = extractHasId(top_extract)? getExtractId(top_extract) : mkEqualityNode(getBitwidth(id) - i); + storeExtractTerm(bottom_id, bottom_extract); + storeExtractTerm(top_id, top_extract); + setChildren(id, top_id, bottom_id); recordOperation(UnionFind::SPLIT, id); @@ -471,7 +519,10 @@ void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposit } -void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) { +void UnionFind::alignSlicings(TermId id1, TermId id2) { + const ExtractTerm& term1 = getExtractTerm(id1); + const ExtractTerm& term2 = getExtractTerm(id2); + Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl; Debug("bv-slicer") << " " << term2.debugPrint() << endl; NormalForm nf1(term1.getBitwidth()); @@ -519,15 +570,18 @@ void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2 } } while (changed); } + + /** * Given an extract term a[i:j] makes sure a is sliced * at indices i and j. * * @param term */ -void UnionFind::ensureSlicing(const ExtractTerm& term) { +void UnionFind::ensureSlicing(TermId t) { + ExtractTerm term = getExtractTerm(t); //Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl; - TermId id = find(term.id); + TermId id = term.id; split(id, term.high + 1); split(id, term.low); } @@ -576,30 +630,69 @@ void UnionFind::getBase(TermId id, Base& base, Index offset) { getBase(id0, base, offset); } +/// getter methods for the internal nodes +TermId UnionFind::getRepr(TermId id) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].getRepr(); +} +ExplanationId UnionFind::getReason(TermId id) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].getReason(); +} +TermId UnionFind::getChild(TermId id, Index i) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].getChild(i); +} +Index UnionFind::getCutPoint(TermId id) const { + return getBitwidth(getChild(id, 0)); +} +bool UnionFind::hasChildren(TermId id) const { + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].hasChildren(); +} + +/// setter methods for the internal nodes +void UnionFind::setRepr(TermId id, TermId new_repr, ExplanationId reason) { + Assert (id < d_equalityNodes.size()); + d_equalityNodes[id].setRepr(new_repr, reason); +} +void UnionFind::setChildren(TermId id, TermId ch1, TermId ch0) { + Assert ((ch1 == UndefinedId && ch0 == UndefinedId) || + (id < d_equalityNodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0))); + d_equalityNodes[id].setChildren(ch1, ch0); +} + /** * Slicer * */ -ExtractTerm Slicer::registerTerm(TNode node) { - Index low = 0, high = utils::getSize(node) - 1; - TNode n = node; +TermId Slicer::registerTerm(TNode node) { if (node.getKind() == kind::BITVECTOR_EXTRACT) { - n = node[0]; - high = utils::getExtractHigh(node); - low = utils::getExtractLow(node); - } - if (d_nodeToId.find(n) == d_nodeToId.end()) { - TermId id = d_unionFind.addNode(utils::getSize(n)); - d_nodeToId[n] = id; - d_idToNode[id] = n; + TNode n = node[0]; + TermId top_id = registerTopLevelTerm(n); + Index high = utils::getExtractHigh(node); + Index low = utils::getExtractLow(node); + TermId id = d_unionFind.addEqualityNode(utils::getSize(node), top_id, high, low); + return id; + } + TermId id = registerTopLevelTerm(node); + return id; +} + +TermId Slicer::registerTopLevelTerm(TNode node) { + Assert (node.getKind() != kind::BITVECTOR_EXTRACT || + node.getKind() != kind::BITVECTOR_CONCAT); + + if (d_nodeToId.find(node) == d_nodeToId.end()) { + TermId id = d_unionFind.registerTopLevelTerm(utils::getSize(node)); + d_idToNode[id] = node; + d_nodeToId[node] = id; + Debug("bv-slicer") << "Slicer::registerTopLevelTerm " << node << " => id" << id << endl; + return id; } - TermId id = d_nodeToId[n]; - d_unionFind.addExtract(id, high, low); - ExtractTerm res(id, high, low); - Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl; - return res; + return d_nodeToId[node]; } void Slicer::processEquality(TNode eq) { @@ -609,42 +702,38 @@ void Slicer::processEquality(TNode eq) { Assert (eq.getKind() == kind::EQUAL); TNode a = eq[0]; TNode b = eq[1]; - ExtractTerm a_ex= registerTerm(a); - ExtractTerm b_ex= registerTerm(b); + TermId a_id = registerTerm(a); + TermId b_id = registerTerm(b); - d_unionFind.ensureSlicing(a_ex); - d_unionFind.ensureSlicing(b_ex); + d_unionFind.ensureSlicing(a_id); + d_unionFind.ensureSlicing(b_id); - d_unionFind.alignSlicings(a_ex, b_ex); + d_unionFind.alignSlicings(a_id, b_id); - Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; - Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; - Debug("bv-slicer") << "Slicer::processEquality done. " << endl; + // Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; + // Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; + // Debug("bv-slicer") << "Slicer::processEquality done. " << endl; } void Slicer::assertEquality(TNode eq) { Assert (eq.getKind() == kind::EQUAL); - ExtractTerm a = registerTerm(eq[0]); - ExtractTerm b = registerTerm(eq[1]); + TermId a = registerTerm(eq[0]); + TermId b = registerTerm(eq[1]); ExplanationId reason = getExplanationId(eq); d_unionFind.unionTerms(a, b, reason); } -TermId Slicer::getId(TNode node) const { - __gnu_cxx::hash_map::const_iterator it = d_nodeToId.find(node); - Assert (it != d_nodeToId.end()); - return it->second; -} void Slicer::registerEquality(TNode eq) { if (d_explanationToId.find(eq) == d_explanationToId.end()) { ExplanationId id = d_explanations.size(); d_explanations.push_back(eq); - d_explanationToId[eq] = id; + d_explanationToId[eq] = id; + Debug("bv-slicer-explanation") << "Slicer::registerEquality " << eq << " => id"<< id << "\n"; } } -void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation) { +void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation) { Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; Index high = utils::getSize(node) - 1; @@ -672,13 +761,18 @@ void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::ve Node current = getNode(nf.decomp[i]); decomp.push_back(current); } - - - Debug("bv-slicer") << "as ["; - for (unsigned i = 0; i < decomp.size(); ++i) { - Debug("bv-slicer") << decomp[i] <<" "; + if (Debug.isOn("bv-slicer-explanation")) { + Debug("bv-slicer-explanation") << "Slicer::getBaseDecomposition for " << node << "\n" + << "as "; + for (unsigned i = 0; i < decomp.size(); ++i) { + Debug("bv-slicer-explanation") << decomp[i] <<" " ; + } + Debug("bv-slicer-explanation") << "\n Explanation : \n"; + for (unsigned i = 0; i < explanation.size(); ++i) { + Debug("bv-slicer-explanation") << " " << explanation[i] << "\n"; + } + } - Debug("bv-slicer") << "]" << endl; } @@ -754,6 +848,10 @@ void Slicer::splitEqualities(TNode node, std::vector& equalities) { ExtractTerm UnionFind::getExtractTerm(TermId id) const { + if (d_topLevelIds.find(id) != d_topLevelIds.end()) { + // if it's a top level term so we don't have an extract stored for it + return ExtractTerm(id, getBitwidth(id) - 1, 0); + } Assert (isExtractTerm(id)); return (d_idToExtract.find(id))->second; @@ -763,19 +861,21 @@ bool UnionFind::isExtractTerm(TermId id) const { return d_idToExtract.find(id) != d_idToExtract.end(); } -bool Slicer::hasNode(TermId id) const { +bool Slicer::isTopLevelNode(TermId id) const { return d_idToNode.find(id) != d_idToNode.end(); } Node Slicer::getNode(TermId id) const { - if (hasNode(id)) { + if (isTopLevelNode(id)) { return d_idToNode.find(id)->second; } - // otherwise must be an extract Assert (d_unionFind.isExtractTerm(id)); - ExtractTerm extract = d_unionFind.getExtractTerm(id); - Assert (hasNode(extract.id)); + const ExtractTerm& extract = d_unionFind.getExtractTerm(id); + Assert (isTopLevelNode(extract.id)); TNode node = d_idToNode.find(extract.id)->second; + if (extract.high == utils::getSize(node) -1 && extract.low == 0) { + return node; + } Node ex = utils::mkExtract(node, extract.high, extract.low); return ex; } diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index ab2d5e88f..c46ef99ed 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -161,13 +161,13 @@ class UnionFind : public context::ContextNotifyObj { {} }; - class Node { + class EqualityNode { Index d_bitwidth; TermId d_ch1, d_ch0; // the ids of the two children if they exist ReprEdge d_edge; // points to the representative and stores the explanation public: - Node(Index b) + EqualityNode(Index b) : d_bitwidth(b), d_ch1(UndefinedId), d_ch0(UndefinedId), @@ -189,54 +189,36 @@ class UnionFind : public context::ContextNotifyObj { d_edge.reason = reason; } void setChildren(TermId ch1, TermId ch0) { - // Assert (d_repr == UndefinedId && !hasChildren()); d_ch1 = ch1; d_ch0 = ch0; } std::string debugPrint() const; }; + + // the equality nodes in the union find + std::vector d_equalityNodes; + + /// getter methods for the internal nodes + TermId getRepr(TermId id) const; + ExplanationId getReason(TermId id) const; + TermId getChild(TermId id, Index i) const; + Index getCutPoint(TermId id) const; + bool hasChildren(TermId id) const; - /// map from TermId to the nodes that represent them - std::vector d_nodes; + /// setter methods for the internal nodes + void setRepr(TermId id, TermId new_repr, ExplanationId reason); + void setChildren(TermId id, TermId ch1, TermId ch0); + + // the mappings between ExtractTerms and ids __gnu_cxx::hash_map > d_idToExtract; __gnu_cxx::hash_map d_extractToId; + + __gnu_cxx::hash_set d_topLevelIds; void getDecomposition(const ExtractTerm& term, Decomposition& decomp); void getDecompositionWithExplanation(const ExtractTerm& term, Decomposition& decomp, std::vector& explanation); void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); - /// getter methods for the internal nodes - TermId getRepr(TermId id) const { - Assert (id < d_nodes.size()); - return d_nodes[id].getRepr(); - } - ExplanationId getReason(TermId id) const { - Assert (id < d_nodes.size()); - return d_nodes[id].getReason(); - } - TermId getChild(TermId id, Index i) const { - Assert (id < d_nodes.size()); - return d_nodes[id].getChild(i); - } - Index getCutPoint(TermId id) const { - return getBitwidth(getChild(id, 0)); - } - bool hasChildren(TermId id) const { - Assert (id < d_nodes.size()); - return d_nodes[id].hasChildren(); - } - // TermId getTopLevel(TermId id) const; - /// setter methods for the internal nodes - void setRepr(TermId id, TermId new_repr, ExplanationId reason) { - Assert (id < d_nodes.size()); - d_nodes[id].setRepr(new_repr, reason); - } - void setChildren(TermId id, TermId ch1, TermId ch0) { - Assert ((ch1 == UndefinedId && ch0 == UndefinedId) || - (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0))); - d_nodes[id].setChildren(ch1, ch0); - } - /* Backtracking mechanisms */ enum OperationKind { @@ -271,36 +253,44 @@ class UnionFind : public context::ContextNotifyObj { ~Statistics(); }; Statistics d_statistics; - Slicer* d_slicer; + Slicer* d_slicer; + TermId d_termIdCount; + + TermId mkEqualityNode(Index bitwidth); + ExtractTerm mkExtractTerm(TermId id, Index high, Index low); + void storeExtractTerm(Index id, const ExtractTerm& term); + ExtractTerm getExtractTerm(TermId id) const; + bool extractHasId(const ExtractTerm& ex) const { return d_extractToId.find(ex) != d_extractToId.end(); } + TermId getExtractId(const ExtractTerm& ex) const {Assert (extractHasId(ex)); return d_extractToId.find(ex)->second; } + bool isExtractTerm(TermId id) const; public: UnionFind(context::Context* ctx, Slicer* slicer) : ContextNotifyObj(ctx), - d_nodes(), + d_equalityNodes(), d_idToExtract(), - d_extractToId(), + d_extractToId(), + d_topLevelIds(), d_undoStack(), d_undoStackIndex(ctx), d_statistics(), - d_slicer(slicer) + d_slicer(slicer), + d_termIdCount(0) {} - TermId addNode(Index bitwidth); - TermId addExtract(Index topLevel, Index high, Index low); - ExtractTerm getExtractTerm(TermId id) const; - bool isExtractTerm(TermId id) const; - - void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2, TermId reason); + TermId addEqualityNode(unsigned bitwidth, TermId id, Index high, Index low); + TermId registerTopLevelTerm(Index bitwidth); + void unionTerms(TermId id1, TermId id2, TermId reason); void merge(TermId t1, TermId t2, TermId reason); TermId find(TermId t1); TermId findWithExplanation(TermId id, std::vector& explanation); void split(TermId term, Index i); void getNormalForm(const ExtractTerm& term, NormalForm& nf); void getNormalFormWithExplanation(const ExtractTerm& term, NormalForm& nf, std::vector& explanation); - void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); - void ensureSlicing(const ExtractTerm& term); + void alignSlicings(TermId id1, TermId id2); + void ensureSlicing(TermId id); Index getBitwidth(TermId id) const { - Assert (id < d_nodes.size()); - return d_nodes[id].getBitwidth(); + Assert (id < d_equalityNodes.size()); + return d_equalityNodes[id].getBitwidth(); } void getBase(TermId id, Base& base, Index offset); std::string debugPrint(TermId id); @@ -314,17 +304,19 @@ public: class CoreSolver; class Slicer { - __gnu_cxx::hash_map > d_idToNode; - __gnu_cxx::hash_map d_nodeToId; - __gnu_cxx::hash_map d_coreTermCache; - __gnu_cxx::hash_map d_explanationToId; - std::vector d_explanations; + __gnu_cxx::hash_map d_idToNode; + __gnu_cxx::hash_map d_nodeToId; + __gnu_cxx::hash_map d_coreTermCache; + __gnu_cxx::hash_map d_explanationToId; + std::vector d_explanations; UnionFind d_unionFind; context::CDQueue d_newSplits; context::CDO d_newSplitsIndex; CoreSolver* d_coreSolver; - TermId d_termIdCount; + TermId registerTopLevelTerm(TNode node); + bool isTopLevelNode(TermId id) const; + TermId registerTerm(TNode node); public: Slicer(context::Context* ctx, CoreSolver* coreSolver) : d_idToNode(), @@ -338,16 +330,15 @@ public: d_coreSolver(coreSolver) {} - void getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation); + void getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation); void registerEquality(TNode eq); - ExtractTerm registerTerm(TNode node); + void processEquality(TNode eq); void assertEquality(TNode eq); bool isCoreTerm (TNode node); bool hasNode(TermId id) const; Node getNode(TermId id) const; - TermId getId(TNode node) const; bool hasExplanation(ExplanationId id) const; TNode getExplanation(ExplanationId id) const; diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index e5a7bbb84..98bc8041d 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -69,28 +69,6 @@ inline Node mkVar(unsigned size) { return nm->mkSkolem("bv_$$", nm->mkBitVectorType(size), "is a variable created by the theory of bitvectors"); } -inline Node mkAnd(std::vector& children) { - std::set distinctChildren; - distinctChildren.insert(children.begin(), children.end()); - - if (distinctChildren.size() == 0) { - return mkTrue(); - } - - if (distinctChildren.size() == 1) { - return *children.begin(); - } - - NodeBuilder<> conjunction(kind::AND); - std::set::const_iterator it = distinctChildren.begin(); - std::set::const_iterator it_end = distinctChildren.end(); - while (it != it_end) { - conjunction << *it; - ++ it; - } - - return conjunction; -} inline Node mkSortedNode(Kind kind, std::vector& children) { Assert (kind == kind::BITVECTOR_AND || @@ -155,14 +133,6 @@ inline Node mkXor(TNode node1, TNode node2) { } -inline Node mkAnd(std::vector& children) { - if(children.size() > 1) { - return NodeManager::currentNM()->mkNode(kind::AND, children); - } else { - return children[0]; - } -} - inline Node mkExtract(TNode node, unsigned high, unsigned low) { Node extractOp = NodeManager::currentNM()->mkConst(BitVectorExtract(high, low)); std::vector children; @@ -268,7 +238,6 @@ inline Node mkConjunction(const std::set nodes) { return conjunction; } - inline unsigned isPow2Const(TNode node) { if (node.getKind() != kind::CONST_BITVECTOR) { return false; @@ -278,6 +247,83 @@ inline unsigned isPow2Const(TNode node) { return bv.isPow2(); } +typedef __gnu_cxx::hash_set TNodeSet; + +inline Node mkAnd(const std::vector& conjunctions) { + std::set all; + all.insert(conjunctions.begin(), conjunctions.end()); + + if (all.size() == 0) { + return mkTrue(); + } + + if (all.size() == 1) { + // All the same, or just one + return conjunctions[0]; + } + + + NodeBuilder<> conjunction(kind::AND); + std::set::const_iterator it = all.begin(); + std::set::const_iterator it_end = all.end(); + while (it != it_end) { + conjunction << *it; + ++ it; + } + + return conjunction; +}/* mkAnd() */ + +inline Node mkAnd(const std::vector& conjunctions) { + std::set all; + all.insert(conjunctions.begin(), conjunctions.end()); + + if (all.size() == 0) { + return mkTrue(); + } + + if (all.size() == 1) { + // All the same, or just one + return conjunctions[0]; + } + + + NodeBuilder<> conjunction(kind::AND); + std::set::const_iterator it = all.begin(); + std::set::const_iterator it_end = all.end(); + while (it != it_end) { + conjunction << *it; + ++ it; + } + + return conjunction; +}/* mkAnd() */ + + + +inline Node flattenAnd(std::vector& queue) { + TNodeSet nodes; + while(!queue.empty()) { + TNode current = queue.back(); + queue.pop_back(); + if (current.getKind() == kind::AND) { + for (unsigned i = 0; i < current.getNumChildren(); ++i) { + if (nodes.count(current[i]) == 0) { + queue.push_back(current[i]); + } + } + } else { + nodes.insert(current); + } + } + std::vector children; + for (TNodeSet::const_iterator it = nodes.begin(); it!= nodes.end(); ++it) { + children.push_back(*it); + } + return mkAnd(children); +} + + // neeed a better name, this is not technically a ground term inline bool isBVGroundTerm(TNode node) { if (node.getNumChildren() == 0) { @@ -356,27 +402,7 @@ inline Node mkConjunction(const std::vector& nodes) { } -inline Node mkAnd(const std::vector& conjunctions) { - Assert(conjunctions.size() > 0); - - std::set all; - all.insert(conjunctions.begin(), conjunctions.end()); - if (all.size() == 1) { - // All the same, or just one - return conjunctions[0]; - } - - NodeBuilder<> conjunction(kind::AND); - std::set::const_iterator it = all.begin(); - std::set::const_iterator it_end = all.end(); - while (it != it_end) { - conjunction << *it; - ++ it; - } - - return conjunction; -}/* mkAnd() */ // Turn a set into a string -- 2.30.2