[Strings] Minor refactor of eager solver (#7628)
authorAndres Noetzli <andres.noetzli@gmail.com>
Mon, 15 Nov 2021 15:49:26 +0000 (07:49 -0800)
committerGitHub <noreply@github.com>
Mon, 15 Nov 2021 15:49:26 +0000 (15:49 +0000)
This moves code that is not strictly related to the eager solver out of
the eager solver and into TheoryStrings. This is cleaner and makes it
easier to enable/disable the eager solver.

src/theory/strings/eager_solver.cpp
src/theory/strings/eager_solver.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index 21fdd6fa281696705206d45b33d408f7fb7f13d7..ce47ff4fc4673f5631e8fc43dfa1b285aab5ebae 100644 (file)
@@ -37,41 +37,30 @@ EagerSolver::~EagerSolver() {}
 void EagerSolver::eqNotifyNewClass(TNode t)
 {
   Kind k = t.getKind();
-  if (k == STRING_LENGTH || k == STRING_TO_CODE)
+  if (k == STRING_LENGTH)
   {
-    eq::EqualityEngine* ee = d_state.getEqualityEngine();
-    Node r = ee->getRepresentative(t[0]);
-    EqcInfo* ei = d_state.getOrMakeEqcInfo(r);
-    if (k == STRING_LENGTH)
+    // also assume it as upper/lower bound as applicable for the equivalence
+    // class info of t.
+    EqcInfo* eil = nullptr;
+    for (size_t i = 0; i < 2; i++)
     {
-      ei->d_lengthTerm = t;
-      // also assume it as upper/lower bound as applicable for the equivalence
-      // class info of t.
-      EqcInfo* eil = nullptr;
-      for (size_t i = 0; i < 2; i++)
+      Node b = getBoundForLength(t, i == 0);
+      if (b.isNull())
       {
-        Node b = getBoundForLength(t, i == 0);
-        if (b.isNull())
-        {
-          continue;
-        }
-        if (eil == nullptr)
-        {
-          eil = d_state.getOrMakeEqcInfo(t);
-        }
-        if (i == 0)
-        {
-          eil->d_firstBound = t;
-        }
-        else if (i == 1)
-        {
-          eil->d_secondBound = t;
-        }
+        continue;
+      }
+      if (eil == nullptr)
+      {
+        eil = d_state.getOrMakeEqcInfo(t);
+      }
+      if (i == 0)
+      {
+        eil->d_firstBound = t;
+      }
+      else if (i == 1)
+      {
+        eil->d_secondBound = t;
       }
-    }
-    else
-    {
-      ei->d_codeTerm = t[0];
     }
   }
   else if (t.isConst())
@@ -90,15 +79,10 @@ void EagerSolver::eqNotifyNewClass(TNode t)
   }
 }
 
-void EagerSolver::eqNotifyMerge(TNode t1, TNode t2)
+void EagerSolver::eqNotifyMerge(EqcInfo* e1, TNode t1, EqcInfo* e2, TNode t2)
 {
-  EqcInfo* e2 = d_state.getOrMakeEqcInfo(t2, false);
-  if (e2 == nullptr)
-  {
-    return;
-  }
-  // always create it if e2 was non-null
-  EqcInfo* e1 = d_state.getOrMakeEqcInfo(t1);
+  Assert(e1 != nullptr);
+  Assert(e2 != nullptr);
   // check for conflict
   Node conf = checkForMergeConflict(t1, t2, e1, e2);
   if (!conf.isNull())
@@ -109,33 +93,6 @@ void EagerSolver::eqNotifyMerge(TNode t1, TNode t2)
     d_state.setPendingMergeConflict(conf, id);
     return;
   }
-  // add information from e2 to e1
-  if (!e2->d_lengthTerm.get().isNull())
-  {
-    e1->d_lengthTerm.set(e2->d_lengthTerm);
-  }
-  if (!e2->d_codeTerm.get().isNull())
-  {
-    e1->d_codeTerm.set(e2->d_codeTerm);
-  }
-  if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get())
-  {
-    e1->d_cardinalityLemK.set(e2->d_cardinalityLemK);
-  }
-  if (!e2->d_normalizedLength.get().isNull())
-  {
-    e1->d_normalizedLength.set(e2->d_normalizedLength);
-  }
-}
-
-void EagerSolver::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
-{
-  if (t1.getType().isStringLike())
-  {
-    // store disequalities between strings, may need to check if their lengths
-    // are equal/disequal
-    d_state.addDisequality(t1, t2);
-  }
 }
 
 void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
index 03fb0ff63d31316be6b25bbf467a41330be487f2..4181a15c366fee70b2052aa4b5bde04687249c2d 100644 (file)
@@ -46,9 +46,7 @@ class EagerSolver : protected EnvObj
   /** called when a new equivalence class is created */
   void eqNotifyNewClass(TNode t);
   /** called when two equivalence classes merge */
