Properly set up equality engine for BV bitblast solver. (#5905)
authorMathias Preiner <mathias.preiner@gmail.com>
Sat, 13 Feb 2021 14:08:37 +0000 (06:08 -0800)
committerGitHub <noreply@github.com>
Sat, 13 Feb 2021 14:08:37 +0000 (08:08 -0600)
Theory BV now sets up the default equality engine for BV solvers that do not use their own equality engine like e.g. the BV bitblast solver. This commit also adds the missing equality engine pieces to the BV bitblast solver (getEqualityStatus, explain).

src/theory/bv/bitblast/lazy_bitblaster.cpp
src/theory/bv/bv_solver.h
src/theory/bv/bv_solver_bitblast.cpp
src/theory/bv/bv_solver_bitblast.h
src/theory/bv/bv_solver_lazy.cpp
src/theory/bv/bv_solver_lazy.h
src/theory/bv/bv_solver_simple.cpp
src/theory/bv/bv_subtheory_core.cpp
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv.h

index f3adc4b21ff8dfd92680e1bc88d2bbf09305f9ba..0c541ba89ee471cb7a3f9cec87157158e597dc2e 100644 (file)
@@ -416,9 +416,9 @@ void TLazyBitblaster::MinisatNotify::notify(prop::SatClause& clause) {
       lemmab << d_cnf->getNode(clause[i]);
     }
     Node lemma = lemmab;
-    d_bv->d_inferManager.lemma(lemma, InferenceId::UNKNOWN);
+    d_bv->d_im.lemma(lemma, InferenceId::UNKNOWN);
   } else {
-    d_bv->d_inferManager.lemma(d_cnf->getNode(clause[0]), InferenceId::UNKNOWN);
+    d_bv->d_im.lemma(d_cnf->getNode(clause[0]), InferenceId::UNKNOWN);
   }
 }
 
@@ -429,7 +429,7 @@ void TLazyBitblaster::MinisatNotify::spendResource(ResourceManager::Resource r)
 
 void TLazyBitblaster::MinisatNotify::safePoint(ResourceManager::Resource r)
 {
-  d_bv->d_inferManager.safePoint(r);
+  d_bv->d_im.safePoint(r);
 }
 
 EqualityStatus TLazyBitblaster::getEqualityStatus(TNode a, TNode b)
index f4b5a9d111b7dbb4b7865211cb79a6694e3ecb97..9f4ac54df77e2e91a967b1201cd982586586a29a 100644 (file)
@@ -29,7 +29,7 @@ class BVSolver
 {
  public:
   BVSolver(TheoryState& state, TheoryInferenceManager& inferMgr)
-      : d_state(state), d_inferManager(inferMgr){};
+      : d_state(state), d_im(inferMgr){};
 
   virtual ~BVSolver(){};
 
@@ -112,7 +112,7 @@ class BVSolver
 
  protected:
   TheoryState& d_state;
-  TheoryInferenceManager& d_inferManager;
+  TheoryInferenceManager& d_im;
 };
 
 }  // namespace bv
index ce8bc3645ac6df72deab795dc34ddf7eae75d156..0b5d4cfef4ec26323e9d479ef595b34684a11204 100644 (file)
@@ -35,6 +35,8 @@ BVSolverBitblast::BVSolverBitblast(TheoryState* s,
       d_nullRegistrar(new prop::NullRegistrar()),
       d_nullContext(new context::Context()),
       d_facts(s->getSatContext()),
+      d_invalidateModelCache(s->getSatContext(), true),
+      d_inSatMode(s->getSatContext(), false),
       d_epg(pnm ? new EagerProofGenerator(pnm, s->getUserContext(), "")
                 : nullptr)
 {
@@ -82,7 +84,10 @@ void BVSolverBitblast::postCheck(Theory::Effort level)
     node_map.emplace(lit, fact);
   }
 
+  d_invalidateModelCache.set(true);
   prop::SatValue val = d_satSolver->solve(assumptions);
+  d_inSatMode = val == prop::SatValue::SAT_VALUE_TRUE;
+  Debug("bv-bitblast") << "d_inSatMode: " << d_inSatMode << std::endl;
 
   if (val == prop::SatValue::SAT_VALUE_FALSE)
   {
@@ -97,7 +102,7 @@ void BVSolverBitblast::postCheck(Theory::Effort level)
     }
 
     NodeManager* nm = NodeManager::currentNM();
-    d_inferManager.conflict(nm->mkAnd(conflict), InferenceId::UNKNOWN);
+    d_im.conflict(nm->mkAnd(conflict), InferenceId::UNKNOWN);
   }
 }
 
