Add rewrites for str.replace in str.contains (#2623)
authorAndres Noetzli <andres.noetzli@gmail.com>
Fri, 12 Oct 2018 07:21:51 +0000 (00:21 -0700)
committerGitHub <noreply@github.com>
Fri, 12 Oct 2018 07:21:51 +0000 (00:21 -0700)
This commit adds two rewrites for `(str.contains (str.replace x y x) z) ---> (str.contains x z)`, either when `z = y` or `(str.len z) <= 1`. Additionally, the commit adds `(str.contains (str.replace x y z) w) ---> true` if `(str.contains x w) --> true` and `(str.contains z w) ---> true`.

src/theory/strings/theory_strings_rewriter.cpp
test/unit/theory/theory_strings_rewriter_white.h

index 28883b6b9e55b657bb538c85bc863ed744dd9d21..407279d224722789d702d357404150fc65a1d58e 100644 (file)
@@ -2108,8 +2108,7 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
       }
     }
   }
-
-  if (node[0].getKind() == kind::STRING_SUBSTR)
+  else if (node[0].getKind() == kind::STRING_SUBSTR)
   {
     // (str.contains (str.substr x n (str.len y)) y) --->
     //   (= (str.substr x n (str.len y)) y)
@@ -2121,6 +2120,27 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
       return returnRewrite(node, ret, "ctn-substr");
     }
   }
