Split eager solver from strings solver state (#5775)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 13 Jan 2021 22:01:53 +0000 (16:01 -0600)
committerGitHub <noreply@github.com>
Wed, 13 Jan 2021 22:01:53 +0000 (16:01 -0600)
This splits the eager solver from solver state. The solver state contains the EqcInfo data, while the eager solver is responsible for populating it.

This is in preparation for adding new techniques to the eager solver.

No behavior changes in this PR, only reorganization.

src/CMakeLists.txt
src/theory/strings/eager_solver.cpp [new file with mode: 0644]
src/theory/strings/eager_solver.h [new file with mode: 0644]
src/theory/strings/solver_state.cpp
src/theory/strings/solver_state.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index 15d39f06ad4eb3762f5c0c30a77715fe8c33e2da..6f5647e9ab6fe78907219c152c976fbd64ccd750 100644 (file)
@@ -872,6 +872,8 @@ libcvc4_add_sources(
   theory/strings/base_solver.h
   theory/strings/core_solver.cpp
   theory/strings/core_solver.h
+  theory/strings/eager_solver.cpp
+  theory/strings/eager_solver.h
   theory/strings/extf_solver.cpp
   theory/strings/extf_solver.h
   theory/strings/eqc_info.cpp
diff --git a/src/theory/strings/eager_solver.cpp b/src/theory/strings/eager_solver.cpp
new file mode 100644 (file)
index 0000000..d67b390
--- /dev/null
@@ -0,0 +1,151 @@
+/*********************                                                        */
+/*! \file eager_solver.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The eager solver
+ **/
+
+#include "theory/strings/eager_solver.h"
+
+#include "theory/strings/theory_strings_utils.h"
+
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace theory {
+namespace strings {
+
+EagerSolver::EagerSolver(SolverState& state) : d_state(state) {}
+
+EagerSolver::~EagerSolver() {}
+
+void EagerSolver::eqNotifyNewClass(TNode t)
+{
+  Kind k = t.getKind();
+  if (k == STRING_LENGTH || k == STRING_TO_CODE)
+  {
+    eq::EqualityEngine* ee = d_state.getEqualityEngine();
+    Node r = ee->getRepresentative(t[0]);
+    EqcInfo* ei = d_state.getOrMakeEqcInfo(r);
+    if (k == STRING_LENGTH)
+    {
+      ei->d_lengthTerm = t[0];
+    }
+    else
+    {
+      ei->d_codeTerm = t[0];
+    }
+  }
+  else if (t.isConst())
+  {
+    if (t.getType().isStringLike())
+    {
+      EqcInfo* ei = d_state.getOrMakeEqcInfo(t);
+      ei->d_prefixC = t;
+      ei->d_suffixC = t;
+    }
+  }
+  else if (k == STRING_CONCAT)
+  {
+    addEndpointsToEqcInfo(t, t, t);
+  }
+}
+
+void EagerSolver::eqNotifyMerge(TNode t1, TNode t2)
+{
+  EqcInfo* e2 = d_state.getOrMakeEqcInfo(t2, false);
+  if (e2 == nullptr)
+  {
+    return;
+  }
+  Assert(t1.getType().isStringLike());
+  EqcInfo* e1 = d_state.getOrMakeEqcInfo(t1);
+  // add information from e2 to e1
+  if (!e2->d_lengthTerm.get().isNull())
+  {
+    e1->d_lengthTerm.set(e2->d_lengthTerm);
+  }
+  if (!e2->d_codeTerm.get().isNull())
+  {
+    e1->d_codeTerm.set(e2->d_codeTerm);
+  }
+  if (!e2->d_prefixC.get().isNull())
+  {
+    d_state.setPendingPrefixConflictWhen(
+        e1->addEndpointConst(e2->d_prefixC, Node::null(), false));
+  }
+  if (!e2->d_suffixC.get().isNull())
+  {
+    d_state.setPendingPrefixConflictWhen(
+        e1->addEndpointConst(e2->d_suffixC, Node::null(), true));
+  }
+  if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get())
+  {
+    e1->d_cardinalityLemK.set(e2->d_cardinalityLemK);
+  }
+  if (!e2->d_normalizedLength.get().isNull())
+  {
+    e1->d_normalizedLength.set(e2->d_normalizedLength);
+  }
+}
+
+void EagerSolver::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
+{
+  if (t1.getType().isStringLike())
+  {
+    // store disequalities between strings, may need to check if their lengths
+    // are equal/disequal
+    d_state.addDisequality(t1, t2);
+  }
+}
+
+void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
+{
+  Assert(concat.getKind() == STRING_CONCAT
+         || concat.getKind() == REGEXP_CONCAT);
+  EqcInfo* ei = nullptr;
+  // check each side
+  for (unsigned r = 0; r < 2; r++)
+  {
+    unsigned index = r == 0 ? 0 : concat.getNumChildren() - 1;
+    Node c = utils::getConstantComponent(concat[index]);
+    if (!c.isNull())
+    {
+      if (ei == nullptr)
+      {
+        ei = d_state.getOrMakeEqcInfo(eqc);
+      }
+      Trace("strings-eager-pconf-debug")
+          << "New term: " << concat << " for " << t << " with prefix " << c
+          << " (" << (r == 1) << ")" << std::endl;
+      d_state.setPendingPrefixConflictWhen(ei->addEndpointConst(t, c, r == 1));
+    }
+  }
+}
+
+void EagerSolver::notifyFact(TNode atom,
+                             bool polarity,
+                             TNode fact,
+                             bool isInternal)
+{
+  if (atom.getKind() == STRING_IN_REGEXP)
+  {
+    if (polarity && atom[1].getKind() == REGEXP_CONCAT)
+    {
+      eq::EqualityEngine* ee = d_state.getEqualityEngine();
+      Node eqc = ee->getRepresentative(atom[0]);
+      addEndpointsToEqcInfo(atom, atom[1], eqc);
+    }
+  }
+}
+
+}  // namespace strings
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/strings/eager_solver.h b/src/theory/strings/eager_solver.h
new file mode 100644 (file)
index 0000000..ee520f7
--- /dev/null
@@ -0,0 +1,68 @@
+/*********************                                                        */
+/*! \file eager_solver.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The eager solver
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__STRINGS__EAGER_SOLVER_H
+#define CVC4__THEORY__STRINGS__EAGER_SOLVER_H
+
+#include <map>
+
+#include "expr/node.h"
+#include "theory/strings/eqc_info.h"
+#include "theory/strings/solver_state.h"
+
+namespace CVC4 {
+namespace theory {
+namespace strings {
+
+/**
+ * Eager solver, which is responsible for tracking of eager information and
+ * reporting conflicts to the solver state.
+ */
+class EagerSolver
+{
+ public:
+  EagerSolver(SolverState& state);
+  ~EagerSolver();
+  /** called when a new equivalence class is created */
+  void eqNotifyNewClass(TNode t);
+  /** called when two equivalence classes merge */
+  void eqNotifyMerge(TNode t1, TNode t2);
+  /** called when two equivalence classes are made disequal */
+  void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
+  /** notify fact, called when a fact is asserted to theory of strings */
+  void notifyFact(TNode atom, bool polarity, TNode fact, bool isInternal);
+
+ private:
+  /** add endpoints to eqc info
+   *
+   * This method is called when term t is the explanation for why equivalence
+   * class eqc may have a constant endpoint due to a concatentation term concat.
+   * For example, we may call this method on:
+   *   t := (str.++ x y), concat := (str.++ x y), eqc
+   * for some eqc that is currently equal to t. Another example is:
+   *   t := (str.in.re z (re.++ r s)), concat := (re.++ r s), eqc
+   * for some eqc that is currently equal to z.
+   */
+  void addEndpointsToEqcInfo(Node t, Node concat, Node eqc);
+  /** Reference to the solver state */
+  SolverState& d_state;
+};
+
+}  // namespace strings
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__STRINGS__EAGER_SOLVER_H */
index d9fac671884229385bd399d6b55f0e4ccda0c3ce..f341e681d315cbe4a7e48bf0746fa4d91b2bd885 100644 (file)
@@ -47,82 +47,9 @@ const context::CDList<Node>& SolverState::getDisequalityList() const
   return d_eeDisequalities;
 }
 