@@ -108,6 +113,12 @@ bool BVSolverBitblast::preNotifyFact(
   return false;  // Return false to enable equality engine reasoning in Theory.
 }
 
+TrustNode BVSolverBitblast::explain(TNode n)
+{
+  Debug("bv-bitblast") << "explain called on " << n << std::endl;
+  return d_im.explainLit(n);
+}
+
 bool BVSolverBitblast::collectModelValues(TheoryModel* m,
                                           const std::set<Node>& termSet)
 {
@@ -118,7 +129,7 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m,
       continue;
     }
 
-    Node value = getValueFromSatSolver(term);
+    Node value = getValueFromSatSolver(term, true);
     Assert(value.isConst());
     if (!m->assertEquality(term, value, true))
     {
@@ -128,12 +139,37 @@ bool BVSolverBitblast::collectModelValues(TheoryModel* m,
   return true;
 }
 
-Node BVSolverBitblast::getValueFromSatSolver(TNode node)
+EqualityStatus BVSolverBitblast::getEqualityStatus(TNode a, TNode b)
+{
+  Debug("bv-bitblast") << "getEqualityStatus on " << a << " and " << b
+                       << std::endl;
+  if (!d_inSatMode)
+  {
+    Debug("bv-bitblast") << EQUALITY_UNKNOWN << std::endl;
+    return EQUALITY_UNKNOWN;
+  }
+  Node value_a = getValue(a);
+  Node value_b = getValue(b);
+
+  if (value_a == value_b)
+  {
+    Debug("bv-bitblast") << EQUALITY_TRUE_IN_MODEL << std::endl;
+    return EQUALITY_TRUE_IN_MODEL;
+  }
+  Debug("bv-bitblast") << EQUALITY_FALSE_IN_MODEL << std::endl;
+  return EQUALITY_FALSE_IN_MODEL;
+}
+
+Node BVSolverBitblast::getValueFromSatSolver(TNode node, bool initialize)
 {
-  /* If node was not bit-blasted return zero-initialized bit-vector. */
+  if (node.isConst())
+  {
+    return node;
+  }
+
   if (!d_bitblaster->hasBBTerm(node))
   {
-    return utils::mkConst(utils::getSize(node), 0u);
+    return initialize ? utils::mkConst(utils::getSize(node), 0u) : Node();
   }
 
   std::vector<Node> bits;
@@ -149,6 +185,7 @@ Node BVSolverBitblast::getValueFromSatSolver(TNode node)
     }
     else
     {
+      if (!initialize) return Node();
       bit = zero;
     }
     value = value * 2 + bit;
@@ -156,6 +193,76 @@ Node BVSolverBitblast::getValueFromSatSolver(TNode node)
   return utils::mkConst(bits.size(), value);
 }
 
+Node BVSolverBitblast::getValue(TNode node)
+{
+  if (d_invalidateModelCache.get())
+  {
+    d_modelCache.clear();
+  }
+  d_invalidateModelCache.set(false);
+
+  std::vector<TNode> visit;
+
+  TNode cur;
+  visit.push_back(node);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+
+    auto it = d_modelCache.find(cur);
+    if (it != d_modelCache.end() && !it->second.isNull())
+    {
+      continue;
+    }
+
+    if (d_bitblaster->hasBBTerm(cur))
+    {
+      Node value = getValueFromSatSolver(cur, false);
+      if (value.isConst())
+      {
+        d_modelCache[cur] = value;
+        continue;
+      }
+    }
+    if (Theory::isLeafOf(cur, theory::THEORY_BV))
+    {
+      Node value = getValueFromSatSolver(cur, true);
+      d_modelCache[cur] = value;
+      continue;
+    }
+
+    if (it == d_modelCache.end())
+    {
+      visit.push_back(cur);
+      d_modelCache.emplace(cur, Node());
+      visit.insert(visit.end(), cur.begin(), cur.end());
+    }
+    else if (it->second.isNull())
+    {
+      NodeBuilder<> nb(cur.getKind());
+      if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
+      {
+        nb << cur.getOperator();
+      }
+
+      std::unordered_map<Node, Node, NodeHashFunction>::iterator iit;
+      for (const TNode& child : cur)
+      {
+        iit = d_modelCache.find(child);
+        Assert(iit != d_modelCache.end());
+        Assert(iit->second.isConst());
+        nb << iit->second;
+      }
+      it->second = Rewriter::rewrite(nb.constructNode());
+    }
+  } while (!visit.empty());
+
+  auto it = d_modelCache.find(node);
+  Assert(it != d_modelCache.end());
+  return it->second;
+}
+
 }  // namespace bv
 }  // namespace theory
 }  // namespace CVC4
