More rewrites for indexof (#1648)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 21 Mar 2018 21:29:41 +0000 (16:29 -0500)
committerGitHub <noreply@github.com>
Wed, 21 Mar 2018 21:29:41 +0000 (16:29 -0500)
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h

index e668615791e25c920ce5998174b6a2e18c99a42e..60d0d73b6fc2b98190e9349a4541b3d06c15b05a 100644 (file)
@@ -1859,9 +1859,49 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
     }
   }
 
-  Node len0 = nm->mkNode(kind::STRING_LENGTH, node[0]);
-  Node len1 = nm->mkNode(kind::STRING_LENGTH, node[1]);
-  Node len0m2 = nm->mkNode(kind::MINUS, len0, node[2]);
+  if (node[0] == node[1])
+  {
+    if (node[2].isConst())
+    {
+      if (node[2].getConst<Rational>().sgn() == 0)
+      {
+        // indexof( x, x, 0 ) --> 0
+        Node zero = nm->mkConst(Rational(0));
+        return returnRewrite(node, zero, "idof-eq-cst-start");
+      }
+    }
+    if (checkEntailArith(node[2], true))
+    {
+      // y>0  implies  indexof( x, x, y ) --> -1
+      Node negone = nm->mkConst(Rational(-1));
+      return returnRewrite(node, negone, "idof-eq-nstart");
+    }
+    Node emp = nm->mkConst(CVC4::String(""));
+    if (node[0] != emp)
+    {
+      // indexof( x, x, z ) ---> indexof( "", "", z )
+      Node ret = nm->mkNode(STRING_STRIDOF, emp, emp, node[2]);
+      return returnRewrite(node, ret, "idof-eq-norm");
+    }
+  }
+
+  Node len0 = nm->mkNode(STRING_LENGTH, node[0]);
+  Node len1 = nm->mkNode(STRING_LENGTH, node[1]);
+  Node len0m2 = nm->mkNode(MINUS, len0, node[2]);
+
+  if (node[1].isConst())
+  {
+    CVC4::String t = node[1].getConst<String>();
+    if (t.size() == 0)
+    {
+      if (checkEntailArith(len0, node[2]) && checkEntailArith(node[2]))
+      {
+        // len(x)>=z ^ z >=0 implies indexof( x, "", z ) ---> z
+        return returnRewrite(node, node[2], "idof-emp-idof");
+      }
+    }
+  }
+
   if (checkEntailArith(len1, len0m2, true))
   {
     // len(x)-z < len(y)  implies  indexof( x, y, z ) ----> -1
@@ -1877,7 +1917,12 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
   }
 
   Node cmp_con = nm->mkNode(kind::STRING_STRCTN, fstr, node[1]);
+  Trace("strings-rewrite-debug")
+      << "For " << node << ", check " << cmp_con << std::endl;
   Node cmp_conr = Rewriter::rewrite(cmp_con);
+  Trace("strings-rewrite-debug") << "...got " << cmp_conr << std::endl;
+  std::vector<Node> children1;
+  getConcat(node[1], children1);
   if (cmp_conr.isConst())
   {
     if (cmp_conr.getConst<bool>())
@@ -1885,8 +1930,6 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
       if (node[2].isConst() && node[2].getConst<Rational>().sgn() == 0)
       {
         // past the first position in node[0] that contains node[1], we can drop
-        std::vector<Node> children1;
-        getConcat(node[1], children1);
         std::vector<Node> nb;
         std::vector<Node> ne;
         int cc = componentContains(children0, children1, nb, ne, true, 1);
@@ -1900,30 +1943,22 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
         }
       }
 