+  else if (node[0].getKind() == kind::STRING_STRREPL)
+  {
+    if (node[0][0] == node[0][2])
+    {
+      // (str.contains (str.replace x y x) y) ---> (str.contains x y)
+      if (node[0][1] == node[1])
+      {
+        Node ret = nm->mkNode(kind::STRING_STRCTN, node[0][0], node[1]);
+        return returnRewrite(node, ret, "ctn-repl-to-ctn");
+      }
+
+      // (str.contains (str.replace x y x) z) ---> (str.contains x z)
+      // if (str.len z) <= 1
+      Node one = nm->mkConst(Rational(1));
+      if (checkEntailArith(one, len_n2))
+      {
+        Node ret = nm->mkNode(kind::STRING_STRCTN, node[0][0], node[1]);
+        return returnRewrite(node, ret, "ctn-repl-len-one-to-ctn");
+      }
+    }
+  }
 
   if (node[1].getKind() == kind::STRING_STRREPL)
   {
@@ -3344,6 +3364,9 @@ bool TheoryStringsRewriter::componentContainsBase(
 {
   Assert(n1rb.isNull());
   Assert(n1re.isNull());
+
+  NodeManager* nm = NodeManager::currentNM();
+
   if (n1 == n2)
   {
     return true;
@@ -3362,8 +3385,7 @@ bool TheoryStringsRewriter::componentContainsBase(
           {
             if (computeRemainder)
             {
-              n1rb = NodeManager::currentNM()->mkConst(
-                  ::CVC4::String(s.prefix(s.size() - t.size())));
+              n1rb = nm->mkConst(::CVC4::String(s.prefix(s.size() - t.size())));
             }
             return true;
           }
@@ -3374,8 +3396,7 @@ bool TheoryStringsRewriter::componentContainsBase(
           {
             if (computeRemainder)
             {
-              n1re = NodeManager::currentNM()->mkConst(
-                  ::CVC4::String(s.suffix(s.size() - t.size())));
+              n1re = nm->mkConst(::CVC4::String(s.suffix(s.size() - t.size())));
             }
             return true;
           }
@@ -3389,12 +3410,11 @@ bool TheoryStringsRewriter::componentContainsBase(
             {
               if (f > 0)
               {
-                n1rb = NodeManager::currentNM()->mkConst(
-                    ::CVC4::String(s.prefix(f)));
+                n1rb = nm->mkConst(::CVC4::String(s.prefix(f)));
               }
               if (s.size() > f + t.size())
               {
-                n1re = NodeManager::currentNM()->mkConst(
+                n1re = nm->mkConst(
                     ::CVC4::String(s.suffix(s.size() - (f + t.size()))));
               }
             }
@@ -3413,10 +3433,8 @@ bool TheoryStringsRewriter::componentContainsBase(
         {
           bool success = true;
           Node start_pos = n2[1];
-          Node end_pos =
-              NodeManager::currentNM()->mkNode(kind::PLUS, n2[1], n2[2]);
-          Node len_n2s =
-              NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, n2[0]);
+          Node end_pos = nm->mkNode(kind::PLUS, n2[1], n2[2]);
+          Node len_n2s = nm->mkNode(kind::STRING_LENGTH, n2[0]);
           if (dir == 1)
           {
             // To be a suffix, start + length must be greater than
@@ -3444,22 +3462,39 @@ bool TheoryStringsRewriter::componentContainsBase(
               }
               if (dir != 1)
               {
-                n1rb = NodeManager::currentNM()->mkNode(
-                    kind::STRING_SUBSTR,
-                    n2[0],
-                    NodeManager::currentNM()->mkConst(Rational(0)),
-                    start_pos);
+                n1rb = nm->mkNode(kind::STRING_SUBSTR,
+                                  n2[0],
+                                  nm->mkConst(Rational(0)),
+                                  start_pos);
               }
               if (dir != -1)
               {
-                n1re = NodeManager::currentNM()->mkNode(
-                    kind::STRING_SUBSTR, n2[0], end_pos, len_n2s);
+                n1re = nm->mkNode(kind::STRING_SUBSTR, n2[0], end_pos, len_n2s);
               }
             }
             return true;
           }
         }
       }
+
+      if (!computeRemainder && dir == 0)
+      {
+        if (n1.getKind() == STRING_STRREPL)
+        {
+          // (str.contains (str.replace x y z) w) ---> true
+          // if (str.contains x w) --> true and (str.contains z w) ---> true
+          Node xCtnW = Rewriter::rewrite(nm->mkNode(STRING_STRCTN, n1[0], n2));
+          if (xCtnW.isConst() && xCtnW.getConst<bool>())
+          {
+            Node zCtnW =
+                Rewriter::rewrite(nm->mkNode(STRING_STRCTN, n1[2], n2));
+            if (zCtnW.isConst() && zCtnW.getConst<bool>())
+            {
+              return true;
+            }
+          }
+        }
+      }
     }
   }
   return false;
index cc29efb23a9a11f10a19bab7559d55624bc4c396..f82140181de6169418510e04de8b221804ae2c25 100644 (file)
@@ -80,6 +80,24 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     TS_ASSERT_DIFFERS(res_t1, res_t2);
   }
 
+  void testCheckEntailArith()
+  {
+    TypeNode intType = d_nm->integerType();
+    TypeNode strType = d_nm->stringType();
+
+    Node z = d_nm->mkVar("z", strType);
+    Node n = d_nm->mkVar("n", intType);
+    Node one = d_nm->mkConst(Rational(1));
+
+    // 1 >= (str.len (str.substr z n 1)) ---> true
+    Node substr_z = d_nm->mkNode(kind::STRING_LENGTH,
+                                 d_nm->mkNode(kind::STRING_SUBSTR, z, n, one));
+    TS_ASSERT(TheoryStringsRewriter::checkEntailArith(one, substr_z));
+
+    // (str.len (str.substr z n 1)) >= 1 ---> false
+    TS_ASSERT(!TheoryStringsRewriter::checkEntailArith(substr_z, one));
+  }
+
   void testCheckEntailArithWithAssumption()
   {
     TypeNode intType = d_nm->integerType();
@@ -497,9 +515,11 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node yx = d_nm->mkNode(kind::STRING_CONCAT, y, x);
     Node z = d_nm->mkVar("z", strType);
     Node n = d_nm->mkVar("n", intType);
-    Node one = d_nm->mkConst(Rational(2));
+    Node one = d_nm->mkConst(Rational(1));
+    Node two = d_nm->mkConst(Rational(2));
     Node three = d_nm->mkConst(Rational(3));
     Node four = d_nm->mkConst(Rational(4));
+    Node t = d_nm->mkConst(true);
     Node f = d_nm->mkConst(false);
 
     // Same normal form for:
@@ -644,6 +664,58 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node ctn_yxxy = d_nm->mkNode(kind::STRING_STRCTN, yx, xy);
     Node eq_yxxy = d_nm->mkNode(kind::EQUAL, yx, xy);
     sameNormalForm(ctn_yxxy, eq_yxxy);
+
+    // (str.contains (str.replace x y x) x) ---> true
+    ctn_repl = d_nm->mkNode(
+        kind::STRING_STRCTN, d_nm->mkNode(kind::STRING_STRREPL, x, y, x), x);
+    sameNormalForm(ctn_repl, t);
+
+    // (str.contains (str.replace (str.++ x y) z (str.++ y x)) x) ---> true
+    ctn_repl = d_nm->mkNode(
+        kind::STRING_STRCTN, d_nm->mkNode(kind::STRING_STRREPL, xy, z, yx), x);
+    sameNormalForm(ctn_repl, t);
+
+    // (str.contains (str.++ z (str.replace (str.++ x y) z (str.++ y x))) x)
+    //   ---> true
+    ctn_repl = d_nm->mkNode(
+        kind::STRING_STRCTN,
+        d_nm->mkNode(kind::STRING_CONCAT,
+                     z,
+                     d_nm->mkNode(kind::STRING_STRREPL, xy, z, yx)),
+        x);
+    sameNormalForm(ctn_repl, t);
+
+    // Same normal form for:
+    //
+    // (str.contains (str.replace x y x) y)
+    //
+    // (str.contains x y)
+    Node lhs = d_nm->mkNode(
+        kind::STRING_STRCTN, d_nm->mkNode(kind::STRING_STRREPL, x, y, x), y);
+    Node rhs = d_nm->mkNode(kind::STRING_STRCTN, x, y);
+    sameNormalForm(lhs, rhs);
+
+    // Same normal form for:
+    //
+    // (str.contains (str.replace x y x) "B")
+    //
+    // (str.contains x "B")
+    lhs = d_nm->mkNode(
+        kind::STRING_STRCTN, d_nm->mkNode(kind::STRING_STRREPL, x, y, x), b);
+    rhs = d_nm->mkNode(kind::STRING_STRCTN, x, b);
+    sameNormalForm(lhs, rhs);
+
+    // Same normal form for:
+    //
+    // (str.contains (str.replace x y x) (str.substr z n 1))
+    //
+    // (str.contains x (str.substr z n 1))
+    Node substr_z = d_nm->mkNode(kind::STRING_SUBSTR, z, n, one);
+    lhs = d_nm->mkNode(kind::STRING_STRCTN,
+                       d_nm->mkNode(kind::STRING_STRREPL, x, y, x),
+                       substr_z);
+    rhs = d_nm->mkNode(kind::STRING_STRCTN, x, substr_z);
+    sameNormalForm(lhs, rhs);
   }
 
   void testInferEqsFromContains()