Unify rewrites related to (str.contains x y) --> (= x y) (#2512)
authorAndres Noetzli <andres.noetzli@gmail.com>
Mon, 24 Sep 2018 21:52:08 +0000 (14:52 -0700)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 24 Sep 2018 21:52:08 +0000 (16:52 -0500)
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h
test/unit/theory/theory_strings_rewriter_white.h

index 7803224c603cacad891cab85eaed9e0a9b21090a..48b288ea3e43058e95d0eec8e70d82ccb3ef0ec9 100644 (file)
@@ -1587,6 +1587,8 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
 
       // (str.contains x (str.++ w (str.replace x y x) z)) --->
       //   (and (= w "") (= x (str.replace x y x)) (= z ""))
+      //
+      // TODO: Remove with under-/over-approximation
       if (node[0] == n[0] && node[0] == n[2])
       {
         Node ret;
@@ -1684,7 +1686,6 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
       }
     }
     Trace("strings-rewrite-multiset") << "For " << node << " : " << std::endl;
-    bool sameConst = true;
     for (const Node& ch : chars)
     {
       Trace("strings-rewrite-multiset") << "  # occurrences of substring ";
@@ -1696,64 +1697,6 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
         Node ret = NodeManager::currentNM()->mkConst(false);
         return returnRewrite(node, ret, "ctn-mset-nss");
       }
-      else if (count_const[0][ch] > count_const[1][ch])
-      {
-        sameConst = false;
-      }
-    }
-
-    if (sameConst)
-    {
-      // At this point, we know that both the first and the second argument
-      // both contain the same constants. Now we can check if there are
-      // non-const components that appear in the second argument but not the
-      // first. If there are, we know that the str.contains is true iff those
-      // components are empty, so we can pull them out of the str.contains. For
-      // example:
-      //
-      // (str.contains (str.++ "A" x) (str.++ y x "A")) -->
-      //   (and (str.contains (str.++ "A" x) (str.++ x "A")) (= y ""))
-      //
-      // These equalities can be used by other rewrites for subtitutions.
-
-      // Find all non-const components that appear more times in second
-      // argument than the first
-      std::unordered_set<Node, NodeHashFunction> nConstEmpty;
-      for (std::pair<const Node, unsigned>& nncp : num_nconst[1])
-      {
-        if (nncp.second > num_nconst[0][nncp.first])
-        {
-          nConstEmpty.insert(nncp.first);
-        }
-      }
-
-      // Check if there are any non-const components that must be empty
-      if (nConstEmpty.size() > 0)
-      {
-        // Generate str.contains of the (potentially) non-empty parts
-        std::vector<Node> cs;
-        std::vector<Node> nnc2;
-        for (const Node& n : nc2)
-        {
-          if (nConstEmpty.find(n) == nConstEmpty.end())
-          {
-            nnc2.push_back(n);
-          }
-        }
-        cs.push_back(nm->mkNode(
-            kind::STRING_STRCTN, node[0], mkConcat(kind::STRING_CONCAT, nnc2)));
-
-        // Generate equalities for the parts that must be empty
-        Node emptyStr = nm->mkConst(String(""));
-        for (const Node& n : nConstEmpty)
-        {
-          cs.push_back(nm->mkNode(kind::EQUAL, n, emptyStr));
-        }
-
-        Assert(cs.size() >= 2);
-        Node res = nm->mkNode(kind::AND, cs);
-        return returnRewrite(node, res, "ctn-mset-substs");
-      }
     }
 
     // TODO (#1180): count the number of 2,3,4,.. character substrings
@@ -1767,11 +1710,22 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
   // TODO (#1180): abstract interpretation with multi-set domain
   // to show first argument is a strict subset of second argument
 
