Add more (str.replace x y z) rewrites (#2628)
authorAndres Noetzli <andres.noetzli@gmail.com>
Mon, 15 Oct 2018 15:39:02 +0000 (08:39 -0700)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 15 Oct 2018 15:39:02 +0000 (10:39 -0500)
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h
test/unit/theory/theory_strings_rewriter_white.h

index 407279d224722789d702d357404150fc65a1d58e..e8a11e62e05119be1276b05b76cbfd764c38c1cb 100644 (file)
@@ -1651,6 +1651,28 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
       return returnRewrite(node, ret, "ss-start-geq-len");
     }
   }
+  else if (node[0].getKind() == STRING_STRREPL)
+  {
+    // (str.substr (str.replace x y z) 0 n)
+    //          ---> (str.replace (str.substr x 0 n) y z)
+    // if (str.len y) = 1 and (str.len z) = 1
+    if (node[1] == zero)
+    {
+      Node one = nm->mkConst(Rational(1));
+      Node n1len = nm->mkNode(kind::STRING_LENGTH, node[0][1]);
+      Node n2len = nm->mkNode(kind::STRING_LENGTH, node[0][2]);
+      if (checkEntailArith(one, n1len) && checkEntailArith(one, n2len)
+          && checkEntailNonEmpty(node[0][1]) && checkEntailNonEmpty(node[0][2]))
+      {
+        Node ret = nm->mkNode(
+            kind::STRING_STRREPL,
+            nm->mkNode(kind::STRING_SUBSTR, node[0][0], node[1], node[2]),
+            node[0][1],
+            node[0][2]);
+        return returnRewrite(node, ret, "substr-repl-swap");
+      }
+    }
+  }
 
   std::vector<Node> n1;
   getConcat(node[0], n1);
@@ -2140,6 +2162,16 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
         return returnRewrite(node, ret, "ctn-repl-len-one-to-ctn");
       }
     }
+
+    // (str.contains (str.replace x y z) z) --->
+    //   (or (str.contains x y) (str.contains x z))
+    if (node[0][2] == node[1])
+    {
+      Node ret = nm->mkNode(OR,
+                            nm->mkNode(STRING_STRCTN, node[0][0], node[0][1]),
+                            nm->mkNode(STRING_STRCTN, node[0][0], node[0][2]));
+      return returnRewrite(node, ret, "ctn-repl-to-ctn-disj");
+    }
   }
 
   if (node[1].getKind() == kind::STRING_STRREPL)
@@ -2439,6 +2471,31 @@ Node TheoryStringsRewriter::rewriteReplace( Node node ) {
     {
       return returnRewrite(node, node[0], "rpl-rpl-len-id");
     }
+
+    // (str.replace x y x) ---> (str.replace x (str.++ y1 ... yn) x)
+    // if 1 >= (str.len x) and (= y "") ---> (= y1 "") ... (= yn "")
+    if (checkEntailArith(nm->mkConst(Rational(1)), l0))
+    {
+      Node empty = nm->mkConst(String(""));
+      Node rn1 = Rewriter::rewrite(
+          rewriteEqualityExt(nm->mkNode(EQUAL, node[1], empty)));
+      if (rn1 != node[1])
+      {
+        std::vector<Node> emptyNodes;
+        bool allEmptyEqs;
+        std::tie(allEmptyEqs, emptyNodes) = collectEmptyEqs(rn1);
+
+        if (allEmptyEqs)
+        {
+          Node nn1 = mkConcat(STRING_CONCAT, emptyNodes);
+          if (node[1] != nn1)
+          {
+            Node ret = nm->mkNode(STRING_STRREPL, node[0], nn1, node[2]);
+            return returnRewrite(node, ret, "rpl-x-y-x-simp");
+          }
+        }
+      }
+    }
   }
 
   std::vector<Node> children1;
