From ff8572914d73449b26edba214ad134c596196e32 Mon Sep 17 00:00:00 2001 From: lianah Date: Thu, 21 Mar 2013 19:25:33 -0400 Subject: [PATCH] fixed more equality stuff --- src/theory/bv/bv_subtheory_core.cpp | 22 +++++++-- src/theory/bv/bv_subtheory_core.h | 4 +- src/theory/bv/slicer.cpp | 77 ++++++++++++++++------------- src/theory/bv/slicer.h | 6 +-- 4 files changed, 67 insertions(+), 42 deletions(-) diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index d7dab10f9..2af0e47b8 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -72,6 +72,9 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv) } } +CoreSolver::~CoreSolver() { + delete d_slicer; +} void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) { d_equalityEngine.setMasterEqualityEngine(eq); } @@ -99,10 +102,11 @@ void CoreSolver::explain(TNode literal, std::vector& assumptions) { } } -Node CoreSolver::getBaseDecomposition(TNode a, std::vector& explanation) { +Node CoreSolver::getBaseDecomposition(TNode a, std::vector& explanation) { std::vector a_decomp; d_slicer->getBaseDecomposition(a, a_decomp, explanation); Node new_a = utils::mkConcat(a_decomp); + Debug("bv-slicer") << "CoreSolver::getBaseDecomposition " << a <<" => " << new_a << "\n"; return new_a; } @@ -118,7 +122,7 @@ bool CoreSolver::decomposeFact(TNode fact) { TNode b = fact[1]; d_slicer->processEquality(fact); - std::vector explanation; + std::vector explanation; Node new_a = getBaseDecomposition(a, explanation); Node new_b = getBaseDecomposition(b, explanation); @@ -157,10 +161,20 @@ bool CoreSolver::decomposeFact(TNode fact) { d_slicer->assertEquality(fact); } else { // still need to register the terms + d_slicer->processEquality(fact[0]); TNode a = fact[0][0]; TNode b = fact[0][1]; - d_slicer->registerTerm(a); - d_slicer->registerTerm(b); + std::vector explanation_a; + Node new_a = getBaseDecomposition(a, explanation_a); + Node reason_a = explanation_a.empty()? mkTrue() : mkAnd(explanation_a); + assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, a, new_a), reason_a); + + std::vector explanation_b; + Node new_b = getBaseDecomposition(b, explanation_b); + Node reason_b = explanation_b.empty()? mkTrue() : mkAnd(explanation_b); + assertFactToEqualityEngine(utils::mkNode(kind::EQUAL, b, new_b), reason_b); + d_reasons.insert(reason_a); + d_reasons.insert(reason_b); } // finally assert the actual fact to the equality engine return assertFactToEqualityEngine(fact, fact); diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index f37cf5bf3..4f2d7a279 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -67,9 +67,10 @@ class CoreSolver : public SubtheorySolver { context::CDHashSet d_reasons; bool assertFactToEqualityEngine(TNode fact, TNode reason); bool decomposeFact(TNode fact); - Node getBaseDecomposition(TNode a, std::vector& explanation); + Node getBaseDecomposition(TNode a, std::vector& explanation); public: CoreSolver(context::Context* c, TheoryBV* bv); + ~CoreSolver(); bool isCoreTheory() { return d_isCoreTheory; } void setMasterEqualityEngine(eq::EqualityEngine* eq); void preRegister(TNode node); @@ -91,6 +92,7 @@ public: return EQUALITY_UNKNOWN; } bool hasTerm(TNode node) const { return d_equalityEngine.hasTerm(node); } + void addTermToEqualityEngine(TNode node) { d_equalityEngine.addTerm(node); } }; diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index 437be9bf4..5d376ea50 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -41,8 +41,11 @@ Base::Base(uint32_t size) void Base::sliceAt(Index index) { + if (index == d_size) + return; + Assert(index < d_size); Index vector_index = index / 32; - Assert (vector_index < d_size); + Assert (vector_index < d_repr.size()); Index int_index = index % 32; uint32_t bit_mask = utils::pow2(int_index); d_repr[vector_index] = d_repr[vector_index] | bit_mask; @@ -184,6 +187,15 @@ TermId UnionFind::addNode(Index bitwidth) { } TermId UnionFind::addExtract(TermId topLevel, Index high, Index low) { + if (isExtractTerm(topLevel)) { + ExtractTerm top = getExtractTerm(topLevel); + Index top_high = top.high; + Index top_low = top.low; + Assert (top_high - top_low + 1 > high); + high += top_low; + low += top_low; + topLevel = top.id; + } ExtractTerm extract(topLevel, high, low); if (d_extractToId.find(extract) != d_extractToId.end()) { return d_extractToId[extract]; @@ -292,13 +304,13 @@ void UnionFind::split(TermId id, Index i) { Assert (i < getBitwidth(id)); if (!hasChildren(id)) { // first time we split this term - TermId bottom_id = addExtract(getTopLevel(id), i - 1, 0); - TermId top_id = addExtract(getTopLevel(id), getBitwidth(id) - 1, i); + TermId bottom_id = addExtract(id, i - 1, 0); + TermId top_id = addExtract(id, getBitwidth(id) - 1, i); setChildren(id, top_id, bottom_id); recordOperation(UnionFind::SPLIT, id); if (d_slicer->termInEqualityEngine(id)) { - d_slicer->enqueueSplit(id, i); + d_slicer->enqueueSplit(id, i, top_id, bottom_id); } } else { Index cut = getCutPoint(id); @@ -310,13 +322,13 @@ void UnionFind::split(TermId id, Index i) { ++(d_statistics.d_numSplits); } -TermId UnionFind::getTopLevel(TermId id) const { - __gnu_cxx::hash_map >::const_iterator it = d_idToExtract.find(id); - if (it != d_idToExtract.end()) { - return (*it).second.id; - } - return id; -} +// TermId UnionFind::getTopLevel(TermId id) const { +// __gnu_cxx::hash_map >::const_iterator it = d_idToExtract.find(id); +// if (it != d_idToExtract.end()) { +// return (*it).second.id; +// } +// return id; +// } void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) { nf.clear(); @@ -576,7 +588,7 @@ ExtractTerm Slicer::registerTerm(TNode node) { if (node.getKind() == kind::BITVECTOR_EXTRACT) { n = node[0]; high = utils::getExtractHigh(node); - low = utils::getExtractLow(node); + low = utils::getExtractLow(node); } if (d_nodeToId.find(n) == d_nodeToId.end()) { TermId id = d_unionFind.addNode(utils::getSize(n)); @@ -584,6 +596,7 @@ ExtractTerm Slicer::registerTerm(TNode node) { d_idToNode[id] = n; } TermId id = d_nodeToId[n]; + d_unionFind.addExtract(id, high, low); ExtractTerm res(id, high, low); Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl; return res; @@ -631,7 +644,7 @@ void Slicer::registerEquality(TNode eq) { } } -void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation) { +void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation) { Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; Index high = utils::getSize(node) - 1; @@ -655,16 +668,8 @@ void Slicer::getBaseDecomposition(TNode node, std::vector& decomp, std::ve explanation.push_back(exp); } - // construct actual extract nodes - Index size = utils::getSize(node); - Index current_low = size - 1; - Index current_high = size - 1; - for (int i = nf.decomp.size() - 1; i>=0 ; --i) { - Index current_size = d_unionFind.getBitwidth(nf.decomp[i]); - current_low = current_low - current_size; - Node current = Rewriter::rewrite(utils::mkExtract(node, current_high, current_low+1)); - current_high -= current_size; + Node current = getNode(nf.decomp[i]); decomp.push_back(current); } @@ -763,17 +768,16 @@ bool Slicer::hasNode(TermId id) const { } Node Slicer::getNode(TermId id) const { - // if it was an extract - if (d_unionFind.isExtractTerm(id)) { - ExtractTerm extract = d_unionFind.getExtractTerm(id); - Assert (hasNode(extract.id)); - TNode node = d_idToNode.find(extract.id)->second; - Node ex = utils::mkExtract(node, extract.high, extract.low); - return ex; + if (hasNode(id)) { + return d_idToNode.find(id)->second; } - // otherwise must be a top-level term - Assert (hasNode(id)); - return (d_idToNode.find(id))->second; + // otherwise must be an extract + Assert (d_unionFind.isExtractTerm(id)); + ExtractTerm extract = d_unionFind.getExtractTerm(id); + Assert (hasNode(extract.id)); + TNode node = d_idToNode.find(extract.id)->second; + Node ex = utils::mkExtract(node, extract.high, extract.low); + return ex; } bool Slicer::termInEqualityEngine(TermId id) { @@ -781,13 +785,18 @@ bool Slicer::termInEqualityEngine(TermId id) { return d_coreSolver->hasTerm(node); } -void Slicer::enqueueSplit(TermId id, Index i) { +void Slicer::enqueueSplit(TermId id, Index i, TermId top_id, TermId bottom_id) { Node node = getNode(id); Node bottom = Rewriter::rewrite(utils::mkExtract(node, i -1 , 0)); Node top = Rewriter::rewrite(utils::mkExtract(node, utils::getSize(node) - 1, i)); + // must add terms to equality engine so we get notified when they get split more + d_coreSolver->addTermToEqualityEngine(bottom); + d_coreSolver->addTermToEqualityEngine(top); + Node eq = utils::mkNode(kind::EQUAL, node, utils::mkConcat(top, bottom)); d_newSplits.push_back(eq); - Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl; + Debug("bv-slicer") << "Slicer::enqueueSplit " << eq << endl; + Debug("bv-slicer") << " " << id << "=" << top_id << " " << bottom_id << endl; } void Slicer::getNewSplits(std::vector& splits) { diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index f63cf7284..ab2d5e88f 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -224,7 +224,7 @@ class UnionFind : public context::ContextNotifyObj { Assert (id < d_nodes.size()); return d_nodes[id].hasChildren(); } - TermId getTopLevel(TermId id) const; + // TermId getTopLevel(TermId id) const; /// setter methods for the internal nodes void setRepr(TermId id, TermId new_repr, ExplanationId reason) { @@ -338,7 +338,7 @@ public: d_coreSolver(coreSolver) {} - void getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation); + void getBaseDecomposition(TNode node, std::vector& decomp, std::vector& explanation); void registerEquality(TNode eq); ExtractTerm registerTerm(TNode node); void processEquality(TNode eq); @@ -354,7 +354,7 @@ public: ExplanationId getExplanationId(TNode reason) const; bool termInEqualityEngine(TermId id); - void enqueueSplit(TermId id, Index i); + void enqueueSplit(TermId id, Index i, TermId top, TermId bottom); void getNewSplits(std::vector& splits); static void splitEqualities(TNode node, std::vector& equalities); static unsigned d_numAddedEqualities; -- 2.30.2