[Sequences/Array Solver] Minor refactoring (#7843)
authorAndres Noetzli <andres.noetzli@gmail.com>
Mon, 20 Dec 2021 14:41:06 +0000 (06:41 -0800)
committerGitHub <noreply@github.com>
Mon, 20 Dec 2021 14:41:06 +0000 (14:41 +0000)
This commit performs a minor refactoring of our array core solver. It
adds more comments and avoids sending the
STRINGS_ARRAY_NTH_TERM_FROM_UPDATE lemma more than once in a given
user context.

src/theory/strings/array_core_solver.cpp
src/theory/strings/array_core_solver.h

index fd48eab6bd560ec4448752a8c5b06504d3c292f1..160f289ad5b1aaea784504dca201707ec49c25a5 100644 (file)
@@ -41,7 +41,8 @@ ArrayCoreSolver::ArrayCoreSolver(Env& env,
       d_csolver(cs),
       d_esolver(es),
       d_extt(extt),
-      d_lem(context())
+      d_lem(context()),
+      d_registeredUpdates(userContext())
 {
 }
 
@@ -86,83 +87,97 @@ void ArrayCoreSolver::checkUpdate(const std::vector<Node>& updateTerms)
 {
   NodeManager* nm = NodeManager::currentNM();
 
-  Trace("seq-array-debug") << "updateTerms number: " << updateTerms.size()
-                           << std::endl;
+  Trace("seq-array-core-debug")
+      << "number of update terms: " << updateTerms.size() << std::endl;
   for (const Node& n : updateTerms)
   {
-    // current term (seq.update x i a)
-
-    // inference rule is:
-    // (seq.update x i a) in TERMS
-    // (seq.nth t j) in TERMS
-    // t = (seq.update x i a)
-    // ---------------------------------------------------------------------
-    // (seq.nth (seq.update x i a) j) =
-    //   (ITE, j in range(i, i+len(a)), (seq.nth a (j - i)),  (seq.nth x j))
+    Trace("seq-array-core-debug") << "check term " << n << std::endl;
 
     // note that the term could rewrites to a skolem
     // get proxy variable for the update term as t
     Node termProxy = d_termReg.getProxyVariableFor(n);
-    Trace("seq-update") << "- " << termProxy << " = " << n << std::endl;
     std::vector<Node> exp;
     d_im.addToExplanation(termProxy, n, exp);
 
-    // reasoning about nth(t, n[1]) even if it does not exist.
-    // x = update(s, n, t)
-    // ---------------------------------------------------------------------
-    // nth(x, n) = ite(n in range(0, len(s)), nth(t, 0), nth(s, n))
-    Node left = nm->mkNode(SEQ_NTH, termProxy, n[1]);
-    Node cond =
-        nm->mkNode(AND,
-                   nm->mkNode(GEQ, n[1], nm->mkConstInt(Rational(0))),
-                   nm->mkNode(LT, n[1], nm->mkNode(STRING_LENGTH, n[0])));
-    Node body1 = nm->mkNode(SEQ_NTH, n[2], nm->mkConstInt(Rational(0)));
-    Node body2 = nm->mkNode(SEQ_NTH, n[0], n[1]);
-    Node right = nm->mkNode(ITE, cond, body1, body2);
-    Node lem = nm->mkNode(EQUAL, left, right);
-    sendInference(exp, lem, InferenceId::STRINGS_ARRAY_NTH_TERM_FROM_UPDATE);
+    if (d_registeredUpdates.find(n) == d_registeredUpdates.end())
+    {
+      Trace("seq-array-core-debug") << "... registering" << std::endl;
+      d_registeredUpdates.insert(n);
+      // Introduce nth(update(s, n, t), n) for all update(s, n, t) terms.
+      //
+      // x = update(s, n, t)
+      // ------------------------------------------------------------
+      // nth(x, n) = ite(n in range(0, len(s)), nth(t, 0), nth(s, n))
+      Node left = nm->mkNode(SEQ_NTH, termProxy, n[1]);
+      Node cond =
+          nm->mkNode(AND,
+                     nm->mkNode(GEQ, n[1], nm->mkConstInt(Rational(0))),
+                     nm->mkNode(LT, n[1], nm->mkNode(STRING_LENGTH, n[0])));
+      Node body1 = nm->mkNode(SEQ_NTH, n[2], nm->mkConstInt(Rational(0)));
+      Node body2 = nm->mkNode(SEQ_NTH, n[0], n[1]);
+      Node right = nm->mkNode(ITE, cond, body1, body2);
+      Node lem = nm->mkNode(EQUAL, left, right);
+      d_im.sendInference(exp,
+                         lem,
+                         InferenceId::STRINGS_ARRAY_NTH_TERM_FROM_UPDATE,
+                         false,
+                         true);
+    }
 
-    // enumerate possible index
-    for (const auto& nth : d_index_map)
+    for (const auto& nthIdxs : d_indexMap)
     {
-      Node seq = nth.first;
-      if (d_state.areEqual(seq, n) || d_state.areEqual(seq, n[0]))
+      // Enumerate n-th terms for sequences that are related to the current
+      // update term
+      Node seq = nthIdxs.first;
+      if (!d_state.areEqual(seq, n) && !d_state.areEqual(seq, n[0]))
       {
-        const std::set<Node>& indexes = nth.second;
-        for (Node j : indexes)
+        continue;
+      }
+
+      const std::set<Node>& indexes = nthIdxs.second;
+      Trace("seq-array-core-debug") << "  check nth for " << seq
+                                    << " with indices " << indexes << std::endl;
+      for (Node j : indexes)
+      {
+        if (n[2].getKind() == SEQ_UNIT)
         {
-          // optimization: add a short cut for special case
+          // Special case for updates using unit
+          //
           // x = update(s, n, unit(t))
-          // y = nth(s, m)
+          // y = nth(x, m)
           // -----------------------------------------
           // n != m => nth(x, m) = nth(s, m)
-          if (n[2].getKind() == SEQ_UNIT)
-          {
-            left = nm->mkNode(DISTINCT, n[1], j);
-            Node nth1 = nm->mkNode(SEQ_NTH, termProxy, j);
-            Node nth2 = nm->mkNode(SEQ_NTH, n[0], j);
-            right = nm->mkNode(EQUAL, nth1, nth2);
-            lem = nm->mkNode(IMPLIES, left, right);
-            sendInference(
-                exp, lem, InferenceId::STRINGS_ARRAY_NTH_UPDATE_WITH_UNIT);
-            continue;
-          }
-
-          // normal cases
-          left = nm->mkNode(SEQ_NTH, termProxy, j);
-          cond = nm->mkNode(
-              AND,
-              nm->mkNode(LEQ, n[1], j),
-              nm->mkNode(
-                  LT,
-                  j,
-                  nm->mkNode(PLUS, n[1], nm->mkNode(STRING_LENGTH, n[2]))));
-          body1 = nm->mkNode(SEQ_NTH, n[2], nm->mkNode(MINUS, j, n[1]));
-          body2 = nm->mkNode(SEQ_NTH, n[0], j);
-          right = nm->mkNode(ITE, cond, body1, body2);
-          lem = nm->mkNode(EQUAL, left, right);
-          sendInference(exp, lem, InferenceId::STRINGS_ARRAY_NTH_UPDATE);
+          Node left = n[1].eqNode(j).notNode();
+          Node nth1 = nm->mkNode(SEQ_NTH, termProxy, j);
+          Node nth2 = nm->mkNode(SEQ_NTH, n[0], j);
+          Node right = nm->mkNode(EQUAL, nth1, nth2);
+          Node lem = nm->mkNode(IMPLIES, left, right);
+          sendInference(
+              exp, lem, InferenceId::STRINGS_ARRAY_NTH_UPDATE_WITH_UNIT);
+          continue;
         }
+
+        // Regular case
+        //
+        // x = update(s, n, t)
+        // y = nth(x, m)
+        // -----------------------------------------
+        // y = ite(n <= m < n + len(t), nth(t, m - n), nth(s, m))
+        Node nth = nm->mkNode(SEQ_NTH, termProxy, j);
+        Node cond = nm->mkNode(
+            AND,
+            nm->mkNode(LEQ, n[1], j),
+            nm->mkNode(
+                LT,
+                j,
+                nm->mkNode(PLUS, n[1], nm->mkNode(STRING_LENGTH, n[2]))));
+        Node cases =
+            nm->mkNode(ITE,
+                       cond,
+                       nm->mkNode(SEQ_NTH, n[2], nm->mkNode(MINUS, j, n[1])),
+                       nm->mkNode(SEQ_NTH, n[0], j));
+        Node lem = nm->mkNode(EQUAL, nth, cases);
+        sendInference(exp, lem, InferenceId::STRINGS_ARRAY_NTH_UPDATE);
       }
     }
   }
@@ -192,7 +207,7 @@ void ArrayCoreSolver::check(const std::vector<Node>& nthTerms,
   }
   Trace("seq-update") << "SequencesArraySolver::check..." << std::endl;
   d_writeModel.clear();
-  d_index_map.clear();
+  d_indexMap.clear();
   for (const Node& n : nthTerms)
   {
     // (seq.nth n[0] n[1])
@@ -200,7 +215,7 @@ void ArrayCoreSolver::check(const std::vector<Node>& nthTerms,
     Trace("seq-update") << "- " << r << ": " << n[1] << " -> " << n
                         << std::endl;
     d_writeModel[r][n[1]] = n;
-    d_index_map[r].insert(n[1]);
+    d_indexMap[r].insert(n[1]);
 
     if (n[0].getKind() == STRING_REV)
     {
index 3873f6a691964bcd267be8834f9d51196638fb62..b21919259d5515081b351533672f72fdc0f36b28 100644 (file)
@@ -133,10 +133,12 @@ class ArrayCoreSolver : protected EnvObj
   std::map<Node, Node> d_connectedSeq;
   /** The set of lemmas been sent */
   context::CDHashSet<Node> d_lem;
+  /** Set of updates that have been registered */
+  context::CDHashSet<Node> d_registeredUpdates;
 
   // ========= data structure =========
   /** Map sequence variable to indices that occurred in nth terms */
-  std::map<Node, std::set<Node>> d_index_map;
+  std::map<Node, std::set<Node>> d_indexMap;
 };
 
 }  // namespace strings