-void SolverState::eqNotifyNewClass(TNode t)
+void SolverState::addDisequality(TNode t1, TNode t2)
 {
-  Kind k = t.getKind();
-  if (k == STRING_LENGTH || k == STRING_TO_CODE)
-  {
-    Node r = d_ee->getRepresentative(t[0]);
-    EqcInfo* ei = getOrMakeEqcInfo(r);
-    if (k == STRING_LENGTH)
-    {
-      ei->d_lengthTerm = t[0];
-    }
-    else
-    {
-      ei->d_codeTerm = t[0];
-    }
-  }
-  else if (t.isConst())
-  {
-    if (t.getType().isStringLike())
-    {
-      EqcInfo* ei = getOrMakeEqcInfo(t);
-      ei->d_prefixC = t;
-      ei->d_suffixC = t;
-    }
-  }
-  else if (k == STRING_CONCAT)
-  {
-    addEndpointsToEqcInfo(t, t, t);
-  }
-}
-
-void SolverState::eqNotifyMerge(TNode t1, TNode t2)
-{
-  EqcInfo* e2 = getOrMakeEqcInfo(t2, false);
-  if (e2)
-  {
-    Assert(t1.getType().isStringLike());
-    EqcInfo* e1 = getOrMakeEqcInfo(t1);
-    // add information from e2 to e1
-    if (!e2->d_lengthTerm.get().isNull())
-    {
-      e1->d_lengthTerm.set(e2->d_lengthTerm);
-    }
-    if (!e2->d_codeTerm.get().isNull())
-    {
-      e1->d_codeTerm.set(e2->d_codeTerm);
-    }
-    if (!e2->d_prefixC.get().isNull())
-    {
-      setPendingPrefixConflictWhen(
-          e1->addEndpointConst(e2->d_prefixC, Node::null(), false));
-    }
-    if (!e2->d_suffixC.get().isNull())
-    {
-      setPendingPrefixConflictWhen(
-          e1->addEndpointConst(e2->d_suffixC, Node::null(), true));
-    }
-    if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get())
-    {
-      e1->d_cardinalityLemK.set(e2->d_cardinalityLemK);
-    }
-    if (!e2->d_normalizedLength.get().isNull())
-    {
-      e1->d_normalizedLength.set(e2->d_normalizedLength);
-    }
-  }
-}
-
-void SolverState::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
-{
-  if (t1.getType().isStringLike())
-  {
-    // store disequalities between strings, may need to check if their lengths
-    // are equal/disequal
-    d_eeDisequalities.push_back(t1.eqNode(t2));
-  }
+  d_eeDisequalities.push_back(t1.eqNode(t2));
 }
 
 EqcInfo* SolverState::getOrMakeEqcInfo(Node eqc, bool doMake)
