From: Mathias Preiner Date: Sat, 13 Feb 2021 14:08:37 +0000 (-0800) Subject: Properly set up equality engine for BV bitblast solver. (#5905) X-Git-Tag: cvc5-1.0.0~2280 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=10f6aae991a550e2b970c6234ebdd75742d078dd;p=cvc5.git Properly set up equality engine for BV bitblast solver. (#5905) Theory BV now sets up the default equality engine for BV solvers that do not use their own equality engine like e.g. the BV bitblast solver. This commit also adds the missing equality engine pieces to the BV bitblast solver (getEqualityStatus, explain). --- diff --git a/src/theory/bv/bitblast/lazy_bitblaster.cpp b/src/theory/bv/bitblast/lazy_bitblaster.cpp index f3adc4b21..0c541ba89 100644 --- a/src/theory/bv/bitblast/lazy_bitblaster.cpp +++ b/src/theory/bv/bitblast/lazy_bitblaster.cpp @@ -416,9 +416,9 @@ void TLazyBitblaster::MinisatNotify::notify(prop::SatClause& clause) { lemmab << d_cnf->getNode(clause[i]); } Node lemma = lemmab; - d_bv->d_inferManager.lemma(lemma, InferenceId::UNKNOWN); + d_bv->d_im.lemma(lemma, InferenceId::UNKNOWN); } else { - d_bv->d_inferManager.lemma(d_cnf->getNode(clause[0]), InferenceId::UNKNOWN); + d_bv->d_im.lemma(d_cnf->getNode(clause[0]), InferenceId::UNKNOWN); } } @@ -429,7 +429,7 @@ void TLazyBitblaster::MinisatNotify::spendResource(ResourceManager::Resource r) void TLazyBitblaster::MinisatNotify::safePoint(ResourceManager::Resource r) { - d_bv->d_inferManager.safePoint(r); + d_bv->d_im.safePoint(r); } EqualityStatus TLazyBitblaster::getEqualityStatus(TNode a, TNode b) diff --git a/src/theory/bv/bv_solver.h b/src/theory/bv/bv_solver.h index f4b5a9d11..9f4ac54df 100644 --- a/src/theory/bv/bv_solver.h +++ b/src/theory/bv/bv_solver.h @@ -29,7 +29,7 @@ class BVSolver { public: BVSolver(TheoryState& state, TheoryInferenceManager& inferMgr) - : d_state(state), d_inferManager(inferMgr){}; + : d_state(state), d_im(inferMgr){}; virtual ~BVSolver(){}; @@ -112,7 +112,7 @@ class BVSolver protected: TheoryState& d_state; - TheoryInferenceManager& d_inferManager; + TheoryInferenceManager& d_im; }; } // namespace bv diff --git a/src/theory/bv/bv_solver_bitblast.cpp b/src/theory/bv/bv_solver_bitblast.cpp index ce8bc3645..0b5d4cfef 100644 --- a/src/theory/bv/bv_solver_bitblast.cpp +++ b/src/theory/bv/bv_solver_bitblast.cpp @@ -35,6 +35,8 @@ BVSolverBitblast::BVSolverBitblast(TheoryState* s, d_nullRegistrar(new prop::NullRegistrar()), d_nullContext(new context::Context()), d_facts(s->getSatContext()), + d_invalidateModelCache(s->getSatContext(), true), + d_inSatMode(s->getSatContext(), false), d_epg(pnm ? new EagerProofGenerator(pnm, s->getUserContext(), "") : nullptr) { @@ -82,7 +84,10 @@ void BVSolverBitblast::postCheck(Theory::Effort level) node_map.emplace(lit, fact); } + d_invalidateModelCache.set(true); prop::SatValue val = d_satSolver->solve(assumptions); + d_inSatMode = val == prop::SatValue::SAT_VALUE_TRUE; + Debug("bv-bitblast") << "d_inSatMode: " << d_inSatMode << std::endl; if (val == prop::SatValue::SAT_VALUE_FALSE) { @@ -97,7 +102,7 @@ void BVSolverBitblast::postCheck(Theory::Effort level) } NodeManager* nm = NodeManager::currentNM(); - d_inferManager.conflict(nm->mkAnd(conflict), InferenceId::UNKNOWN); + d_im.conflict(nm->mkAnd(conflict), InferenceId::UNKNOWN); } } @@ -108,6 +113,12 @@ bool BVSolverBitblast::preNotifyFact( return false; // Return false to enable equality engine reasoning in Theory. } +TrustNode BVSolverBitblast::explain(TNode n) +{ + Debug("bv-bitblast") << "explain called on " << n << std::endl; + return d_im.explainLit(n); +} + bool BVSolverBitblast::collectModelValues(TheoryModel* m, const std::set& termSet) { @@ -118,7 +129,7 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m, continue; } - Node value = getValueFromSatSolver(term); + Node value = getValueFromSatSolver(term, true); Assert(value.isConst()); if (!m->assertEquality(term, value, true)) { @@ -128,12 +139,37 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m, return true; } -Node BVSolverBitblast::getValueFromSatSolver(TNode node) +EqualityStatus BVSolverBitblast::getEqualityStatus(TNode a, TNode b) +{ + Debug("bv-bitblast") << "getEqualityStatus on " << a << " and " << b + << std::endl; + if (!d_inSatMode) + { + Debug("bv-bitblast") << EQUALITY_UNKNOWN << std::endl; + return EQUALITY_UNKNOWN; + } + Node value_a = getValue(a); + Node value_b = getValue(b); + + if (value_a == value_b) + { + Debug("bv-bitblast") << EQUALITY_TRUE_IN_MODEL << std::endl; + return EQUALITY_TRUE_IN_MODEL; + } + Debug("bv-bitblast") << EQUALITY_FALSE_IN_MODEL << std::endl; + return EQUALITY_FALSE_IN_MODEL; +} + +Node BVSolverBitblast::getValueFromSatSolver(TNode node, bool initialize) { - /* If node was not bit-blasted return zero-initialized bit-vector. */ + if (node.isConst()) + { + return node; + } + if (!d_bitblaster->hasBBTerm(node)) { - return utils::mkConst(utils::getSize(node), 0u); + return initialize ? utils::mkConst(utils::getSize(node), 0u) : Node(); } std::vector bits; @@ -149,6 +185,7 @@ Node BVSolverBitblast::getValueFromSatSolver(TNode node) } else { + if (!initialize) return Node(); bit = zero; } value = value * 2 + bit; @@ -156,6 +193,76 @@ Node BVSolverBitblast::getValueFromSatSolver(TNode node) return utils::mkConst(bits.size(), value); } +Node BVSolverBitblast::getValue(TNode node) +{ + if (d_invalidateModelCache.get()) + { + d_modelCache.clear(); + } + d_invalidateModelCache.set(false); + + std::vector visit; + + TNode cur; + visit.push_back(node); + do + { + cur = visit.back(); + visit.pop_back(); + + auto it = d_modelCache.find(cur); + if (it != d_modelCache.end() && !it->second.isNull()) + { + continue; + } + + if (d_bitblaster->hasBBTerm(cur)) + { + Node value = getValueFromSatSolver(cur, false); + if (value.isConst()) + { + d_modelCache[cur] = value; + continue; + } + } + if (Theory::isLeafOf(cur, theory::THEORY_BV)) + { + Node value = getValueFromSatSolver(cur, true); + d_modelCache[cur] = value; + continue; + } + + if (it == d_modelCache.end()) + { + visit.push_back(cur); + d_modelCache.emplace(cur, Node()); + visit.insert(visit.end(), cur.begin(), cur.end()); + } + else if (it->second.isNull()) + { + NodeBuilder<> nb(cur.getKind()); + if (cur.getMetaKind() == kind::metakind::PARAMETERIZED) + { + nb << cur.getOperator(); + } + + std::unordered_map::iterator iit; + for (const TNode& child : cur) + { + iit = d_modelCache.find(child); + Assert(iit != d_modelCache.end()); + Assert(iit->second.isConst()); + nb << iit->second; + } + it->second = Rewriter::rewrite(nb.constructNode()); + } + } while (!visit.empty()); + + auto it = d_modelCache.find(node); + Assert(it != d_modelCache.end()); + return it->second; +} + } // namespace bv } // namespace theory } // namespace CVC4 diff --git a/src/theory/bv/bv_solver_bitblast.h b/src/theory/bv/bv_solver_bitblast.h index df0f2e085..d9b4a26e9 100644 --- a/src/theory/bv/bv_solver_bitblast.h +++ b/src/theory/bv/bv_solver_bitblast.h @@ -56,6 +56,8 @@ class BVSolverBitblast : public BVSolver bool isPrereg, bool isInternal) override; + TrustNode explain(TNode n) override; + std::string identify() const override { return "BVSolverBitblast"; }; Theory::PPAssertStatus ppAssert( @@ -64,17 +66,39 @@ class BVSolverBitblast : public BVSolver return Theory::PPAssertStatus::PP_ASSERT_STATUS_UNSOLVED; } + EqualityStatus getEqualityStatus(TNode a, TNode b) override; + bool collectModelValues(TheoryModel* m, const std::set& termSet) override; private: - /** Get value of `node` from SAT solver. */ - Node getValueFromSatSolver(TNode node); + /** + * Get value of `node` from SAT solver. + * + * The `initialize` flag indicates whether bits should be zero-initialized + * if they were not bit-blasted yet. + */ + Node getValueFromSatSolver(TNode node, bool initialize); + + /** + * Get the current value of `node`. + * + * Computes the value if `node` was not yet bit-blasted. + */ + Node getValue(TNode node); + + /** + * Cache for getValue() calls. + * + * Is cleared at the beginning of a getValue() call if the + * `d_invalidateModelCache` flag is set to true. + */ + std::unordered_map d_modelCache; /** Bit-blaster used to bit-blast atoms/terms. */ std::unique_ptr d_bitblaster; - /** Used for initializing CnfStream> */ + /** Used for initializing `d_cnfStream`. */ std::unique_ptr d_nullRegistrar; std::unique_ptr d_nullContext; @@ -86,6 +110,12 @@ class BVSolverBitblast : public BVSolver /** Facts sent to this solver. */ context::CDList d_facts; + /** Flag indicating whether `d_modelCache` should be invalidated. */ + context::CDO d_invalidateModelCache; + + /** Indicates whether the last check() call was satisfiable. */ + context::CDO d_inSatMode; + /** Proof generator that manages proofs for lemmas generated by this class. */ std::unique_ptr d_epg; diff --git a/src/theory/bv/bv_solver_lazy.cpp b/src/theory/bv/bv_solver_lazy.cpp index 9c44e32f1..0e81d0648 100644 --- a/src/theory/bv/bv_solver_lazy.cpp +++ b/src/theory/bv/bv_solver_lazy.cpp @@ -42,7 +42,7 @@ BVSolverLazy::BVSolverLazy(TheoryBV& bv, context::UserContext* u, ProofNodeManager* pnm, std::string name) - : BVSolver(bv.d_state, bv.d_inferMgr), + : BVSolver(bv.d_state, bv.d_im), d_bv(bv), d_context(c), d_alreadyPropagatedSet(c), @@ -119,7 +119,7 @@ void BVSolverLazy::finishInit() void BVSolverLazy::spendResource(ResourceManager::Resource r) { - d_inferManager.spendResource(r); + d_im.spendResource(r); } BVSolverLazy::Statistics::Statistics() @@ -196,7 +196,7 @@ void BVSolverLazy::sendConflict() { Debug("bitvector") << indent() << "BVSolverLazy::check(): conflict " << d_conflictNode << std::endl; - d_inferManager.conflict(d_conflictNode, InferenceId::UNKNOWN); + d_im.conflict(d_conflictNode, InferenceId::UNKNOWN); d_statistics.d_avgConflictSize.addEntry(d_conflictNode.getNumChildren()); d_conflictNode = Node::null(); } @@ -287,11 +287,11 @@ void BVSolverLazy::check(Theory::Effort e) { if (assertions.size() == 1) { - d_inferManager.conflict(assertions[0], InferenceId::UNKNOWN); + d_im.conflict(assertions[0], InferenceId::UNKNOWN); return; } Node conflict = utils::mkAnd(assertions); - d_inferManager.conflict(conflict, InferenceId::UNKNOWN); + d_im.conflict(conflict, InferenceId::UNKNOWN); return; } return; @@ -426,7 +426,7 @@ void BVSolverLazy::propagate(Theory::Effort e) { Debug("bitvector::propagate") << "BVSolverLazy:: propagating " << literal << "\n"; - ok = d_inferManager.propagateLit(literal); + ok = d_im.propagateLit(literal); } } @@ -670,7 +670,7 @@ bool BVSolverLazy::storePropagation(TNode literal, SubTheory subtheory) constexpr bool ok = true; if (subtheory == SUB_CORE) { - d_inferManager.propagateLit(literal); + d_im.propagateLit(literal); if (!ok) { setConflict(); diff --git a/src/theory/bv/bv_solver_lazy.h b/src/theory/bv/bv_solver_lazy.h index da5f1cbf8..46d01d129 100644 --- a/src/theory/bv/bv_solver_lazy.h +++ b/src/theory/bv/bv_solver_lazy.h @@ -203,7 +203,7 @@ class BVSolverLazy : public BVSolver void lemma(TNode node) { - d_inferManager.lemma(node, InferenceId::UNKNOWN); + d_im.lemma(node, InferenceId::UNKNOWN); d_lemmasAdded = true; } diff --git a/src/theory/bv/bv_solver_simple.cpp b/src/theory/bv/bv_solver_simple.cpp index c4a404041..02196a4ed 100644 --- a/src/theory/bv/bv_solver_simple.cpp +++ b/src/theory/bv/bv_solver_simple.cpp @@ -93,12 +93,12 @@ void BVSolverSimple::addBBLemma(TNode fact) if (d_epg == nullptr) { - d_inferManager.lemma(lemma, InferenceId::UNKNOWN); + d_im.lemma(lemma, InferenceId::UNKNOWN); } else { TrustNode tlem = d_epg->mkTrustNode(lemma, PfRule::BV_BITBLAST, {}, {fact}); - d_inferManager.trustedLemma(tlem, InferenceId::UNKNOWN); + d_im.trustedLemma(tlem, InferenceId::UNKNOWN); } } @@ -123,13 +123,13 @@ bool BVSolverSimple::preNotifyFact( if (d_epg == nullptr) { - d_inferManager.lemma(lemma, InferenceId::UNKNOWN); + d_im.lemma(lemma, InferenceId::UNKNOWN); } else { TrustNode tlem = d_epg->mkTrustNode(lemma, PfRule::BV_EAGER_ATOM, {}, {fact}); - d_inferManager.trustedLemma(tlem, InferenceId::UNKNOWN); + d_im.trustedLemma(tlem, InferenceId::UNKNOWN); } std::unordered_set bv_atoms; diff --git a/src/theory/bv/bv_subtheory_core.cpp b/src/theory/bv/bv_subtheory_core.cpp index 87cc0bc4d..b83906688 100644 --- a/src/theory/bv/bv_subtheory_core.cpp +++ b/src/theory/bv/bv_subtheory_core.cpp @@ -188,8 +188,7 @@ void CoreSolver::explain(TNode literal, std::vector& assumptions) { bool CoreSolver::check(Theory::Effort e) { Trace("bitvector::core") << "CoreSolver::check \n"; - d_bv->d_inferManager.spendResource( - ResourceManager::Resource::TheoryCheckStep); + d_bv->d_im.spendResource(ResourceManager::Resource::TheoryCheckStep); d_checkCalled = true; Assert(!d_bv->inConflict()); @@ -560,7 +559,7 @@ bool CoreSolver::doExtfInferences(std::vector& terms) nm->mkNode(kind::LT, n, max)); Trace("bv-extf-lemma") << "BV extf lemma (range) : " << lem << std::endl; - d_bv->d_inferManager.lemma(lem, InferenceId::UNKNOWN); + d_bv->d_im.lemma(lem, InferenceId::UNKNOWN); sentLemma = true; } } @@ -609,7 +608,7 @@ bool CoreSolver::doExtfInferences(std::vector& terms) // (bv2nat ((_ int2bv w) x)) == x + k*2^w for some k Trace("bv-extf-lemma") << "BV extf lemma (collapse) : " << lem << std::endl; - d_bv->d_inferManager.lemma(lem, InferenceId::UNKNOWN); + d_bv->d_im.lemma(lem, InferenceId::UNKNOWN); sentLemma = true; } } diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index b27bd04e1..f6e056f42 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -39,12 +39,13 @@ TheoryBV::TheoryBV(context::Context* c, d_ufRemByZero(), d_rewriter(), d_state(c, u, valuation), - d_inferMgr(*this, d_state, nullptr) + d_im(*this, d_state, nullptr), + d_notify(d_im) { switch (options::bvSolver()) { case options::BVSolver::BITBLAST: - d_internal.reset(new BVSolverBitblast(&d_state, d_inferMgr, pnm)); + d_internal.reset(new BVSolverBitblast(&d_state, d_im, pnm)); break; case options::BVSolver::LAZY: @@ -53,10 +54,10 @@ TheoryBV::TheoryBV(context::Context* c, default: AlwaysAssert(options::bvSolver() == options::BVSolver::SIMPLE); - d_internal.reset(new BVSolverSimple(&d_state, d_inferMgr, pnm)); + d_internal.reset(new BVSolverSimple(&d_state, d_im, pnm)); } d_theoryState = &d_state; - d_inferManager = &d_inferMgr; + d_inferManager = &d_im; } TheoryBV::~TheoryBV() {} @@ -65,7 +66,16 @@ TheoryRewriter* TheoryBV::getTheoryRewriter() { return &d_rewriter; } bool TheoryBV::needsEqualityEngine(EeSetupInfo& esi) { - return d_internal->needsEqualityEngine(esi); + bool need_ee = d_internal->needsEqualityEngine(esi); + + /* Set up default notify class for equality engine. */ + if (need_ee && esi.d_notify == nullptr) + { + esi.d_notify = &d_notify; + esi.d_name = "theory::bv::ee"; + } + + return need_ee; } void TheoryBV::finishInit() @@ -194,6 +204,19 @@ TrustNode TheoryBV::expandDefinition(Node node) void TheoryBV::preRegisterTerm(TNode node) { d_internal->preRegisterTerm(node); + + eq::EqualityEngine* ee = getEqualityEngine(); + if (ee) + { + if (node.getKind() == kind::EQUAL) + { + ee->addTriggerPredicate(node); + } + else + { + ee->addTerm(node); + } + } } bool TheoryBV::preCheck(Effort e) { return d_internal->preCheck(e); } diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 306b1ff93..2aa722e48 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -23,6 +23,7 @@ #include "theory/bv/theory_bv_rewriter.h" #include "theory/theory.h" +#include "theory/theory_eq_notify.h" namespace CVC4 { namespace theory { @@ -130,7 +131,10 @@ class TheoryBV : public Theory TheoryState d_state; /** A (default) theory inference manager. */ - TheoryInferenceManager d_inferMgr; + TheoryInferenceManager d_im; + + /** The notify class for equality engine. */ + TheoryEqNotifyClass d_notify; }; /* class TheoryBV */