-      // these rewrites are only possible if we will not return -1
-      Node l1 = nm->mkNode(kind::STRING_LENGTH, node[1]);
-      Node zero = NodeManager::currentNM()->mkConst(CVC4::Rational(0));
-      bool is_non_empty = checkEntailArith(l1, zero, true);
-
-      if (is_non_empty)
+      // strip symbolic length
+      Node new_len = node[2];
+      std::vector<Node> nr;
+      if (stripSymbolicLength(children0, nr, 1, new_len))
       {
-        // strip symbolic length
-        Node new_len = node[2];
-        std::vector<Node> nr;
-        if (stripSymbolicLength(children0, nr, 1, new_len))
-        {
-          // For example:
-          // z>str.len( x1 ) and str.len( y )>0 and str.contains( x2, y )-->true
-          // implies
-          // str.indexof( str.++( x1, x2 ), y, z ) --->
-          // str.len( x1 ) + str.indexof( x2, y, z-str.len(x1) )
-          Node nn = mkConcat(kind::STRING_CONCAT, children0);
-          Node ret = nm->mkNode(
-              kind::PLUS,
-              nm->mkNode(kind::MINUS, node[2], new_len),
-              nm->mkNode(kind::STRING_STRIDOF, nn, node[1], new_len));
-          return returnRewrite(node, ret, "idof-strip-sym-len");
-        }
+        // For example:
+        // z>str.len( x1 ) and str.contains( x2, y )-->true
+        // implies
+        // str.indexof( str.++( x1, x2 ), y, z ) --->
+        // str.len( x1 ) + str.indexof( x2, y, z-str.len(x1) )
+        Node nn = mkConcat(kind::STRING_CONCAT, children0);
+        Node ret =
+            nm->mkNode(kind::PLUS,
+                       nm->mkNode(kind::MINUS, node[2], new_len),
+                       nm->mkNode(kind::STRING_STRIDOF, nn, node[1], new_len));
+        return returnRewrite(node, ret, "idof-strip-sym-len");
       }
     }
     else
@@ -1934,6 +1969,20 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
     }
   }
 
+  if (node[2].isConst() && node[2].getConst<Rational>().sgn()==0)
+  {
+    std::vector<Node> cb;
+    std::vector<Node> ce;
+    if (stripConstantEndpoints(children0, children1, cb, ce, -1))
+    {
+      Node ret = mkConcat(kind::STRING_CONCAT, children0);
+      ret = nm->mkNode(STRING_STRIDOF, ret, node[1], node[2]);
+      // For example:
+      // str.indexof( str.++( x, "A" ), "B", 0 ) ---> str.indexof( x, "B", 0 )
+      return returnRewrite(node, ret, "rpl-pull-endpt");
+    }
+  }
+
   Trace("strings-rewrite-nf") << "No rewrites for : " << node << std::endl;
   return node;
 }