@@ -2511,45 +2568,9 @@ Node TheoryStringsRewriter::rewriteReplace( Node node ) {
     //
     Node empty = nm->mkConst(::CVC4::String(""));
 
-    // Collect the equalities of the form (= x "") (sorted)
-    std::set<TNode> emptyNodes;
-    bool allEmptyEqs = true;
-    if (cmp_conr.getKind() == kind::EQUAL)
-    {
-      if (cmp_conr[0] == empty)
-      {
-        emptyNodes.insert(cmp_conr[1]);
-      }
-      else if (cmp_conr[1] == empty)
-      {
-        emptyNodes.insert(cmp_conr[0]);
-      }
-      else
-      {
-        allEmptyEqs = false;
-      }
-    }
-    else
-    {
-      for (const Node& c : cmp_conr)
-      {
-        if (c.getKind() == kind::EQUAL)
-        {
-          if (c[0] == empty)
-          {
-            emptyNodes.insert(c[1]);
-          }
-          else if (c[1] == empty)
-          {
-            emptyNodes.insert(c[0]);
-          }
-        }
-        else
-        {
-          allEmptyEqs = false;
-        }
-      }
-    }
+    std::vector<Node> emptyNodes;
+    bool allEmptyEqs;
+    std::tie(allEmptyEqs, emptyNodes) = collectEmptyEqs(cmp_conr);
 
     if (emptyNodes.size() > 0)
     {
@@ -4761,6 +4782,61 @@ Node TheoryStringsRewriter::inferEqsFromContains(Node x, Node y)
   return nb.constructNode();
 }
 
+std::pair<bool, std::vector<Node> > TheoryStringsRewriter::collectEmptyEqs(
+    Node x)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Node empty = nm->mkConst(::CVC4::String(""));
+
+  // Collect the equalities of the form (= x "") (sorted)
+  std::set<TNode> emptyNodes;
+  bool allEmptyEqs = true;
+  if (x.getKind() == kind::EQUAL)
+  {
+    if (x[0] == empty)
+    {
+      emptyNodes.insert(x[1]);
+    }
+    else if (x[1] == empty)
+    {
+      emptyNodes.insert(x[0]);
+    }
+    else
+    {
+      allEmptyEqs = false;
+    }
+  }
+  else
+  {
+    for (const Node& c : x)
+    {
+      if (c.getKind() == kind::EQUAL)
+      {
+        if (c[0] == empty)
+        {
+          emptyNodes.insert(c[1]);
+        }
+        else if (c[1] == empty)
+        {
+          emptyNodes.insert(c[0]);
+        }
+      }
+      else
+      {
+        allEmptyEqs = false;
+      }
+    }
+  }
+
+  if (emptyNodes.size() == 0)
+  {
+    allEmptyEqs = false;
+  }
+
+  return std::make_pair(
+      allEmptyEqs, std::vector<Node>(emptyNodes.begin(), emptyNodes.end()));
+}
+
 Node TheoryStringsRewriter::returnRewrite(Node node, Node ret, const char* c)
 {
   Trace("strings-rewrite") << "Rewrite " << node << " to " << ret << " by " << c
index ed42ce762b23d4881ae1f815e3c43c35b13be7fa..2c38ce8dc88feef9b7416c51a17154ccbc29dc2b 100644 (file)
@@ -18,6 +18,9 @@
 #ifndef __CVC4__THEORY__STRINGS__THEORY_STRINGS_REWRITER_H
 #define __CVC4__THEORY__STRINGS__THEORY_STRINGS_REWRITER_H
 
+#include <utility>
+#include <vector>
+
 #include "theory/rewriter.h"
 #include "theory/type_enumerator.h"
 #include "expr/attribute.h"
@@ -632,6 +635,26 @@ class TheoryStringsRewriter {
    * infer that any of the yi must be empty.
    */
   static Node inferEqsFromContains(Node x, Node y);
+
+  /**
+   * Collects equal-to-empty nodes from a conjunction or a single
+   * node. Returns a list of nodes that are compared to empty nodes
+   * and a boolean that indicates whether all nodes in the
+   * conjunction were a comparison with the empty node. The nodes in
+   * the list are sorted and duplicates removed.
+   *
+   * Examples:
+   *
+   * collectEmptyEqs( (= "" x) ) = { true, [x] }
+   * collectEmptyEqs( (and (= "" x) (= "" y)) ) = { true, [x, y] }
+   * collectEmptyEqs( (and (= "A" x) (= "" y) (= "" y)) ) = { false, [y] }
+   *
+   * @param x The conjunction of equalities or a single equality
+   * @return A pair of a boolean that indicates whether the
+   * conjunction consists only of comparisons to the empty string
+   * and the list of nodes that are compared to the empty string
+   */
+  static std::pair<bool, std::vector<Node> > collectEmptyEqs(Node x);
 };/* class TheoryStringsRewriter */
 
 }/* CVC4::theory::strings namespace */
index f82140181de6169418510e04de8b221804ae2c25..c9259722469643dd88b86c06da822912428f393b 100644 (file)
@@ -173,10 +173,12 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node b = d_nm->mkConst(::CVC4::String("B"));
     Node abcd = d_nm->mkConst(::CVC4::String("ABCD"));
     Node zero = d_nm->mkConst(Rational(0));
+    Node one = d_nm->mkConst(Rational(1));
     Node two = d_nm->mkConst(Rational(2));
     Node three = d_nm->mkConst(Rational(3));
 
     Node s = d_nm->mkVar("s", strType);
+    Node s2 = d_nm->mkVar("s2", strType);
     Node x = d_nm->mkVar("x", intType);
     Node y = d_nm->mkVar("y", intType);
 
