Add rewrites for `seq.update`/`seq.nth` (#7966)
authorAndres Noetzli <andres.noetzli@gmail.com>
Wed, 19 Jan 2022 23:45:21 +0000 (15:45 -0800)
committerGitHub <noreply@github.com>
Wed, 19 Jan 2022 23:45:21 +0000 (23:45 +0000)
This commit adds rewrites for `seq.update` and `seq.nth`. It adds a
rewrite to check if an update is out of bounds and it adds rewrites to
evaluate these operators using `stripSymbolicLength()`. This allows us
for example to evaluate updates on concatenations of non-constant
`seq.unit`s.

src/theory/strings/arith_entail.cpp
src/theory/strings/rewrites.cpp
src/theory/strings/rewrites.h
src/theory/strings/sequences_rewriter.cpp
test/unit/theory/sequences_rewriter_white.cpp

index 19f1bc97df52bdc8e69cb1963c6f1ffe50428f1d..0659a2ff97ab9fec986ea69c61b6ca0124b0ea9d 100644 (file)
@@ -861,12 +861,17 @@ Node ArithEntail::getConstantBoundLength(TNode s, bool isLower) const
     return ret;
   }
   NodeManager* nm = NodeManager::currentNM();
+  Kind sk = s.getKind();
   if (s.isConst())
   {
     size_t len = Word::getLength(s);
     ret = nm->mkConstInt(Rational(len));
   }
-  else if (s.getKind() == STRING_CONCAT)
+  else if (sk == SEQ_UNIT)
+  {
+    ret = nm->mkConstInt(1);
+  }
+  else if (sk == STRING_CONCAT)
   {
     Rational sum(0);
     bool success = true;
index 262459b4f6531f5348a27139c66600a1ee037790..98420faaac3b55d255ca3826d933dfe7e5e76df1 100644 (file)
@@ -151,11 +151,13 @@ const char* toString(Rewrite r)
     case Rewrite::SS_STRIP_END_PT: return "SS_STRIP_END_PT";
     case Rewrite::SS_STRIP_START_PT: return "SS_STRIP_START_PT";
     case Rewrite::UPD_EVAL: return "UPD_EVAL";
+    case Rewrite::UPD_EVAL_SYM: return "UPD_EVAL_SYM";
     case Rewrite::UPD_EMPTYSTR: return "UPD_EMPTYSTR";
     case Rewrite::UPD_CONST_INDEX_MAX_OOB: return "UPD_CONST_INDEX_MAX_OOB";
     case Rewrite::UPD_CONST_INDEX_NEG: return "UPD_CONST_INDEX_NEG";
     case Rewrite::UPD_CONST_INDEX_OOB: return "UPD_CONST_INDEX_OOB";
     case Rewrite::UPD_REV: return "UPD_REV";
+    case Rewrite::UPD_OOB: return "UPD_OOB";
     case Rewrite::STOI_CONCAT_NONNUM: return "STOI_CONCAT_NONNUM";
     case Rewrite::STOI_EVAL: return "STOI_EVAL";
     case Rewrite::STR_CONV_CONST: return "STR_CONV_CONST";
@@ -225,7 +227,7 @@ const char* toString(Rewrite r)
     case Rewrite::SEQ_UNIT_EVAL: return "SEQ_UNIT_EVAL";
     case Rewrite::SEQ_NTH_EVAL: return "SEQ_NTH_EVAL";
     case Rewrite::SEQ_NTH_TOTAL_OOB: return "SEQ_NTH_TOTAL_OOB";
-    case Rewrite::SEQ_NTH_UNIT: return "SEQ_NTH_UNIT";
+    case Rewrite::SEQ_NTH_EVAL_SYM: return "SEQ_NTH_EVAL_SYM";
     default: return "?";
   }
 }
index eedd0fb402dd9c25295d6e6392d8039b5c6860cc..de8af0b4b7ddbc76fcc8c2088cadcc7cb9a6894a 100644 (file)
@@ -152,11 +152,13 @@ enum class Rewrite : uint32_t
   SS_STRIP_END_PT,
   SS_STRIP_START_PT,
   UPD_EVAL,
+  UPD_EVAL_SYM,
   UPD_EMPTYSTR,
   UPD_CONST_INDEX_MAX_OOB,
   UPD_CONST_INDEX_NEG,
   UPD_CONST_INDEX_OOB,
   UPD_REV,
+  UPD_OOB,
   STOI_CONCAT_NONNUM,
   STOI_EVAL,
   STR_CONV_CONST,
@@ -226,7 +228,7 @@ enum class Rewrite : uint32_t
   SEQ_UNIT_EVAL,
   SEQ_NTH_EVAL,
   SEQ_NTH_TOTAL_OOB,
