From 76805f8b4690093888bbb3d68e4d5c2c6ff221de Mon Sep 17 00:00:00 2001 From: Liana Hadarean Date: Tue, 5 Feb 2013 00:49:39 -0500 Subject: [PATCH] Added path compression and caching for getBaseDecomposition. --- src/theory/bv/bv_subtheory_core.cpp | 44 +++++++++++++++++++---------- src/theory/bv/bv_subtheory_core.h | 2 ++ src/theory/bv/slicer.cpp | 9 ++++-- src/theory/bv/slicer.h | 2 +- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index a3290ff7c..e31ab2fdf 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -33,6 +33,7 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer) d_notify(*this), d_equalityEngine(d_notify, c, "theory::bv::TheoryBV"), d_assertions(c), + d_normalFormCache(), d_slicer(slicer), d_isCoreTheory(c, true) { @@ -97,6 +98,19 @@ void CoreSolver::explain(TNode literal, std::vector& assumptions) { } } +Node CoreSolver::getBaseDecomposition(TNode a) { + if (d_normalFormCache.find(a) != d_normalFormCache.end()) { + return d_normalFormCache[a]; + } + + // otherwise we must compute the normal form + std::vector a_decomp; + d_slicer->getBaseDecomposition(a, a_decomp); + Node new_a = utils::mkConcat(a_decomp); + d_normalFormCache[a] = new_a; + return new_a; +} + bool CoreSolver::decomposeFact(TNode fact) { Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl; // FIXME: are this the right things to assert? @@ -107,18 +121,13 @@ bool CoreSolver::decomposeFact(TNode fact) { TNode eq = fact.getKind() == kind::NOT? fact[0] : fact; TNode a = eq[0]; - TNode b = eq[1]; - std::vector a_decomp; - std::vector b_decomp; - - d_slicer->getBaseDecomposition(a, a_decomp); - d_slicer->getBaseDecomposition(b, b_decomp); - - Assert (a_decomp.size() == b_decomp.size()); + TNode b = eq[1]; + Node new_a = getBaseDecomposition(a); + Node new_b = getBaseDecomposition(b); + + Assert (utils::getSize(new_a) == utils::getSize(new_b) && + utils::getSize(new_a) == utils::getSize(a)); - Node new_a = utils::mkConcat(a_decomp); - Node new_b = utils::mkConcat(b_decomp); - NodeManager* nm = NodeManager::currentNM(); Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a); Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b); @@ -134,10 +143,15 @@ bool CoreSolver::decomposeFact(TNode fact) { if (fact.getKind() == kind::EQUAL) { // assert the individual equalities as well // a_i == b_i - for (unsigned i = 0; i < a_decomp.size(); ++i) { - Node eq_i = nm->mkNode(kind::EQUAL, a_decomp[i], b_decomp[i]); - ok = assertFact(eq_i, fact); - if (!ok) return false; + if (new_a.getKind() == kind::BITVECTOR_CONCAT && + new_b.getKind() == kind::BITVECTOR_CONCAT) { + + Assert (new_a.getNumChildren() == new_b.getNumChildren()); + for (unsigned i = 0; i < new_a.getNumChildren(); ++i) { + Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]); + ok = assertFact(eq_i, fact); + if (!ok) return false; + } } } return true; diff --git a/src/theory/bv/bv_subtheory_core.h b/src/theory/bv/bv_subtheory_core.h index 20b42d61c..38676bfa6 100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@ -69,11 +69,13 @@ class CoreSolver : public SubtheorySolver { /** FIXME: for debugging purposes only */ context::CDList d_assertions; + __gnu_cxx::hash_map d_normalFormCache; Slicer* d_slicer; context::CDO d_isCoreTheory; bool assertFact(TNode fact, TNode reason); bool decomposeFact(TNode fact); + Node getBaseDecomposition(TNode a); public: bool isCoreTheory() {return d_isCoreTheory; } CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer); diff --git a/src/theory/bv/slicer.cpp b/src/theory/bv/slicer.cpp index 87295e8f6..f41612df3 100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@ -220,10 +220,13 @@ void UnionFind::merge(TermId t1, TermId t2) { d_statistics.d_numRepresentatives += -1; } -TermId UnionFind::find(TermId id) const { +TermId UnionFind::find(TermId id) { TermId repr = getRepr(id); - if (repr != UndefinedId) - return find(repr); + if (repr != UndefinedId) { + TermId find_id = find(repr); + setRepr(id, find_id); + return find_id; + } return id; } /** diff --git a/src/theory/bv/slicer.h b/src/theory/bv/slicer.h index b0929d617..55cecb117 100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@ -212,7 +212,7 @@ public: TermId addTerm(Index bitwidth); void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2); void merge(TermId t1, TermId t2); - TermId find(TermId t1) const ; + TermId find(TermId t1); void split(TermId term, Index i); void getNormalForm(const ExtractTerm& term, NormalForm& nf); -- 2.30.2