From: lianah Date: Wed, 13 Mar 2013 17:44:33 +0000 (-0400) Subject: post failed attempts at getting the incremental solver to work X-Git-Tag: cvc5-1.0.0~7361^2~42 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=3fcdb18fe92e5213aa708285c0d7d5e55633492b;p=cvc5.git post failed attempts at getting the incremental solver to work --- 3fcdb18fe92e5213aa708285c0d7d5e55633492b diff --cc src/theory/bv/bitblast_strategies.cpp index 3ce9bcb44,a952b2929..773685997 --- a/src/theory/bv/bitblast_strategies.cpp +++ b/src/theory/bv/bitblast_strategies.cpp @@@ -342,11 -342,11 +342,11 @@@ void DefaultVarBB (TNode node, Bits& bi } if(Debug.isOn("bitvector-bb")) { - BVDebug("bitvector-bb") << "theory::bv::DefaultVarBB bitblasting " << node << "\n"; - BVDebug("bitvector-bb") << " with bits " << toString(bits); + Debug("bitvector-bb") << "theory::bv::DefaultVarBB bitblasting " << node << "\n"; + Debug("bitvector-bb") << " with bits " << toString(bits); } - bb->storeVariable(node); + bb->storeVariable(node); } void DefaultConstBB (TNode node, Bits& bits, Bitblaster* bb) { diff --cc src/theory/bv/bv_subtheory.h index a256b6001,4dbba0797..d95aaa873 --- a/src/theory/bv/bv_subtheory.h +++ b/src/theory/bv/bv_subtheory.h @@@ -72,19 -71,19 +72,31 @@@ protected /** The bit-vector theory */ TheoryBV* d_bv; -- ++ context::CDQueue d_assertionQueue; ++ context::CDO d_assertionIndex; public: SubtheorySolver(context::Context* c, TheoryBV* bv) : d_context(c), -- d_bv(bv) ++ d_bv(bv), ++ d_assertionQueue(c), ++ d_assertionIndex(c, 0) {} virtual ~SubtheorySolver() {} ++ ++ virtual bool check(Theory::Effort e) = 0; ++ virtual void explain(TNode literal, std::vector& assumptions) = 0; ++ virtual void preRegister(TNode node) {} ++ virtual void collectModelInfo(TheoryModel* m) = 0; ++ bool done() { return d_assertionQueue.size() == d_assertionIndex; } ++ TNode get() { ++ Assert (!done()); ++ TNode res = d_assertionQueue[d_assertionIndex]; ++ d_assertionIndex = d_assertionIndex + 1; ++ return res; ++ } ++ void assertFact(TNode fact) { d_assertionQueue.push_back(fact); } -- virtual bool addAssertions(const std::vector& assertions, Theory::Effort e) = 0; -- virtual void explain(TNode literal, std::vector& assumptions) = 0; -- virtual void preRegister(TNode node) {} -- virtual void collectModelInfo(TheoryModel* m) = 0; }; } diff --cc src/theory/bv/bv_subtheory_bitblast.cpp index 985a9b500,501aafb29..2f76e32d3 --- a/src/theory/bv/bv_subtheory_bitblast.cpp +++ b/src/theory/bv/bv_subtheory_bitblast.cpp @@@ -52,12 -52,22 +52,21 @@@ void BitblastSolver::explain(TNode lite d_bitblaster->explain(literal, assumptions); } --bool BitblastSolver::addAssertions(const std::vector& assertions, Theory::Effort e) { -- Debug("bitvector::bitblaster") << "BitblastSolver::addAssertions (" << e << ")" << std::endl; - Debug("bitvector::bitblaster") << "number of assertions: " << assertions.size() << std::endl; - //// Lazy bit-blasting ++bool BitblastSolver::check(Theory::Effort e) { + //// Eager bit-blasting + if (options::bitvectorEagerBitblast()) { - for (unsigned i = 0; i < assertions.size(); ++i) { - TNode atom = assertions[i].getKind() == kind::NOT ? assertions[i][0] : assertions[i]; ++ while (!done()) { ++ TNode assertion = get(); ++ TNode atom = assertion.getKind() == kind::NOT ? assertion[0] : assertion; + if (atom.getKind() != kind::BITVECTOR_BITOF) { + d_bitblaster->bbAtom(atom); + } ++ return true; + } - return true; + } + //// Lazy bit-blasting - // bit-blast enqueued nodes while (!d_bitblastQueue.empty()) { TNode atom = d_bitblastQueue.front(); @@@ -65,9 -75,9 +74,9 @@@ d_bitblastQueue.pop(); } -- // propagation -- for (unsigned i = 0; i < assertions.size(); ++i) { -- TNode fact = assertions[i]; ++ // Processinga ssertions ++ while (!done()) { ++ TNode fact = get(); if (!d_bv->inConflict() && !d_bv->propagatedBy(fact, SUB_BITBLAST)) { // Some atoms have not been bit-blasted yet d_bitblaster->bbAtom(fact); @@@ -93,7 -103,7 +102,7 @@@ } } -- // solving ++ // Solving if (e == Theory::EFFORT_FULL || options::bitvectorEagerFullcheck()) { Assert(!d_bv->inConflict()); Debug("bitvector::bitblaster") << "BitblastSolver::addAssertions solving. \n"; diff --cc src/theory/bv/bv_subtheory_bitblast.h index 3396d813b,3396d813b..318fdd230 --- a/src/theory/bv/bv_subtheory_bitblast.h +++ b/src/theory/bv/bv_subtheory_bitblast.h @@@ -42,7 -42,7 +42,7 @@@ public ~BitblastSolver(); void preRegister(TNode node); -- bool addAssertions(const std::vector& assertions, Theory::Effort e); ++ bool check(Theory::Effort e); void explain(TNode literal, std::vector& assumptions); EqualityStatus getEqualityStatus(TNode a, TNode b); void collectModelInfo(TheoryModel* m); diff --cc src/theory/bv/bv_subtheory_core.cpp index 91cf29ee9,000000000..2e1320d1a mode 100644,000000..100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@@ -1,330 -1,0 +1,278 @@@ +/********************* */ +/*! \file bv_subtheory_eq.cpp + ** \verbatim + ** Original author: dejan + ** Major contributors: none + ** Minor contributors (to current version): lianah + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009-2012 New York University and The University of Iowa + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Algebraic solver. + ** + ** Algebraic solver. + **/ + +#include "theory/bv/bv_subtheory_eq.h" + +#include "theory/bv/theory_bv.h" +#include "theory/bv/theory_bv_utils.h" +#include "theory/bv/slicer.h" +#include "theory/model.h" + +using namespace std; +using namespace CVC4; +using namespace CVC4::context; +using namespace CVC4::theory; +using namespace CVC4::theory::bv; +using namespace CVC4::theory::bv::utils; + +CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer) + : SubtheorySolver(c, bv), + d_notify(*this), + d_equalityEngine(d_notify, c, "theory::bv::TheoryBV"), + d_assertions(c), + d_normalFormCache(), + d_slicer(slicer), - d_isCoreTheory(c, true) ++ d_isCoreTheory(c, true), ++ d_baseChanged(false), ++ d_checkCalled(false) +{ + if (d_useEqualityEngine) { + + // The kinds we are treating as function application in congruence + d_equalityEngine.addFunctionKind(kind::BITVECTOR_CONCAT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_AND); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_OR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XOR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NAND); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NOR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_XNOR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_COMP); + d_equalityEngine.addFunctionKind(kind::BITVECTOR_MULT); + d_equalityEngine.addFunctionKind(kind::BITVECTOR_PLUS); + d_equalityEngine.addFunctionKind(kind::BITVECTOR_EXTRACT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SUB); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_NEG); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UDIV); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UREM); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SDIV); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SREM); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SMOD); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SHL); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_LSHR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ASHR); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_ULE); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_UGE); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SLE); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGT); + // d_equalityEngine.addFunctionKind(kind::BITVECTOR_SGE); + } +} + +void CoreSolver::setMasterEqualityEngine(eq::EqualityEngine* eq) { + d_equalityEngine.setMasterEqualityEngine(eq); +} + +void CoreSolver::preRegister(TNode node) { + if (!d_useEqualityEngine) + return; + + if (node.getKind() == kind::EQUAL) { + d_equalityEngine.addTriggerEquality(node); ++ d_slicer->processEquality(node); + } else { + d_equalityEngine.addTerm(node); + } +} + + +void CoreSolver::explain(TNode literal, std::vector& assumptions) { + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; + if (atom.getKind() == kind::EQUAL) { + d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions); + } else { + d_equalityEngine.explainPredicate(atom, polarity, assumptions); + } +} + +Node CoreSolver::getBaseDecomposition(TNode a) { + std::vector a_decomp; ++ // FIXME: hack to do bitwise decomposition ++ // for (int i = utils::getSize(a) - 1; i>= 0; --i) { ++ // Node bit = Rewriter::rewrite(utils::mkExtract(a, i, i)); ++ // a_decomp.push_back(bit); ++ // } + d_slicer->getBaseDecomposition(a, a_decomp); + Node new_a = utils::mkConcat(a_decomp); + return new_a; +} + +bool CoreSolver::decomposeFact(TNode fact) { + Debug("bv-slicer") << "CoreSolver::decomposeFact fact=" << fact << endl; + // assert decompositions since the equality engine does not know the semantics of + // concat: + // a == a_1 concat ... concat a_k + // b == b_1 concat ... concat b_k + TNode eq = fact.getKind() == kind::NOT? fact[0] : fact; + + TNode a = eq[0]; + TNode b = eq[1]; - // we need to get the old decomposition to keep track of the cuts we added - Base a_old_base = d_slicer->getTopLevelBase(a); - Base b_old_base = d_slicer->getTopLevelBase(b); + - d_slicer->processEquality(eq); ++ // d_slicer->processEquality(eq); + + Node new_a = getBaseDecomposition(a); + Node new_b = getBaseDecomposition(b); + + Assert (utils::getSize(new_a) == utils::getSize(new_b) && + utils::getSize(new_a) == utils::getSize(a)); + + NodeManager* nm = NodeManager::currentNM(); + Node a_eq_new_a = nm->mkNode(kind::EQUAL, a, new_a); + Node b_eq_new_b = nm->mkNode(kind::EQUAL, b, new_b); + - Base a_new_base = d_slicer->getTopLevelBase(a); - Base b_new_base = d_slicer->getTopLevelBase(b); - + bool ok = true; - ok = addNewSplits(a, a_old_base, a_new_base); - if (!ok) return false; - ok = addNewSplits(b, b_old_base, b_new_base); - if (!ok) return false; - - ok = assertFact(a_eq_new_a, utils::mkTrue()); ++ ok = assertFactToEqualityEngine(a_eq_new_a, utils::mkTrue()); + if (!ok) return false; - ok = assertFact(b_eq_new_b, utils::mkTrue()); ++ ok = assertFactToEqualityEngine(b_eq_new_b, utils::mkTrue()); + if (!ok) return false; - ok = assertFact(fact, fact); ++ ok = assertFactToEqualityEngine(fact, fact); + if (!ok) return false; + + if (fact.getKind() == kind::EQUAL) { + // assert the individual equalities as well + // a_i == b_i + if (new_a.getKind() == kind::BITVECTOR_CONCAT && + new_b.getKind() == kind::BITVECTOR_CONCAT) { + + Assert (new_a.getNumChildren() == new_b.getNumChildren()); + for (unsigned i = 0; i < new_a.getNumChildren(); ++i) { + Node eq_i = nm->mkNode(kind::EQUAL, new_a[i], new_b[i]); - ok = assertFact(eq_i, fact); ++ ok = assertFactToEqualityEngine(eq_i, fact); + if (!ok) return false; + } + } + } + return true; +} + - bool CoreSolver::addNewSplits(TNode n, Base& old_base, Base& new_base) { - if (n.getKind() == kind::BITVECTOR_EXTRACT) { - n = n[0]; - } - Assert (old_base.getBitwidth() == new_base.getBitwidth() && - utils::getSize(n) == old_base.getBitwidth()); - - Index high, low = 0; - std::vector > toSlice; - bool hasNewCut = false; - // collect the intervals that need to be sliced - for (unsigned i = 0; i <= old_base.getBitwidth(); ++i) { - Assert (! old_base.isCutPoint(i) || new_base.isCutPoint(i)); - if (new_base.isCutPoint(i) && !old_base.isCutPoint(i)) { - hasNewCut = true; - } - if (new_base.isCutPoint(i) && old_base.isCutPoint(i)) { - high = i; - if (hasNewCut) { - toSlice.push_back(std::pair(high, low)); - } - low = i; - hasNewCut = false; - } - } - // for each interval, assert the proper equality - for (unsigned i = 0; i < toSlice.size(); ++i) { - int high = toSlice[i].first; - int low = toSlice[i].second; - int prev = high; - std::vector extracts; - for (int k = high -1; k >= low; --k) { - if (new_base.isCutPoint(k) && (!old_base.isCutPoint(k) || k == low)) { - // add a new extract - Node ex = utils::mkExtract(n, prev - 1, k); - prev = k; - extracts.push_back(ex); - } - } - Node concat = utils::mkConcat(extracts); - Node current = utils::mkExtract(n, high - 1, low); - Node eq = utils::mkNode(kind::EQUAL, concat, current); - bool ok = assertFact(eq, utils::mkTrue()); - if (!ok) - return false; - } - return true; - } - - - bool CoreSolver::addAssertions(const std::vector& assertions, Theory::Effort e) { - Trace("bitvector::core") << "CoreSolver::addAssertions \n"; ++bool CoreSolver::check(Theory::Effort e) { ++ d_checkCalled = true; ++ Trace("bitvector::core") << "CoreSolver::check \n"; + Assert (!d_bv->inConflict()); + + bool ok = true; + std::vector core_eqs; - for (unsigned i = 0; i < assertions.size(); ++i) { - TNode fact = assertions[i]; ++ while (! done()) { ++ TNode fact = get(); + + // update whether we are in the core fragment + if (d_isCoreTheory && !d_slicer->isCoreTerm(fact)) { + d_isCoreTheory = false; + } + + // only reason about equalities + if (fact.getKind() == kind::EQUAL || (fact.getKind() == kind::NOT && fact[0].getKind() == kind::EQUAL)) { + TNode eq = fact.getKind() == kind::EQUAL ? fact : fact[0]; + ok = decomposeFact(fact); + } else { - ok = assertFact(fact, fact); ++ ok = assertFactToEqualityEngine(fact, fact); + } + if (!ok) + return false; + } - ++ + return true; +} + - bool CoreSolver::assertFact(TNode fact, TNode reason) { - Debug("bv-slicer") << "CoreSolver::assertFact fact=" << fact << endl; ++bool CoreSolver::assertFactToEqualityEngine(TNode fact, TNode reason) { ++ Debug("bv-slicer") << "CoreSolver::assertFactToEqualityEngine fact=" << fact << endl; + Debug("bv-slicer") << " reason=" << reason << endl; + // Notify the equality engine + if (d_useEqualityEngine && !d_bv->inConflict() && !d_bv->propagatedBy(fact, SUB_CORE) ) { + Trace("bitvector::core") << " (assert " << fact << ")\n"; + //d_assertions.push_back(fact); + bool negated = fact.getKind() == kind::NOT; + TNode predicate = negated ? fact[0] : fact; + if (predicate.getKind() == kind::EQUAL) { + if (negated) { + // dis-equality + d_equalityEngine.assertEquality(predicate, false, reason); + } else { + // equality + d_equalityEngine.assertEquality(predicate, true, reason); + } + } else { + // Adding predicate if the congruence over it is turned on + if (d_equalityEngine.isFunctionKind(predicate.getKind())) { + d_equalityEngine.assertPredicate(predicate, !negated, reason); + } + } + } + + // checking for a conflict + if (d_bv->inConflict()) { + return false; + } + return true; +} + +bool CoreSolver::NotifyClass::eqNotifyTriggerEquality(TNode equality, bool value) { - BVDebug("bitvector::core") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; ++ Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerEquality(" << equality << ", " << (value ? "true" : "false" )<< ")" << std::endl; + if (value) { + return d_solver.storePropagation(equality); + } else { + return d_solver.storePropagation(equality.notNode()); + } +} + +bool CoreSolver::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, bool value) { - BVDebug("bitvector::core") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" ) << ")" << std::endl; ++ Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false" ) << ")" << std::endl; + if (value) { + return d_solver.storePropagation(predicate); + } else { + return d_solver.storePropagation(predicate.notNode()); + } +} + +bool CoreSolver::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) { + Debug("bitvector::core") << "NotifyClass::eqNotifyTriggerTermMerge(" << t1 << ", " << t2 << ")" << std::endl; + if (value) { + return d_solver.storePropagation(t1.eqNode(t2)); + } else { + return d_solver.storePropagation(t1.eqNode(t2).notNode()); + } +} + +void CoreSolver::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) { + d_solver.conflict(t1, t2); +} + +bool CoreSolver::storePropagation(TNode literal) { + return d_bv->storePropagation(literal, SUB_CORE); +} + +void CoreSolver::conflict(TNode a, TNode b) { + std::vector assumptions; + d_equalityEngine.explainEquality(a, b, true, assumptions); + d_bv->setConflict(mkAnd(assumptions)); +} + +void CoreSolver::collectModelInfo(TheoryModel* m) { + if (Debug.isOn("bitvector-model")) { + context::CDList::const_iterator it = d_assertions.begin(); + for (; it!= d_assertions.end(); ++it) { + Debug("bitvector-model") << "CoreSolver::collectModelInfo (assert " + << *it << ")\n"; + } + } + set termSet; + d_bv->computeRelevantTerms(termSet); + m->assertEqualityEngine(&d_equalityEngine, &termSet); +} diff --cc src/theory/bv/bv_subtheory_core.h index 1adf813ff,000000000..d5235a864 mode 100644,000000..100644 --- a/src/theory/bv/bv_subtheory_core.h +++ b/src/theory/bv/bv_subtheory_core.h @@@ -1,107 -1,0 +1,108 @@@ +/********************* */ +/*! \file bv_subtheory_eq.h + ** \verbatim + ** Original author: dejan + ** Major contributors: none + ** Minor contributors (to current version): lianah, mdeters + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009-2012 New York University and The University of Iowa + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Algebraic solver. + ** + ** Algebraic solver. + **/ + +#pragma once + +#include "cvc4_private.h" +#include "theory/bv/bv_subtheory.h" +#include "context/cdhashmap.h" + +namespace CVC4 { +namespace theory { +namespace bv { + +class Slicer; +class Base; +/** + * Bitvector equality solver + */ +class CoreSolver : public SubtheorySolver { + + enum FactSource { + AXIOM = 0, // this is asserting that a node is equal to its decomposition + ASSERTION = 1, // externally visible assertion + SPLIT = 2 // fact resulting from a split + }; + + // NotifyClass: handles call-back from congruence closure module + + class NotifyClass : public eq::EqualityEngineNotify { + CoreSolver& d_solver; + + public: + NotifyClass(CoreSolver& solver): d_solver(solver) {} + bool eqNotifyTriggerEquality(TNode equality, bool value); + bool eqNotifyTriggerPredicate(TNode predicate, bool value); + bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value); + void eqNotifyConstantTermMerge(TNode t1, TNode t2); + void eqNotifyNewClass(TNode t) { } + void eqNotifyPreMerge(TNode t1, TNode t2) { } + void eqNotifyPostMerge(TNode t1, TNode t2) { } + void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) { } +}; + + + /** The notify class for d_equalityEngine */ + NotifyClass d_notify; + + /** Equality engine */ + eq::EqualityEngine d_equalityEngine; + + /** Store a propagation to the bv solver */ + bool storePropagation(TNode literal); + + /** Store a conflict from merging two constants */ + void conflict(TNode a, TNode b); + + /** FIXME: for debugging purposes only */ + context::CDList d_assertions; + __gnu_cxx::hash_map d_normalFormCache; + Slicer* d_slicer; + context::CDO d_isCoreTheory; + - bool assertFact(TNode fact, TNode reason); ++ bool assertFactToEqualityEngine(TNode fact, TNode reason); + bool decomposeFact(TNode fact); + Node getBaseDecomposition(TNode a); - bool addNewSplits(TNode n, Base& old_base, Base& new_base); ++ bool d_baseChanged; ++ bool d_checkCalled; +public: - bool isCoreTheory() {return d_isCoreTheory; } + CoreSolver(context::Context* c, TheoryBV* bv, Slicer* slicer); - void setMasterEqualityEngine(eq::EqualityEngine* eq); ++ bool isCoreTheory() { return d_isCoreTheory; } ++ void setMasterEqualityEngine(eq::EqualityEngine* eq); + void preRegister(TNode node); - bool addAssertions(const std::vector& assertions, Theory::Effort e); ++ bool check(Theory::Effort e); + void explain(TNode literal, std::vector& assumptions); + void collectModelInfo(TheoryModel* m); + void addSharedTerm(TNode t) { + d_equalityEngine.addTriggerTerm(t, THEORY_BV); + } + EqualityStatus getEqualityStatus(TNode a, TNode b) { + if (d_equalityEngine.areEqual(a, b)) { + // The terms are implied to be equal + return EQUALITY_TRUE; + } + if (d_equalityEngine.areDisequal(a, b, false)) { + // The terms are implied to be dis-equal + return EQUALITY_FALSE; + } + return EQUALITY_UNKNOWN; + } +}; + + +} +} +} diff --cc src/theory/bv/slicer.cpp index 2334ed2b0,000000000..ac668ab20 mode 100644,000000..100644 --- a/src/theory/bv/slicer.cpp +++ b/src/theory/bv/slicer.cpp @@@ -1,671 -1,0 +1,684 @@@ +/********************* */ +/*! \file slicer.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Bitvector theory. + ** + ** Bitvector theory. + **/ + +#include "theory/bv/slicer.h" +#include "theory/bv/theory_bv_utils.h" +#include "theory/rewriter.h" + +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::bv; +using namespace std; + + +const TermId CVC4::theory::bv::UndefinedId = -1; + +/** + * Base + * + */ +Base::Base(uint32_t size) + : d_size(size), + d_repr(size/32 + (size % 32 == 0? 0 : 1), 0) +{ + Assert (d_size > 0); +} + + +void Base::sliceAt(Index index) { + Index vector_index = index / 32; + Assert (vector_index < d_size); + Index int_index = index % 32; + uint32_t bit_mask = utils::pow2(int_index); + d_repr[vector_index] = d_repr[vector_index] | bit_mask; +} + ++void Base::undoSliceAt(Index index) { ++ Index vector_index = index / 32; ++ Assert (vector_index < d_size); ++ Index int_index = index % 32; ++ uint32_t bit_mask = utils::pow2(int_index); ++ d_repr[vector_index] = d_repr[vector_index] ^ bit_mask; ++} ++ +void Base::sliceWith(const Base& other) { + Assert (d_size == other.d_size); + for (unsigned i = 0; i < d_repr.size(); ++i) { + d_repr[i] = d_repr[i] | other.d_repr[i]; + } +} + +bool Base::isCutPoint (Index index) const { + // there is an implicit cut point at the end and begining of the bv + if (index == d_size || index == 0) + return true; + + Index vector_index = index / 32; + Assert (vector_index < d_size); + Index int_index = index % 32; + uint32_t bit_mask = utils::pow2(int_index); + + return (bit_mask & d_repr[vector_index]) != 0; +} + +void Base::diffCutPoints(const Base& other, Base& res) const { + Assert (d_size == other.d_size && res.d_size == d_size); + for (unsigned i = 0; i < d_repr.size(); ++i) { + Assert (res.d_repr[i] == 0); + res.d_repr[i] = d_repr[i] ^ other.d_repr[i]; + } +} + +bool Base::isEmpty() const { + for (unsigned i = 0; i< d_repr.size(); ++i) { + if (d_repr[i] != 0) + return false; + } + return true; +} + +std::string Base::debugPrint() const { + std::ostringstream os; + os << "["; + bool first = true; + for (int i = d_size - 1; i >= 0; --i) { + if (isCutPoint(i)) { + if (first) + first = false; + else + os <<"| "; + + os << i ; + } + } + os << "]"; + return os.str(); +} + +/** + * ExtractTerm + * + */ + +std::string ExtractTerm::debugPrint() const { + ostringstream os; + os << "id" << id << "[" << high << ":" << low <<"] "; + return os.str(); +} + +/** + * NormalForm + * + */ + +std::pair NormalForm::getTerm(Index index, const UnionFind& uf) const { + Assert (index < base.getBitwidth()); + Index count = 0; + for (unsigned i = 0; i < decomp.size(); ++i) { + Index size = uf.getBitwidth(decomp[i]); + if ( count + size > index && index >= count) { + return pair(decomp[i], count); + } + count += size; + } + Unreachable(); +} + + + +std::string NormalForm::debugPrint(const UnionFind& uf) const { + ostringstream os; + os << "NF " << base.debugPrint() << endl; + os << "("; + for (int i = decomp.size() - 1; i>= 0; --i) { + os << decomp[i] << "[" << uf.getBitwidth(decomp[i]) <<"]"; + os << (i != 0? ", " : ""); + } + os << ") \n"; + return os.str(); +} +/** + * UnionFind::Node + * + */ + +std::string UnionFind::Node::debugPrint() const { + ostringstream os; + os << "Repr " << d_repr << " ["<< d_bitwidth << "] "; + os << "( " << d_ch1 <<", " << d_ch0 << ")" << endl; + return os.str(); +} + + +/** + * UnionFind + * + */ +TermId UnionFind::addTerm(Index bitwidth) { + Node node(bitwidth); + d_nodes.push_back(node); + ++(d_statistics.d_numNodes); + + TermId id = d_nodes.size() - 1; + // d_representatives.insert(id); + ++(d_statistics.d_numRepresentatives); + + Debug("bv-slicer-uf") << "UnionFind::addTerm " << id << " size " << bitwidth << endl; + return id; +} +/** + * At this point we assume the slicings of the two terms are properly aligned. + * + * @param t1 + * @param t2 + */ +void UnionFind::unionTerms(const ExtractTerm& t1, const ExtractTerm& t2) { + Debug("bv-slicer") << "UnionFind::unionTerms " << t1.debugPrint() << " and \n" + << " " << t2.debugPrint() << endl; + Assert (t1.getBitwidth() == t2.getBitwidth()); + + NormalForm nf1(t1.getBitwidth()); + NormalForm nf2(t2.getBitwidth()); + + getNormalForm(t1, nf1); + getNormalForm(t2, nf2); + + Assert (nf1.decomp.size() == nf2.decomp.size()); + Assert (nf1.base == nf2.base); + + for (unsigned i = 0; i < nf1.decomp.size(); ++i) { + merge (nf1.decomp[i], nf2.decomp[i]); + } +} + +/** + * Merge the two terms in the union find. Both t1 and t2 + * should be root terms. + * + * @param t1 + * @param t2 + */ +void UnionFind::merge(TermId t1, TermId t2) { + Debug("bv-slicer-uf") << "UnionFind::merge (" << t1 <<", " << t2 << ")" << endl; + ++(d_statistics.d_numMerges); + t1 = find(t1); + t2 = find(t2); + + if (t1 == t2) + return; + + Assert (! hasChildren(t1) && ! hasChildren(t2)); + setRepr(t1, t2); + recordOperation(UnionFind::MERGE, t1); + //d_representatives.erase(t1); + d_statistics.d_numRepresentatives += -1; +} + +TermId UnionFind::find(TermId id) { + TermId repr = getRepr(id); + if (repr != UndefinedId) { + TermId find_id = find(repr); + // setRepr(id, find_id); + return find_id; + } + return id; +} +/** + * Splits the representative of the term between i-1 and i + * + * @param id the id of the term + * @param i the index we are splitting at + * + * @return + */ +void UnionFind::split(TermId id, Index i) { + Debug("bv-slicer-uf") << "UnionFind::split " << id << " at " << i << endl; + id = find(id); + Debug("bv-slicer-uf") << " node: " << d_nodes[id].debugPrint() << endl; + + if (i == 0 || i == getBitwidth(id)) { + // nothing to do - return; ++ return; + } ++ + Assert (i < getBitwidth(id)); + if (!hasChildren(id)) { + // first time we split this term + TermId bottom_id = addTerm(i); + TermId top_id = addTerm(getBitwidth(id) - i); + setChildren(id, top_id, bottom_id); + recordOperation(UnionFind::SPLIT, id); + } else { + Index cut = getCutPoint(id); + if (i < cut ) + split(getChild(id, 0), i); + else + split(getChild(id, 1), i - cut); + } + ++(d_statistics.d_numSplits); +} + +void UnionFind::getNormalForm(const ExtractTerm& term, NormalForm& nf) { + nf.clear(); + getDecomposition(term, nf.decomp); + // update nf base + Index count = 0; + for (unsigned i = 0; i < nf.decomp.size(); ++i) { + count += getBitwidth(nf.decomp[i]); + nf.base.sliceAt(count); + } + Debug("bv-slicer-uf") << "UnionFind::getNormalFrom term: " << term.debugPrint() << endl; + Debug("bv-slicer-uf") << " nf: " << nf.debugPrint(*this) << endl; +} + +void UnionFind::getDecomposition(const ExtractTerm& term, Decomposition& decomp) { + // making sure the term is aligned + TermId id = find(term.id); + + Assert (term.high < getBitwidth(id)); + // because we split the node, this must be the whole extract + if (!hasChildren(id)) { + Assert (term.high == getBitwidth(id) - 1 && + term.low == 0); + decomp.push_back(id); + return; + } + + Index cut = getCutPoint(id); + + if (term.low < cut && term.high < cut) { + // the extract falls entirely on the low child + ExtractTerm child_ex(getChild(id, 0), term.high, term.low); + getDecomposition(child_ex, decomp); + } + else if (term.low >= cut && term.high >= cut){ + // the extract falls entirely on the high child + ExtractTerm child_ex(getChild(id, 1), term.high - cut, term.low - cut); + getDecomposition(child_ex, decomp); + } + else { + // the extract is split over the two children + ExtractTerm low_child(getChild(id, 0), cut - 1, term.low); + getDecomposition(low_child, decomp); + ExtractTerm high_child(getChild(id, 1), term.high - cut, 0); + getDecomposition(high_child, decomp); + } +} +/** + * May cause reslicings of the decompositions. Must not assume the decompositons + * are the current normal form. + * + * @param d1 + * @param d2 + * @param common + */ +void UnionFind::handleCommonSlice(const Decomposition& decomp1, const Decomposition& decomp2, TermId common) { + Debug("bv-slicer") << "UnionFind::handleCommonSlice common = " << common << endl; + Index common_size = getBitwidth(common); + // find starting points of common slice + Index start1 = 0; + for (unsigned j = 0; j < decomp1.size(); ++j) { + if (decomp1[j] == common) + break; + start1 += getBitwidth(decomp1[j]); + } + + Index start2 = 0; + for (unsigned j = 0; j < decomp2.size(); ++j) { + if (decomp2[j] == common) + break; + start2 += getBitwidth(decomp2[j]); + } + if (start1 > start2) { + Index temp = start1; + start1 = start2; + start2 = temp; + } + + if (start2 - start1 < common_size) { + Index overlap = start1 + common_size - start2; + Assert (overlap > 0); + Index diff = common_size - overlap; + Assert (diff >= 0); + Index granularity = utils::gcd(diff, overlap); + // split the common part + for (unsigned i = 0; i < common_size; i+= granularity) { + split(common, i); + } + } + +} + +void UnionFind::alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2) { + Debug("bv-slicer") << "UnionFind::alignSlicings " << term1.debugPrint() << endl; + Debug("bv-slicer") << " " << term2.debugPrint() << endl; + NormalForm nf1(term1.getBitwidth()); + NormalForm nf2(term2.getBitwidth()); + + getNormalForm(term1, nf1); + getNormalForm(term2, nf2); + + Assert (nf1.base.getBitwidth() == nf2.base.getBitwidth()); + + // first check if the two have any common slices + std::vector intersection; + utils::intersect(nf1.decomp, nf2.decomp, intersection); + for (unsigned i = 0; i < intersection.size(); ++i) { + // handle common slice may change the normal form + handleCommonSlice(nf1.decomp, nf2.decomp, intersection[i]); + } + // propagate cuts to a fixpoint + bool changed; + Base cuts(term1.getBitwidth()); + do { + changed = false; + // we need to update the normal form which may have changed + getNormalForm(term1, nf1); + getNormalForm(term2, nf2); + + // align the cuts points of the two slicings + // FIXME: this can be done more efficiently + cuts.sliceWith(nf1.base); + cuts.sliceWith(nf2.base); + + for (unsigned i = 0; i < cuts.getBitwidth(); ++i) { + if (cuts.isCutPoint(i)) { + if (!nf1.base.isCutPoint(i)) { + pair pair1 = nf1.getTerm(i, *this); + split(pair1.first, i - pair1.second); + changed = true; + } + if (!nf2.base.isCutPoint(i)) { + pair pair2 = nf2.getTerm(i, *this); + split(pair2.first, i - pair2.second); + changed = true; + } + } + } + } while (changed); +} +/** + * Given an extract term a[i:j] makes sure a is sliced + * at indices i and j. + * + * @param term + */ +void UnionFind::ensureSlicing(const ExtractTerm& term) { + //Debug("bv-slicer") << "Slicer::ensureSlicing " << term.debugPrint() << endl; + TermId id = find(term.id); + split(id, term.high + 1); + split(id, term.low); +} + +void UnionFind::backtrack() { ++ return; + int size = d_undoStack.size(); + for (int i = size; i > d_undoStackIndex.get(); --i) { + Operation op = d_undoStack.back(); + Assert (!d_undoStack.empty()); + d_undoStack.pop_back(); + if (op.op == UnionFind::MERGE) { + undoMerge(op.id); + } else { + Assert (op.op == UnionFind::SPLIT); + undoSplit(op.id); + } + } +} + +void UnionFind::undoMerge(TermId id) { + TermId repr = getRepr(id); + Assert (repr != id); + setRepr(id, UndefinedId); +} + +void UnionFind::undoSplit(TermId id) { + Assert (hasChildren(id)); + setChildren(id, UndefinedId, UndefinedId); +} + +void UnionFind::recordOperation(OperationKind op, TermId term) { ++ if (op == SPLIT) { ++ d_newSplit = true; ++ } + d_undoStackIndex.set(d_undoStackIndex.get() + 1); + d_undoStack.push_back(Operation(op, term)); + Assert (d_undoStack.size() == d_undoStackIndex); +} + +void UnionFind::getBase(TermId id, Base& base, Index offset) { + id = find(id); + if (!hasChildren(id)) + return; + TermId id1 = find(getChild(id, 1)); + TermId id0 = find(getChild(id, 0)); + Index cut = getCutPoint(id); + base.sliceAt(cut + offset); + getBase(id1, base, cut + offset); + getBase(id0, base, offset); +} + + +/** + * Slicer + * + */ + +ExtractTerm Slicer::registerTerm(TNode node) { + Index low = 0, high = utils::getSize(node) - 1; + TNode n = node; + if (node.getKind() == kind::BITVECTOR_EXTRACT) { + n = node[0]; + high = utils::getExtractHigh(node); + low = utils::getExtractLow(node); + } + if (d_nodeToId.find(n) == d_nodeToId.end()) { + TermId id = d_unionFind.addTerm(utils::getSize(n)); + d_nodeToId[n] = id; + d_idToNode[id] = n; + } + TermId id = d_nodeToId[n]; + ExtractTerm res(id, high, low); + Debug("bv-slicer") << "Slicer::registerTerm " << node << " => " << res.debugPrint() << endl; + return res; +} + +void Slicer::processEquality(TNode eq) { + Debug("bv-slicer") << "Slicer::processEquality: " << eq << endl; + + Assert (eq.getKind() == kind::EQUAL); + TNode a = eq[0]; + TNode b = eq[1]; + ExtractTerm a_ex= registerTerm(a); + ExtractTerm b_ex= registerTerm(b); + + d_unionFind.ensureSlicing(a_ex); + d_unionFind.ensureSlicing(b_ex); + + d_unionFind.alignSlicings(a_ex, b_ex); + d_unionFind.unionTerms(a_ex, b_ex); + Debug("bv-slicer") << "Base of " << a_ex.id <<" " << d_unionFind.debugPrint(a_ex.id) << endl; + Debug("bv-slicer") << "Base of " << b_ex.id <<" " << d_unionFind.debugPrint(b_ex.id) << endl; + Debug("bv-slicer") << "Slicer::processEquality done. " << endl; +} + +void Slicer::getBaseDecomposition(TNode node, std::vector& decomp) { + Debug("bv-slicer") << "Slicer::getBaseDecomposition " << node << endl; + + Index high = utils::getSize(node) - 1; + Index low = 0; + TNode top = node; + if (node.getKind() == kind::BITVECTOR_EXTRACT) { + high = utils::getExtractHigh(node); + low = utils::getExtractLow(node); + top = node[0]; + } + Assert (d_nodeToId.find(top) != d_nodeToId.end()); + TermId id = d_nodeToId[top]; + NormalForm nf(high-low+1); + d_unionFind.getNormalForm(ExtractTerm(id, high, low), nf); + + // construct actual extract nodes + Index current_low = 0; + Index current_high = 0; + for (unsigned i = 0; i < nf.decomp.size(); ++i) { + Index current_size = d_unionFind.getBitwidth(nf.decomp[i]); + current_high += current_size; + Node current = Rewriter::rewrite(utils::mkExtract(node, current_high - 1, current_low)); + current_low += current_size; + decomp.push_back(current); + } + + Debug("bv-slicer") << "as ["; + for (unsigned i = 0; i < decomp.size(); ++i) { + Debug("bv-slicer") << decomp[i] <<" "; + } + Debug("bv-slicer") << "]" << endl; + +} + +bool Slicer::isCoreTerm(TNode node) { + if (d_coreTermCache.find(node) == d_coreTermCache.end()) { + Kind kind = node.getKind(); + if (kind != kind::BITVECTOR_EXTRACT && + kind != kind::BITVECTOR_CONCAT && + kind != kind::EQUAL && kind != kind::NOT && + node.getMetaKind() != kind::metakind::VARIABLE && + kind != kind::CONST_BITVECTOR) { + d_coreTermCache[node] = false; + return false; + } else { + // we need to recursively check whether the term is a root term or not + bool isCore = true; + for (unsigned i = 0; i < node.getNumChildren(); ++i) { + isCore = isCore && isCoreTerm(node[i]); + } + d_coreTermCache[node] = isCore; + return isCore; + } + } + return d_coreTermCache[node]; +} +unsigned Slicer::d_numAddedEqualities = 0; + +void Slicer::splitEqualities(TNode node, std::vector& equalities) { + Assert (node.getKind() == kind::EQUAL); + TNode t1 = node[0]; + TNode t2 = node[1]; + + uint32_t width = utils::getSize(t1); + + Base base1(width); + if (t1.getKind() == kind::BITVECTOR_CONCAT) { + int size = 0; + // no need to count the last child since the end cut point is implicit + for (int i = t1.getNumChildren() - 1; i >= 1 ; --i) { + size = size + utils::getSize(t1[i]); + base1.sliceAt(size); + } + } + + Base base2(width); + if (t2.getKind() == kind::BITVECTOR_CONCAT) { + unsigned size = 0; + for (int i = t2.getNumChildren() - 1; i >= 1; --i) { + size = size + utils::getSize(t2[i]); + base2.sliceAt(size); + } + } + + base1.sliceWith(base2); + if (!base1.isEmpty()) { + // we split the equalities according to the base + int last = 0; + for (unsigned i = 1; i <= utils::getSize(t1); ++i) { + if (base1.isCutPoint(i)) { + Node extract1 = utils::mkExtract(t1, i-1, last); + Node extract2 = utils::mkExtract(t2, i-1, last); + last = i; + Assert (utils::getSize(extract1) == utils::getSize(extract2)); + equalities.push_back(utils::mkNode(kind::EQUAL, extract1, extract2)); + } + } + } else { + // just return same equality + equalities.push_back(node); + } + d_numAddedEqualities += equalities.size() - 1; +} + +/** + * Returns the base decomposition of the current term. + * + * @param id + * + * @return + */ +Base Slicer::getTopLevelBase(TNode node) { + if (node.getKind() == kind::BITVECTOR_EXTRACT) { + node = node[0]; + } + // if we haven't seen this node before it must not be sliced yet + if (d_nodeToId.find(node) == d_nodeToId.end()) { + return Base(utils::getSize(node)); + } + TermId id = d_nodeToId[node]; + Base base(d_unionFind.getBitwidth(id)); + d_unionFind.getBase(id, base, 0); + return base; +} + +std::string UnionFind::debugPrint(TermId id) { + ostringstream os; + if (hasChildren(id)) { + TermId id1 = find(getChild(id, 1)); + TermId id0 = find(getChild(id, 0)); + os << debugPrint(id1); + os << debugPrint(id0); + } else { + if (getRepr(id) == UndefinedId) { + os <<"id"<< id <<"[" << getBitwidth(id) <<"] "; + } else { + os << debugPrint(find(id)); + } + } + return os.str(); +} + +UnionFind::Statistics::Statistics(): + d_numNodes("theory::bv::slicer::NumberOfNodes", 0), + d_numRepresentatives("theory::bv::slicer::NumberOfRepresentatives", 0), + d_numSplits("theory::bv::slicer::NumberOfSplits", 0), + d_numMerges("theory::bv::slicer::NumberOfMerges", 0), + d_avgFindDepth("theory::bv::slicer::AverageFindDepth"), + d_numAddedEqualities("theory::bv::slicer::NumberOfEqualitiesAdded", Slicer::d_numAddedEqualities) +{ + StatisticsRegistry::registerStat(&d_numRepresentatives); + StatisticsRegistry::registerStat(&d_numSplits); + StatisticsRegistry::registerStat(&d_numMerges); + StatisticsRegistry::registerStat(&d_avgFindDepth); + StatisticsRegistry::registerStat(&d_numAddedEqualities); +} + +UnionFind::Statistics::~Statistics() { + StatisticsRegistry::unregisterStat(&d_numRepresentatives); + StatisticsRegistry::unregisterStat(&d_numSplits); + StatisticsRegistry::unregisterStat(&d_numMerges); + StatisticsRegistry::unregisterStat(&d_avgFindDepth); + StatisticsRegistry::unregisterStat(&d_numAddedEqualities); +} diff --cc src/theory/bv/slicer.h index 0508c67c1,000000000..6e09d971b mode 100644,000000..100644 --- a/src/theory/bv/slicer.h +++ b/src/theory/bv/slicer.h @@@ -1,290 -1,0 +1,338 @@@ +/********************* */ +/*! \file slicer.h + ** \verbatim + ** Original author: lianah + ** Major contributors: none + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010, 2011 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief Bitvector theory. + ** + ** Bitvector theory. + **/ + +#include "cvc4_private.h" + + +#include +#include +#include +#include + +#include "util/bitvector.h" +#include "util/statistics_registry.h" +#include "util/index.h" +#include "expr/node.h" +#include "theory/bv/theory_bv_utils.h" +#include "context/context.h" +#include "context/cdhashset.h" +#include "context/cdo.h" + +#ifndef __CVC4__THEORY__BV__SLICER_BV_H +#define __CVC4__THEORY__BV__SLICER_BV_H + + +namespace CVC4 { + +namespace theory { +namespace bv { + + + +typedef Index TermId; +extern const TermId UndefinedId; + ++class CDBase; + +/** + * Base + * + */ +class Base { + Index d_size; + std::vector d_repr; ++ void undoSliceAt(Index index); +public: + Base (Index size); - void sliceAt(Index index); ++ void sliceAt(Index index); ++ + void sliceWith(const Base& other); + bool isCutPoint(Index index) const; + void diffCutPoints(const Base& other, Base& res) const; + bool isEmpty() const; + std::string debugPrint() const; + Index getBitwidth() const { return d_size; } + void clear() { + for (unsigned i = 0; i < d_repr.size(); ++i) { + d_repr[i] = 0; + } + } + bool operator==(const Base& other) const { + if (other.getBitwidth() != getBitwidth()) + return false; + for (unsigned i = 0; i < d_repr.size(); ++i) { + if (d_repr[i] != other.d_repr[i]) + return false; + } + return true; + } ++ friend class CDBase; ++}; ++ ++ ++class CDBase : public context::ContextNotifyObj { ++ context::Context* d_ctx; ++ context::CDO d_undoIndex; ++ ++ std::vector d_undoStack; ++ Base d_base; ++ CDBase(context::Context* ctx, Index bitwidth) ++ : ContextNotifyObj(ctx), ++ d_ctx(ctx), ++ d_undoIndex(d_ctx), ++ d_undoStack(), ++ d_base(bitwidth) ++ {} ++ void sliceAt(Index i) { ++ Assert (!d_base.isCutPoint(i)); ++ d_undoStack.push_back(i); ++ d_undoIndex.set(d_undoIndex.get() + 1); ++ d_base.sliceAt(i); ++ } ++ bool isCutPoint(Index i) { ++ return d_base.isCutPoint(i); ++ } ++ Index getBitwidth() const {return d_base.getBitwidth(); } ++ virtual ~CDBase() throw(AssertionException) {} ++ void contextNotifyPop() { ++ backtrack(); ++ } ++ ++ void backtrack() { ++ for (unsigned i = d_undoIndex.get(); i < d_undoStack.size(); ++i) { ++ Index i = d_undoStack.back(); ++ d_undoStack.pop_back(); ++ d_base.undoSliceAt(i); ++ } ++ Assert(d_undoIndex.get() == d_undoStack.size()); ++ } ++ +}; + +/** + * UnionFind + * + */ +typedef context::CDHashSet > CDTermSet; +typedef std::vector Decomposition; + +struct ExtractTerm { + TermId id; + Index high; + Index low; + ExtractTerm(TermId i, Index h, Index l) + : id (i), + high(h), + low(l) + { + Assert (h >= l && id != UndefinedId); + } + Index getBitwidth() const { return high - low + 1; } + std::string debugPrint() const; +}; + +class UnionFind; + +struct NormalForm { + Base base; + Decomposition decomp; + + NormalForm(Index bitwidth) + : base(bitwidth), + decomp() + {} + /** + * Returns the term in the decomposition on which the index i + * falls in + * @param i + * + * @return + */ + std::pair getTerm(Index i, const UnionFind& uf) const; + std::string debugPrint(const UnionFind& uf) const; + void clear() { base.clear(); decomp.clear(); } +}; + + +class UnionFind : public context::ContextNotifyObj { + class Node { + Index d_bitwidth; + TermId d_ch1, d_ch0; + TermId d_repr; + public: + Node(Index b) + : d_bitwidth(b), + d_ch1(UndefinedId), + d_ch0(UndefinedId), + d_repr(UndefinedId) + {} + + TermId getRepr() const { return d_repr; } + Index getBitwidth() const { return d_bitwidth; } + bool hasChildren() const { return d_ch1 != UndefinedId && d_ch0 != UndefinedId; } + + TermId getChild(Index i) const { + Assert (i < 2); + return i == 0? d_ch0 : d_ch1; + } + void setRepr(TermId id) { + Assert (! hasChildren()); + d_repr = id; + } + void setChildren(TermId ch1, TermId ch0) { + // Assert (d_repr == UndefinedId && !hasChildren()); + d_ch1 = ch1; + d_ch0 = ch0; + } + std::string debugPrint() const; + }; + + /// map from TermId to the nodes that represent them + std::vector d_nodes; + /// a term is in this set if it is its own representative + //CDTermSet d_representatives; + + void getDecomposition(const ExtractTerm& term, Decomposition& decomp); + void handleCommonSlice(const Decomposition& d1, const Decomposition& d2, TermId common); + /// getter methods for the internal nodes + TermId getRepr(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getRepr(); + } + TermId getChild(TermId id, Index i) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getChild(i); + } + Index getCutPoint(TermId id) const { + return getBitwidth(getChild(id, 0)); + } + bool hasChildren(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].hasChildren(); + } + /// setter methods for the internal nodes + void setRepr(TermId id, TermId new_repr) { + Assert (id < d_nodes.size()); + d_nodes[id].setRepr(new_repr); + } + void setChildren(TermId id, TermId ch1, TermId ch0) { + Assert ((ch1 == UndefinedId && ch0 == UndefinedId) || + (id < d_nodes.size() && getBitwidth(id) == getBitwidth(ch1) + getBitwidth(ch0))); + d_nodes[id].setChildren(ch1, ch0); + } + + /* Backtracking mechanisms */ + + enum OperationKind { + MERGE, + SPLIT + }; + + struct Operation { + OperationKind op; + TermId id; + Operation(OperationKind o, TermId i) + : op(o), id(i) {} + }; + + std::vector d_undoStack; + context::CDO d_undoStackIndex; + + void backtrack(); + void undoMerge(TermId id); + void undoSplit(TermId id); + void recordOperation(OperationKind op, TermId term); + virtual ~UnionFind() throw(AssertionException) {} + class Statistics { + public: + IntStat d_numNodes; + IntStat d_numRepresentatives; + IntStat d_numSplits; + IntStat d_numMerges; + AverageStat d_avgFindDepth; + ReferenceStat d_numAddedEqualities; + Statistics(); + ~Statistics(); + }; + + Statistics d_statistics; ++ bool d_newSplit; +public: + UnionFind(context::Context* ctx) + : ContextNotifyObj(ctx), + d_nodes(), - // d_representatives(ctx), + d_undoStack(), + d_undoStackIndex(ctx), + d_statistics() + {} + + TermId addTerm(Index bitwidth); + void unionTerms(const ExtractTerm& t1, const ExtractTerm& t2); + void merge(TermId t1, TermId t2); + TermId find(TermId t1); + void split(TermId term, Index i); + + void getNormalForm(const ExtractTerm& term, NormalForm& nf); + void alignSlicings(const ExtractTerm& term1, const ExtractTerm& term2); + void ensureSlicing(const ExtractTerm& term); + Index getBitwidth(TermId id) const { + Assert (id < d_nodes.size()); + return d_nodes[id].getBitwidth(); + } + void getBase(TermId id, Base& base, Index offset); + std::string debugPrint(TermId id); + + void contextNotifyPop() { + backtrack(); + } ++ bool hasNewSplit() { return d_newSplit; } ++ void resetNewSplit() { d_newSplit = false; } + + friend class Slicer; +}; + +class Slicer { + __gnu_cxx::hash_map d_idToNode; + __gnu_cxx::hash_map d_nodeToId; + __gnu_cxx::hash_map d_coreTermCache; + UnionFind d_unionFind; + ExtractTerm registerTerm(TNode node); +public: + Slicer(context::Context* ctx) + : d_idToNode(), + d_nodeToId(), + d_coreTermCache(), + d_unionFind(ctx) + {} + + void getBaseDecomposition(TNode node, std::vector& decomp); + void processEquality(TNode eq); + bool isCoreTerm (TNode node); + Base getTopLevelBase(TNode node); + static void splitEqualities(TNode node, std::vector& equalities); - static unsigned d_numAddedEqualities; ++ static unsigned d_numAddedEqualities; ++ inline bool hasNewSplit() { return d_unionFind.hasNewSplit(); } ++ inline void resetNewSplit() { d_unionFind.resetNewSplit(); } +}; + + +}/* CVC4::theory::bv namespace */ +}/* CVC4::theory namespace */ +}/* CVC4 namespace */ + +#endif /* __CVC4__THEORY__BV__SLICER_BV_H */ diff --cc src/theory/bv/theory_bv.cpp index 6248782bd,57a77c0d2..5d034287d --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@@ -40,10 -39,8 +40,9 @@@ TheoryBV::TheoryBV(context::Context* c d_context(c), d_alreadyPropagatedSet(c), d_sharedTermsSet(c), + d_slicer(c), - d_bitblastAssertionsQueue(c), d_bitblastSolver(c, this), - d_equalitySolver(c, this), + d_coreSolver(c, this, &d_slicer), d_statistics(), d_conflict(c, false), d_literalsToPropagate(c), @@@ -110,29 -105,25 +109,23 @@@ void TheoryBV::check(Effort e return; } -- // getting the new assertions -- std::vector new_assertions; while (!done()) { -- Assertion assertion = get(); -- TNode fact = assertion.assertion; -- new_assertions.push_back(fact); - d_bitblastAssertionsQueue.push_back(fact); -- Debug("bitvector-assertions") << "TheoryBV::check assertion " << fact << "\n"; ++ TNode fact = get().assertion; ++ d_coreSolver.assertFact(fact); ++ d_bitblastSolver.assertFact(fact); } ++ bool ok = true; if (!inConflict()) { -- // sending assertions to the equality solver first - d_coreSolver.addAssertions(new_assertions, e); - d_equalitySolver.addAssertions(new_assertions, e); ++ ok = d_coreSolver.check(e); } - if (!inConflict()) { - // sending assertions to the bitblast solver - d_bitblastSolver.addAssertions(new_assertions, e); ++ Assert (!ok == inConflict()); + if (!inConflict() && !d_coreSolver.isCoreTheory()) { - // sending assertions to the bitblast solver if it's not just core theory - d_bitblastSolver.addAssertions(new_assertions, e); - } else { - // sending assertions to the bitblast solver if it's not just core theory - d_bitblastSolver.addAssertions(new_assertions, EFFORT_STANDARD); ++ ok = d_bitblastSolver.check(e); } - + ++ Assert (!ok == inConflict()); if (inConflict()) { sendConflict(); } diff --cc src/theory/bv/theory_bv.h index ec72f40e1,e38f3568c..3e14584ed --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@@ -25,12 -25,10 +25,11 @@@ #include "context/cdhashset.h" #include "theory/bv/theory_bv_utils.h" #include "util/statistics_registry.h" --#include "context/cdqueue.h" #include "theory/bv/bv_subtheory.h" #include "theory/bv/bv_subtheory_eq.h" +#include "theory/bv/bv_subtheory_core.h" #include "theory/bv/bv_subtheory_bitblast.h" +#include "theory/bv/slicer.h" namespace CVC4 { namespace theory { @@@ -45,13 -43,8 +44,10 @@@ class TheoryBV : public Theory context::CDHashSet d_alreadyPropagatedSet; context::CDHashSet d_sharedTermsSet; + Slicer d_slicer; - - context::CDQueue d_bitblastAssertionsQueue; - BitblastSolver d_bitblastSolver; - CoreSolver d_coreSolver; - EqualitySolver d_equalitySolver; ++ CoreSolver d_coreSolver; + public: TheoryBV(context::Context* c, context::UserContext* u, OutputChannel& out, Valuation valuation, const LogicInfo& logicInfo, QuantifiersEngine* qe);