index df0f2e0857c1ef68fe6fec2d8835f52f16fbe8dc..d9b4a26e98be6f063e1fb1b2b035b7efbb0be6f5 100644 (file)
@@ -56,6 +56,8 @@ class BVSolverBitblast : public BVSolver
                      bool isPrereg,
                      bool isInternal) override;
 
+  TrustNode explain(TNode n) override;
+
   std::string identify() const override { return "BVSolverBitblast"; };
 
   Theory::PPAssertStatus ppAssert(
@@ -64,17 +66,39 @@ class BVSolverBitblast : public BVSolver
     return Theory::PPAssertStatus::PP_ASSERT_STATUS_UNSOLVED;
   }
 
+  EqualityStatus getEqualityStatus(TNode a, TNode b) override;
+
   bool collectModelValues(TheoryModel* m,
                           const std::set<Node>& termSet) override;
 
  private:
-  /** Get value of `node` from SAT solver. */
-  Node getValueFromSatSolver(TNode node);
+  /**
+   * Get value of `node` from SAT solver.
+   *
+   * The `initialize` flag indicates whether bits should be zero-initialized
+   * if they were not bit-blasted yet.
+   */
+  Node getValueFromSatSolver(TNode node, bool initialize);
+
+  /**
+   * Get the current value of `node`.
+   *
+   * Computes the value if `node` was not yet bit-blasted.
+   */
+  Node getValue(TNode node);
+
+  /**
+   * Cache for getValue() calls.
+   *
+   * Is cleared at the beginning of a getValue() call if the
+   * `d_invalidateModelCache` flag is set to true.
+   */
+  std::unordered_map<Node, Node, NodeHashFunction> d_modelCache;
 
   /** Bit-blaster used to bit-blast atoms/terms. */
   std::unique_ptr<BBSimple> d_bitblaster;
 
-  /** Used for initializing CnfStream> */
+  /** Used for initializing `d_cnfStream`. */
   std::unique_ptr<prop::NullRegistrar> d_nullRegistrar;
   std::unique_ptr<context::Context> d_nullContext;
 
@@ -86,6 +110,12 @@ class BVSolverBitblast : public BVSolver
   /** Facts sent to this solver. */
   context::CDList<Node> d_facts;
 
+  /** Flag indicating whether `d_modelCache` should be invalidated. */
+  context::CDO<bool> d_invalidateModelCache;
+
+  /** Indicates whether the last check() call was satisfiable. */
+  context::CDO<bool> d_inSatMode;
+
   /** Proof generator that manages proofs for lemmas generated by this class. */
   std::unique_ptr<EagerProofGenerator> d_epg;
 
index 9c44e32f12c584c9c4e5eb56a1e1169fa9ec9d52..0e81d064862ce295eb7cd290a2c254362f279814 100644 (file)
@@ -42,7 +42,7 @@ BVSolverLazy::BVSolverLazy(TheoryBV& bv,
                            context::UserContext* u,
                            ProofNodeManager* pnm,
                            std::string name)
-    : BVSolver(bv.d_state, bv.d_inferMgr),
+    : BVSolver(bv.d_state, bv.d_im),
       d_bv(bv),
       d_context(c),
       d_alreadyPropagatedSet(c),
@@ -119,7 +119,7 @@ void BVSolverLazy::finishInit()
 
 void BVSolverLazy::spendResource(ResourceManager::Resource r)
 {
-  d_inferManager.spendResource(r);
+  d_im.spendResource(r);
 }
 
 BVSolverLazy::Statistics::Statistics()