-  if (checkEntailArithEq(len_n1, len_n2))
+  // Try to rewrite (str.contains x y) into an equality or a conjunction of
+  // equalities:
+  //
+  // (str.contains x y) ---> (= x y) if (<= (str.len x) (str.len y))
+  //
+  // or more generally:
+  //
+  // (str.contains x (str.++ y1 ... yn)) --->
+  //  (and (= x (str.++ y1' ... ym')) (= y1'' "") ... (= yk'' ""))
+  //
+  // where yi' and yi'' correspond to some yj and
+  // (<= (str.len x) (str.++ y1' ... ym'))
+  Node eqs = inferEqsFromContains(node[0], node[1]);
+  if (!eqs.isNull())
   {
-    // len( n2 ) = len( n1 ) => contains( n1, n2 ) ---> n1 = n2
-    Node ret = node[0].eqNode(node[1]);
-    return returnRewrite(node, ret, "ctn-len-eq");
+    return returnRewrite(node, eqs, "ctn-to-eqs");
   }
 
   // splitting
@@ -1822,9 +1776,7 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
     // (str.contains (str.substr x n (str.len y)) y) --->
     //   (= (str.substr x n (str.len y)) y)
     //
-    // TODO: generalize with over-/underapproximation to:
-    //
-    // (str.contains x y) ---> (= x y) if (<= (str.len x) (str.len y))
+    // TODO: Remove with under-/over-approximation
     if (node[0][2] == nm->mkNode(kind::STRING_LENGTH, node[1]))
     {
       Node ret = nm->mkNode(kind::EQUAL, node[0], node[1]);
@@ -1844,6 +1796,10 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
 
     // (str.contains x (str.replace "" x y)) --->
     //   (= "" (str.replace "" x y))
+    //
+    // Note: Length-based reasoning is not sufficient to get this rewrite. We
+    // can neither show that str.len(str.replace("", x, y)) - str.len(x) >= 0
+    // nor str.len(x) - str.len(str.replace("", x, y)) >= 0
     Node emp = nm->mkConst(CVC4::String(""));
     if (node[0] == node[1][1] && node[1][0] == emp)
     {
@@ -2606,6 +2562,15 @@ Node TheoryStringsRewriter::rewritePrefixSuffix(Node n)
   {
     val = NodeManager::currentNM()->mkNode(kind::MINUS, lent, lens);
   }
+
+  // Check if we can turn the prefix/suffix into equalities by showing that the
+  // prefix/suffix is at least as long as the string
+  Node eqs = inferEqsFromContains(n[1], n[0]);
+  if (!eqs.isNull())
+  {
+    return returnRewrite(n, eqs, "suf/prefix-to-eqs");
+  }
+
   // general reduction to equality + substr
   Node retNode = n[0].eqNode(
       NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR, n[1], val, lens));
@@ -3809,6 +3774,141 @@ Node TheoryStringsRewriter::getStringOrEmpty(Node n)
   return res;
 }
 
+bool TheoryStringsRewriter::inferZerosInSumGeq(Node x,
+                                               std::vector<Node>& ys,
+                                               std::vector<Node>& zeroYs)
+{
+  Assert(zeroYs.empty());
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  // Check if we can show that y1 + ... + yn >= x
+  Node sum = (ys.size() > 1) ? nm->mkNode(PLUS, ys) : ys[0];
+  if (!checkEntailArith(sum, x))
+  {
+    return false;
+  }
+
+  // Try to remove yi one-by-one and check if we can still show:
+  //
+  // y1 + ... + yi-1 +  yi+1 + ... + yn >= x
+  //
+  // If that's the case, we know that yi can be zero and the inequality still
+  // holds.
+  size_t i = 0;
+  while (i < ys.size())
+  {
+    Node yi = ys[i];
+    std::vector<Node>::iterator pos = ys.erase(ys.begin() + i);
+    if (ys.size() > 1)
+    {
+      sum = nm->mkNode(PLUS, ys);
+    }
+    else
+    {
+      sum = ys.size() == 1 ? ys[0] : nm->mkConst(Rational(0));
+    }
+
+    if (checkEntailArith(sum, x))
+    {
+      zeroYs.push_back(yi);
+    }
+    else
+    {
+      ys.insert(pos, yi);
+      i++;
+    }
+  }
+  return true;
+}
+
+Node TheoryStringsRewriter::inferEqsFromContains(Node x, Node y)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Node emp = nm->mkConst(String(""));
+
+  Node xLen = nm->mkNode(STRING_LENGTH, x);
+  std::vector<Node> yLens;
+  if (y.getKind() != STRING_CONCAT)
+  {
+    yLens.push_back(nm->mkNode(STRING_LENGTH, y));
+  }
+  else
+  {
+    for (const Node& yi : y)
+    {
+      yLens.push_back(nm->mkNode(STRING_LENGTH, yi));
+    }
+  }
+
+  std::vector<Node> zeroLens;
+  if (x == emp)
+  {
+    // If x is the empty string, then all ys must be empty, too, and we can
+    // skip the expensive checks. Note that this is just a performance
+    // optimization.
+    zeroLens.swap(yLens);
+  }
+  else
+  {
+    // Check if we can infer that str.len(x) <= str.len(y). If that is the
+    // case, try to minimize the sum in str.len(x) <= str.len(y1) + ... +
+    // str.len(yn) (where y = y1 ++ ... ++ yn) while keeping the inequality
+    // true. The terms that can have length zero without making the inequality
+    // false must be all be empty if (str.contains x y) is true.
+    if (!inferZerosInSumGeq(xLen, yLens, zeroLens))
+    {
+      // We could not prove that the inequality holds
+      return Node::null();
+    }
+    else if (yLens.size() == y.getNumChildren())
+    {
+      // We could only prove that the inequality holds but not that any of the
+      // ys must be empty
+      return nm->mkNode(EQUAL, x, y);
+    }
+  }
+
+  if (y.getKind() != STRING_CONCAT)
+  {
+    if (zeroLens.size() == 1)
+    {
+      // y is not a concatenation and we found that it must be empty, so just
+      // return (= y "")
+      Assert(zeroLens[0][0] == y);
+      return nm->mkNode(EQUAL, y, emp);
+    }
+    else
+    {
+      Assert(yLens.size() == 1 && yLens[0][0] == y);
+      return nm->mkNode(EQUAL, x, y);
+    }
+  }
+
+  std::vector<Node> cs;
+  for (const Node& yiLen : yLens)
+  {
+    Assert(std::find(y.begin(), y.end(), yiLen[0]) != y.end());
+    cs.push_back(yiLen[0]);
+  }
+
+  NodeBuilder<> nb(AND);
+  // (= x (str.++ y1' ... ym'))
+  if (!cs.empty())
+  {
+    nb << nm->mkNode(EQUAL, x, mkConcat(STRING_CONCAT, cs));
+  }
+  // (= y1'' "") ... (= yk'' "")
+  for (const Node& zeroLen : zeroLens)
+  {
+    Assert(std::find(y.begin(), y.end(), zeroLen[0]) != y.end());
+    nb << nm->mkNode(EQUAL, zeroLen[0], emp);
+  }
+
+  // (and (= x (str.++ y1' ... ym')) (= y1'' "") ... (= yk'' ""))
+  return nb.constructNode();
+}
+
 Node TheoryStringsRewriter::returnRewrite(Node node, Node ret, const char* c)
 {
   Trace("strings-rewrite") << "Rewrite " << node << " to " << ret << " by " << c
index 5937e778f6066dcc76a29017bf6fd469d770ead5..70c573d9e77ccea575d2efbc86f78d6b2e33725f 100644 (file)
@@ -534,6 +534,42 @@ class TheoryStringsRewriter {
    * because the function could not compute a simpler
    */
   static Node getStringOrEmpty(Node n);
+
+  /**
+   * Given an inequality y1 + ... + yn >= x, removes operands yi s.t. the
+   * original inequality still holds. Returns true if the original inequality
+   * holds and false otherwise. The list of ys is modified to contain a subset
+   * of the original ys.
+   *
+   * Example:
+   *
+   * inferZerosInSumGeq( (str.len x), [ (str.len x), (str.len y), 1 ], [] )
+   * --> returns true with ys = [ (str.len x) ] and zeroYs = [ (str.len y), 1 ]
+   *     (can be used to rewrite the inequality to false)
+   *
+   * inferZerosInSumGeq( (str.len x), [ (str.len y) ], [] )
+   * --> returns false because it is not possible to show
+   *     str.len(y) >= str.len(x)
+   */
+  static bool inferZerosInSumGeq(Node x,
+                                 std::vector<Node>& ys,
+                                 std::vector<Node>& zeroYs);
+
+  /**
+   * Infers a conjunction of equalities that correspond to (str.contains x y)
+   * if it can show that the length of y is greater or equal to the length of
+   * x. If y is a concatentation, we get x = y1 ++ ... ++ yn, the conjunction
+   * is of the form:
+   *
+   * (and (= x (str.++ y1' ... ym')) (= y1'' "") ... (= yk'' ""))
+   *
+   * where each yi'' are yi that must be empty for (= x y) to hold and yi' are
+   * yi that the function could not infer anything about.  Returns a null node
+   * if the function cannot infer that str.len(y) >= str.len(x). Returns (= x
+   * y) if the function can infer that str.len(y) >= str.len(x) but cannot
+   * infer that any of the yi must be empty.
+   */
+  static Node inferEqsFromContains(Node x, Node y);
 };/* class TheoryStringsRewriter */
 
 }/* CVC4::theory::strings namespace */
