Make buffered inference manager more robust to backtracking (#6833)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 5 Jul 2021 13:03:33 +0000 (08:03 -0500)
committerGitHub <noreply@github.com>
Mon, 5 Jul 2021 13:03:33 +0000 (13:03 +0000)
This makes TheoryEngine notify all theories when a theory sends a conflict. This means that buffered inference managers always clear their buffers when any theory sends a conflict.

This is required for making theories robust to conflicts that may arise when using the central equality engine, where a different theory may raise a conflict during another theory's check.

src/theory/inference_manager_buffered.cpp
src/theory/inference_manager_buffered.h
src/theory/theory.cpp
src/theory/theory.h
src/theory/theory_engine.cpp
src/theory/theory_engine.h
src/theory/theory_inference_manager.cpp
src/theory/theory_inference_manager.h

index 534d59aeb062663eda6e07100ba05a5beb8b8f51..015da9372be83df7557190ef177fc76f4e81d4a9 100644 (file)
@@ -187,5 +187,12 @@ void InferenceManagerBuffered::assertInternalFactTheoryInference(
   assertInternalFact(atom, pol, fact->getId(), exp, pg);
 }
 
+void InferenceManagerBuffered::notifyInConflict()
+{
+  d_theoryState.notifyInConflict();
+  // also clear the pending facts, which will be stale after backtracking
+  clearPending();
+}
+
 }  // namespace theory
 }  // namespace cvc5
index 080033562f4b240ed2146bbcfc23ed21f16545f8..cc4bd7ba4c8a6f418dde87ec52ab6a1c0f481053 100644 (file)
@@ -159,6 +159,14 @@ class InferenceManagerBuffered : public TheoryInferenceManager
    */
   void assertInternalFactTheoryInference(TheoryInference* fact);
 
+  /**
+   * Notify this inference manager that a conflict was sent in this SAT context.
+   * This method is called via TheoryEngine when a conflict is sent. This
+   * method will clear all pending facts, lemmas, and phase requirements, as
+   * these will be stale after the solver backtracks.
+   */
+  void notifyInConflict() override;
+
  protected:
   /** A set of pending inferences to be processed as lemmas */
   std::vector<std::unique_ptr<TheoryInference>> d_pendingLem;
index b9dc1ba42871ecd9393452266d7b64493f0f3e3a..10c31edb76afaf0565cabdae4a94d81165b33a18 100644 (file)
@@ -275,6 +275,14 @@ void Theory::notifySharedTerm(TNode n)
   // do nothing
 }
 