@@ -196,7 +196,7 @@ void BVSolverLazy::sendConflict()
   {
     Debug("bitvector") << indent() << "BVSolverLazy::check(): conflict "
                        << d_conflictNode << std::endl;
-    d_inferManager.conflict(d_conflictNode, InferenceId::UNKNOWN);
+    d_im.conflict(d_conflictNode, InferenceId::UNKNOWN);
     d_statistics.d_avgConflictSize.addEntry(d_conflictNode.getNumChildren());
     d_conflictNode = Node::null();
   }
@@ -287,11 +287,11 @@ void BVSolverLazy::check(Theory::Effort e)
     {
       if (assertions.size() == 1)
       {
-        d_inferManager.conflict(assertions[0], InferenceId::UNKNOWN);
+        d_im.conflict(assertions[0], InferenceId::UNKNOWN);
         return;
       }
       Node conflict = utils::mkAnd(assertions);
-      d_inferManager.conflict(conflict, InferenceId::UNKNOWN);
+      d_im.conflict(conflict, InferenceId::UNKNOWN);
       return;
     }
     return;
@@ -426,7 +426,7 @@ void BVSolverLazy::propagate(Theory::Effort e)
     {
       Debug("bitvector::propagate")
           << "BVSolverLazy:: propagating " << literal << "\n";
-      ok = d_inferManager.propagateLit(literal);
+      ok = d_im.propagateLit(literal);
     }
   }
 
@@ -670,7 +670,7 @@ bool BVSolverLazy::storePropagation(TNode literal, SubTheory subtheory)
   constexpr bool ok = true;
   if (subtheory == SUB_CORE)
   {
-    d_inferManager.propagateLit(literal);
+    d_im.propagateLit(literal);
     if (!ok)
     {
       setConflict();
index da5f1cbf88fa5cc51c48bc89bd950c99ffc867b8..46d01d129d6301cf00ef10d6eee6f3806a51b1e4 100644 (file)
@@ -203,7 +203,7 @@ class BVSolverLazy : public BVSolver
 
   void lemma(TNode node)
   {
-    d_inferManager.lemma(node, InferenceId::UNKNOWN);
+    d_im.lemma(node, InferenceId::UNKNOWN);
     d_lemmasAdded = true;
   }
 
index c4a40404105e9f83e28040750e499d7e3dd25c47..02196a4ede181ce0dad68e4319e957e866f543df 100644 (file)
@@ -93,12 +93,12 @@ void BVSolverSimple::addBBLemma(TNode fact)
 
   if (d_epg == nullptr)
   {
-    d_inferManager.lemma(lemma, InferenceId::UNKNOWN);
+    d_im.lemma(lemma, InferenceId::UNKNOWN);
   }
   else
   {
     TrustNode tlem = d_epg->mkTrustNode(lemma, PfRule::BV_BITBLAST, {}, {fact});
-    d_inferManager.trustedLemma(tlem, InferenceId::UNKNOWN);
+    d_im.trustedLemma(tlem, InferenceId::UNKNOWN);
   }
 }
 
@@ -123,13 +123,13 @@ bool BVSolverSimple::preNotifyFact(
 
     if (d_epg == nullptr)
     {
-      d_inferManager.lemma(lemma, InferenceId::UNKNOWN);
+      d_im.lemma(lemma, InferenceId::UNKNOWN);
     }
     else
     {
       TrustNode tlem =
           d_epg->mkTrustNode(lemma, PfRule::BV_EAGER_ATOM, {}, {fact});
-      d_inferManager.trustedLemma(tlem, InferenceId::UNKNOWN);
+      d_im.trustedLemma(tlem, InferenceId::UNKNOWN);
     }
 
     std::unordered_set<Node, NodeHashFunction> bv_atoms;
index 87cc0bc4de76d7b31b537924d02ae009e4602ef0..b839066888b0414dbfcdf4b28de093506d2db48f 100644 (file)
@@ -188,8 +188,7 @@ void CoreSolver::explain(TNode literal, std::vector<TNode>& assumptions) {
 bool CoreSolver::check(Theory::Effort e) {
   Trace("bitvector::core") << "CoreSolver::check \n";
 
-  d_bv->d_inferManager.spendResource(
-      ResourceManager::Resource::TheoryCheckStep);
+  d_bv->d_im.spendResource(ResourceManager::Resource::TheoryCheckStep);
 
   d_checkCalled = true;
   Assert(!d_bv->inConflict());
@@ -560,7 +559,7 @@ bool CoreSolver::doExtfInferences(std::vector<Node>& terms)
                               nm->mkNode(kind::LT, n, max));
         Trace("bv-extf-lemma")
             << "BV extf lemma (range) : " << lem << std::endl;
-        d_bv->d_inferManager.lemma(lem, InferenceId::UNKNOWN);
+        d_bv->d_im.lemma(lem, InferenceId::UNKNOWN);
         sentLemma = true;
       }
     }
@@ -609,7 +608,7 @@ bool CoreSolver::doExtfInferences(std::vector<Node>& terms)
           //   (bv2nat ((_ int2bv w) x)) == x + k*2^w for some k
           Trace("bv-extf-lemma")
               << "BV extf lemma (collapse) : " << lem << std::endl;
-          d_bv->d_inferManager.lemma(lem, InferenceId::UNKNOWN);
+          d_bv->d_im.lemma(lem, InferenceId::UNKNOWN);
           sentLemma = true;
         }
       }
