From 56b5ebfed26283db73c55bbcc9391d2e06897727 Mon Sep 17 00:00:00 2001 From: Mathias Preiner Date: Tue, 27 Jul 2021 13:23:32 -0700 Subject: [PATCH] bv: Refactor getEqualityStatus and use for both bitblasting solvers. (#6933) This commit refactors the getEqualityStatus handling for bitblast and bitblast-internal. --- src/theory/bv/bitblast/proof_bitblaster.cpp | 5 + src/theory/bv/bitblast/proof_bitblaster.h | 2 + src/theory/bv/bv_solver.h | 8 ++ src/theory/bv/bv_solver_bitblast.cpp | 100 +---------------- src/theory/bv/bv_solver_bitblast.h | 33 +----- src/theory/bv/bv_solver_bitblast_internal.cpp | 34 ++++++ src/theory/bv/bv_solver_bitblast_internal.h | 4 + src/theory/bv/theory_bv.cpp | 103 +++++++++++++++++- src/theory/bv/theory_bv.h | 13 +++ 9 files changed, 174 insertions(+), 128 deletions(-) diff --git a/src/theory/bv/bitblast/proof_bitblaster.cpp b/src/theory/bv/bitblast/proof_bitblaster.cpp index f714ffda9..43618974b 100644 --- a/src/theory/bv/bitblast/proof_bitblaster.cpp +++ b/src/theory/bv/bitblast/proof_bitblaster.cpp @@ -172,6 +172,11 @@ Node BBProof::getStoredBBAtom(TNode node) return d_bb->getStoredBBAtom(node); } +void BBProof::getBBTerm(TNode node, Bits& bits) const +{ + d_bb->getBBTerm(node, bits); +} + bool BBProof::collectModelValues(TheoryModel* m, const std::set& relevantTerms) { diff --git a/src/theory/bv/bitblast/proof_bitblaster.h b/src/theory/bv/bitblast/proof_bitblaster.h index f6aa71f21..bc99d27bf 100644 --- a/src/theory/bv/bitblast/proof_bitblaster.h +++ b/src/theory/bv/bitblast/proof_bitblaster.h @@ -43,6 +43,8 @@ class BBProof bool hasBBTerm(TNode node) const; /** Get bit-blasted node stored for atom. */ Node getStoredBBAtom(TNode node); + /** Get bit-blasted bits stored for node. */ + void getBBTerm(TNode node, Bits& bits) const; /** Collect model values for all relevant terms given in 'relevantTerms'. */ bool collectModelValues(TheoryModel* m, const std::set& relevantTerms); diff --git a/src/theory/bv/bv_solver.h b/src/theory/bv/bv_solver.h index 6ccc6c7c1..c959fb648 100644 --- a/src/theory/bv/bv_solver.h +++ b/src/theory/bv/bv_solver.h @@ -102,6 +102,14 @@ class BVSolver return EqualityStatus::EQUALITY_UNKNOWN; } + /** + * Get the current value of `node`. + * + * The `initialize` flag indicates whether bits should be zero-initialized + * if they don't have a value yet. + */ + virtual Node getValue(TNode node, bool initialize) { return Node::null(); } + /** Called by abstraction preprocessing pass. */ virtual bool applyAbstraction(const std::vector& assertions, std::vector& new_assertions) diff --git a/src/theory/bv/bv_solver_bitblast.cpp b/src/theory/bv/bv_solver_bitblast.cpp index 5b70fb3a2..ecd42e4a0 100644 --- a/src/theory/bv/bv_solver_bitblast.cpp +++ b/src/theory/bv/bv_solver_bitblast.cpp @@ -119,8 +119,6 @@ BVSolverBitblast::BVSolverBitblast(TheoryState* s, d_bbInputFacts(s->getSatContext()), d_assumptions(s->getSatContext()), d_assertions(s->getSatContext()), - d_invalidateModelCache(s->getSatContext(), true), - d_inSatMode(s->getSatContext(), false), d_epg(pnm ? new EagerProofGenerator(pnm, s->getUserContext(), "") : nullptr), d_factLiteralCache(s->getSatContext()), @@ -208,12 +206,9 @@ void BVSolverBitblast::postCheck(Theory::Effort level) d_assumptions.push_back(d_factLiteralCache[fact]); } - d_invalidateModelCache.set(true); std::vector assumptions(d_assumptions.begin(), d_assumptions.end()); prop::SatValue val = d_satSolver->solve(assumptions); - d_inSatMode = val == prop::SatValue::SAT_VALUE_TRUE; - Debug("bv-bitblast") << "d_inSatMode: " << d_inSatMode << std::endl; if (val == prop::SatValue::SAT_VALUE_FALSE) { @@ -298,7 +293,7 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m, continue; } - Node value = getValueFromSatSolver(term, true); + Node value = getValue(term, true); Assert(value.isConst()); if (!m->assertEquality(term, value, true)) { @@ -330,27 +325,6 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m, return true; } -EqualityStatus BVSolverBitblast::getEqualityStatus(TNode a, TNode b) -{ - Debug("bv-bitblast") << "getEqualityStatus on " << a << " and " << b - << std::endl; - if (!d_inSatMode) - { - Debug("bv-bitblast") << EQUALITY_UNKNOWN << std::endl; - return EQUALITY_UNKNOWN; - } - Node value_a = getValue(a); - Node value_b = getValue(b); - - if (value_a == value_b) - { - Debug("bv-bitblast") << EQUALITY_TRUE_IN_MODEL << std::endl; - return EQUALITY_TRUE_IN_MODEL; - } - Debug("bv-bitblast") << EQUALITY_FALSE_IN_MODEL << std::endl; - return EQUALITY_FALSE_IN_MODEL; -} - void BVSolverBitblast::initSatSolver() { switch (options::bvSatSolver()) @@ -372,7 +346,7 @@ void BVSolverBitblast::initSatSolver() "theory::bv::BVSolverBitblast")); } -Node BVSolverBitblast::getValueFromSatSolver(TNode node, bool initialize) +Node BVSolverBitblast::getValue(TNode node, bool initialize) { if (node.isConst()) { @@ -405,76 +379,6 @@ Node BVSolverBitblast::getValueFromSatSolver(TNode node, bool initialize) return utils::mkConst(bits.size(), value); } -Node BVSolverBitblast::getValue(TNode node) -{ - if (d_invalidateModelCache.get()) - { - d_modelCache.clear(); - } - d_invalidateModelCache.set(false); - - std::vector visit; - - TNode cur; - visit.push_back(node); - do - { - cur = visit.back(); - visit.pop_back(); - - auto it = d_modelCache.find(cur); - if (it != d_modelCache.end() && !it->second.isNull()) - { - continue; - } - - if (d_bitblaster->hasBBTerm(cur)) - { - Node value = getValueFromSatSolver(cur, false); - if (value.isConst()) - { - d_modelCache[cur] = value; - continue; - } - } - if (Theory::isLeafOf(cur, theory::THEORY_BV)) - { - Node value = getValueFromSatSolver(cur, true); - d_modelCache[cur] = value; - continue; - } - - if (it == d_modelCache.end()) - { - visit.push_back(cur); - d_modelCache.emplace(cur, Node()); - visit.insert(visit.end(), cur.begin(), cur.end()); - } - else if (it->second.isNull()) - { - NodeBuilder nb(cur.getKind()); - if (cur.getMetaKind() == kind::metakind::PARAMETERIZED) - { - nb << cur.getOperator(); - } - - std::unordered_map::iterator iit; - for (const TNode& child : cur) - { - iit = d_modelCache.find(child); - Assert(iit != d_modelCache.end()); - Assert(iit->second.isConst()); - nb << iit->second; - } - it->second = Rewriter::rewrite(nb.constructNode()); - } - } while (!visit.empty()); - - auto it = d_modelCache.find(node); - Assert(it != d_modelCache.end()); - return it->second; -} - void BVSolverBitblast::handleEagerAtom(TNode fact, bool assertFact) { Assert(fact.getKind() == kind::BITVECTOR_EAGER_ATOM); diff --git a/src/theory/bv/bv_solver_bitblast.h b/src/theory/bv/bv_solver_bitblast.h index 8dee3c2c4..3f4ab5025 100644 --- a/src/theory/bv/bv_solver_bitblast.h +++ b/src/theory/bv/bv_solver_bitblast.h @@ -63,31 +63,22 @@ class BVSolverBitblast : public BVSolver std::string identify() const override { return "BVSolverBitblast"; }; - EqualityStatus getEqualityStatus(TNode a, TNode b) override; - void computeRelevantTerms(std::set& termSet) override; bool collectModelValues(TheoryModel* m, const std::set& termSet) override; - private: - /** Initialize SAT solver and CNF stream. */ - void initSatSolver(); - /** - * Get value of `node` from SAT solver. + * Get the current value of `node`. * * The `initialize` flag indicates whether bits should be zero-initialized * if they were not bit-blasted yet. */ - Node getValueFromSatSolver(TNode node, bool initialize); + Node getValue(TNode node, bool initialize) override; - /** - * Get the current value of `node`. - * - * Computes the value if `node` was not yet bit-blasted. - */ - Node getValue(TNode node); + private: + /** Initialize SAT solver and CNF stream. */ + void initSatSolver(); /** * Handle BITVECTOR_EAGER_ATOM atoms and assert/assume to CnfStream. @@ -97,14 +88,6 @@ class BVSolverBitblast : public BVSolver */ void handleEagerAtom(TNode fact, bool assertFact); - /** - * Cache for getValue() calls. - * - * Is cleared at the beginning of a getValue() call if the - * `d_invalidateModelCache` flag is set to true. - */ - std::unordered_map d_modelCache; - /** Bit-blaster used to bit-blast atoms/terms. */ std::unique_ptr d_bitblaster; @@ -137,12 +120,6 @@ class BVSolverBitblast : public BVSolver /** Stores the current input assertions. */ context::CDList d_assertions; - /** Flag indicating whether `d_modelCache` should be invalidated. */ - context::CDO d_invalidateModelCache; - - /** Indicates whether the last check() call was satisfiable. */ - context::CDO d_inSatMode; - /** Proof generator that manages proofs for lemmas generated by this class. */ std::unique_ptr d_epg; diff --git a/src/theory/bv/bv_solver_bitblast_internal.cpp b/src/theory/bv/bv_solver_bitblast_internal.cpp index bd47cc45e..ef4f3559b 100644 --- a/src/theory/bv/bv_solver_bitblast_internal.cpp +++ b/src/theory/bv/bv_solver_bitblast_internal.cpp @@ -147,6 +147,40 @@ bool BVSolverBitblastInternal::collectModelValues(TheoryModel* m, return d_bitblaster->collectModelValues(m, termSet); } +Node BVSolverBitblastInternal::getValue(TNode node, bool initialize) +{ + if (node.isConst()) + { + return node; + } + + if (!d_bitblaster->hasBBTerm(node)) + { + return initialize ? utils::mkConst(utils::getSize(node), 0u) : Node(); + } + + Valuation& val = d_state.getValuation(); + + std::vector bits; + d_bitblaster->getBBTerm(node, bits); + Integer value(0), one(1), zero(0), bit; + for (size_t i = 0, size = bits.size(), j = size - 1; i < size; ++i, --j) + { + bool satValue; + if (val.hasSatValue(bits[j], satValue)) + { + bit = satValue ? one : zero; + } + else + { + if (!initialize) return Node(); + bit = zero; + } + value = value * 2 + bit; + } + return utils::mkConst(bits.size(), value); +} + BVProofRuleChecker* BVSolverBitblastInternal::getProofChecker() { return &d_checker; diff --git a/src/theory/bv/bv_solver_bitblast_internal.h b/src/theory/bv/bv_solver_bitblast_internal.h index 8a1886084..1ec3ec1fe 100644 --- a/src/theory/bv/bv_solver_bitblast_internal.h +++ b/src/theory/bv/bv_solver_bitblast_internal.h @@ -42,6 +42,8 @@ class BVSolverBitblastInternal : public BVSolver ProofNodeManager* pnm); ~BVSolverBitblastInternal() = default; + bool needsEqualityEngine(EeSetupInfo& esi) override { return true; } + void preRegisterTerm(TNode n) override {} bool preNotifyFact(TNode atom, @@ -55,6 +57,8 @@ class BVSolverBitblastInternal : public BVSolver bool collectModelValues(TheoryModel* m, const std::set& termSet) override; + Node getValue(TNode node, bool initialize) override; + /** get the proof checker of this theory */ BVProofRuleChecker* getProofChecker(); diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 37881f9b2..547d24b23 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -43,6 +43,7 @@ TheoryBV::TheoryBV(context::Context* c, d_state(c, u, valuation), d_im(*this, d_state, nullptr, "theory::bv::"), d_notify(d_im), + d_invalidateModelCache(c, true), d_stats("theory::bv::") { switch (options::bvSolver()) @@ -158,7 +159,11 @@ void TheoryBV::preRegisterTerm(TNode node) bool TheoryBV::preCheck(Effort e) { return d_internal->preCheck(e); } -void TheoryBV::postCheck(Effort e) { d_internal->postCheck(e); } +void TheoryBV::postCheck(Effort e) +{ + d_invalidateModelCache = true; + d_internal->postCheck(e); +} bool TheoryBV::preNotifyFact( TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal) @@ -282,7 +287,27 @@ void TheoryBV::presolve() { d_internal->presolve(); } EqualityStatus TheoryBV::getEqualityStatus(TNode a, TNode b) { - return d_internal->getEqualityStatus(a, b); + EqualityStatus status = d_internal->getEqualityStatus(a, b); + + if (status == EqualityStatus::EQUALITY_UNKNOWN) + { + Node value_a = getValue(a); + Node value_b = getValue(b); + + if (value_a.isNull() || value_b.isNull()) + { + return status; + } + + if (value_a == value_b) + { + Debug("theory-bv") << EQUALITY_TRUE_IN_MODEL << std::endl; + return EQUALITY_TRUE_IN_MODEL; + } + Debug("theory-bv") << EQUALITY_FALSE_IN_MODEL << std::endl; + return EQUALITY_FALSE_IN_MODEL; + } + return status; } TrustNode TheoryBV::explain(TNode node) { return d_internal->explain(node); } @@ -303,6 +328,80 @@ bool TheoryBV::applyAbstraction(const std::vector& assertions, return d_internal->applyAbstraction(assertions, new_assertions); } +Node TheoryBV::getValue(TNode node) +{ + if (d_invalidateModelCache.get()) + { + d_modelCache.clear(); + } + d_invalidateModelCache.set(false); + + std::vector visit; + + TNode cur; + visit.push_back(node); + do + { + cur = visit.back(); + visit.pop_back(); + + auto it = d_modelCache.find(cur); + if (it != d_modelCache.end() && !it->second.isNull()) + { + continue; + } + + if (cur.isConst()) + { + d_modelCache[cur] = cur; + continue; + } + + Node value = d_internal->getValue(cur, false); + if (value.isConst()) + { + d_modelCache[cur] = value; + continue; + } + + if (Theory::isLeafOf(cur, theory::THEORY_BV)) + { + value = d_internal->getValue(cur, true); + d_modelCache[cur] = value; + continue; + } + + if (it == d_modelCache.end()) + { + visit.push_back(cur); + d_modelCache.emplace(cur, Node()); + visit.insert(visit.end(), cur.begin(), cur.end()); + } + else if (it->second.isNull()) + { + NodeBuilder nb(cur.getKind()); + if (cur.getMetaKind() == kind::metakind::PARAMETERIZED) + { + nb << cur.getOperator(); + } + + std::unordered_map::iterator iit; + for (const TNode& child : cur) + { + iit = d_modelCache.find(child); + Assert(iit != d_modelCache.end()); + Assert(iit->second.isConst()); + nb << iit->second; + } + it->second = Rewriter::rewrite(nb.constructNode()); + } + } while (!visit.empty()); + + auto it = d_modelCache.find(node); + Assert(it != d_modelCache.end()); + return it->second; +} + TheoryBV::Statistics::Statistics(const std::string& name) : d_solveSubstitutions( smtStatisticsRegistry().registerInt(name + "NumSolveSubstitutions")) diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index f2d6bb47e..da44d7022 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -109,6 +109,8 @@ class TheoryBV : public Theory private: void notifySharedTerm(TNode t) override; + Node getValue(TNode node); + /** Internal BV solver. */ std::unique_ptr d_internal; @@ -124,6 +126,17 @@ class TheoryBV : public Theory /** The notify class for equality engine. */ TheoryEqNotifyClass d_notify; + /** Flag indicating whether `d_modelCache` should be invalidated. */ + context::CDO d_invalidateModelCache; + + /** + * Cache for getValue() calls. + * + * Is cleared at the beginning of a getValue() call if the + * `d_invalidateModelCache` flag is set to true. + */ + std::unordered_map d_modelCache; + /** TheoryBV statistics. */ struct Statistics { -- 2.30.2