Refactor shared solver to use theory builtin inference manager (#6960)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 3 Aug 2021 16:50:51 +0000 (11:50 -0500)
committerGitHub <noreply@github.com>
Tue, 3 Aug 2021 16:50:51 +0000 (16:50 +0000)
This ensures that e.g. COMBINATION_SPLIT shows up under theory::builtin::inferencesLemmas, and -t im. It also removes outdated interfaces from OutputChannel, and makes the feature TheoryEngine::ensureLemmaAtoms more modular, which was required for making these interfaces more consistent.

It also ensures that TheoryBuiltin has an inference manager, which will simplify special casing in #6956.

13 files changed:
src/theory/builtin/theory_builtin.cpp
src/theory/builtin/theory_builtin.h
src/theory/ee_manager_central.cpp
src/theory/engine_output_channel.cpp
src/theory/engine_output_channel.h
src/theory/output_channel.cpp
src/theory/output_channel.h
src/theory/shared_solver.cpp
src/theory/shared_solver.h
src/theory/theory_engine.cpp
src/theory/theory_engine.h
test/unit/test_smt.h
test/unit/theory/theory_white.cpp

index 80c5dba3db210719a0555ef01213d42f59d7eee0..1db03d22bd7ab8e77ac744d83320fcefbf6d6c5e 100644 (file)
@@ -31,8 +31,13 @@ TheoryBuiltin::TheoryBuiltin(context::Context* c,
                              Valuation valuation,
                              const LogicInfo& logicInfo,
                              ProofNodeManager* pnm)
-    : Theory(THEORY_BUILTIN, c, u, out, valuation, logicInfo, pnm)
+    : Theory(THEORY_BUILTIN, c, u, out, valuation, logicInfo, pnm),
+      d_state(c, u, valuation),
+      d_im(*this, d_state, pnm, "theory::builtin::")
 {
+  // indicate we are using the default theory state and inference managers
+  d_theoryState = &d_state;
+  d_inferManager = &d_im;
 }
 
 TheoryRewriter* TheoryBuiltin::getTheoryRewriter() { return &d_rewriter; }
index 27492cb532afd196282ce4cf3f0dc0a203d451b9..72485f0ea8093b78142587817bc30cece8d6d5ef 100644 (file)
@@ -21,6 +21,8 @@
 #include "theory/builtin/proof_checker.h"
 #include "theory/builtin/theory_builtin_rewriter.h"
 #include "theory/theory.h"
+#include "theory/theory_inference_manager.h"
+#include "theory/theory_state.h"
 
 namespace cvc5 {
 namespace theory {
@@ -51,6 +53,10 @@ class TheoryBuiltin : public Theory
   TheoryBuiltinRewriter d_rewriter;
   /** Proof rule checker */
   BuiltinProofRuleChecker d_checker;
+  /** A (default) theory state object */
+  TheoryState d_state;
+  /** A (default) inference manager */
+  TheoryInferenceManager d_im;
 }; /* class TheoryBuiltin */
 
 }  // namespace builtin
index f98bce970d78aec6b29e1691342f1ddde045e598..23946512ef81a378be7f8e288fe757e36d37b085 100644 (file)
@@ -298,7 +298,8 @@ void EqEngineManagerCentral::eqNotifyConstantTermMerge(TNode t1, TNode t2)
   Node conflict = d_centralEqualityEngine.mkExplainLit(lit);
   Trace("eem-central") << "...explained conflict of " << lit << " ... "
                        << conflict << std::endl;
-  d_sharedSolver.sendConflict(TrustNode::mkTrustConflict(conflict));
+  d_sharedSolver.sendConflict(TrustNode::mkTrustConflict(conflict),
+                              InferenceId::EQ_CONSTANT_MERGE);
   return;
 }
 
index b1f35821f1144558e1607b152d150a332d5c2aaf..ee07dcd5797ec64b5c8f392d5d7df9e6369ae871 100644 (file)
@@ -60,31 +60,7 @@ void EngineOutputChannel::safePoint(Resource r)
 
 void EngineOutputChannel::lemma(TNode lemma, LemmaProperty p)
 {
-  Trace("theory::lemma") << "EngineOutputChannel<" << d_theory << ">::lemma("
-                         << lemma << ")"
-                         << ", properties = " << p << std::endl;
-  ++d_statistics.lemmas;
-  d_engine->d_outputChannelUsed = true;
-
-  TrustNode tlem = TrustNode::mkTrustLemma(lemma);
-  d_engine->lemma(tlem,
-                  p,
-                  isLemmaPropertySendAtoms(p) ? d_theory : theory::THEORY_LAST,
-                  d_theory);
-}
-
-void EngineOutputChannel::splitLemma(TNode lemma, bool removable)
-{
-  Trace("theory::lemma") << "EngineOutputChannel<" << d_theory << ">::lemma("
-                         << lemma << ")" << std::endl;
-  ++d_statistics.lemmas;
-  d_engine->d_outputChannelUsed = true;
-
-  Trace("pf::explain") << "EngineOutputChannel::splitLemma( " << lemma << " )"
-                       << std::endl;
-  TrustNode tlem = TrustNode::mkTrustLemma(lemma);
-  LemmaProperty p = removable ? LemmaProperty::REMOVABLE : LemmaProperty::NONE;
-  d_engine->lemma(tlem, p, d_theory);
+  trustedLemma(TrustNode::mkTrustLemma(lemma), p);
 }
 
 bool EngineOutputChannel::propagate(TNode literal)
@@ -172,10 +148,13 @@ void EngineOutputChannel::trustedLemma(TrustNode plem, LemmaProperty p)
   }
   ++d_statistics.lemmas;
   d_engine->d_outputChannelUsed = true;
+  if (isLemmaPropertySendAtoms(p))
+  {
+    d_engine->ensureLemmaAtoms(plem.getNode(), d_theory);
+  }
   // now, call the normal interface for lemma
   d_engine->lemma(plem,
                   p,
-                  isLemmaPropertySendAtoms(p) ? d_theory : theory::THEORY_LAST,
                   d_theory);
 }
 
index dcf8fba556440c5e27eb8af850f83b2c91cb1e1a..cc1d8ece73eebd552248ee0a3de97c72e13eb364 100644 (file)
@@ -53,8 +53,6 @@ class EngineOutputChannel : public theory::OutputChannel
 
   void lemma(TNode lemma, LemmaProperty p = LemmaProperty::NONE) override;
 
-  void splitLemma(TNode lemma, bool removable = false) override;
-
   void demandRestart() override;
 
   void requirePhase(TNode n, bool phase) override;
index 2d85f4de67af88b3716b2a63e23e44ed8087e5b1..5ab91b6e0ddc42a9872ce0941d63c73ca6569496 100644 (file)
@@ -77,8 +77,6 @@ std::ostream& operator<<(std::ostream& out, LemmaProperty p)
   return out;
 }
 
-void OutputChannel::split(TNode n) { splitLemma(n.orNode(n.notNode())); }
-
 void OutputChannel::trustedConflict(TrustNode pconf)
 {
   Unreachable() << "OutputChannel::trustedConflict: no implementation"
index b681dad17d6c899381895bc0f8cbfbcf2780957e..80115d4381226b8e905372e4dbeb8cefe6da68f7 100644 (file)
@@ -121,16 +121,6 @@ class OutputChannel {
    */
   virtual void lemma(TNode n, LemmaProperty p = LemmaProperty::NONE) = 0;
 
-  /**
-   * Request a split on a new theory atom.  This is equivalent to
-   * calling lemma({OR n (NOT n)}).
-   *
-   * @param n - a theory atom; must be of Boolean type
-   */
-  void split(TNode n);
-
-  virtual void splitLemma(TNode n, bool removable = false) = 0;
-
   /**
    * If a decision is made on n, it must be in the phase specified.
    * Note that this is enforced *globally*, i.e., it is completely
index b020a39386cff8fa5077e0fef236c176be55cf45..95558ead164ea08c1399ead1b0920fa595e1cf96 100644 (file)
@@ -19,6 +19,7 @@
 #include "theory/ee_setup_info.h"
 #include "theory/logic_info.h"
 #include "theory/theory_engine.h"
+#include "theory/theory_inference_manager.h"
 
 namespace cvc5 {
 namespace theory {
@@ -35,7 +36,7 @@ SharedSolver::SharedSolver(TheoryEngine& te, ProofNodeManager* pnm)
       d_sharedTerms(&d_te, d_te.getSatContext(), d_te.getUserContext(), pnm),
       d_preRegistrationVisitor(&te, d_te.getSatContext()),
       d_sharedTermsVisitor(&te, d_sharedTerms, d_te.getSatContext()),
-      d_out(te.theoryOf(THEORY_BUILTIN)->getOutputChannel())
+      d_im(te.theoryOf(THEORY_BUILTIN)->getInferenceManager())
 {
 }
 
@@ -113,9 +114,9 @@ bool SharedSolver::propagateLit(TNode predicate, bool value)
 {
   if (value)
   {
-    return d_out.propagate(predicate);
+    return d_im->propagateLit(predicate);
   }
-  return d_out.propagate(predicate.notNode());
+  return d_im->propagateLit(predicate.notNode());
 }
 
 bool SharedSolver::propagateSharedEquality(theory::TheoryId theory,
@@ -141,11 +142,18 @@ bool SharedSolver::isShared(TNode t) const { return d_sharedTerms.isShared(t); }
 
 void SharedSolver::sendLemma(TrustNode trn, TheoryId atomsTo, InferenceId id)
 {
-  Trace("im") << "(lemma " << id << " " << trn.getProven() << ")" << std::endl;
-  d_te.lemma(trn, LemmaProperty::NONE, atomsTo);
+  // Do we need to check atoms
+  if (atomsTo != theory::THEORY_LAST)
+  {
+    d_te.ensureLemmaAtoms(trn.getNode(), atomsTo);
+  }
+  d_im->trustedLemma(trn, id);
 }
 
-void SharedSolver::sendConflict(TrustNode trn) { d_out.trustedConflict(trn); }
+void SharedSolver::sendConflict(TrustNode trn, InferenceId id)
+{
+  d_im->trustedConflict(trn, id);
+}
 
 }  // namespace theory
 }  // namespace cvc5
index e2cda0fbc927624a4e038172e66a7399dd7e18ce..a7f9ceff594fb8978491be69a09bacda00bda017 100644 (file)
@@ -33,7 +33,7 @@ class TheoryEngine;
 namespace theory {
 
 struct EeSetupInfo;
-class OutputChannel;
+class TheoryInferenceManager;
 
 /**
  * A base class for shared solver. The shared solver is the component of theory
@@ -124,7 +124,7 @@ class SharedSolver
   /** Send lemma to the theory engine, atomsTo is the theory to send atoms to */
   void sendLemma(TrustNode trn, TheoryId atomsTo, InferenceId id);
   /** Send conflict to the theory engine */
-  void sendConflict(TrustNode trn);
+  void sendConflict(TrustNode trn, InferenceId id);
 
  protected:
   /** Solver-specific pre-register shared */
@@ -139,8 +139,8 @@ class SharedSolver
   PreRegisterVisitor d_preRegistrationVisitor;
   /** Visitor for collecting shared terms */
   SharedTermsVisitor d_sharedTermsVisitor;
-  /** Output channel of theory builtin */
-  OutputChannel& d_out;
+  /** Theory inference manager of theory builtin */
+  TheoryInferenceManager* d_im;
 };
 
 }  // namespace theory
index 63fd6d9b7557978744711cdff286638c86f5ef9c..fb93403b96dfb4d7eacfdc5a66614e45a5a3cd84 100644 (file)
@@ -1246,6 +1246,16 @@ struct AtomsCollect {
   }
 };
 
+void TheoryEngine::ensureLemmaAtoms(TNode n, theory::TheoryId atomsTo)
+{
+  Assert(atomsTo != THEORY_LAST);
+  Debug("theory::atoms") << "TheoryEngine::ensureLemmaAtoms(" << n << ", "
+                         << atomsTo << ")" << endl;
+  AtomsCollect collectAtoms;
+  NodeVisitor<AtomsCollect>::run(collectAtoms, n);
+  ensureLemmaAtoms(collectAtoms.getAtoms(), atomsTo);
+}
+
 void TheoryEngine::ensureLemmaAtoms(const std::vector<TNode>& atoms, theory::TheoryId atomsTo) {
   for (unsigned i = 0; i < atoms.size(); ++ i) {
 
@@ -1314,7 +1324,6 @@ void TheoryEngine::ensureLemmaAtoms(const std::vector<TNode>& atoms, theory::The
 
 void TheoryEngine::lemma(TrustNode tlemma,
                          theory::LemmaProperty p,
-                         theory::TheoryId atomsTo,
                          theory::TheoryId from)
 {
   // For resource-limiting (also does a time check).
@@ -1346,15 +1355,6 @@ void TheoryEngine::lemma(TrustNode tlemma,
     tlemma.debugCheckClosed("te-proof-debug", "TheoryEngine::lemma_initial");
   }
 
-  // Do we need to check atoms
-  if (atomsTo != theory::THEORY_LAST) {
-    Debug("theory::atoms") << "TheoryEngine::lemma(" << node << ", " << atomsTo
-                           << ")" << endl;
-    AtomsCollect collectAtoms;
-    NodeVisitor<AtomsCollect>::run(collectAtoms, node);
-    ensureLemmaAtoms(collectAtoms.getAtoms(), atomsTo);
-  }
-
   if(Dump.isOn("t-lemmas")) {
     // we dump the negation of the lemma, to show validity of the lemma
     Node n = lemma.negate();
@@ -1504,7 +1504,7 @@ void TheoryEngine::conflict(TrustNode tconflict, TheoryId theoryId)
     // When only one theory, the conflict should need no processing
     Assert(properConflict(conflict));
     // pass the trust node that was sent from the theory
-    lemma(tconflict, LemmaProperty::REMOVABLE, THEORY_LAST, theoryId);
+    lemma(tconflict, LemmaProperty::REMOVABLE, theoryId);
   }
 }
 
index cfcdcce13cb794a2010c52fe368e1b35a846e1c1..5d16f04ba475ed58b6e2bc85f8cd680e0695f2c8 100644 (file)
@@ -277,11 +277,13 @@ class TheoryEngine {
    */
   void lemma(TrustNode node,
              theory::LemmaProperty p,
-             theory::TheoryId atomsTo = theory::THEORY_LAST,
              theory::TheoryId from = theory::THEORY_LAST);
 
-  /** Enusre that the given atoms are send to the given theory */
-  void ensureLemmaAtoms(const std::vector<TNode>& atoms, theory::TheoryId theory);
+  /** Ensure atoms from the given node are sent to the given theory */
+  void ensureLemmaAtoms(TNode n, theory::TheoryId atomsTo);
+  /** Ensure that the given atoms are sent to the given theory */
+  void ensureLemmaAtoms(const std::vector<TNode>& atoms,
+                        theory::TheoryId atomsTo);
 
   /** sort inference module */
   std::unique_ptr<theory::SortInference> d_sortInfer;
index 672693366f623bcdd0eda983ee61fabaaff39dc7..4226f809595911d6cfe7721fb71ac6daccb6cdaa 100644 (file)
@@ -138,8 +138,6 @@ class DummyOutputChannel : public cvc5::theory::OutputChannel
   void setIncomplete(theory::IncompleteId id) override {}
   void handleUserAttribute(const char* attr, theory::Theory* t) override {}
 
-  void splitLemma(TNode n, bool removable = false) override { push(LEMMA, n); }
-
   void clear() { d_callHistory.clear(); }
 
   Node getIthNode(int i) const
index 5a469ed97a80f262b46e5c8c68b7f8928c665c31..94021a9e364b46b0f477cde6a3c4032f0d26c895 100644 (file)
@@ -100,7 +100,7 @@ TEST_F(TestTheoryWhite, outputChannel)
 {
   Node n = d_atom0.orNode(d_atom1);
   d_outputChannel.lemma(n);
-  d_outputChannel.split(d_atom0);
+  d_outputChannel.lemma(d_atom0.orNode(d_atom0.notNode()));
   Node s = d_atom0.orNode(d_atom0.notNode());
   ASSERT_EQ(d_outputChannel.d_callHistory.size(), 2u);
   ASSERT_EQ(d_outputChannel.d_callHistory[0], std::make_pair(LEMMA, n));