@@ -143,30 +70,6 @@ EqcInfo* SolverState::getOrMakeEqcInfo(Node eqc, bool doMake)
 
 TheoryModel* SolverState::getModel() { return d_valuation.getModel(); }
 
-void SolverState::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
-{
-  Assert(concat.getKind() == STRING_CONCAT
-         || concat.getKind() == REGEXP_CONCAT);
-  EqcInfo* ei = nullptr;
-  // check each side
-  for (unsigned r = 0; r < 2; r++)
-  {
-    unsigned index = r == 0 ? 0 : concat.getNumChildren() - 1;
-    Node c = utils::getConstantComponent(concat[index]);
-    if (!c.isNull())
-    {
-      if (ei == nullptr)
-      {
-        ei = getOrMakeEqcInfo(eqc);
-      }
-      Trace("strings-eager-pconf-debug")
-          << "New term: " << concat << " for " << t << " with prefix " << c
-          << " (" << (r == 1) << ")" << std::endl;
-      setPendingPrefixConflictWhen(ei->addEndpointConst(t, c, r == 1));
-    }
-  }
-}
-
 Node SolverState::getLengthExp(Node t, std::vector<Node>& exp, Node te)
 {
   Assert(areEqual(t, te));
index 291a15feb5a06bacd8ec71b2515be5c57b06d908..ee8cdb8d79eff2abb9c23f790667d1421d952f86 100644 (file)
@@ -51,21 +51,18 @@ class SolverState : public TheoryState
               context::UserContext* u,
               Valuation& v);
   ~SolverState();