-  SEQ_NTH_UNIT
+  SEQ_NTH_EVAL_SYM
 };
 
 /**
index 374b43c34c18380587520f07b5faa4661aaa51bd..e4db2f325371ad4e4746444d7d74d6c84cd9861b 100644 (file)
@@ -1765,10 +1765,18 @@ Node SequencesRewriter::rewriteSeqNth(Node node)
     }
   }
 
-  if (s.getKind() == SEQ_UNIT && i.isConst() && i.getConst<Rational>().isZero())
+  std::vector<Node> prefix, suffix;
+  utils::getConcat(s, suffix);
+  if ((i.isConst() && i.getConst<Rational>().isZero())
+      || d_stringsEntail.stripSymbolicLength(suffix, prefix, 1, i, true))
   {
-    Node ret = s[0];
-    return returnRewrite(node, ret, Rewrite::SEQ_NTH_UNIT);
+    if (suffix.size() > 0 && suffix[0].getKind() == SEQ_UNIT)
+    {
+      // (seq.nth (seq.++ prefix (seq.unit x) suffix) n) ---> x
+      // if len(prefix) = n
+      Node ret = suffix[0][0];
+      return returnRewrite(node, ret, Rewrite::SEQ_NTH_EVAL_SYM);
+    }
   }
 
   return node;
@@ -2088,11 +2096,39 @@ Node SequencesRewriter::rewriteUpdate(Node node)
     }
   }
 
+  NodeManager* nm = NodeManager::currentNM();
+  Node zero = nm->mkConstInt(0);
+  Node sLen = nm->mkNode(STRING_LENGTH, s);
+  if (d_arithEntail.check(zero, i, true) || d_arithEntail.check(i, sLen))
+  {
+    // (seq.update s i x) ---> s if x < 0 or x >= len(s)
+    Node ret = s;
+    return returnRewrite(node, ret, Rewrite::UPD_OOB);
+  }
+
+  std::vector<Node> prefix, suffix;
+  utils::getConcat(s, suffix);
+  if ((i.isConst() && i.getConst<Rational>().isZero())
+      || d_stringsEntail.stripSymbolicLength(suffix, prefix, 1, i, true))
+  {
+    Node updateLen = nm->mkNode(STRING_LENGTH, x);
+    std::vector<Node> replaced;
+    if (d_stringsEntail.stripSymbolicLength(
+            suffix, replaced, 1, updateLen, true))
+    {
+      // (seq.update (seq.++ p r s) i x) ---> (seq.++ p x s)
+      // if len(p) = i and len(r) = len(x)
+      prefix.emplace_back(x);
+      prefix.insert(prefix.end(), suffix.begin(), suffix.end());
+      Node ret = utils::mkConcat(prefix, node.getType());
+      return returnRewrite(node, ret, Rewrite::UPD_EVAL_SYM);
+    }
+  }
+
   if (s.getKind() == STRING_REV && d_stringsEntail.checkLengthOne(x))
   {
     // str.update(str.rev(s), n, t) --->
     //   str.rev(str.update(s, len(s) - (n + 1), t))
-    NodeManager* nm = NodeManager::currentNM();
     Node idx = nm->mkNode(MINUS,
                           nm->mkNode(STRING_LENGTH, s),
                           nm->mkNode(PLUS, i, nm->mkConstInt(Rational(1))));
index c468aa4630da97eea55fef207645e681242c2188..8ce982291208640f0cfb52b6d18cea587e6901ba 100644 (file)
@@ -216,6 +216,58 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption)
       false));
 }
 
+TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_nth)
+{
+  TypeNode intType = d_nodeManager->integerType();
+
+  Node x = d_nodeManager->mkVar("x", intType);
+  Node y = d_nodeManager->mkVar("y", intType);
+  Node z = d_nodeManager->mkVar("z", intType);
+  Node w = d_nodeManager->mkVar("w", intType);
+  Node v = d_nodeManager->mkVar("v", intType);
+
+  Node zero = d_nodeManager->mkConstInt(0);
+  Node one = d_nodeManager->mkConstInt(1);
+
+  Node sx = d_nodeManager->mkNode(SEQ_UNIT, x);
+  Node sy = d_nodeManager->mkNode(SEQ_UNIT, y);
+  Node sz = d_nodeManager->mkNode(SEQ_UNIT, z);
+  Node sw = d_nodeManager->mkNode(SEQ_UNIT, w);
+  Node sv = d_nodeManager->mkNode(SEQ_UNIT, v);
+  Node xyz = d_nodeManager->mkNode(STRING_CONCAT, sx, sy, sz);
+  Node wv = d_nodeManager->mkNode(STRING_CONCAT, sw, sv);
+
+  {
+    // Same normal form for:
+    //
+    // (seq.nth (seq.unit x) 0)
+    //
+    // x
+    Node n = d_nodeManager->mkNode(SEQ_NTH, sx, zero);
+    sameNormalForm(n, x);
+  }
+
+  {
+    // Same normal form for:
+    //
+    // (seq.nth (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) 0)
+    //
+    // x
+    Node n = d_nodeManager->mkNode(SEQ_NTH, xyz, zero);
+    sameNormalForm(n, x);
+  }
+
+  {
+    // Same normal form for:
+    //
+    // (seq.nth (seq.++ (seq.unit x) (seq.unit y) (seq.unit z)) 0)
+    //
+    // x
+    Node n = d_nodeManager->mkNode(SEQ_NTH, xyz, one);
+    sameNormalForm(n, y);
+  }
+}
+
 TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr)
 {
   StringsRewriter sr(d_rewriter, nullptr);
@@ -356,6 +408,111 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr)
   sameNormalForm(substr, empty);
 }
 
+TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_update)
+{
+  TypeNode intType = d_nodeManager->integerType();
+
+  Node x = d_nodeManager->mkVar("x", intType);
+  Node y = d_nodeManager->mkVar("y", intType);
+  Node z = d_nodeManager->mkVar("z", intType);
+  Node w = d_nodeManager->mkVar("w", intType);
+  Node v = d_nodeManager->mkVar("v", intType);
+
+  Node negOne = d_nodeManager->mkConstInt(-1);
+  Node zero = d_nodeManager->mkConstInt(0);
+  Node one = d_nodeManager->mkConstInt(1);
+  Node three = d_nodeManager->mkConstInt(3);
+
+  Node sx = d_nodeManager->mkNode(SEQ_UNIT, x);
+  Node sy = d_nodeManager->mkNode(SEQ_UNIT, y);
+  Node sz = d_nodeManager->mkNode(SEQ_UNIT, z);
+  Node sw = d_nodeManager->mkNode(SEQ_UNIT, w);
+  Node sv = d_nodeManager->mkNode(SEQ_UNIT, v);
+  Node xyz = d_nodeManager->mkNode(STRING_CONCAT, sx, sy, sz);
+  Node wv = d_nodeManager->mkNode(STRING_CONCAT, sw, sv);
+
+  {
+    // Same normal form for:
+    //
+    // (seq.update
+    //   (seq.unit x))
+    //   0
+    //   (seq.unit w))
+    //
+    // (seq.unit w)
+    Node n = d_nodeManager->mkNode(STRING_UPDATE, sx, zero, sw);
+    sameNormalForm(n, sw);
+  }
+
+  {
+    // Same normal form for:
+    //
+    // (seq.update
+    //   (seq.++ (seq.unit x) (seq.unit y) (seq.unit z))
+    //   0
+    //   (seq.unit w))
+    //
+    // (seq.++ (seq.unit w) (seq.unit y) (seq.unit z))
+    Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, zero, sw);
+    Node wyz = d_nodeManager->mkNode(STRING_CONCAT, sw, sy, sz);
+    sameNormalForm(n, wyz);
+  }
+
+  {
+    // Same normal form for:
+    //
+    // (seq.update
+    //   (seq.++ (seq.unit x) (seq.unit y) (seq.unit z))
+    //   1
+    //   (seq.unit w))
+    //
+    // (seq.++ (seq.unit x) (seq.unit w) (seq.unit z))
+    Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, one, sw);
+    Node xwz = d_nodeManager->mkNode(STRING_CONCAT, sx, sw, sz);
+    sameNormalForm(n, xwz);
+  }
+
+  {
+    // Same normal form for:
+    //
+    // (seq.update
+    //   (seq.++ (seq.unit x) (seq.unit y) (seq.unit z))
+    //   1
+    //   (seq.++ (seq.unit w) (seq.unit v)))
+    //
+    // (seq.++ (seq.unit x) (seq.unit w) (seq.unit v))
+    Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, one, wv);
+    Node xwv = d_nodeManager->mkNode(STRING_CONCAT, sx, sw, sv);
+    sameNormalForm(n, xwv);
+  }
+
+  {
+    // Same normal form for:
+    //
+    // (seq.update
+    //   (seq.++ (seq.unit x) (seq.unit y) (seq.unit z))
+    //   -1
+    //   (seq.++ (seq.unit w) (seq.unit v)))
+    //
+    //  (seq.++ (seq.unit x) (seq.unit y) (seq.unit z))
+    Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, negOne, wv);
+    sameNormalForm(n, xyz);
+  }
+
+  {
+    // Same normal form for:
+    //
+    // (seq.update
+    //   (seq.++ (seq.unit x) (seq.unit y) (seq.unit z))
+    //   3
+    //   w)
+    //
+    //  (seq.++ (seq.unit x) (seq.unit y) (seq.unit z))
+    Node n = d_nodeManager->mkNode(STRING_UPDATE, xyz, three, sw);
+    sameNormalForm(n, xyz);
+  }
+}
+
 TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_concat)
 {
   TypeNode intType = d_nodeManager->integerType();