+void Theory::notifyInConflict()
+{
+  if (d_inferManager != nullptr)
+  {
+    d_inferManager->notifyInConflict();
+  }
+}
+
 void Theory::computeCareGraph() {
   Debug("sharing") << "Theory::computeCareGraph<" << getId() << ">()" << endl;
   for (unsigned i = 0; i < d_sharedTerms.size(); ++ i) {
@@ -362,11 +370,14 @@ bool Theory::collectModelInfo(TheoryModel* m, const std::set<Node>& termSet)
   // if we are using an equality engine, assert it to the model
   if (d_equalityEngine != nullptr)
   {
+    Trace("model-builder") << "Assert Equality engine for " << d_id
+                           << std::endl;
     if (!m->assertEqualityEngine(d_equalityEngine, &termSet))
     {
       return false;
     }
   }
+  Trace("model-builder") << "Collect Model values for " << d_id << std::endl;
   // now, collect theory-specific value assigments
   return collectModelValues(m, termSet);
 }
index 378305c75d4bf48cdacd20eb01dea23107707e09..4dbb7a4365c271794db6eaafe83be17bf06de716 100644 (file)
@@ -227,7 +227,6 @@ class Theory {
 
   /** Pointer to proof node manager */
   ProofNodeManager* d_pnm;
-
   /**
    * Are proofs enabled?
    *
@@ -309,6 +308,12 @@ class Theory {
    */
   virtual void notifySharedTerm(TNode n);
 
+  /**
+   * Notify in conflict, called when a conflict clause is added to TheoryEngine
+   * by any theory (not necessarily this one). This signals that the theory
+   * should suspend what it is currently doing and wait for backtracking.
+   */
+  virtual void notifyInConflict();
  public:
   //--------------------------------- initialization
   /**
index 85cd7e6b3ff67d8b71c0413f23718a25823bbabc..bf196273e5242ff74b3d2df5547579a5c93c7288 100644 (file)
@@ -126,6 +126,7 @@ std::string getTheoryString(theory::TheoryId id)
 
 void TheoryEngine::finishInit()
 {
+  Trace("theory") << "Begin TheoryEngine::finishInit" << std::endl;
   // NOTE: This seems to be required since
   // theory::TheoryTraits<THEORY>::isParametric cannot be accessed without
   // using the CVC5_FOR_EACH_THEORY_STATEMENT macro. -AJR
@@ -202,6 +203,7 @@ void TheoryEngine::finishInit()
     // finish initializing the theory
     t->finishInit();
   }
+  Trace("theory") << "End TheoryEngine::finishInit" << std::endl;
 }
 
 ProofNodeManager* TheoryEngine::getProofNodeManager() const { return d_pnm; }
@@ -836,7 +838,6 @@ void TheoryEngine::notifyPreprocessedAssertions(
 }
 
 bool TheoryEngine::markPropagation(TNode assertion, TNode originalAssertion, theory::TheoryId toTheoryId, theory::TheoryId fromTheoryId) {
-
   // What and where we are asserting
   NodeTheoryPair toAssert(assertion, toTheoryId, d_propagationMapTimestamp);
   // What and where it came from
@@ -861,7 +862,6 @@ bool TheoryEngine::markPropagation(TNode assertion, TNode originalAssertion, the
 
 
 void TheoryEngine::assertToTheory(TNode assertion, TNode originalAssertion, theory::TheoryId toTheoryId, theory::TheoryId fromTheoryId) {
-
   Trace("theory::assertToTheory") << "TheoryEngine::assertToTheory(" << assertion << ", " << originalAssertion << "," << toTheoryId << ", " << fromTheoryId << ")" << endl;
 
   Assert(toTheoryId != fromTheoryId);
@@ -1030,7 +1030,8 @@ void TheoryEngine::assertFact(TNode literal)
         const AtomRequests::Request& request = it.get();
         Node toAssert =
             polarity ? (Node)request.d_atom : request.d_atom.notNode();
-        Debug("theory::atoms") << "TheoryEngine::assertFact(" << literal << "): sending requested " << toAssert << endl;
+        Debug("theory::atoms") << "TheoryEngine::assertFact(" << literal
+                               << "): sending requested " << toAssert << endl;
         assertToTheory(
             toAssert, literal, request.d_toTheory, THEORY_SAT_SOLVER);
         it.next();
@@ -1047,7 +1048,8 @@ void TheoryEngine::assertFact(TNode literal)
 }
 
 bool TheoryEngine::propagate(TNode literal, theory::TheoryId theory) {
-  Debug("theory::propagate") << "TheoryEngine::propagate(" << literal << ", " << theory << ")" << endl;
+  Debug("theory::propagate")
+      << "TheoryEngine::propagate(" << literal << ", " << theory << ")" << endl;
 
   Trace("dtview::prop") << std::string(d_env.getContext()->getLevel(), ' ')
                         << ":THEORY-PROP: " << literal << endl;
@@ -1250,7 +1252,8 @@ void TheoryEngine::ensureLemmaAtoms(const std::vector<TNode>& atoms, theory::The
     // Rewrite the equality
     Node eqNormalized = Rewriter::rewrite(atoms[i]);
 
-    Debug("theory::atoms") << "TheoryEngine::ensureLemmaAtoms(): " << eq << " with nf " << eqNormalized << endl;
+    Debug("theory::atoms") << "TheoryEngine::ensureLemmaAtoms(): " << eq
+                           << " with nf " << eqNormalized << endl;
 
     // If the equality is a boolean constant, we send immediately
     if (eqNormalized.isConst()) {
@@ -1333,7 +1336,8 @@ void TheoryEngine::lemma(TrustNode tlemma,
 
   // Do we need to check atoms
   if (atomsTo != theory::THEORY_LAST) {
-    Debug("theory::atoms") << "TheoryEngine::lemma(" << node << ", " << atomsTo << ")" << endl;
+    Debug("theory::atoms") << "TheoryEngine::lemma(" << node << ", " << atomsTo
+                           << ")" << endl;
     AtomsCollect collectAtoms;
     NodeVisitor<AtomsCollect>::run(collectAtoms, node);
     ensureLemmaAtoms(collectAtoms.getAtoms(), atomsTo);
@@ -1368,11 +1372,23 @@ void TheoryEngine::lemma(TrustNode tlemma,
   d_lemmasAdded = true;
 }
 
+void TheoryEngine::markInConflict()
+{
+#ifdef CVC5_FOR_EACH_THEORY_STATEMENT
+#undef CVC5_FOR_EACH_THEORY_STATEMENT
+#endif
+#define CVC5_FOR_EACH_THEORY_STATEMENT(THEORY) \
+  theoryOf(THEORY)->notifyInConflict();
+  CVC5_FOR_EACH_THEORY;
+  d_inConflict = true;
+}
+
 void TheoryEngine::conflict(TrustNode tconflict, TheoryId theoryId)
 {
   Assert(tconflict.getKind() == TrustNodeKind::CONFLICT);
+
   TNode conflict = tconflict.getNode();
-  Trace("theory::conflict") << "TheoryEngine::conflict(" << conflict << ", "
+  Debug("theory::conflict") << "TheoryEngine::conflict(" << conflict << ", "
                             << theoryId << ")" << endl;
   Trace("te-proof-debug") << "Check closed conflict" << std::endl;
   // doesn't require proof generator, yet, since THEORY_LEMMA is added below
@@ -1382,7 +1398,7 @@ void TheoryEngine::conflict(TrustNode tconflict, TheoryId theoryId)
   Trace("dtview::conflict") << ":THEORY-CONFLICT: " << conflict << std::endl;
 
   // Mark that we are in conflict
-  d_inConflict = true;
+  markInConflict();
 
   if(Dump.isOn("t-conflicts")) {
     const Printer& printer = d_outMgr.getPrinter();
@@ -1464,7 +1480,9 @@ void TheoryEngine::conflict(TrustNode tconflict, TheoryId theoryId)
     // pass the processed trust node
     TrustNode tconf =
         TrustNode::mkTrustConflict(fullConflict, d_lazyProof.get());
-    Debug("theory::conflict") << "TheoryEngine::conflict(" << conflict << ", " << theoryId << "): full = " << fullConflict << endl;
+    Debug("theory::conflict")
+        << "TheoryEngine::conflict(" << conflict << ", " << theoryId
+        << "): full = " << fullConflict << endl;
     Assert(properConflict(fullConflict));
     Trace("te-proof-debug")
         << "Check closed conflict with sharing" << std::endl;
@@ -1552,7 +1570,8 @@ TrustNode TheoryEngine::getExplanation(
     // If from the SAT solver, keep it
     if (toExplain.d_theory == THEORY_SAT_SOLVER)
     {
-      Debug("theory::explain") << "\tLiteral came from THEORY_SAT_SOLVER. Kepping it." << endl;
+      Debug("theory::explain")
+          << "\tLiteral came from THEORY_SAT_SOLVER. Keeping it." << endl;
       exp.insert(explanationVector[i++].d_node);
       // it will be a free assumption in the proof
       Trace("te-proof-exp") << "- keep " << toExplain.d_node << std::endl;
index f293a2cc815b72038fcb4450beb1b8e996dd6c02..cfcdcce13cb794a2010c52fe368e1b35a846e1c1 100644 (file)
@@ -193,6 +193,9 @@ class TheoryEngine {
    */
   void conflict(TrustNode conflict, theory::TheoryId theoryId);
 
+  /** set in conflict */
+  void markInConflict();
+
   /**
    * Debugging flag to ensure that shutdown() is called before the
    * destructor.
index ad988e5345b862202b65fe98c8f0d75ae4a0064b..c152481b59b9790ec02644c7d59d87a4b147ebca 100644 (file)
@@ -122,7 +122,6 @@ void TheoryInferenceManager::trustedConflict(TrustNode tconf, InferenceId id)
   smt::currentResourceManager()->spendResource(id);
   Trace("im") << "(conflict " << id << " " << tconf.getProven() << ")"
               << std::endl;
-  d_theoryState.notifyInConflict();
   d_out.trustedConflict(tconf);
   ++d_numConflicts;
 }
@@ -374,10 +373,10 @@ bool TheoryInferenceManager::processInternalFact(TNode atom,
 {
   d_factIdStats << iid;
   smt::currentResourceManager()->spendResource(iid);
-  Trace("im") << "(fact " << iid << " " << (pol ? Node(atom) : atom.notNode())
-              << ")" << std::endl;
   // make the node corresponding to the explanation
   Node expn = NodeManager::currentNM()->mkAnd(exp);
+  Trace("im") << "(fact " << iid << " " << (pol ? Node(atom) : atom.notNode())
+              << " " << expn << ")" << std::endl;
   // call the pre-notify fact method with preReg = false, isInternal = true
   if (d_theory.preNotifyFact(atom, pol, expn, false, true))
   {
@@ -387,6 +386,7 @@ bool TheoryInferenceManager::processInternalFact(TNode atom,
   }
   Assert(d_ee != nullptr);
   Trace("infer-manager") << "TheoryInferenceManager::assertInternalFact: "
+                         << (pol ? Node(atom) : atom.notNode()) << " from "
                          << expn << std::endl;
   d_numCurrentFacts++;
   // Now, assert the fact. How to do so depends on whether proofs are enabled.
@@ -524,5 +524,10 @@ void TheoryInferenceManager::setIncomplete(IncompleteId id)
   d_out.setIncomplete(id);
 }
 
+void TheoryInferenceManager::notifyInConflict()
+{
+  d_theoryState.notifyInConflict();
+}
+
 }  // namespace theory
 }  // namespace cvc5
index 06806f8d430f65a5265d0ad6b10d8c2b5de73802..181e678765755622b9fc7d49834757eb1936e25d 100644 (file)
@@ -371,6 +371,11 @@ class TheoryInferenceManager
    * this context level.
    */
   void setIncomplete(IncompleteId id);
+  /**
+   * Notify this inference manager that a conflict was sent in this SAT context.
+   * This method is called via TheoryEngine when a conflict is sent.
+   */
+  virtual void notifyInConflict();
 
  protected:
   /**