-  //-------------------------------------- equality information
+  //-------------------------------------- disequality information
   /**
    * Get the list of disequalities that are currently asserted to the equality
    * engine.
    */
   const context::CDList<Node>& getDisequalityList() const;
-  //-------------------------------------- end equality information
-  //-------------------------------------- notifications for equalities
-  /** called when a new equivalence class is created */
-  void eqNotifyNewClass(TNode t);
-  /** called when two equivalence classes merge */
-  void eqNotifyMerge(TNode t1, TNode t2);
-  /** called when two equivalence classes are made disequal */
-  void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
-  //-------------------------------------- end notifications for equalities
+  /**
+   * notify the state that disequality (not (= t1 t2)) holds in the current
+   * context. This will be included in the return of the above method.
+   */
+  void addDisequality(TNode t1, TNode t2);
+  //-------------------------------------- end disequality information
   //------------------------------------------ conflicts
   /** set pending prefix conflict
    *
@@ -128,18 +125,6 @@ class SolverState : public TheoryState
   EqcInfo* getOrMakeEqcInfo(Node eqc, bool doMake = true);
   /** Get pointer to the model object of the Valuation object */
   TheoryModel* getModel();
-
-  /** add endpoints to eqc info
-   *
-   * This method is called when term t is the explanation for why equivalence
-   * class eqc may have a constant endpoint due to a concatentation term concat.
-   * For example, we may call this method on:
-   *   t := (str.++ x y), concat := (str.++ x y), eqc
-   * for some eqc that is currently equal to t. Another example is:
-   *   t := (str.in.re z (re.++ r s)), concat := (re.++ r s), eqc
-   * for some eqc that is currently equal to z.
-   */
-  void addEndpointsToEqcInfo(Node t, Node concat, Node eqc);
   /** Entailment check
    *
    * This calls entailmentCheck on the Valuation object of theory of strings.
@@ -172,7 +157,7 @@ class SolverState : public TheoryState
   InferInfo d_pendingConflict;
   /** Map from representatives to their equivalence class information */
   std::map<Node, EqcInfo*> d_eqcInfo;
-}; /* class TheoryStrings */
+};
 
 }  // namespace strings
 }  // namespace theory