index b27bd04e1d9eee063e0085e0d0c315bbed87c93b..f6e056f4241e3fb7da90f0705873e568a8ad1cb3 100644 (file)
@@ -39,12 +39,13 @@ TheoryBV::TheoryBV(context::Context* c,
       d_ufRemByZero(),
       d_rewriter(),
       d_state(c, u, valuation),
-      d_inferMgr(*this, d_state, nullptr)
+      d_im(*this, d_state, nullptr),
+      d_notify(d_im)
 {
   switch (options::bvSolver())
   {
     case options::BVSolver::BITBLAST:
-      d_internal.reset(new BVSolverBitblast(&d_state, d_inferMgr, pnm));
+      d_internal.reset(new BVSolverBitblast(&d_state, d_im, pnm));
       break;
 
     case options::BVSolver::LAZY:
@@ -53,10 +54,10 @@ TheoryBV::TheoryBV(context::Context* c,
 
     default:
       AlwaysAssert(options::bvSolver() == options::BVSolver::SIMPLE);
-      d_internal.reset(new BVSolverSimple(&d_state, d_inferMgr, pnm));
+      d_internal.reset(new BVSolverSimple(&d_state, d_im, pnm));
   }
   d_theoryState = &d_state;
-  d_inferManager = &d_inferMgr;
+  d_inferManager = &d_im;
 }
 
 TheoryBV::~TheoryBV() {}
@@ -65,7 +66,16 @@ TheoryRewriter* TheoryBV::getTheoryRewriter() { return &d_rewriter; }
 
 bool TheoryBV::needsEqualityEngine(EeSetupInfo& esi)
 {
-  return d_internal->needsEqualityEngine(esi);
+  bool need_ee = d_internal->needsEqualityEngine(esi);
+
+  /* Set up default notify class for equality engine. */
+  if (need_ee && esi.d_notify == nullptr)
+  {
+    esi.d_notify = &d_notify;
+    esi.d_name = "theory::bv::ee";
+  }
+
+  return need_ee;
 }
 
 void TheoryBV::finishInit()
@@ -194,6 +204,19 @@ TrustNode TheoryBV::expandDefinition(Node node)
 void TheoryBV::preRegisterTerm(TNode node)
 {
   d_internal->preRegisterTerm(node);
+
+  eq::EqualityEngine* ee = getEqualityEngine();
+  if (ee)
+  {
+    if (node.getKind() == kind::EQUAL)
+    {
+      ee->addTriggerPredicate(node);
+    }
+    else
+    {
+      ee->addTerm(node);
+    }
+  }
 }
 
 bool TheoryBV::preCheck(Effort e) { return d_internal->preCheck(e); }
index 306b1ff937eb0331e61a69a5c839eb3d4626527b..2aa722e48f025bb561a01587e624430b633ee516 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "theory/bv/theory_bv_rewriter.h"
 #include "theory/theory.h"
+#include "theory/theory_eq_notify.h"
 
 namespace CVC4 {
 namespace theory {
@@ -130,7 +131,10 @@ class TheoryBV : public Theory
   TheoryState d_state;
 
   /** A (default) theory inference manager. */
-  TheoryInferenceManager d_inferMgr;
+  TheoryInferenceManager d_im;
+
+  /** The notify class for equality engine. */
+  TheoryEqNotifyClass d_notify;
 
 }; /* class TheoryBV */