Make sets and strings solver states inherit from TheoryState (#4918)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 19 Aug 2020 18:36:59 +0000 (13:36 -0500)
committerGitHub <noreply@github.com>
Wed, 19 Aug 2020 18:36:59 +0000 (13:36 -0500)
This is towards the new standard for theory solvers.

This PR makes the custom states of sets and strings inherit from the standard base class TheoryState. It also makes a minor change to InferenceManager/SolverState to make sets more in line with the plan for a standard base class InferenceManager.

Followup PRs will establish the official TheoryState classes for all other theories (which in most cases will be an instance of the base class).

14 files changed:
src/theory/sets/inference_manager.cpp
src/theory/sets/inference_manager.h
src/theory/sets/solver_state.cpp
src/theory/sets/solver_state.h
src/theory/sets/theory_sets.cpp
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h
src/theory/strings/inference_manager.cpp
src/theory/strings/solver_state.cpp
src/theory/strings/solver_state.h
src/theory/strings/theory_strings.cpp
src/theory/theory.cpp
src/theory/theory_state.cpp
src/theory/theory_state.h

index f99dad91e652f615824d8ed10c4e8784cf64ed7c..8f25f651150a723f028e6ea28cfe6450199c8d2d 100644 (file)
@@ -72,7 +72,7 @@ bool InferenceManager::assertFactRec(Node fact, Node exp, int inferType)
     if (fact == d_false)
     {
       Trace("sets-lemma") << "Conflict : " << exp << std::endl;
-      d_state.setConflict(exp);
+      conflict(exp);
       return true;
     }
     return false;
@@ -233,6 +233,12 @@ bool InferenceManager::hasProcessed() const
 bool InferenceManager::hasSentLemma() const { return d_sentLemma; }
 bool InferenceManager::hasAddedFact() const { return d_addedFact; }
 
+void InferenceManager::conflict(Node conf)
+{
+  d_parent.getOutputChannel()->conflict(conf);
+  d_state.notifyInConflict();
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace CVC4
index ba6be9905707eb357d9f3e7ebe54fdae40b620b7..3278b848e0405ba9ef51b313536b5f72bfb0e362 100644 (file)
@@ -109,6 +109,12 @@ class InferenceManager
   /** Have we sent lem as a lemma in the current user context? */
   bool hasLemmaCached(Node lem) const;
 
+  /** 
+   * Send conflict.
+   * @param conf The conflict node to be sent on the output channel
+   */
+  void conflict(Node conf);
+
  private:
   /** constants */
   Node d_true;
index f3371cf61e09fe1cff84ce2bacf3d3f44b2d2afe..5e5e9d22a80b41c4bbc815e9cb373dde145269a9 100644 (file)
@@ -27,19 +27,14 @@ namespace sets {
 
 SolverState::SolverState(TheorySetsPrivate& p,
                          context::Context* c,
-                         context::UserContext* u)
-    : d_conflict(c), d_parent(p), d_ee(nullptr), d_proxy(u), d_proxy_to_term(u)
+                         context::UserContext* u,
+                         Valuation val)
+    : TheoryState(c, u, val), d_parent(p), d_proxy(u), d_proxy_to_term(u)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
 }
 
-void SolverState::finishInit(eq::EqualityEngine* ee)
-{
-  Assert(ee != nullptr);
-  d_ee = ee;
-}
-
 void SolverState::reset()
 {
   d_set_eqc.clear();
@@ -169,52 +164,6 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
   }
 }
 
