Fixes for update/nth over constant sequences (#7631)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 11 Nov 2021 16:24:38 +0000 (10:24 -0600)
committerGitHub <noreply@github.com>
Thu, 11 Nov 2021 16:24:38 +0000 (16:24 +0000)
This includes fixes related to how we process seq.update / seq.nth applied to sequence constants.

src/theory/strings/array_solver.cpp

index 09e3aefddd4708f524ef6c57013e833279f1fc6e..b3e2aa457cdfb8a54d73390ddea8ee468f734b72 100644 (file)
 
 #include "theory/strings/array_solver.h"
 
+#include "expr/sequence.h"
 #include "theory/strings/arith_entail.h"
 #include "theory/strings/theory_strings_utils.h"
+#include "theory/strings/word.h"
 #include "util/rational.h"
 
 using namespace cvc5::context;
@@ -81,6 +83,7 @@ void ArraySolver::checkTerms(Kind k)
     Node r = d_state.getRepresentative(t[0]);
     NormalForm& nf = d_csolver.getNormalForm(r);
     Trace("seq-array-debug") << "...normal form " << nf.d_nf << std::endl;
+    std::vector<Node> nfChildren;
     if (nf.d_nf.empty())
     {
       // updates should have been reduced (UPD_EMPTYSTR)
@@ -92,8 +95,16 @@ void ArraySolver::checkTerms(Kind k)
     {
       Trace("seq-array-debug") << "...norm form size 1" << std::endl;
       // NOTE: could split on n=0 if needed, do not introduce ITE
-      if (nf.d_nf[0].getKind() == SEQ_UNIT)
+      Kind ck = nf.d_nf[0].getKind();
+      // Note that (seq.unit c) is rewritten to CONST_SEQUENCE{c}, hence we
+      // check two cases here. It is important for completeness of this schema
+      // to handle this differently from STRINGS_ARRAY_UPDATE_CONCAT /
+      // STRINGS_ARRAY_NTH_CONCAT. Otherwise we would conclude a trivial
+      // equality when update/nth is applied to a constant of length one.
+      if (ck == SEQ_UNIT
+          || (ck == CONST_SEQUENCE && Word::getLength(nf.d_nf[0]) == 1))
       {
+        Trace("seq-array-debug") << "...unit case" << std::endl;
         // do we know whether n = 0 ?
         // x = (seq.unit m) => (seq.update x n z) = ite(n=0, z, (seq.unit m))
         // x = (seq.unit m) => (seq.nth x n) = ite(n=0, m, Uf(x, n))
@@ -109,7 +120,15 @@ void ArraySolver::checkTerms(Kind k)
         else
         {
           Assert(k == SEQ_NTH);
-          thenBranch = nf.d_nf[0][0];
+          if (ck == CONST_SEQUENCE)
+          {
+            const Sequence& seq = nf.d_nf[0].getConst<Sequence>();
+            thenBranch = seq.getVec()[0];
+          }
+          else
+          {
+            thenBranch = nf.d_nf[0][0];
+          }
           Node uf = SkolemCache::mkSkolemSeqNth(t[0].getType(), "Uf");
           elseBranch = nm->mkNode(APPLY_UF, uf, t[0], t[1]);
           iid = InferenceId::STRINGS_ARRAY_NTH_UNIT;
@@ -126,17 +145,33 @@ void ArraySolver::checkTerms(Kind k)
           d_eqProc.insert(eq);
           d_im.sendInference(exp, eq, iid);
         }
+        continue;
       }
-      // otherwise, the equivalence class is pure wrt concatenation
-      d_currTerms[k].push_back(t);
-      continue;
+      else if (ck != CONST_SEQUENCE)
+      {
+        // otherwise, if the normal form is not a constant sequence, the
+        // equivalence class is pure wrt concatenation.
+        d_currTerms[k].push_back(t);
+        continue;
+      }
+      // if the normal form is a constant sequence, it is treated as a
+      // concatenation. We split per character and case split on whether the
+      // nth/update falls on each character below, which must have a size
+      // greater than one.
+      std::vector<Node> chars = Word::getChars(nf.d_nf[0]);
+      Assert (chars.size()>1);
+      nfChildren.insert(nfChildren.end(), chars.begin(), chars.end());
+    }
+    else
+    {
+      nfChildren.insert(nfChildren.end(), nf.d_nf.begin(), nf.d_nf.end());
     }
     // otherwise, we are the concatenation of the components
     // NOTE: for nth, split on index vs component lengths, do not introduce ITE
     std::vector<Node> cond;
     std::vector<Node> cchildren;
     std::vector<Node> lacc;
-    for (const Node& c : nf.d_nf)
+    for (const Node& c : nfChildren)
     {
       Trace("seq-array-debug") << "...process " << c << std::endl;
       Node clen = nm->mkNode(STRING_LENGTH, c);
@@ -146,26 +181,49 @@ void ArraySolver::checkTerms(Kind k)
         Node currSum = lacc.size() == 1 ? lacc[0] : nm->mkNode(PLUS, lacc);
         currIndex = nm->mkNode(MINUS, currIndex, currSum);
       }
-      if (k == STRING_UPDATE)
+      Node cc;
+      // If it is a constant of length one, then the update/nth is determined
+      // in this interval. Notice this is done here as
+      // an optimization to short cut introducing terms like
+      // (seq.nth (seq.unit c) i), which by construction is only relevant in
+      // the context where i = 0, hence we replace by c here.
+      if (c.getKind() == CONST_SEQUENCE)
       {
-        Node cc = nm->mkNode(STRING_UPDATE, c, currIndex, t[2]);
-        Trace("seq-array-debug") << "......component " << cc << std::endl;
-        cchildren.push_back(cc);
+        const Sequence& seq = c.getConst<Sequence>();
+        if (seq.size() == 1)
+        {
+          if (k == STRING_UPDATE)
+          {
+            cc = nm->mkNode(ITE, t[1].eqNode(d_zero), t[2], c);
+          }
+          else
+          {
+            cc = seq.getVec()[0];
+          }
+        }
       }
-      else
+      // if we did not process as a constant of length one
+      if (cc.isNull())
       {
-        Assert(k == SEQ_NTH);
-        Node cc = nm->mkNode(SEQ_NTH, c, currIndex);
-        Trace("seq-array-debug") << "......component " << cc << std::endl;
-        cchildren.push_back(cc);
+        if (k == STRING_UPDATE)
+        {
+          cc = nm->mkNode(STRING_UPDATE, c, currIndex, t[2]);
+        }
+        else
+        {
+          Assert(k == SEQ_NTH);
+          cc = nm->mkNode(SEQ_NTH, c, currIndex);
+        }
       }
+      Trace("seq-array-debug") << "......component " << cc << std::endl;
+      cchildren.push_back(cc);
       lacc.push_back(clen);
       if (k == SEQ_NTH)
       {
         Node currSumPost = lacc.size() == 1 ? lacc[0] : nm->mkNode(PLUS, lacc);
-        Node cc = nm->mkNode(LT, t[1], currSumPost);
-        Trace("seq-array-debug") << "......condition " << cc << std::endl;
-        cond.push_back(cc);
+        Node cf = nm->mkNode(LT, t[1], currSumPost);
+        Trace("seq-array-debug") << "......condition " << cf << std::endl;
+        cond.push_back(cf);
       }
     }
     // z = (seq.++ x y) =>