bv: Refactor getEqualityStatus and use for both bitblasting solvers. (#6933)
authorMathias Preiner <mathias.preiner@gmail.com>
Tue, 27 Jul 2021 20:23:32 +0000 (13:23 -0700)
committerGitHub <noreply@github.com>
Tue, 27 Jul 2021 20:23:32 +0000 (20:23 +0000)
This commit refactors the getEqualityStatus handling for bitblast and bitblast-internal.

src/theory/bv/bitblast/proof_bitblaster.cpp
src/theory/bv/bitblast/proof_bitblaster.h
src/theory/bv/bv_solver.h
src/theory/bv/bv_solver_bitblast.cpp
src/theory/bv/bv_solver_bitblast.h
src/theory/bv/bv_solver_bitblast_internal.cpp
src/theory/bv/bv_solver_bitblast_internal.h
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv.h

index f714ffda9ee814b624dc8c6a7f4eed040a7c6c10..43618974b80307f8ba8a75910edfcd31e0306a07 100644 (file)
@@ -172,6 +172,11 @@ Node BBProof::getStoredBBAtom(TNode node)
   return d_bb->getStoredBBAtom(node);
 }
 
+void BBProof::getBBTerm(TNode node, Bits& bits) const
+{
+  d_bb->getBBTerm(node, bits);
+}
+
 bool BBProof::collectModelValues(TheoryModel* m,
                                  const std::set<Node>& relevantTerms)
 {
index f6aa71f218c586752d9dc2b135810eaca66894a8..bc99d27bff93f328ce28fe70b8d8cf6a2ccfe671 100644 (file)
@@ -43,6 +43,8 @@ class BBProof
   bool hasBBTerm(TNode node) const;
   /** Get bit-blasted node stored for atom. */
   Node getStoredBBAtom(TNode node);
+  /** Get bit-blasted bits stored for node. */
+  void getBBTerm(TNode node, Bits& bits) const;
   /** Collect model values for all relevant terms given in 'relevantTerms'. */
   bool collectModelValues(TheoryModel* m, const std::set<Node>& relevantTerms);
 
index 6ccc6c7c17890801e967c6630a37220d49d52866..c959fb64891ac73c70896ebeafc45979427f4d4e 100644 (file)
@@ -102,6 +102,14 @@ class BVSolver
     return EqualityStatus::EQUALITY_UNKNOWN;
   }
 
+  /**
+   * Get the current value of `node`.
+   *
+   * The `initialize` flag indicates whether bits should be zero-initialized
+   * if they don't have a value yet.
+   */
+  virtual Node getValue(TNode node, bool initialize) { return Node::null(); }
+
   /** Called by abstraction preprocessing pass. */
   virtual bool applyAbstraction(const std::vector<Node>& assertions,
                                 std::vector<Node>& new_assertions)
index 5b70fb3a26fabf4b554bde27c9b661194505d51e..ecd42e4a0a7cf4750657af5bd1c34543455d77a9 100644 (file)
@@ -119,8 +119,6 @@ BVSolverBitblast::BVSolverBitblast(TheoryState* s,
       d_bbInputFacts(s->getSatContext()),
       d_assumptions(s->getSatContext()),
       d_assertions(s->getSatContext()),
-      d_invalidateModelCache(s->getSatContext(), true),
-      d_inSatMode(s->getSatContext(), false),
       d_epg(pnm ? new EagerProofGenerator(pnm, s->getUserContext(), "")
                 : nullptr),
       d_factLiteralCache(s->getSatContext()),
@@ -208,12 +206,9 @@ void BVSolverBitblast::postCheck(Theory::Effort level)
     d_assumptions.push_back(d_factLiteralCache[fact]);
   }
 
-  d_invalidateModelCache.set(true);
   std::vector<prop::SatLiteral> assumptions(d_assumptions.begin(),
                                             d_assumptions.end());
   prop::SatValue val = d_satSolver->solve(assumptions);
-  d_inSatMode = val == prop::SatValue::SAT_VALUE_TRUE;
-  Debug("bv-bitblast") << "d_inSatMode: " << d_inSatMode << std::endl;
 
   if (val == prop::SatValue::SAT_VALUE_FALSE)
   {
@@ -298,7 +293,7 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m,
       continue;
     }
 
-    Node value = getValueFromSatSolver(term, true);
+    Node value = getValue(term, true);
     Assert(value.isConst());
     if (!m->assertEquality(term, value, true))
     {
@@ -330,27 +325,6 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m,
   return true;
 }
 
