Add rewrite for splitting equalities (#2957)
authorAndres Noetzli <andres.noetzli@gmail.com>
Mon, 30 Sep 2019 03:31:18 +0000 (20:31 -0700)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 30 Sep 2019 03:31:18 +0000 (22:31 -0500)
src/theory/strings/theory_strings_rewriter.cpp
test/unit/theory/theory_strings_rewriter_white.h

index 32190e093e901bf70079024787a30a52d4cedaca..fa3addf1f71597781aae4c57d30d6e52d4a2e23b 100644 (file)
@@ -609,6 +609,107 @@ Node TheoryStringsRewriter::rewriteStrEqualityExt(Node node)
     }
   }
 
+  if (node[0].getKind() == STRING_CONCAT && node[1].getKind() == STRING_CONCAT)
+  {
+    // (= (str.++ x_1 ... x_i x_{i + 1} ... x_n)
+    //    (str.++ y_1 ... y_j y_{j + 1} ... y_m)) --->
+    //  (and (= (str.++ x_1 ... x_i) (str.++ y_1 ... y_j))
+    //       (= (str.++ x_{i + 1} ... x_n) (str.++ y_{j + 1} ... y_m)))
+    //
+    // if (str.len (str.++ x_1 ... x_i)) = (str.len (str.++ y_1 ... y_j))
+    //
+    // This rewrite performs length-based equality splitting: If we can show
+    // that two prefixes have the same length, we can split an equality into
+    // two equalities, one over the prefixes and another over the suffixes.
+    std::vector<Node> v0, v1;
+    utils::getConcat(node[0], v0);
+    utils::getConcat(node[1], v1);
+    size_t startRhs = 0;
+    for (size_t i = 0, size0 = v0.size(); i <= size0; i++)
+    {
+      std::vector<Node> pfxv0(v0.begin(), v0.begin() + i);
+      Node pfx0 = utils::mkConcat(STRING_CONCAT, pfxv0);
+      for (size_t j = startRhs, size1 = v1.size(); j <= size1; j++)
+      {
+        if (!(i == 0 && j == 0) && !(i == v0.size() && j == v1.size()))
+        {
+          std::vector<Node> pfxv1(v1.begin(), v1.begin() + j);
+          Node pfx1 = utils::mkConcat(STRING_CONCAT, pfxv1);
+          Node lenPfx0 = nm->mkNode(STRING_LENGTH, pfx0);
+          Node lenPfx1 = nm->mkNode(STRING_LENGTH, pfx1);
+
+          if (checkEntailArithEq(lenPfx0, lenPfx1))
+          {
+            std::vector<Node> sfxv0(v0.begin() + i, v0.end());
+            std::vector<Node> sfxv1(v1.begin() + j, v1.end());
+            Node ret =
+                nm->mkNode(kind::AND,
+                           pfx0.eqNode(pfx1),
+                           utils::mkConcat(STRING_CONCAT, sfxv0)
+                               .eqNode(utils::mkConcat(STRING_CONCAT, sfxv1)));
+            return returnRewrite(node, ret, "split-eq");
+          }
+          else if (checkEntailArith(lenPfx1, lenPfx0, true))
+          {
+            // The prefix on the right-hand side is strictly longer than the
+            // prefix on the left-hand side, so we try to strip the right-hand
+            // prefix by the length of the left-hand prefix
+            //
+            // Example:
+            // (= (str.++ "A" x y) (str.++ x "AB" z)) --->
+            //   (and (= (str.++ "A" x) (str.++ x "A")) (= y (str.++ "B" z)))
+            std::vector<Node> rpfxv1;
+            if (stripSymbolicLength(pfxv1, rpfxv1, 1, lenPfx0))
+            {
+              std::vector<Node> sfxv0(v0.begin() + i, v0.end());
+              pfxv1.insert(pfxv1.end(), v1.begin() + j, v1.end());
+              Node ret = nm->mkNode(
+                  kind::AND,
+                  pfx0.eqNode(utils::mkConcat(STRING_CONCAT, rpfxv1)),
+                  utils::mkConcat(STRING_CONCAT, sfxv0)
+                      .eqNode(utils::mkConcat(STRING_CONCAT, pfxv1)));
+              return returnRewrite(node, ret, "split-eq-strip-r");
+            }
+
+            // If the prefix of the right-hand side is (strictly) longer than
+            // the prefix of the left-hand side, we can advance the left-hand
+            // side (since the length of the right-hand side is only increasing
+            // in the inner loop)
+            break;
+          }
+          else if (checkEntailArith(lenPfx0, lenPfx1, true))
+          {
+            // The prefix on the left-hand side is strictly longer than the
+            // prefix on the right-hand side, so we try to strip the left-hand
+            // prefix by the length of the right-hand prefix
+            //
+            // Example:
+            // (= (str.++ x "AB" z) (str.++ "A" x y)) --->
+            //   (and (= (str.++ x "A") (str.++ "A" x)) (= (str.++ "B" z) y))
+            std::vector<Node> rpfxv0;
+            if (stripSymbolicLength(pfxv0, rpfxv0, 1, lenPfx1))
+            {
+              pfxv0.insert(pfxv0.end(), v0.begin() + i, v0.end());
+              std::vector<Node> sfxv1(v1.begin() + j, v1.end());
+              Node ret = nm->mkNode(
+                  kind::AND,
+                  utils::mkConcat(STRING_CONCAT, rpfxv0).eqNode(pfx1),
+                  utils::mkConcat(STRING_CONCAT, pfxv0)
+                      .eqNode(utils::mkConcat(STRING_CONCAT, sfxv1)));
+              return returnRewrite(node, ret, "split-eq-strip-l");
+            }
+
+            // If the prefix of the left-hand side is (strictly) longer than
+            // the prefix of the right-hand side, then we don't need to check
+            // that right-hand prefix for future left-hand prefixes anymore
+            // (since they are increasing in length)
+            startRhs = j + 1;
+          }
+        }
+      }
+    }
+  }
+
   return node;
 }
 