@@ -238,6 +240,42 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
                             s,
                             d_nm->mkNode(kind::STRING_SUBSTR, b, x, x));
     sameNormalForm(lhs, rhs);
+
+    // Same normal form:
+    //
+    // (str.substr (str.replace s "A" "B") 0 x)
+    //
+    // (str.replace (str.substr s 0 x) "A" "B")
+    Node substr_repl = d_nm->mkNode(kind::STRING_SUBSTR,
+                                    d_nm->mkNode(kind::STRING_STRREPL, s, a, b),
+                                    zero,
+                                    x);
+    Node repl_substr =
+        d_nm->mkNode(kind::STRING_STRREPL,
+                     d_nm->mkNode(kind::STRING_SUBSTR, s, zero, x),
+                     a,
+                     b);
+    sameNormalForm(substr_repl, repl_substr);
+
+    // Same normal form:
+    //
+    // (str.substr (str.replace s (str.substr (str.++ s2 "A") 0 1) "B") 0 x)
+    //
+    // (str.replace (str.substr s 0 x) (str.substr (str.++ s2 "A") 0 1) "B")
+    Node substr_y = d_nm->mkNode(kind::STRING_SUBSTR,
+                                 d_nm->mkNode(kind::STRING_CONCAT, s2, a),
+                                 zero,
+                                 one);
+    substr_repl =
+        d_nm->mkNode(kind::STRING_SUBSTR,
+                     d_nm->mkNode(kind::STRING_STRREPL, s, substr_y, b),
+                     zero,
+                     x);
+    repl_substr = d_nm->mkNode(kind::STRING_STRREPL,
+                               d_nm->mkNode(kind::STRING_SUBSTR, s, zero, x),
+                               substr_y,
+                               b);
+    sameNormalForm(substr_repl, repl_substr);
   }
 
   void testRewriteConcat()
@@ -389,6 +427,7 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
 
   void testRewriteReplace()
   {
+    TypeNode intType = d_nm->integerType();
     TypeNode strType = d_nm->stringType();
 
     Node empty = d_nm->mkConst(::CVC4::String(""));
@@ -401,6 +440,7 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node z = d_nm->mkVar("z", strType);
     Node zero = d_nm->mkConst(Rational(0));
     Node one = d_nm->mkConst(Rational(1));
+    Node n = d_nm->mkVar("n", intType);
 
     // (str.replace (str.replace x "B" x) x "A") -->
     //   (str.replace (str.replace x "B" "A") x "A")
@@ -498,6 +538,18 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
                              y);
     Node repl = d_nm->mkNode(kind::STRING_STRREPL, empty, x, y);
     sameNormalForm(repl_repl, repl);
+
+    // Same normal form:
+    //
+    // (str.replace "B" (str.replace x "A" "B") "B")
+    //
+    // (str.replace "B" x "B"))
+    repl_repl = d_nm->mkNode(kind::STRING_STRREPL,
+                             b,
+                             d_nm->mkNode(kind::STRING_STRREPL, x, a, b),
+                             b);
+    repl = d_nm->mkNode(kind::STRING_STRREPL, b, x, b);
+    sameNormalForm(repl_repl, repl);
   }
 
   void testRewriteContains()
@@ -716,6 +768,17 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
                        substr_z);
     rhs = d_nm->mkNode(kind::STRING_STRCTN, x, substr_z);
     sameNormalForm(lhs, rhs);
+
+    // Same normal form for:
+    //
+    // (str.contains (str.replace x y z) z)
+    //
+    // (str.contains (str.replace x z y) y)
+    lhs = d_nm->mkNode(
+        kind::STRING_STRCTN, d_nm->mkNode(kind::STRING_STRREPL, x, y, z), z);
+    rhs = d_nm->mkNode(
+        kind::STRING_STRCTN, d_nm->mkNode(kind::STRING_STRREPL, x, z, y), y);
+    sameNormalForm(lhs, rhs);
   }
 
   void testInferEqsFromContains()
@@ -914,6 +977,16 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node eq_xxa_repl_xxa = d_nm->mkNode(
         kind::EQUAL, xxa, d_nm->mkNode(kind::STRING_STRREPL, xxa, a, b));
     sameNormalForm(eq_xxa_repl_xxa, f);
+
+    // Same normal form for:
+    //
+    // (= (str.replace x "A" "B") "")
+    //
+    // (= x "")
+    Node eq_repl = d_nm->mkNode(
+        kind::EQUAL, d_nm->mkNode(kind::STRING_STRREPL, x, a, b), empty);
+    Node eq_x = d_nm->mkNode(kind::EQUAL, x, empty);
+    sameNormalForm(eq_repl, eq_x);
   }
 
  private: