Add bit-level propagation support to BV bitblast solver. (#5906)
authorMathias Preiner <mathias.preiner@gmail.com>
Wed, 17 Feb 2021 00:36:57 +0000 (16:36 -0800)
committerGitHub <noreply@github.com>
Wed, 17 Feb 2021 00:36:57 +0000 (16:36 -0800)
This commit adds support for bit-level propagation for the BV bitblast solver to quickly detect conflicts on effort levels != FULL. Bit-level propagation for the bitblast solver is by default disabled for now. Further, bit-blasting of facts is now handled more lazily with a bit-blast queue.

src/prop/cadical.cpp
src/prop/cadical.h
src/prop/sat_solver.h
src/smt/set_defaults.cpp
src/theory/bv/bitblast/simple_bitblaster.cpp
src/theory/bv/bv_solver_bitblast.cpp
src/theory/bv/bv_solver_bitblast.h

index b3563bf28adbdbff05c20f42bb963ef8a5b51a1d..0324f8128c8daaa549620b1dbffa990c80b05e2f 100644 (file)
@@ -61,6 +61,7 @@ CadicalSolver::CadicalSolver(StatisticsRegistry* registry,
       // Note: CaDiCaL variables start with index 1 rather than 0 since negated
       //       literals are represented as the negation of the index.
       d_nextVarIdx(1),
+      d_inSatMode(false),
       d_statistics(registry, name)
 {
 }
@@ -111,10 +112,10 @@ SatVariable CadicalSolver::falseVar() { return d_false; }
 
 SatValue CadicalSolver::solve()
 {
-  d_assumptions.clear();
   TimerStat::CodeTimer codeTimer(d_statistics.d_solveTime);
+  d_assumptions.clear();
   SatValue res = toSatValue(d_solver->solve());
-  d_okay = (res == SAT_VALUE_TRUE);
+  d_inSatMode = (res == SAT_VALUE_TRUE);
   ++d_statistics.d_numSatCalls;
   return res;
 }
@@ -126,19 +127,25 @@ SatValue CadicalSolver::solve(long unsigned int&)
 
 SatValue CadicalSolver::solve(const std::vector<SatLiteral>& assumptions)
 {
-  d_assumptions.clear();
   TimerStat::CodeTimer codeTimer(d_statistics.d_solveTime);
+  d_assumptions.clear();
   for (const SatLiteral& lit : assumptions)
   {
     d_solver->assume(toCadicalLit(lit));
     d_assumptions.push_back(lit);
   }
   SatValue res = toSatValue(d_solver->solve());
-  d_okay = (res == SAT_VALUE_TRUE);
+  d_inSatMode = (res == SAT_VALUE_TRUE);
   ++d_statistics.d_numSatCalls;
   return res;
 }
 
+bool CadicalSolver::setPropagateOnly()
+{
+  d_solver->limit("decisions", 0); /* Gets reset after next solve() call. */
+  return true;
+}
+
 void CadicalSolver::getUnsatAssumptions(std::vector<SatLiteral>& assumptions)
 {
   for (const SatLiteral& lit : d_assumptions)
@@ -154,13 +161,13 @@ void CadicalSolver::interrupt() { d_solver->terminate(); }
 
 SatValue CadicalSolver::value(SatLiteral l)
 {
-  Assert(d_okay);
+  Assert(d_inSatMode);
   return toSatValueLit(d_solver->val(toCadicalLit(l)));
 }
 
 SatValue CadicalSolver::modelValue(SatLiteral l)
 {
-  Assert(d_okay);
+  Assert(d_inSatMode);
   return value(l);
 }
 
@@ -169,7 +176,7 @@ unsigned CadicalSolver::getAssertionLevel() const
   Unreachable() << "CaDiCaL does not support assertion levels.";
 }
 
-bool CadicalSolver::ok() const { return d_okay; }
+bool CadicalSolver::ok() const { return d_inSatMode; }
 
 CadicalSolver::Statistics::Statistics(StatisticsRegistry* registry,
                                       const std::string& prefix)
index 6a7258091ba40cde333a62daed0914b5adf03566..f8b2d3bcd556b0574f68813727f327becf7c7948 100644 (file)
@@ -50,6 +50,7 @@ class CadicalSolver : public SatSolver
   SatValue solve() override;
   SatValue solve(long unsigned int&) override;
   SatValue solve(const std::vector<SatLiteral>& assumptions) override;
+  bool setPropagateOnly() override;
   void getUnsatAssumptions(std::vector<SatLiteral>& assumptions) override;
 
   void interrupt() override;
@@ -82,7 +83,7 @@ class CadicalSolver : public SatSolver
   std::vector<SatLiteral> d_assumptions;
 
   unsigned d_nextVarIdx;
-  bool d_okay;
+  bool d_inSatMode;
   SatVariable d_true;
   SatVariable d_false;
 
index 896233f413bf31102f7a773a0ae3d5b98213df9e..d421155aee7a0ac66b61d012ed022e15ee493e45 100644 (file)
@@ -81,6 +81,13 @@ public:
     return SAT_VALUE_UNKNOWN;
   };
 
+  /**
+   * Tell SAT solver to only do propagation on next solve().
+   *
+   * @return true if feature is supported, otherwise false.
+   */
+  virtual bool setPropagateOnly() { return false; }
+
   /** Interrupt the solver */
   virtual void interrupt() = 0;
 
index d18b30430d91f5e9224f217e101240d7443b82de..93196fde4ee644aec2aa15cbaa78f841f5a65c65 100644 (file)
@@ -132,6 +132,12 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
     options::bvLazyRewriteExtf.set(false);
   }
 
+  /* Disable bit-level propagation by default for the BITBLAST solver. */
+  if (options::bvSolver() == options::BVSolver::BITBLAST)
+  {
+    options::bitvectorPropagate.set(false);
+  }
+
   if (options::solveIntAsBV() > 0)
   {
     // not compatible with incremental
index 8eb209b53c14cadbf01c7c446253d6d8368a3960..551c186124d13ea955e16fa5e51bab52ef912818 100644 (file)
@@ -38,7 +38,7 @@ void BBSimple::bbAtom(TNode node)
           ? d_atomBBStrategies[normalized.getKind()](normalized, this)
           : normalized;
 
-  storeBBAtom(node, atom_bb);
+  storeBBAtom(node, Rewriter::rewrite(atom_bb));
 }
 
 void BBSimple::storeBBAtom(TNode atom, Node atom_bb)
index 0b5d4cfef4ec26323e9d479ef595b34684a11204..bf264f9cd757ce7887bcac27470511ea7e065f8d 100644 (file)
@@ -34,11 +34,15 @@ BVSolverBitblast::BVSolverBitblast(TheoryState* s,
       d_bitblaster(new BBSimple(s)),
       d_nullRegistrar(new prop::NullRegistrar()),
       d_nullContext(new context::Context()),
-      d_facts(s->getSatContext()),
+      d_bbFacts(s->getSatContext()),
+      d_assumptions(s->getSatContext()),
       d_invalidateModelCache(s->getSatContext(), true),
       d_inSatMode(s->getSatContext(), false),
       d_epg(pnm ? new EagerProofGenerator(pnm, s->getUserContext(), "")
-                : nullptr)
+                : nullptr),
+      d_factLiteralCache(s->getSatContext()),
+      d_literalFactCache(s->getSatContext()),
+      d_propagate(options::bitvectorPropagate())
 {
   if (pnm != nullptr)
   {
@@ -66,25 +70,35 @@ void BVSolverBitblast::postCheck(Theory::Effort level)
 {
   if (level != Theory::Effort::EFFORT_FULL)
   {
-    return;
+    /* Do bit-level propagation only if the SAT solver supports it. */
+    if (!d_propagate || !d_satSolver->setPropagateOnly())
+    {
+      return;
+    }
   }
 
-  std::vector<prop::SatLiteral> assumptions;
-  std::unordered_map<prop::SatLiteral, TNode, prop::SatLiteralHashFunction>
-      node_map;
-  for (const TNode fact : d_facts)
+  /* Process bit-blast queue and store SAT literals. */
+  while (!d_bbFacts.empty())
   {
-    /* Bitblast fact */
-    d_bitblaster->bbAtom(fact);
-    Node bb_fact = Rewriter::rewrite(d_bitblaster->getStoredBBAtom(fact));
-    d_cnfStream->ensureLiteral(bb_fact);
+    Node fact = d_bbFacts.front();
+    d_bbFacts.pop();
+    /* Bit-blast fact and cache literal. */
+    if (d_factLiteralCache.find(fact) == d_factLiteralCache.end())
+    {
+      d_bitblaster->bbAtom(fact);
+      Node bb_fact = d_bitblaster->getStoredBBAtom(fact);
+      d_cnfStream->ensureLiteral(bb_fact);
 
-    prop::SatLiteral lit = d_cnfStream->getLiteral(bb_fact);
-    assumptions.push_back(lit);
-    node_map.emplace(lit, fact);
+      prop::SatLiteral lit = d_cnfStream->getLiteral(bb_fact);
+      d_factLiteralCache[fact] = lit;
+      d_literalFactCache[lit] = fact;
+    }
+    d_assumptions.push_back(d_factLiteralCache[fact]);
   }
 
   d_invalidateModelCache.set(true);
+  std::vector<prop::SatLiteral> assumptions(d_assumptions.begin(),
+                                            d_assumptions.end());
   prop::SatValue val = d_satSolver->solve(assumptions);
   d_inSatMode = val == prop::SatValue::SAT_VALUE_TRUE;
   Debug("bv-bitblast") << "d_inSatMode: " << d_inSatMode << std::endl;
@@ -98,7 +112,9 @@ void BVSolverBitblast::postCheck(Theory::Effort level)
     std::vector<Node> conflict;
     for (const prop::SatLiteral& lit : unsat_assumptions)
     {
-      conflict.push_back(node_map[lit]);
+      conflict.push_back(d_literalFactCache[lit]);
+      Debug("bv-bitblast") << "unsat assumption (" << lit
+                           << "): " << conflict.back() << std::endl;
     }
 
     NodeManager* nm = NodeManager::currentNM();
@@ -109,7 +125,7 @@ void BVSolverBitblast::postCheck(Theory::Effort level)
 bool BVSolverBitblast::preNotifyFact(
     TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
 {
-  d_facts.push_back(fact);
+  d_bbFacts.push_back(fact);
   return false;  // Return false to enable equality engine reasoning in Theory.
 }
 
index d9b4a26e98be6f063e1fb1b2b035b7efbb0be6f5..4eec45e4b87d186f5c0edd1ba69f5cfb8a27de7f 100644 (file)
@@ -107,8 +107,15 @@ class BVSolverBitblast : public BVSolver
   /** CNF stream. */
   std::unique_ptr<prop::CnfStream> d_cnfStream;
 
-  /** Facts sent to this solver. */
-  context::CDList<Node> d_facts;
+  /**
+   * Bit-blast queue for facts sent to this solver.
+   *
+   * Get populated on preNotifyFact().
+   */
+  context::CDQueue<Node> d_bbFacts;
+
+  /** Corresponds to the SAT literals of the currently asserted facts. */
+  context::CDList<prop::SatLiteral> d_assumptions;
 
   /** Flag indicating whether `d_modelCache` should be invalidated. */
   context::CDO<bool> d_invalidateModelCache;
@@ -120,6 +127,17 @@ class BVSolverBitblast : public BVSolver
   std::unique_ptr<EagerProofGenerator> d_epg;
 
   BVProofRuleChecker d_bvProofChecker;
+
+  /** Stores the SatLiteral for a given fact. */
+  context::CDHashMap<Node, prop::SatLiteral, NodeHashFunction>
+      d_factLiteralCache;
+
+  /** Reverse map of `d_factLiteralCache`. */
+  context::CDHashMap<prop::SatLiteral, Node, prop::SatLiteralHashFunction>
+      d_literalFactCache;
+
+  /** Option to enable/disable bit-level propagation. */
+  bool d_propagate;
 };
 
 }  // namespace bv