index dbe03afee79782e9825d4b10a30ada68c2f9b209..3189b297a319f784b38dcd836fe9ee3a78d785cd 100644 (file)
@@ -45,6 +45,7 @@ TheoryStrings::TheoryStrings(context::Context* c,
       d_notify(*this),
       d_statistics(),
       d_state(c, u, d_valuation),
+      d_eagerSolver(d_state),
       d_termReg(d_state, out, d_statistics, pnm),
       d_extTheoryCb(),
       d_extTheory(d_extTheoryCb, c, u, out),
@@ -617,14 +618,7 @@ void TheoryStrings::notifyFact(TNode atom,
                                TNode fact,
                                bool isInternal)
 {
-  if (atom.getKind() == STRING_IN_REGEXP)
-  {
-    if (polarity && atom[1].getKind() == REGEXP_CONCAT)
-    {
-      Node eqc = d_equalityEngine->getRepresentative(atom[0]);
-      d_state.addEndpointsToEqcInfo(atom, atom[1], eqc);
-    }
-  }
+  d_eagerSolver.notifyFact(atom, polarity, fact, isInternal);
   // process pending conflicts due to reasoning about endpoints
   if (!d_state.isInConflict() && d_state.hasPendingConflict())
   {
@@ -765,7 +759,7 @@ void TheoryStrings::eqNotifyNewClass(TNode t){
     //we care about the length of this string
     d_termReg.registerTerm(t[0], 1);
   }
-  d_state.eqNotifyNewClass(t);
+  d_eagerSolver.eqNotifyNewClass(t);
 }
 
 void TheoryStrings::addCarePairs(TNodeTrie* t1,
index 13b5a5ebaab2f8659e01e4ea0d6057ea9423dba5..ebded2aec9bbbf30cbb7fb9e8f5f89cf09c672ef 100644 (file)
@@ -28,6 +28,7 @@
 #include "theory/ext_theory.h"
 #include "theory/strings/base_solver.h"
 #include "theory/strings/core_solver.h"
+#include "theory/strings/eager_solver.h"
 #include "theory/strings/extf_solver.h"
 #include "theory/strings/infer_info.h"
 #include "theory/strings/inference_manager.h"
@@ -124,14 +125,18 @@ class TheoryStrings : public Theory {
   /** NotifyClass for equality engine */
   class NotifyClass : public eq::EqualityEngineNotify {
   public:
-   NotifyClass(TheoryStrings& ts) : d_str(ts), d_state(ts.d_state) {}
-    bool eqNotifyTriggerPredicate(TNode predicate, bool value) override
-    {
-      Debug("strings") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate << ", " << (value ? "true" : "false") << ")" << std::endl;
-      if (value) {
-        return d_str.propagateLit(predicate);
-      }
-      return d_str.propagateLit(predicate.notNode());
+   NotifyClass(TheoryStrings& ts) : d_str(ts), d_eagerSolver(ts.d_eagerSolver)
+   {
+   }
+   bool eqNotifyTriggerPredicate(TNode predicate, bool value) override
+   {
+     Debug("strings") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate
+                      << ", " << (value ? "true" : "false") << ")" << std::endl;
+     if (value)
+     {
+       return d_str.propagateLit(predicate);
+     }
+     return d_str.propagateLit(predicate.notNode());
     }
     bool eqNotifyTriggerTermEquality(TheoryId tag,
                                      TNode t1,
@@ -158,19 +163,19 @@ class TheoryStrings : public Theory {
     {
       Debug("strings") << "NotifyClass::eqNotifyMerge(" << t1 << ", " << t2
                        << std::endl;
-      d_state.eqNotifyMerge(t1, t2);
+      d_eagerSolver.eqNotifyMerge(t1, t2);
     }
     void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override
     {
       Debug("strings") << "NotifyClass::eqNotifyDisequal(" << t1 << ", " << t2 << ", " << reason << std::endl;
-      d_state.eqNotifyDisequal(t1, t2, reason);
+      d_eagerSolver.eqNotifyDisequal(t1, t2, reason);
     }
 
    private:
     /** The theory of strings object to notify */
     TheoryStrings& d_str;
-    /** The solver state of the theory of strings */
-    SolverState& d_state;
+    /** The eager solver of the theory of strings */
+    EagerSolver& d_eagerSolver;
   };/* class TheoryStrings::NotifyClass */
   /** compute care graph */
   void computeCareGraph() override;
@@ -255,6 +260,8 @@ class TheoryStrings : public Theory {
   SequencesStatistics d_statistics;
   /** The solver state object */
   SolverState d_state;
+  /** The eager solver */
+  EagerSolver d_eagerSolver;
   /** The term registry for this theory */
   TermRegistry d_termReg;
   /** The extended theory callback */