-EqualityStatus BVSolverBitblast::getEqualityStatus(TNode a, TNode b)
-{
-  Debug("bv-bitblast") << "getEqualityStatus on " << a << " and " << b
-                       << std::endl;
-  if (!d_inSatMode)
-  {
-    Debug("bv-bitblast") << EQUALITY_UNKNOWN << std::endl;
-    return EQUALITY_UNKNOWN;
-  }
-  Node value_a = getValue(a);
-  Node value_b = getValue(b);
-
-  if (value_a == value_b)
-  {
-    Debug("bv-bitblast") << EQUALITY_TRUE_IN_MODEL << std::endl;
-    return EQUALITY_TRUE_IN_MODEL;
-  }
-  Debug("bv-bitblast") << EQUALITY_FALSE_IN_MODEL << std::endl;
-  return EQUALITY_FALSE_IN_MODEL;
-}
-
 void BVSolverBitblast::initSatSolver()
 {
   switch (options::bvSatSolver())
@@ -372,7 +346,7 @@ void BVSolverBitblast::initSatSolver()
                                         "theory::bv::BVSolverBitblast"));
 }
 
-Node BVSolverBitblast::getValueFromSatSolver(TNode node, bool initialize)
+Node BVSolverBitblast::getValue(TNode node, bool initialize)
 {
   if (node.isConst())
   {
@@ -405,76 +379,6 @@ Node BVSolverBitblast::getValueFromSatSolver(TNode node, bool initialize)
   return utils::mkConst(bits.size(), value);
 }
 
-Node BVSolverBitblast::getValue(TNode node)
-{
-  if (d_invalidateModelCache.get())
-  {
-    d_modelCache.clear();
-  }
-  d_invalidateModelCache.set(false);
-
-  std::vector<TNode> visit;
-
-  TNode cur;
-  visit.push_back(node);
-  do
-  {
-    cur = visit.back();
-    visit.pop_back();
-
-    auto it = d_modelCache.find(cur);
-    if (it != d_modelCache.end() && !it->second.isNull())
-    {
-      continue;
-    }
-
-    if (d_bitblaster->hasBBTerm(cur))
-    {
-      Node value = getValueFromSatSolver(cur, false);
-      if (value.isConst())
-      {
-        d_modelCache[cur] = value;
-        continue;
-      }
-    }
-    if (Theory::isLeafOf(cur, theory::THEORY_BV))
-    {
-      Node value = getValueFromSatSolver(cur, true);
-      d_modelCache[cur] = value;
-      continue;
-    }
-
-    if (it == d_modelCache.end())
-    {
-      visit.push_back(cur);
-      d_modelCache.emplace(cur, Node());
-      visit.insert(visit.end(), cur.begin(), cur.end());
-    }
-    else if (it->second.isNull())
-    {
-      NodeBuilder nb(cur.getKind());
-      if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
-      {
-        nb << cur.getOperator();
-      }
-
-      std::unordered_map<Node, Node>::iterator iit;
-      for (const TNode& child : cur)
-      {
-        iit = d_modelCache.find(child);
-        Assert(iit != d_modelCache.end());
-        Assert(iit->second.isConst());
-        nb << iit->second;
-      }
-      it->second = Rewriter::rewrite(nb.constructNode());
-    }
-  } while (!visit.empty());
-
-  auto it = d_modelCache.find(node);
-  Assert(it != d_modelCache.end());
-  return it->second;
-}
-
 void BVSolverBitblast::handleEagerAtom(TNode fact, bool assertFact)
 {
   Assert(fact.getKind() == kind::BITVECTOR_EAGER_ATOM);
index 8dee3c2c42391e6d5ebfe14b7f15fd54748e2eca..3f4ab50250797335c13391241e4ef3cce5736f16 100644 (file)
@@ -63,31 +63,22 @@ class BVSolverBitblast : public BVSolver
 
   std::string identify() const override { return "BVSolverBitblast"; };
 
-  EqualityStatus getEqualityStatus(TNode a, TNode b) override;
-
   void computeRelevantTerms(std::set<Node>& termSet) override;
 
   bool collectModelValues(TheoryModel* m,
                           const std::set<Node>& termSet) override;
 
- private:
-  /** Initialize SAT solver and CNF stream.  */
-  void initSatSolver();
-
   /**
-   * Get value of `node` from SAT solver.
+   * Get the current value of `node`.
    *
    * The `initialize` flag indicates whether bits should be zero-initialized
    * if they were not bit-blasted yet.
    */
-  Node getValueFromSatSolver(TNode node, bool initialize);
+  Node getValue(TNode node, bool initialize) override;
 
-  /**
-   * Get the current value of `node`.
-   *
-   * Computes the value if `node` was not yet bit-blasted.
-   */
-  Node getValue(TNode node);
+ private:
+  /** Initialize SAT solver and CNF stream.  */
+  void initSatSolver();
 
   /**
    * Handle BITVECTOR_EAGER_ATOM atoms and assert/assume to CnfStream.
@@ -97,14 +88,6 @@ class BVSolverBitblast : public BVSolver
    */
   void handleEagerAtom(TNode fact, bool assertFact);
 
-  /**
-   * Cache for getValue() calls.
-   *
-   * Is cleared at the beginning of a getValue() call if the
-   * `d_invalidateModelCache` flag is set to true.
-   */
-  std::unordered_map<Node, Node> d_modelCache;
-
   /** Bit-blaster used to bit-blast atoms/terms. */
   std::unique_ptr<NodeBitblaster> d_bitblaster;
 
@@ -137,12 +120,6 @@ class BVSolverBitblast : public BVSolver
   /** Stores the current input assertions. */
   context::CDList<Node> d_assertions;
 
-  /** Flag indicating whether `d_modelCache` should be invalidated. */
-  context::CDO<bool> d_invalidateModelCache;
-
-  /** Indicates whether the last check() call was satisfiable. */
-  context::CDO<bool> d_inSatMode;
-
   /** Proof generator that manages proofs for lemmas generated by this class. */
   std::unique_ptr<EagerProofGenerator> d_epg;
 
index bd47cc45eaabfe3a52d3d8d5f9d3fee939edc058..ef4f3559be18c53501879478cf4c73a788a35583 100644 (file)
@@ -147,6 +147,40 @@ bool BVSolverBitblastInternal::collectModelValues(TheoryModel* m,
   return d_bitblaster->collectModelValues(m, termSet);
 }
 
+Node BVSolverBitblastInternal::getValue(TNode node, bool initialize)
+{
+  if (node.isConst())
+  {
+    return node;
+  }
+
+  if (!d_bitblaster->hasBBTerm(node))
+  {
+    return initialize ? utils::mkConst(utils::getSize(node), 0u) : Node();
+  }
+
+  Valuation& val = d_state.getValuation();
+
+  std::vector<Node> bits;
+  d_bitblaster->getBBTerm(node, bits);
+  Integer value(0), one(1), zero(0), bit;
+  for (size_t i = 0, size = bits.size(), j = size - 1; i < size; ++i, --j)
+  {
+    bool satValue;
+    if (val.hasSatValue(bits[j], satValue))
+    {
+      bit = satValue ? one : zero;
+    }
+    else
+    {
+      if (!initialize) return Node();
+      bit = zero;
+    }
+    value = value * 2 + bit;
+  }
+  return utils::mkConst(bits.size(), value);
+}
+
 BVProofRuleChecker* BVSolverBitblastInternal::getProofChecker()
 {
   return &d_checker;
index 8a1886084ce3f3a1cb8aafb5fc6afce350bbf833..1ec3ec1fec8eaef12b23f7ba47bd3402b1670beb 100644 (file)
@@ -42,6 +42,8 @@ class BVSolverBitblastInternal : public BVSolver
                            ProofNodeManager* pnm);
   ~BVSolverBitblastInternal() = default;
 
+  bool needsEqualityEngine(EeSetupInfo& esi) override { return true; }
+
   void preRegisterTerm(TNode n) override {}
 
   bool preNotifyFact(TNode atom,
@@ -55,6 +57,8 @@ class BVSolverBitblastInternal : public BVSolver
   bool collectModelValues(TheoryModel* m,
                           const std::set<Node>& termSet) override;
 
+  Node getValue(TNode node, bool initialize) override;
+
   /** get the proof checker of this theory */
   BVProofRuleChecker* getProofChecker();
 
index 37881f9b20373fee2db0dec137f38a65abcfa59b..547d24b23abf5d6b8614483f18c24b6d020d42f8 100644 (file)
@@ -43,6 +43,7 @@ TheoryBV::TheoryBV(context::Context* c,
       d_state(c, u, valuation),
       d_im(*this, d_state, nullptr, "theory::bv::"),
       d_notify(d_im),
+      d_invalidateModelCache(c, true),
       d_stats("theory::bv::")
 {
   switch (options::bvSolver())
@@ -158,7 +159,11 @@ void TheoryBV::preRegisterTerm(TNode node)
 
 bool TheoryBV::preCheck(Effort e) { return d_internal->preCheck(e); }
 
-void TheoryBV::postCheck(Effort e) { d_internal->postCheck(e); }
+void TheoryBV::postCheck(Effort e)
+{
+  d_invalidateModelCache = true;
+  d_internal->postCheck(e);
+}
 
 bool TheoryBV::preNotifyFact(
     TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
@@ -282,7 +287,27 @@ void TheoryBV::presolve() { d_internal->presolve(); }
 
 EqualityStatus TheoryBV::getEqualityStatus(TNode a, TNode b)
 {
-  return d_internal->getEqualityStatus(a, b);
+  EqualityStatus status = d_internal->getEqualityStatus(a, b);
+
+  if (status == EqualityStatus::EQUALITY_UNKNOWN)
+  {
+    Node value_a = getValue(a);
+    Node value_b = getValue(b);
+
+    if (value_a.isNull() || value_b.isNull())
+    {
+      return status;
+    }
+
+    if (value_a == value_b)
+    {
+      Debug("theory-bv") << EQUALITY_TRUE_IN_MODEL << std::endl;
+      return EQUALITY_TRUE_IN_MODEL;
+    }
+    Debug("theory-bv") << EQUALITY_FALSE_IN_MODEL << std::endl;
+    return EQUALITY_FALSE_IN_MODEL;
+  }
+  return status;
 }
 
 TrustNode TheoryBV::explain(TNode node) { return d_internal->explain(node); }
@@ -303,6 +328,80 @@ bool TheoryBV::applyAbstraction(const std::vector<Node>& assertions,
   return d_internal->applyAbstraction(assertions, new_assertions);
 }
 
+Node TheoryBV::getValue(TNode node)
+{
+  if (d_invalidateModelCache.get())
+  {
+    d_modelCache.clear();
+  }
+  d_invalidateModelCache.set(false);
+
+  std::vector<TNode> visit;
+
+  TNode cur;
+  visit.push_back(node);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+
+    auto it = d_modelCache.find(cur);
+    if (it != d_modelCache.end() && !it->second.isNull())
+    {
+      continue;
+    }
+
+    if (cur.isConst())
+    {
+      d_modelCache[cur] = cur;
+      continue;
+    }
+
+    Node value = d_internal->getValue(cur, false);
+    if (value.isConst())
+    {
+      d_modelCache[cur] = value;
+      continue;
+    }
+
+    if (Theory::isLeafOf(cur, theory::THEORY_BV))
+    {
+      value = d_internal->getValue(cur, true);
+      d_modelCache[cur] = value;
+      continue;
+    }
+
+    if (it == d_modelCache.end())
+    {
+      visit.push_back(cur);
+      d_modelCache.emplace(cur, Node());
+      visit.insert(visit.end(), cur.begin(), cur.end());
+    }
+    else if (it->second.isNull())
+    {
+      NodeBuilder nb(cur.getKind());
+      if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
+      {
+        nb << cur.getOperator();
+      }
+
+      std::unordered_map<Node, Node>::iterator iit;
+      for (const TNode& child : cur)
+      {
+        iit = d_modelCache.find(child);
+        Assert(iit != d_modelCache.end());
+        Assert(iit->second.isConst());
+        nb << iit->second;
+      }
+      it->second = Rewriter::rewrite(nb.constructNode());
+    }
+  } while (!visit.empty());
+
+  auto it = d_modelCache.find(node);
+  Assert(it != d_modelCache.end());
+  return it->second;
+}
+
 TheoryBV::Statistics::Statistics(const std::string& name)
     : d_solveSubstitutions(
         smtStatisticsRegistry().registerInt(name + "NumSolveSubstitutions"))
index f2d6bb47ec6c3c361c46358f18d5ba729e4dc8d8..da44d7022fa424c27db960aad42502502383e049 100644 (file)
@@ -109,6 +109,8 @@ class TheoryBV : public Theory
  private:
   void notifySharedTerm(TNode t) override;
 
+  Node getValue(TNode node);
+
   /** Internal BV solver. */
   std::unique_ptr<BVSolver> d_internal;
 
@@ -124,6 +126,17 @@ class TheoryBV : public Theory
   /** The notify class for equality engine. */
   TheoryEqNotifyClass d_notify;
 
+  /** Flag indicating whether `d_modelCache` should be invalidated. */
+  context::CDO<bool> d_invalidateModelCache;
+
+  /**
+   * Cache for getValue() calls.
+   *
+   * Is cleared at the beginning of a getValue() call if the
+   * `d_invalidateModelCache` flag is set to true.
+   */
+  std::unordered_map<Node, Node> d_modelCache;
+
   /** TheoryBV statistics. */
   struct Statistics
   {