@@ -2686,17 +2735,21 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1,
   // for ( forwards, backwards ) direction
   for (unsigned r = 0; r < 2; r++)
   {
-    if (dir == 0 || (r == 0 && dir == -1) || (r == 1 && dir == 1))
+    if (dir == 0 || (r == 0 && dir == 1) || (r == 1 && dir == -1))
     {
       unsigned index0 = r == 0 ? 0 : n1.size() - 1;
       unsigned index1 = r == 0 ? 0 : n2.size() - 1;
       bool removeComponent = false;
-      Trace("strings-rewrite-debug2") << "stripConstantEndpoints : Compare "
-                                      << n1[index0] << " " << n2[index1]
-                                      << ", dir = " << dir << std::endl;
-      if (n1[index0].isConst())
+      Node n1cmp = n1[index0];
+      std::vector<Node> sss;
+      std::vector<Node> sls;
+      n1cmp = decomposeSubstrChain(n1cmp, sss, sls);
+      Trace("strings-rewrite-debug2")
+          << "stripConstantEndpoints : Compare " << n1cmp << " " << n2[index1]
+          << ", dir = " << dir << std::endl;
+      if (n1cmp.isConst())
       {
-        CVC4::String s = n1[index0].getConst<String>();
+        CVC4::String s = n1cmp.getConst<String>();
         // overlap is an overapproximation of the number of characters
         // n2[index1] can match in s
         unsigned overlap = s.size();
@@ -2713,7 +2766,7 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1,
               //   str.contains( "", str.++( "ba", x ) )
               removeComponent = true;
             }
-            else
+            else if (sss.empty())  // only if not substr
             {
               // check how much overlap there is
               // This is used to partially strip off the endpoint
@@ -2722,7 +2775,7 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1,
               overlap = r == 0 ? s.overlap(t) : t.overlap(s);
             }
           }
-          else
+          else if (sss.empty())  // only if not substr
           {
             Assert(ret < s.size());
             // can strip off up to the find position, e.g.
@@ -2745,7 +2798,7 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1,
             {
               break;
             }
-            else
+            else if (sss.empty())  // only if not substr
             {
               // e.g. str.contains( str.++( "a", x ), int.to.str(y) ) -->
               // str.contains( x, int.to.str(y) )
@@ -2784,7 +2837,7 @@ bool TheoryStringsRewriter::stripConstantEndpoints(std::vector<Node>& n1,
           }
         }
       }
-      else if (n1[index0].getKind() == kind::STRING_ITOS)
+      else if (n1cmp.getKind() == kind::STRING_ITOS)
       {
         if (n2[index1].isConst())
         {
@@ -3051,6 +3104,35 @@ bool TheoryStringsRewriter::checkEntailArithInternal(Node a)
   return false;
 }
 
+Node TheoryStringsRewriter::decomposeSubstrChain(Node s,
+                                                 std::vector<Node>& ss,
+                                                 std::vector<Node>& ls)
+{
+  Assert( ss.empty() );
+  Assert( ls.empty() );
+  while (s.getKind() == STRING_SUBSTR)
+  {
+    ss.push_back(s[1]);
+    ls.push_back(s[2]);
+    s = s[0];
+  }
+  std::reverse(ss.begin(), ss.end());
+  std::reverse(ls.begin(), ls.end());
+  return s;
+}
+
+Node TheoryStringsRewriter::mkSubstrChain(Node base,
+                                          const std::vector<Node>& ss,
+                                          const std::vector<Node>& ls)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  for (unsigned i = 0, size = ss.size(); i < size; i++)
+  {
+    base = nm->mkNode(STRING_SUBSTR, base, ss[i], ls[i]);
+  }
+  return base;
+}
+
 Node TheoryStringsRewriter::returnRewrite(Node node, Node ret, const char* c)
 {
   Trace("strings-rewrite") << "Rewrite " << node << " to " << ret << " by " << c
index 217546c71aa8dea11ce151e3e104fed79700a9a9..3aaf3eab735a7440c2f34ad4aba5354633704ba5 100644 (file)
@@ -327,6 +327,10 @@ private:
   *   n1 is updated to { "b", x, "d" }
   *   nb is updated to { "a" }
   *   ne is updated to { "e" }
+  * stripConstantEndpoints({ "ad", substr("ccc",x,y) }, { "d" }, {}, {}, -1)
+  *   returns true,
+  *   n1 is updated to {"ad"}
+  *   ne is updated to { substr("ccc",x,y) }
   */
   static bool stripConstantEndpoints(std::vector<Node>& n1,
                                      std::vector<Node>& n2,
@@ -366,6 +370,22 @@ private:
    *   checkEntailArith( a, strict ) = true.
    */
   static Node getConstantArithBound(Node a, bool isLower = true);
+  /** decompose substr chain
+   *
+   * If s is substr( ... substr( base, x1, y1 ) ..., xn, yn ), then this
+   * function returns base, adds { x1 ... xn } to ss, and { y1 ... yn } to ls.
+   */
+  static Node decomposeSubstrChain(Node s,
+                                   std::vector<Node>& ss,
+                                   std::vector<Node>& ls);
+  /** make substr chain
+   *
+   * If ss is { x1 ... xn } and ls is { y1 ... yn }, this returns the term
+   * substr( ... substr( base, x1, y1 ) ..., xn, yn ).
+   */
+  static Node mkSubstrChain(Node base,
+                            const std::vector<Node>& ss,
+                            const std::vector<Node>& ls);
 };/* class TheoryStringsRewriter */
 
 }/* CVC4::theory::strings namespace */