index 00eb0b4952e2149b0b591596f0de7c7d53cf7a48..c5db12c6ca39a3319fd1c5c0f203c2c2e41be743 100644 (file)
@@ -1093,8 +1093,11 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node a = d_nm->mkConst(::CVC4::String("A"));
     Node aaa = d_nm->mkConst(::CVC4::String("AAA"));
     Node b = d_nm->mkConst(::CVC4::String("B"));
+    Node ba = d_nm->mkConst(::CVC4::String("BA"));
+    Node w = d_nm->mkVar("w", strType);
     Node x = d_nm->mkVar("x", strType);
     Node y = d_nm->mkVar("y", strType);
+    Node z = d_nm->mkVar("z", strType);
     Node xxa = d_nm->mkNode(kind::STRING_CONCAT, x, x, a);
     Node f = d_nm->mkConst(false);
     Node n = d_nm->mkVar("n", intType);
@@ -1300,6 +1303,49 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
           a);
       differentNormalForms(eq, f);
     }
+
+    {
+      // (= (str.++ "A" x y) (str.++ x "B" z)) --> false
+      Node eq = d_nm->mkNode(
+          kind::EQUAL,
+          d_nm->mkNode(kind::STRING_CONCAT, a, x, y),
+          d_nm->mkNode(kind::STRING_CONCAT, x, b, z));
+      sameNormalForm(eq, f);
+    }
+
+    {
+      // (= (str.++ "B" x y) (str.++ x "AAA" z)) --> false
+      Node eq = d_nm->mkNode(kind::EQUAL,
+                             d_nm->mkNode(kind::STRING_CONCAT, b, x, y),
+                             d_nm->mkNode(kind::STRING_CONCAT, x, aaa, z));
+      sameNormalForm(eq, f);
+    }
+
+    {
+      Node xrepl = d_nm->mkNode(kind::STRING_STRREPL, x, a, b);
+
+      // Same normal form for:
+      //
+      // (= (str.++ "B" (str.replace x "A" "B") z y w)
+      //    (str.++ z x "BA" z))
+      //
+      // (and (= (str.++ "B" (str.replace x "A" "B") z)
+      //         (str.++ z x "B"))
+      //      (= (str.++ y w) (str.++ "A" z)))
+      Node lhs =
+          d_nm->mkNode(kind::EQUAL,
+                       d_nm->mkNode(kind::STRING_CONCAT, b, xrepl, z, y, w),
+                       d_nm->mkNode(kind::STRING_CONCAT, z, x, ba, z));
+      Node rhs = d_nm->mkNode(
+          kind::AND,
+          d_nm->mkNode(kind::EQUAL,
+                       d_nm->mkNode(kind::STRING_CONCAT, b, xrepl, z),
+                       d_nm->mkNode(kind::STRING_CONCAT, z, x, b)),
+          d_nm->mkNode(kind::EQUAL,
+                       d_nm->mkNode(kind::STRING_CONCAT, y, w),
+                       d_nm->mkNode(kind::STRING_CONCAT, a, z)));
+      sameNormalForm(lhs, rhs);
+    }
   }
 
   void testStripConstantEndpoints()