-  void eqNotifyMerge(TNode t1, TNode t2);
-  /** called when two equivalence classes are made disequal */
-  void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
+  void eqNotifyMerge(EqcInfo* e1, TNode t1, EqcInfo* e2, TNode t2);
   /** notify fact, called when a fact is asserted to theory of strings */
   void notifyFact(TNode atom, bool polarity, TNode fact, bool isInternal);
 
index ed00758a85730e41a47370ce96afaacfb156b08d..caeb8065e2e3ca552be5fe80ea64421069d203df 100644 (file)
@@ -749,10 +749,63 @@ void TheoryStrings::eqNotifyNewClass(TNode t){
     Trace("strings-debug") << "New length eqc : " << t << std::endl;
     //we care about the length of this string
     d_termReg.registerTerm(t[0], 1);
+
+    eq::EqualityEngine* ee = d_state.getEqualityEngine();
+    Node r = ee->getRepresentative(t[0]);
+    EqcInfo* ei = d_state.getOrMakeEqcInfo(r);
+    if (k == STRING_LENGTH)
+    {
+      ei->d_lengthTerm = t;
+    }
+    else
+    {
+      ei->d_codeTerm = t[0];
+    }
   }
   d_eagerSolver.eqNotifyNewClass(t);
 }
 
+void TheoryStrings::eqNotifyMerge(TNode t1, TNode t2)
+{
+  EqcInfo* e2 = d_state.getOrMakeEqcInfo(t2, false);
+  if (e2 == nullptr)
+  {
+    return;
+  }
+  // always create it if e2 was non-null
+  EqcInfo* e1 = d_state.getOrMakeEqcInfo(t1);
+
+  d_eagerSolver.eqNotifyMerge(e1, t1, e2, t2);
+
+  // add information from e2 to e1
+  if (!e2->d_lengthTerm.get().isNull())
+  {
+    e1->d_lengthTerm.set(e2->d_lengthTerm);
+  }
+  if (!e2->d_codeTerm.get().isNull())
+  {
+    e1->d_codeTerm.set(e2->d_codeTerm);
+  }
+  if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get())
+  {
+    e1->d_cardinalityLemK.set(e2->d_cardinalityLemK);
+  }
+  if (!e2->d_normalizedLength.get().isNull())
+  {
+    e1->d_normalizedLength.set(e2->d_normalizedLength);
+  }
+}
+
+void TheoryStrings::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
+{
+  if (t1.getType().isStringLike())
+  {
+    // store disequalities between strings, may need to check if their lengths
+    // are equal/disequal
+    d_state.addDisequality(t1, t2);
+  }
+}
+
 void TheoryStrings::addCarePairs(TNodeTrie* t1,
                                  TNodeTrie* t2,
                                  unsigned arity,
index dbb04580f2794ea21cd13f5d716db06a55fe7908..21db7da0c3dbae51e32ff04e0ed8432d2c459d57 100644 (file)
@@ -108,6 +108,10 @@ class TheoryStrings : public Theory {
   void conflict(TNode a, TNode b);
   /** called when a new equivalence class is created */
   void eqNotifyNewClass(TNode t);
+  /** Called just after the merge of two equivalence classes */
+  void eqNotifyMerge(TNode t1, TNode t2);
+  /** called a disequality is added */
+  void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
   /** preprocess rewrite */
   TrustNode ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) override;
   /** Collect model values in m based on the relevant terms given by termSet */
@@ -118,9 +122,7 @@ class TheoryStrings : public Theory {
   /** NotifyClass for equality engine */
   class NotifyClass : public eq::EqualityEngineNotify {
   public:
-   NotifyClass(TheoryStrings& ts) : d_str(ts), d_eagerSolver(ts.d_eagerSolver)
-   {
-   }
+   NotifyClass(TheoryStrings& ts) : d_str(ts) {}
    bool eqNotifyTriggerPredicate(TNode predicate, bool value) override
    {
      Debug("strings") << "NotifyClass::eqNotifyTriggerPredicate(" << predicate
@@ -156,19 +158,17 @@ class TheoryStrings : public Theory {
     {
       Debug("strings") << "NotifyClass::eqNotifyMerge(" << t1 << ", " << t2
                        << std::endl;
-      d_eagerSolver.eqNotifyMerge(t1, t2);
+      d_str.eqNotifyMerge(t1, t2);
     }
     void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override
     {
       Debug("strings") << "NotifyClass::eqNotifyDisequal(" << t1 << ", " << t2 << ", " << reason << std::endl;
-      d_eagerSolver.eqNotifyDisequal(t1, t2, reason);
+      d_str.eqNotifyDisequal(t1, t2, reason);
     }
 
    private:
     /** The theory of strings object to notify */
     TheoryStrings& d_str;
-    /** The eager solver of the theory of strings */
-    EagerSolver& d_eagerSolver;
   };/* class TheoryStrings::NotifyClass */
   /** compute care graph */
   void computeCareGraph() override;