-Node SolverState::getRepresentative(Node a) const
-{
-  if (d_ee->hasTerm(a))
-  {
-    return d_ee->getRepresentative(a);
-  }
-  return a;
-}
-
-bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); }
-
-bool SolverState::areEqual(Node a, Node b) const
-{
-  if (a == b)
-  {
-    return true;
-  }
-  if (d_ee->hasTerm(a) && d_ee->hasTerm(b))
-  {
-    return d_ee->areEqual(a, b);
-  }
-  return false;
-}
-
-bool SolverState::areDisequal(Node a, Node b) const
-{
-  if (a == b)
-  {
-    return false;
-  }
-  else if (d_ee->hasTerm(a) && d_ee->hasTerm(b))
-  {
-    return d_ee->areDisequal(a, b, false);
-  }
-  return a.isConst() && b.isConst();
-}
-
-eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; }
-
-void SolverState::setConflict() { d_conflict = true; }
-void SolverState::setConflict(Node conf)
-{
-  d_parent.getOutputChannel()->conflict(conf);
-  d_conflict = true;
-}
-
 void SolverState::addEqualityToExp(Node a, Node b, std::vector<Node>& exp) const
 {
   if (a != b)
index dce90c2d3b44b02c714761edf939441d825fdf5d..3a40befbd3f239803cc6504f9df3a818cef78a2d 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "context/cdhashset.h"
 #include "theory/sets/skolem_cache.h"
+#include "theory/theory_state.h"
 #include "theory/uf/equality_engine.h"
 
 namespace CVC4 {
@@ -42,19 +43,15 @@ class TheorySetsPrivate;
  * to initialize the information in this class regarding full effort checks.
  * Other query calls are then valid for the remainder of the full effort check.
  */
-class SolverState
+class SolverState : public TheoryState
 {
   typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeMap;
 
  public:
   SolverState(TheorySetsPrivate& p,
               context::Context* c,
-              context::UserContext* u);
-  /**
-   * Finish initialize, there ee is a pointer to the official equality engine
-   * of theory of strings.
-   */
-  void finishInit(eq::EqualityEngine* ee);
+              context::UserContext* u,
+              Valuation val);
   //-------------------------------- initialize per check
   /** reset, clears the data structures maintained by this class. */
   void reset();
@@ -63,28 +60,6 @@ class SolverState
   /** register term n of type tnn in the equivalence class of r */
   void registerTerm(Node r, TypeNode tnn, Node n);
   //-------------------------------- end initialize per check
-  /** Are we currently in conflict? */
-  bool isInConflict() const { return d_conflict; }
-  /**
-   * Indicate that we are in conflict, without a conflict clause. This is
-   * called, for instance, when we have propagated a conflicting literal.
-   */
-  void setConflict();
-  /** Set conf is a conflict node to be sent on the output channel.  */
-  void setConflict(Node conf);
-  /**
-   * Get the representative of a in the equality engine of this class, or a
-   * itself if it is not registered as a term.
-   */
-  Node getRepresentative(Node a) const;
-  /** Is a registered as a term in the equality engine of this class? */
-  bool hasTerm(Node a) const;
-  /** Is a=b according to equality reasoning in the current context? */
-  bool areEqual(Node a, Node b) const;
-  /** Is a!=b according to equality reasoning in the current context? */
-  bool areDisequal(Node a, Node b) const;
-  /** get equality engine */
-  eq::EqualityEngine* getEqualityEngine() const;
   /** add equality to explanation
    *
    * This adds a = b to exp if a and b are syntactically disequal. The equality
@@ -229,12 +204,8 @@ class SolverState
   /** the empty vector and map */
   std::vector<Node> d_emptyVec;
   std::map<Node, Node> d_emptyMap;
-  /** Whether or not we are in conflict. This flag is SAT context dependent. */
-  context::CDO<bool> d_conflict;
   /** Reference to the parent theory of sets */
   TheorySetsPrivate& d_parent;
-  /** Pointer to the official equality engine of theory of sets */
-  eq::EqualityEngine* d_ee;
   /** The list of all equivalence classes of type set in the current context */
   std::vector<Node> d_set_eqc;
   /** Maps types to the equivalence class containing empty set of that type */
index fd9af488fcbdfc9eb8c0010bb6732bb1bb224cbd..fc544f46f56a61aa5a1bab59bf399a852a2526e3 100644 (file)
@@ -34,13 +34,11 @@ TheorySets::TheorySets(context::Context* c,
                        const LogicInfo& logicInfo,
                        ProofNodeManager* pnm)
     : Theory(THEORY_SETS, c, u, out, valuation, logicInfo, pnm),
-      d_internal(new TheorySetsPrivate(*this, c, u)),
+      d_internal(new TheorySetsPrivate(*this, c, u, valuation)),
       d_notify(*d_internal.get())
 {
-  // Do not move me to the header.
-  // The constructor + destructor are not in the header as d_internal is a
-  // unique_ptr<TheorySetsPrivate> and TheorySetsPrivate is an opaque type in
-  // the header (Pimpl). See https://herbsutter.com/gotw/_100/ .
+  // use the state object as the official theory state
+  d_theoryState = d_internal->getSolverState();
 }
 
 TheorySets::~TheorySets()
index bb94235706cdcf9fc9b135a4475b02653eb8ca72..879862d151fc71741ca169febbdbf7de2a07da08 100644 (file)
@@ -36,14 +36,15 @@ namespace sets {
 
 TheorySetsPrivate::TheorySetsPrivate(TheorySets& external,
                                      context::Context* c,
-                                     context::UserContext* u)
+                                     context::UserContext* u,
+                                     Valuation valuation)
     : d_members(c),
       d_deq(c),
       d_termProcessed(u),
       d_keep(c),
       d_full_check_incomplete(false),
       d_external(external),
-      d_state(*this, c, u),
+      d_state(*this, c, u, valuation),
       d_im(*this, d_state, c, u),
       d_rels(new TheorySetsRels(d_state, d_im, u)),
       d_cardSolver(new CardinalityExtension(d_state, d_im, c, u)),
@@ -67,7 +68,6 @@ void TheorySetsPrivate::finishInit()
 {
   d_equalityEngine = d_external.getEqualityEngine();
   Assert(d_equalityEngine != nullptr);
-  d_state.finishInit(d_equalityEngine);
 }
 
 void TheorySetsPrivate::eqNotifyNewClass(TNode t)
@@ -178,7 +178,7 @@ void TheorySetsPrivate::eqNotifyMerge(TNode t1, TNode t2)
               // conflict
               Trace("sets-prop")
                   << "Propagate eq-mem conflict : " << exp << std::endl;
-              d_state.setConflict(exp);
+              d_im.conflict(exp);
               return;
             }
           }
