Rewrite `str.contains` used for character matching (#3519)
authorAndres Noetzli <andres.noetzli@gmail.com>
Tue, 3 Dec 2019 18:05:01 +0000 (10:05 -0800)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 3 Dec 2019 18:05:01 +0000 (12:05 -0600)
src/theory/strings/theory_strings_rewriter.cpp
test/unit/theory/theory_strings_rewriter_white.h

index 1e5b2a65a4d3662e3f28c828ed7690c262e17d2a..fa3650b589b510e5b51d8dac60101fad3c6f1218 100644 (file)
@@ -2048,6 +2048,7 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
           NodeManager::currentNM()->mkConst(s.find(t) != std::string::npos);
       return returnRewrite(node, ret, "ctn-const");
     }else{
+      Node t = node[1];
       if (s.size() == 0)
       {
         Node len1 =
@@ -2061,6 +2062,50 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
           return returnRewrite(node, ret, "ctn-lhs-emptystr");
         }
       }
+      else if (checkEntailLengthOne(t))
+      {
+        std::vector<unsigned> svec = s.getVec();
+        std::sort(svec.begin(), svec.end());
+
+        NodeBuilder<> nb(OR);
+        nb << nm->mkConst(String("")).eqNode(t);
+
+        Node tc = nm->mkNode(STRING_CODE, t);
+        unsigned lb = svec[0];
+        unsigned curr = lb;
+        for (size_t i = 0, size = svec.size(); i <= size; i++)
+        {
+          if (i == size || (svec[i] != curr && svec[i] != curr + 1))
+          {
+            Node nlb = nm->mkConst(Rational(CVC4::String::convertUnsignedIntToCode(lb)));
+            Node nub = nm->mkConst(Rational(CVC4::String::convertUnsignedIntToCode(svec[i - 1])));
+            if (nlb == nub)
+            {
+              nb << nm->mkNode(EQUAL, tc, nlb);
+            }
+            else
+            {
+              nb << nm->mkNode(
+                  AND, nm->mkNode(LEQ, nlb, tc), nm->mkNode(LEQ, tc, nub));
+            }
+
+            if (i != size) {
+            lb = svec[i];
+            curr = lb;
+            }
+          } else {
+            curr = svec[i];
+          }
+        }
+
+        Node ret = nb;
+
+        // str.contains("ABCDEFabcdef", t) --->
+        // t = "" v str.code("A") <= str.code(t) <= str.code("F") v
+        //          str.code("a") <= str.code(t) <= str.code("f")
+        // if len(t) <= 1
+        return returnRewrite(node, ret, "ctn-split");
+      }
       else if (node[1].getKind() == kind::STRING_CONCAT)
       {
         int firstc, lastc;
index c5db12c6ca39a3319fd1c5c0f203c2c2e41be743..b5eb92900fc373b496e28660e7a09b525bfe9023 100644 (file)
@@ -703,9 +703,14 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node ab = d_nm->mkConst(::CVC4::String("AB"));
     Node b = d_nm->mkConst(::CVC4::String("B"));
     Node c = d_nm->mkConst(::CVC4::String("C"));
+    Node e = d_nm->mkConst(::CVC4::String("E"));
+    Node h = d_nm->mkConst(::CVC4::String("H"));
+    Node j = d_nm->mkConst(::CVC4::String("J"));
+    Node p = d_nm->mkConst(::CVC4::String("P"));
     Node abc = d_nm->mkConst(::CVC4::String("ABC"));
     Node def = d_nm->mkConst(::CVC4::String("DEF"));
     Node ghi = d_nm->mkConst(::CVC4::String("GHI"));
+    Node abbchijp = d_nm->mkConst(::CVC4::String("ABBCHIJP"));
     Node x = d_nm->mkVar("x", strType);
     Node y = d_nm->mkVar("y", strType);
     Node xy = d_nm->mkNode(kind::STRING_CONCAT, x, y);
@@ -1000,6 +1005,37 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
       rhs = d_nm->mkNode(kind::STRING_STRCTN, x, ab);
       sameNormalForm(lhs, rhs);
     }
+
+    {
+      // Same normal form for:
+      //
+      // (str.contains "ABBCHIJP" (str.at x n))
+      //
+      // (or (= x "")
+      //     (and (<= (str.code "A") (str.code (str.at x n)))
+      //          (<= (str.code (str.at x n)) (str.code "C")))
+      //     (and (<= (str.code "H") (str.code (str.at x n)))
+      //          (<= (str.code (str.at x n)) (str.code "J")))
+      //     (= (str.code (str.at x n)) (str.code "P")))
+      Node cat = d_nm->mkNode(kind::STRING_CHARAT, x, n);
+      lhs = d_nm->mkNode(kind::STRING_STRCTN, abbchijp, cat);
+      Node ca = d_nm->mkNode(kind::STRING_CODE, a);
+      Node cc = d_nm->mkNode(kind::STRING_CODE, c);
+      Node ch = d_nm->mkNode(kind::STRING_CODE, h);
+      Node cj = d_nm->mkNode(kind::STRING_CODE, j);
+      Node cp = d_nm->mkNode(kind::STRING_CODE, p);
+      Node ccat = d_nm->mkNode(kind::STRING_CODE, cat);
+      rhs = d_nm->mkNode(kind::OR,
+                         d_nm->mkNode(kind::EQUAL, cat, empty),
+                         d_nm->mkNode(kind::AND,
+                                      d_nm->mkNode(kind::LEQ, ca, ccat),
+                                      d_nm->mkNode(kind::LEQ, ccat, cc)),
+                         d_nm->mkNode(kind::AND,
+                                      d_nm->mkNode(kind::LEQ, ch, ccat),
+                                      d_nm->mkNode(kind::LEQ, ccat, cj)),
+                         d_nm->mkNode(kind::EQUAL, ccat, cp));
+      sameNormalForm(lhs, rhs);
+    }
   }
 
   void testInferEqsFromContains()