index cb23c34c1098f7210a6c1603d16833564239e89d..0b569394d2d5219c401db04bdd5c86d008533a39 100644 (file)
@@ -424,6 +424,8 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node c = d_nm->mkConst(::CVC4::String("C"));
     Node x = d_nm->mkVar("x", strType);
     Node y = d_nm->mkVar("y", strType);
+    Node xy = d_nm->mkNode(kind::STRING_CONCAT, x, y);
+    Node yx = d_nm->mkNode(kind::STRING_CONCAT, y, x);
     Node z = d_nm->mkVar("z", strType);
     Node n = d_nm->mkVar("n", intType);
     Node one = d_nm->mkConst(Rational(2));
@@ -488,10 +490,8 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     // (str.contains (str.++ y x) (str.++ x z y))
     //
     // (and (str.contains (str.++ y x) (str.++ x y)) (= z ""))
-    Node yx = d_nm->mkNode(kind::STRING_CONCAT, y, x);
     Node yx_cnts_xzy = d_nm->mkNode(
         kind::STRING_STRCTN, yx, d_nm->mkNode(kind::STRING_CONCAT, x, z, y));
-    Node xy = d_nm->mkNode(kind::STRING_CONCAT, x, y);
     Node yx_cnts_xy = d_nm->mkNode(kind::AND,
                                    d_nm->mkNode(kind::EQUAL, z, empty),
                                    d_nm->mkNode(kind::STRING_STRCTN, yx, xy));
@@ -556,6 +556,109 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node eq_repl_empty = d_nm->mkNode(
         kind::EQUAL, empty, d_nm->mkNode(kind::STRING_STRREPL, empty, x, y));
     sameNormalForm(ctn_repl_empty, eq_repl_empty);