@@ -316,7 +316,7 @@ bool TheorySetsPrivate::assertFact(Node fact, Node exp)
             {
               Trace("sets-prop")
                   << "Propagate mem-eq conflict : " << pexp << std::endl;
-              d_state.setConflict(pexp);
+              d_im.conflict(pexp);
             }
           }
         }
@@ -1410,7 +1410,7 @@ bool TheorySetsPrivate::propagate(TNode literal)
   bool ok = d_external.d_out->propagate(literal);
   if (!ok)
   {
-    d_state.setConflict();
+    d_state.notifyInConflict();
   }
 
   return ok;
@@ -1426,7 +1426,7 @@ Valuation& TheorySetsPrivate::getValuation() { return d_external.d_valuation; }
 void TheorySetsPrivate::conflict(TNode a, TNode b)
 {
   Node conf = explain(a.eqNode(b));
-  d_state.setConflict(conf);
+  d_im.conflict(conf);
   Debug("sets") << "[sets] conflict: " << a << " iff " << b << ", explanation "
                 << conf << std::endl;
   Trace("sets-lemma") << "Equality Conflict : " << conf << std::endl;
index 27ea6a9b876ab6d694ccffd43883b37a91dd6c5a..9a786598cee71f55571bc104ebd218f299271793 100644 (file)
@@ -156,12 +156,16 @@ class TheorySetsPrivate {
    */
   TheorySetsPrivate(TheorySets& external,
                     context::Context* c,
-                    context::UserContext* u);
+                    context::UserContext* u,
+                    Valuation valuation);
 
   ~TheorySetsPrivate();
 
   TheoryRewriter* getTheoryRewriter() { return &d_rewriter; }
 
+  /** Get the solver state */
+  SolverState* getSolverState() { return &d_state; }
+
   /**
    * Finish initialize, called after the equality engine of theory sets has
    * been determined.
index 88cf6d958ab1131d7e496c5f6c250240c1adcbec..a8ebd921a548189d4a74ec85b1ce522979921a20 100644 (file)
@@ -171,7 +171,7 @@ void InferenceManager::sendInference(const InferInfo& ii, bool asLemma)
       // only keep stats if we process it here
       d_statistics.d_inferences << ii.d_id;
       d_out.conflict(conf);
-      d_state.setConflict();
+      d_state.notifyInConflict();
       return;
     }
     Trace("strings-infer-debug") << "...as lemma" << std::endl;
@@ -435,7 +435,7 @@ void InferenceManager::assertPendingFact(Node atom, bool polarity, Node exp)
       Trace("strings-pending")
           << "Process pending conflict " << pc << std::endl;
       Node conflictNode = mkExplain(a);
-      d_state.setConflict();
+      d_state.notifyInConflict();
       Trace("strings-conflict")
           << "CONFLICT: Eager prefix : " << conflictNode << std::endl;
       ++(d_statistics.d_conflictsEagerPrefix);
index 8634478fdac8eb275db7a913efb9af0c98e3e8d6..fd0f0174f9952fac8c293823448b947ae2644fe0 100644 (file)
@@ -28,13 +28,7 @@ namespace strings {
 SolverState::SolverState(context::Context* c,
                          context::UserContext* u,
                          Valuation& v)
-    : d_context(c),
-      d_ucontext(u),
-      d_ee(nullptr),
-      d_eeDisequalities(c),
-      d_valuation(v),
-      d_conflict(c, false),
-      d_pendingConflict(c)
+    : TheoryState(c, u, v), d_eeDisequalities(c), d_pendingConflict(c)
 {
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
 }
@@ -47,59 +41,6 @@ SolverState::~SolverState()
   }
 }
 
-void SolverState::finishInit(eq::EqualityEngine* ee)
-{
-  Assert(ee != nullptr);
-  d_ee = ee;
-}
-
-context::Context* SolverState::getSatContext() const { return d_context; }
-context::UserContext* SolverState::getUserContext() const { return d_ucontext; }
-
-Node SolverState::getRepresentative(Node t) const
-{
-  if (d_ee->hasTerm(t))
-  {
-    return d_ee->getRepresentative(t);
-  }
-  return t;
-}
-
-bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); }
-
-bool SolverState::areEqual(Node a, Node b) const
-{
-  if (a == b)
-  {
-    return true;
-  }
-  else if (hasTerm(a) && hasTerm(b))
-  {
-    return d_ee->areEqual(a, b);
-  }
-  return false;
-}
-
-bool SolverState::areDisequal(Node a, Node b) const
-{
-  if (a == b)
-  {
-    return false;
-  }
-  else if (hasTerm(a) && hasTerm(b))
-  {
-    Node ar = d_ee->getRepresentative(a);
-    Node br = d_ee->getRepresentative(b);
-    return (ar != br && ar.isConst() && br.isConst())
-           || d_ee->areDisequal(ar, br, false);
-  }
-  Node ar = getRepresentative(a);
-  Node br = getRepresentative(b);
-  return ar != br && ar.isConst() && br.isConst();
-}
-
-eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; }
-
 const context::CDList<Node>& SolverState::getDisequalityList() const
 {
   return d_eeDisequalities;
@@ -199,7 +140,7 @@ EqcInfo* SolverState::getOrMakeEqcInfo(Node eqc, bool doMake)
   return nullptr;
 }
 
-TheoryModel* SolverState::getModel() const { return d_valuation.getModel(); }
+TheoryModel* SolverState::getModel() { return d_valuation.getModel(); }
 
 void SolverState::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
 {
@@ -286,9 +227,6 @@ bool SolverState::isEqualEmptyWord(Node s, Node& emps)
   return false;
 }
 
-void SolverState::setConflict() { d_conflict = true; }
-bool SolverState::isInConflict() const { return d_conflict; }
-
 void SolverState::setPendingConflictWhen(Node conf)
 {
   if (!conf.isNull() && d_pendingConflict.get().isNull())
index 0322abdb74f1be5535daf619ddaab8956fd4498c..fc27b847b5383270d33ed15a209235cea3f272e8 100644 (file)
@@ -39,7 +39,7 @@ namespace strings {
  * (2) Whether the set of assertions is in conflict.
  * (3) Equivalence class information as in the class above.
  */
-class SolverState
+class SolverState : public TheoryState
 {
   typedef context::CDList<Node> NodeList;
 
@@ -48,35 +48,7 @@ class SolverState
               context::UserContext* u,
               Valuation& v);
   ~SolverState();
-  /**
-   * Finish initialize, ee is a pointer to the official equality engine
-   * of theory of strings.
-   */
-  void finishInit(eq::EqualityEngine* ee);
-  /** Get the SAT context */
-  context::Context* getSatContext() const;
-  /** Get the user context */
-  context::UserContext* getUserContext() const;
   //-------------------------------------- equality information
-  /**
-   * Get the representative of t in the equality engine of this class, or t
-   * itself if it is not registered as a term.
-   */
-  Node getRepresentative(Node t) const;
-  /** Is t registered as a term in the equality engine of this class? */
-  bool hasTerm(Node a) const;
-  /**
-   * Are a and b equal according to the equality engine of this class? Also
-   * returns true if a and b are identical.
-   */
-  bool areEqual(Node a, Node b) const;
-  /**
-   * Are a and b disequal according to the equality engine of this class? Also
-   * returns true if the representative of a and b are distinct constants.
-   */
-  bool areDisequal(Node a, Node b) const;
-  /** get equality engine */
-  eq::EqualityEngine* getEqualityEngine() const;
   /**
    * Get the list of disequalities that are currently asserted to the equality
    * engine.
@@ -92,14 +64,6 @@ class SolverState
   void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
   //-------------------------------------- end notifications for equalities
   //------------------------------------------ conflicts
-  /**
-   * Set that the current state of the solver is in conflict. This should be
-   * called immediately after a call to conflict(...) on the output channel of
-   * the theory of strings.
-   */
-  void setConflict();
-  /** Are we currently in conflict? */
-  bool isInConflict() const;
   /** set pending conflict
    *
    * If conf is non-null, this is called when conf is a conjunction of literals
@@ -153,7 +117,7 @@ class SolverState
    */
   EqcInfo* getOrMakeEqcInfo(Node eqc, bool doMake = true);
   /** Get pointer to the model object of the Valuation object */
-  TheoryModel* getModel() const;
+  TheoryModel* getModel();
 
   /** add endpoints to eqc info
    *
@@ -186,21 +150,11 @@ class SolverState
  private:
   /** Common constants */
   Node d_zero;
-  /** Pointer to the SAT context object used by the theory of strings. */
-  context::Context* d_context;
-  /** Pointer to the user context object used by the theory of strings. */
-  context::UserContext* d_ucontext;
-  /** Pointer to equality engine of the theory of strings. */
-  eq::EqualityEngine* d_ee;
   /**
    * The (SAT-context-dependent) list of disequalities that have been asserted
    * to the equality engine above.
    */
   NodeList d_eeDisequalities;
-  /** Reference to the valuation of the theory of strings */
-  Valuation& d_valuation;
-  /** Are we in conflict? */
-  context::CDO<bool> d_conflict;
   /** The pending conflict if one exists */
   context::CDO<Node> d_pendingConflict;
   /** Map from representatives to their equivalence class information */
index c78e8dc2aaec04f3a1bfc36f13d48eac7312c652..6d81c742a1b3de5ef21aaaddad448529cdac2610 100644 (file)
@@ -81,6 +81,8 @@ TheoryStrings::TheoryStrings(context::Context* c,
     // add checkers
     d_sProofChecker.registerTo(pc);
   }
+  // use the state object as the official theory state
+  d_theoryState = &d_state;
 }
 
 TheoryStrings::~TheoryStrings() {
@@ -126,8 +128,6 @@ void TheoryStrings::finishInit()
   d_equalityEngine->addFunctionKind(kind::STRING_TOLOWER, eagerEval);
   d_equalityEngine->addFunctionKind(kind::STRING_TOUPPER, eagerEval);
   d_equalityEngine->addFunctionKind(kind::STRING_REV, eagerEval);
-
-  d_state.finishInit(d_equalityEngine);
 }
 
 std::string TheoryStrings::identify() const
@@ -196,7 +196,7 @@ bool TheoryStrings::propagate(TNode literal) {
   // Propagate out
   bool ok = d_out->propagate(literal);
   if (!ok) {
-    d_state.setConflict();
+    d_state.notifyInConflict();
   }
   return ok;
 }
@@ -762,7 +762,7 @@ void TheoryStrings::conflict(TNode a, TNode b){
   if (!d_state.isInConflict())
   {
     Debug("strings-conflict") << "Making conflict..." << std::endl;
-    d_state.setConflict();
+    d_state.notifyInConflict();
     TrustNode conflictNode = explain(a.eqNode(b));
     Trace("strings-conflict")
         << "CONFLICT: Eq engine conflict : " << conflictNode.getNode()
index 9669d97e01af2e6dc4ff25142da38b98bc3f5225..7220e2e1c88e66a24ac47e357a74416220bfa90c 100644 (file)
@@ -104,6 +104,10 @@ void Theory::setEqualityEngine(eq::EqualityEngine* ee)
 {
   // set the equality engine pointer
   d_equalityEngine = ee;
+  if (d_theoryState != nullptr)
+  {
+    d_theoryState->setEqualityEngine(ee);
+  }
 }
 void Theory::setQuantifiersEngine(QuantifiersEngine* qe)
 {
@@ -127,7 +131,7 @@ void Theory::finishInitStandalone()
     d_allocEqualityEngine.reset(new eq::EqualityEngine(
         *esi.d_notify, d_satContext, esi.d_name, esi.d_constantsAreTriggers));
     // use it as the official equality engine
-    d_equalityEngine = d_allocEqualityEngine.get();
+    setEqualityEngine(d_allocEqualityEngine.get());
   }
   finishInit();
 }
index bc8e53245ccf84aca9624a5b41ef51ce8bc40b17..8afe3be96b33d831c05430b6724d1a9c4470362d 100644 (file)
@@ -30,7 +30,7 @@ TheoryState::TheoryState(context::Context* c,
 {
 }
 
-void TheoryState::finishInit(eq::EqualityEngine* ee) { d_ee = ee; }
+void TheoryState::setEqualityEngine(eq::EqualityEngine* ee) { d_ee = ee; }
 
 context::Context* TheoryState::getSatContext() const { return d_context; }
 
index 71197dddc3f2fd6f4bd8100fa1e94ac73746b8b6..de6e6d47790834b42811a2c94ce5af768bdc3492 100644 (file)
@@ -34,10 +34,10 @@ class TheoryState
   TheoryState(context::Context* c, context::UserContext* u, Valuation val);
   virtual ~TheoryState() {}
   /**
-   * Finish initialize, ee is a pointer to the official equality engine
+   * Set equality engine, where ee is a pointer to the official equality engine
    * of theory.
    */
-  virtual void finishInit(eq::EqualityEngine* ee);
+  void setEqualityEngine(eq::EqualityEngine* ee);
   /** Get the SAT context */
   context::Context* getSatContext() const;
   /** Get the user context */