+
+    // Same normal form for:
+    //
+    // (str.contains x (str.++ x y))
+    //
+    // (= "" y)
+    Node ctn_x_x_y = d_nm->mkNode(
+        kind::STRING_STRCTN, x, d_nm->mkNode(kind::STRING_CONCAT, x, y));
+    Node eq_emp_y = d_nm->mkNode(kind::EQUAL, empty, y);
+    sameNormalForm(ctn_x_x_y, eq_emp_y);
+
+    // Same normal form for:
+    //
+    // (str.contains (str.++ y x) (str.++ x y))
+    //
+    // (= (str.++ y x) (str.++ x y))
+    Node ctn_yxxy = d_nm->mkNode(kind::STRING_STRCTN, yx, xy);
+    Node eq_yxxy = d_nm->mkNode(kind::EQUAL, yx, xy);
+    sameNormalForm(ctn_yxxy, eq_yxxy);
+  }
+
+  void testInferEqsFromContains()
+  {
+    TypeNode strType = d_nm->stringType();
+
+    Node empty = d_nm->mkConst(::CVC4::String(""));
+    Node a = d_nm->mkConst(::CVC4::String("A"));
+    Node b = d_nm->mkConst(::CVC4::String("B"));
+    Node x = d_nm->mkVar("x", strType);
+    Node y = d_nm->mkVar("y", strType);
+    Node xy = d_nm->mkNode(kind::STRING_CONCAT, x, y);
+    Node f = d_nm->mkConst(false);
+
+    // inferEqsFromContains("", (str.++ x y)) returns something equivalent to
+    // (= "" y)
+    Node empty_x_y = d_nm->mkNode(kind::AND,
+                                  d_nm->mkNode(kind::EQUAL, empty, x),
+                                  d_nm->mkNode(kind::EQUAL, empty, y));
+    sameNormalForm(TheoryStringsRewriter::inferEqsFromContains(empty, xy),
+                   empty_x_y);
+
+    // inferEqsFromContains(x, (str.++ x y)) returns false
+    Node bxya = d_nm->mkNode(kind::STRING_CONCAT, b, y, x, a);
+    sameNormalForm(TheoryStringsRewriter::inferEqsFromContains(x, bxya), f);
+
+    // inferEqsFromContains(x, y) returns null
+    Node n = TheoryStringsRewriter::inferEqsFromContains(x, y);
+    TS_ASSERT(n.isNull());
+
+    // inferEqsFromContains(x, x) returns something equivalent to (= x x)
+    Node eq_x_x = d_nm->mkNode(kind::EQUAL, x, x);
+    sameNormalForm(TheoryStringsRewriter::inferEqsFromContains(x, x), eq_x_x);
+
+    // inferEqsFromContains((str.replace x "B" "A"), x) returns something
+    // equivalent to (= (str.replace x "B" "A") x)
+    Node repl = d_nm->mkNode(kind::STRING_STRREPL, x, b, a);
+    Node eq_repl_x = d_nm->mkNode(kind::EQUAL, repl, x);
+    sameNormalForm(TheoryStringsRewriter::inferEqsFromContains(repl, x),
+                   eq_repl_x);
+
+    // inferEqsFromContains(x, (str.replace x "B" "A")) returns something
+    // equivalent to (= (str.replace x "B" "A") x)
+    Node eq_x_repl = d_nm->mkNode(kind::EQUAL, x, repl);
+    sameNormalForm(TheoryStringsRewriter::inferEqsFromContains(x, repl),
+                   eq_x_repl);
+  }
+
+  void testRewritePrefixSuffix()
+  {
+    TypeNode strType = d_nm->stringType();
+
+    Node empty = d_nm->mkConst(::CVC4::String(""));
+    Node a = d_nm->mkConst(::CVC4::String("A"));
+    Node x = d_nm->mkVar("x", strType);
+    Node y = d_nm->mkVar("y", strType);
+    Node xx = d_nm->mkNode(kind::STRING_CONCAT, x, x);
+    Node xxa = d_nm->mkNode(kind::STRING_CONCAT, x, x, a);
+    Node xy = d_nm->mkNode(kind::STRING_CONCAT, x, y);
+    Node f = d_nm->mkConst(false);
+
+    // Same normal form for:
+    //
+    // (str.prefix x (str.++ x y))
+    //
+    // (= y "")
+    Node p_xy = d_nm->mkNode(kind::STRING_PREFIX, xy, x);
+    Node empty_y = d_nm->mkNode(kind::EQUAL, y, empty);
+    sameNormalForm(p_xy, empty_y);
+
+    // Same normal form for:
+    //
+    // (str.suffix x (str.++ x x))
+    //
+    // (= x "")
+    Node p_xx = d_nm->mkNode(kind::STRING_SUFFIX, xx, x);
+    Node empty_x = d_nm->mkNode(kind::EQUAL, x, empty);
+    sameNormalForm(p_xx, empty_x);
+
+    // (str.suffix x (str.++ x x "A")) --> false
+    //
+    // (= x "")
+    Node p_xxa = d_nm->mkNode(kind::STRING_SUFFIX, xxa, x);
+    sameNormalForm(p_xxa, f);
   }
 
  private: