Rewrites for substr of strings of length one (#1712)
authorAndres Noetzli <andres.noetzli@gmail.com>
Mon, 26 Mar 2018 15:52:40 +0000 (08:52 -0700)
committerGitHub <noreply@github.com>
Mon, 26 Mar 2018 15:52:40 +0000 (08:52 -0700)
This commit adds a rewrite for substrings of strings of length one to
the empty string if it can be shown that it is not possible that the
start position and the length are both greater than zero:

```
(str.substr "A" x y) --> "" if x = 0 |= 0 >= y
```

The commit introduces a set of functions to check such entailments
with assumptions.

src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h
test/unit/Makefile.am
test/unit/theory/theory_strings_rewriter_white.h [new file with mode: 0644]

index 60d0d73b6fc2b98190e9349a4541b3d06c15b05a..8f9d4c596c885b9f2048dd937cddf669259a2c90 100644 (file)
@@ -20,6 +20,8 @@
 
 #include "options/strings_options.h"
 #include "smt/logic_exception.h"
+#include "theory/arith/arith_msum.h"
+#include "theory/theory.h"
 
 using namespace std;
 using namespace CVC4;
@@ -1381,6 +1383,8 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
 Node TheoryStringsRewriter::rewriteSubstr(Node node)
 {
   Assert(node.getKind() == kind::STRING_SUBSTR);
+
+  NodeManager* nm = NodeManager::currentNM();
   if (node[0].isConst())
   {
     if (node[0].getConst<String>().size() == 0)
@@ -1398,13 +1402,13 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
       {
         // start beyond the maximum size of strings
         // thus, it must be beyond the end point of this string
-        Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+        Node ret = nm->mkConst(::CVC4::String(""));
         return returnRewrite(node, ret, "ss-const-start-max-oob");
       }
       else if (node[1].getConst<Rational>().sgn() < 0)
       {
         // start before the beginning of the string
-        Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+        Node ret = nm->mkConst(::CVC4::String(""));
         return returnRewrite(node, ret, "ss-const-start-neg");
       }
       else
@@ -1413,20 +1417,19 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
         if (start >= s.size())
         {
           // start beyond the end of the string
-          Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+          Node ret = nm->mkConst(::CVC4::String(""));
           return returnRewrite(node, ret, "ss-const-start-oob");
         }
       }
       if (node[2].getConst<Rational>() > RMAXINT)
       {
         // take up to the end of the string
-        Node ret = NodeManager::currentNM()->mkConst(
-            ::CVC4::String(s.suffix(s.size() - start)));
+        Node ret = nm->mkConst(::CVC4::String(s.suffix(s.size() - start)));
         return returnRewrite(node, ret, "ss-const-len-max-oob");
       }
       else if (node[2].getConst<Rational>().sgn() <= 0)
       {
-        Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+        Node ret = nm->mkConst(::CVC4::String(""));
         return returnRewrite(node, ret, "ss-const-len-non-pos");
       }
       else
@@ -1436,31 +1439,29 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
         if (start + len > s.size())
         {
           // take up to the end of the string
-          Node ret = NodeManager::currentNM()->mkConst(
-              ::CVC4::String(s.suffix(s.size() - start)));
+          Node ret = nm->mkConst(::CVC4::String(s.suffix(s.size() - start)));
           return returnRewrite(node, ret, "ss-const-end-oob");
         }
         else
         {
           // compute the substr using the constant string
-          Node ret = NodeManager::currentNM()->mkConst(
-              ::CVC4::String(s.substr(start, len)));
+          Node ret = nm->mkConst(::CVC4::String(s.substr(start, len)));
           return returnRewrite(node, ret, "ss-const-ss");
         }
       }
     }
   }
-  Node zero = NodeManager::currentNM()->mkConst(CVC4::Rational(0));
+  Node zero = nm->mkConst(CVC4::Rational(0));
 
   // if entailed non-positive length or negative start point
   if (checkEntailArith(zero, node[1], true))
   {
-    Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+    Node ret = nm->mkConst(::CVC4::String(""));
     return returnRewrite(node, ret, "ss-start-neg");
   }
   else if (checkEntailArith(zero, node[2]))
   {
-    Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+    Node ret = nm->mkConst(::CVC4::String(""));
     return returnRewrite(node, ret, "ss-len-non-pos");
   }
 
@@ -1476,11 +1477,10 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
     {
       if (curr != zero && !n1.empty())
       {
-        childrenr.push_back(
-            NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR,
-                                             mkConcat(kind::STRING_CONCAT, n1),
-                                             node[1],
-                                             curr));
+        childrenr.push_back(nm->mkNode(kind::STRING_SUBSTR,
+                                       mkConcat(kind::STRING_CONCAT, n1),
+                                       node[1],
+                                       curr));
       }
       Node ret = mkConcat(kind::STRING_CONCAT, childrenr);
       return returnRewrite(node, ret, "ss-len-include");
@@ -1502,24 +1502,32 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
     }
     else if (r == 1)
     {
-      Node tot_len = Rewriter::rewrite(
-          NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[0]));
-      Node end_pt = Rewriter::rewrite(
-          NodeManager::currentNM()->mkNode(kind::PLUS, node[1], node[2]));
+      Node tot_len =
+          Rewriter::rewrite(nm->mkNode(kind::STRING_LENGTH, node[0]));
+      Node end_pt = Rewriter::rewrite(nm->mkNode(kind::PLUS, node[1], node[2]));
       if (node[2] != tot_len)
       {
         if (checkEntailArith(node[2], tot_len))
         {
           // end point beyond end point of string, map to tot_len
-          Node ret = NodeManager::currentNM()->mkNode(
-              kind::STRING_SUBSTR, node[0], node[1], tot_len);
+          Node ret = nm->mkNode(kind::STRING_SUBSTR, node[0], node[1], tot_len);
           return returnRewrite(node, ret, "ss-end-pt-norm");
         }
         else
         {
           // strip up to ( str.len(node[0]) - end_pt ) off the end of the string
-          curr = Rewriter::rewrite(
-              NodeManager::currentNM()->mkNode(kind::MINUS, tot_len, end_pt));
+          curr = Rewriter::rewrite(nm->mkNode(kind::MINUS, tot_len, end_pt));
+        }
+      }
+
+      if (tot_len == nm->mkConst(Rational(1)))
+      {
+        Node n1_eq_zero =
+            Rewriter::rewrite(nm->mkNode(kind::EQUAL, node[1], zero));
+        if (checkEntailArithWithAssumption(n1_eq_zero, zero, node[2], false))
+        {
+          Node ret = nm->mkConst(::CVC4::String(""));
+          return returnRewrite(node, ret, "ss-len-one-unsat");
         }
       }
     }
@@ -1532,20 +1540,18 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
       {
         if (r == 0)
         {
-          Node ret = NodeManager::currentNM()->mkNode(
-              kind::STRING_SUBSTR,
-              mkConcat(kind::STRING_CONCAT, n1),
-              curr,
-              node[2]);
+          Node ret = nm->mkNode(kind::STRING_SUBSTR,
+                                mkConcat(kind::STRING_CONCAT, n1),
+                                curr,
+                                node[2]);
           return returnRewrite(node, ret, "ss-strip-start-pt");
         }
         else
         {
-          Node ret = NodeManager::currentNM()->mkNode(
-              kind::STRING_SUBSTR,
-              mkConcat(kind::STRING_CONCAT, n1),
-              node[1],
-              node[2]);
+          Node ret = nm->mkNode(kind::STRING_SUBSTR,
+                                mkConcat(kind::STRING_CONCAT, n1),
+                                node[1],
+                                node[2]);
           return returnRewrite(node, ret, "ss-strip-end-pt");
         }
       }
@@ -1564,8 +1570,8 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
 
       // the length of a string from the inner substr subtracts the start point
       // of the outer substr
-      Node len_from_inner = Rewriter::rewrite(NodeManager::currentNM()->mkNode(
-          kind::MINUS, node[0][2], start_outer));
+      Node len_from_inner =
+          Rewriter::rewrite(nm->mkNode(kind::MINUS, node[0][2], start_outer));
       Node len_from_outer = node[2];
       Node new_len;
       // take quantity that is for sure smaller than the other
@@ -1583,10 +1589,9 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
       }
       if (!new_len.isNull())
       {
-        Node new_start = NodeManager::currentNM()->mkNode(
-            kind::PLUS, start_inner, start_outer);
-        Node ret = NodeManager::currentNM()->mkNode(
-            kind::STRING_SUBSTR, node[0][0], new_start, new_len);
+        Node new_start = nm->mkNode(kind::PLUS, start_inner, start_outer);
+        Node ret =
+            nm->mkNode(kind::STRING_SUBSTR, node[0][0], new_start, new_len);
         return returnRewrite(node, ret, "ss-combine");
       }
     }
@@ -2377,7 +2382,7 @@ bool TheoryStringsRewriter::stripSymbolicLength(std::vector<Node>& n1,
       if (n1[sindex_use].isConst())
       {
         // could strip part of a constant
-        Node lowerBound = getConstantArithBound(curr);
+        Node lowerBound = getConstantArithBound(Rewriter::rewrite(curr));
         if (!lowerBound.isNull())
         {
           Assert(lowerBound.isConst());
@@ -2991,6 +2996,112 @@ bool TheoryStringsRewriter::checkEntailArith(Node a, bool strict)
   }
 }
 
+bool TheoryStringsRewriter::checkEntailArithWithEqAssumption(Node assumption,
+                                                             Node a,
+                                                             bool strict)
+{
+  Assert(assumption.getKind() == kind::EQUAL);
+  Assert(Rewriter::rewrite(assumption) == assumption);
+
+  // Find candidates variables to compute substitutions for
+  std::unordered_set<Node, NodeHashFunction> candVars;
+  std::vector<Node> toVisit = {assumption};
+  while (!toVisit.empty())
+  {
+    Node curr = toVisit.back();
+    toVisit.pop_back();
+
+    if (curr.getKind() == kind::PLUS || curr.getKind() == kind::MULT
+        || curr.getKind() == kind::MINUS || curr.getKind() == kind::EQUAL)
+    {
+      for (const auto& currChild : curr)
+      {
+        toVisit.push_back(currChild);
+      }
+    }
+    else if (curr.getKind() == kind::VARIABLE
+             && Theory::theoryOf(curr) == THEORY_ARITH)
+    {
+      candVars.insert(curr);
+    }
+  }
+
+  // Check if any of the candidate variables are in n
+  Node v;
+  Assert(toVisit.empty());
+  toVisit.push_back(a);
+  while (!toVisit.empty())
+  {
+    Node curr = toVisit.back();
+    toVisit.pop_back();
+
+    for (const auto& currChild : curr)
+    {
+      toVisit.push_back(currChild);
+    }
+
+    if (curr.getKind() == kind::VARIABLE
+        && Theory::theoryOf(curr) == THEORY_ARITH
+        && candVars.find(curr) != candVars.end())
+    {
+      v = curr;
+      break;
+    }
+  }
+
+  if (v.isNull())
+  {
+    // No suitable candidate found
+    return false;
+  }
+
+  Node solution = ArithMSum::solveEqualityFor(assumption, v);
+  if (solution.isNull())
+  {
+    // Could not solve for v
+    return false;
+  }
+
+  a = a.substitute(TNode(v), TNode(solution));
+  return checkEntailArith(a, strict);
+}
+
+bool TheoryStringsRewriter::checkEntailArithWithAssumption(Node assumption,
+                                                           Node a,
+                                                           Node b,
+                                                           bool strict)
+{
+  // TODO: Add support for inequality assumptions.
+  Assert(assumption.getKind() == kind::EQUAL);
+  Assert(Rewriter::rewrite(assumption) == assumption);
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  Node diff = nm->mkNode(kind::MINUS, a, b);
+  return checkEntailArithWithEqAssumption(assumption, diff, strict);
+}
+
+bool TheoryStringsRewriter::checkEntailArithWithAssumptions(
+    std::vector<Node> assumptions, Node a, Node b, bool strict)
+{
+  // TODO: We currently try to show the entailment with each assumption
+  // independently. In the future, we should make better use of multiple
+  // assumptions.
+  bool res = false;
+  for (const auto& assumption : assumptions)
+  {
+    Assert(assumption.getKind() == kind::EQUAL);
+    Assert(Rewriter::rewrite(assumption) == assumption);
+
+    if (checkEntailArithWithAssumption(assumption, a, b, strict))
+    {
+      res = true;
+      break;
+    }
+  }
+  return res;
+}
+
 Node TheoryStringsRewriter::getConstantArithBound(Node a, bool isLower)
 {
   Assert(Rewriter::rewrite(a) == a);
index 3aaf3eab735a7440c2f34ad4aba5354633704ba5..31ad1406ad47220123c74fedf5ff9a41314d5847 100644 (file)
@@ -356,6 +356,56 @@ private:
    * Returns true if it is always the case that a >= 0.
    */
   static bool checkEntailArith(Node a, bool strict = false);
+
+  /**
+   * Checks whether assumption |= a >= 0 (if strict is false) or
+   * assumption |= a > 0 (if strict is true), where assumption is an equality
+   * assumption. The assumption must be in rewritten form.
+   *
+   * Example:
+   *
+   * checkEntailArithWithEqAssumption(x + (str.len y) = 0, -x, false) = true
+   *
+   * Because: x = -(str.len y), so -x >= 0 --> (str.len y) >= 0 --> true
+   */
+  static bool checkEntailArithWithEqAssumption(Node assumption,
+                                               Node a,
+                                               bool strict = false);
+
+  /**
+   * Checks whether assumption |= a >= b (if strict is false) or
+   * assumption |= a > b (if strict is true). The function returns true if it
+   * can be shown that the entailment holds and false otherwise. Assumption
+   * must be in rewritten form and an equality assumption.
+   *
+   * Example:
+   *
+   * checkEntailArithWithAssumption(x + (str.len y) = 0, 0, x, false) = true
+   *
+   * Because: x = -(str.len y), so 0 >= x --> 0 >= -(str.len y) --> true
+   */
+  static bool checkEntailArithWithAssumption(Node assumption,
+                                             Node a,
+                                             Node b,
+                                             bool strict = false);
+
+  /**
+   * Checks whether assumptions |= a >= b (if strict is false) or
+   * assumptions |= a > b (if strict is true). The function returns true if it
+   * can be shown that the entailment holds and false otherwise. Assumptions
+   * must be in rewritten form and must be equality assumptions.
+   *
+   * Example:
+   *
+   * checkEntailArithWithAssumptions([x + (str.len y) = 0], 0, x, false) = true
+   *
+   * Because: x = -(str.len y), so 0 >= x --> 0 >= -(str.len y) --> true
+   */
+  static bool checkEntailArithWithAssumptions(std::vector<Node> assumptions,
+                                              Node a,
+                                              Node b,
+                                              bool strict = false);
+
   /** get arithmetic lower bound
    * If this function returns a non-null Node ret,
    * then ret is a rational constant and
index 3d5f2702dec1554adbedcd1592b17ea2d4425f87..b445583ac8219cf734d44bdf6a180139cd50ebf8 100644 (file)
@@ -7,14 +7,15 @@ UNIT_TESTS = \
 if WHITE_AND_BLACK_TESTS
 UNIT_TESTS += \
        theory/logic_info_white \
-       theory/theory_engine_white \
-       theory/theory_black \
-       theory/theory_white \
        theory/theory_arith_white \
-       theory/theory_bv_white \
+       theory/theory_black \
        theory/theory_bv_bvgauss_white \
+       theory/theory_bv_white \
+       theory/theory_engine_white \
        theory/theory_quantifiers_bv_instantiator_white \
        theory/theory_quantifiers_bv_inverter_white \
+       theory/theory_strings_rewriter_white \
+       theory/theory_white \
        theory/type_enumerator_white \
        expr/node_white \
        expr/node_black \
diff --git a/test/unit/theory/theory_strings_rewriter_white.h b/test/unit/theory/theory_strings_rewriter_white.h
new file mode 100644 (file)
index 0000000..a1e878f
--- /dev/null
@@ -0,0 +1,136 @@
+/*********************                                                        */
+/*! \file theory_strings_rewriter_white.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andres Noetzli
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Unit tests for the strings rewriter
+ **
+ ** Unit tests for the strings rewriter.
+ **/
+
+#include "expr/node.h"
+#include "expr/node_manager.h"
+#include "smt/smt_engine.h"
+#include "smt/smt_engine_scope.h"
+#include "theory/rewriter.h"
+#include "theory/strings/theory_strings_rewriter.h"
+
+#include <cxxtest/TestSuite.h>
+#include <vector>
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::theory::strings;
+
+class TheoryStringsRewriterWhite : public CxxTest::TestSuite
+{
+  ExprManager *d_em;
+  NodeManager *d_nm;
+  SmtEngine *d_smt;
+  SmtScope *d_scope;
+
+ public:
+  TheoryStringsRewriterWhite() {}
+
+  void setUp()
+  {
+    Options opts;
+    opts.setOutputLanguage(language::output::LANG_SMTLIB_V2);
+    d_em = new ExprManager(opts);
+    d_nm = NodeManager::fromExprManager(d_em);
+    d_smt = new SmtEngine(d_em);
+    d_scope = new SmtScope(d_smt);
+  }
+
+  void tearDown()
+  {
+    delete d_scope;
+    delete d_smt;
+    delete d_em;
+  }
+
+  void testCheckEntailArithWithAssumption()
+  {
+    TypeNode intType = d_nm->integerType();
+    TypeNode strType = d_nm->stringType();
+
+    Node x = d_nm->mkVar("x", intType);
+    Node y = d_nm->mkVar("y", strType);
+    Node z = d_nm->mkVar("z", intType);
+
+    Node zero = d_nm->mkConst(Rational(0));
+
+    Node slen_y = d_nm->mkNode(kind::STRING_LENGTH, y);
+    Node x_plus_slen_y = d_nm->mkNode(kind::PLUS, x, slen_y);
+    Node x_plus_slen_y_eq_zero =
+        Rewriter::rewrite(d_nm->mkNode(kind::EQUAL, x_plus_slen_y, zero));
+
+    // x + (str.len y) = 0 /\ 0 >= x --> true
+    TS_ASSERT(TheoryStringsRewriter::checkEntailArithWithAssumption(
+        x_plus_slen_y_eq_zero, zero, x, false));
+
+    // x + (str.len y) = 0 /\ 0 > x --> false
+    TS_ASSERT(!TheoryStringsRewriter::checkEntailArithWithAssumption(
+        x_plus_slen_y_eq_zero, zero, x, true));
+
+    Node x_plus_slen_y_plus_z_eq_zero = Rewriter::rewrite(d_nm->mkNode(
+        kind::EQUAL, d_nm->mkNode(kind::PLUS, x_plus_slen_y, z), zero));
+
+    // x + (str.len y) + z = 0 /\ 0 > x --> false
+    TS_ASSERT(!TheoryStringsRewriter::checkEntailArithWithAssumption(
+        x_plus_slen_y_plus_z_eq_zero, zero, x, true));
+
+    Node x_plus_slen_y_plus_slen_y_eq_zero = Rewriter::rewrite(d_nm->mkNode(
+        kind::EQUAL, d_nm->mkNode(kind::PLUS, x_plus_slen_y, slen_y), zero));
+
+    // x + (str.len y) + (str.len y) = 0 /\ 0 >= x --> true
+    TS_ASSERT(TheoryStringsRewriter::checkEntailArithWithAssumption(
+        x_plus_slen_y_plus_slen_y_eq_zero, zero, x, false));
+  }
+
+  void testRewriteSubstr()
+  {
+    TypeNode intType = d_nm->integerType();
+    TypeNode strType = d_nm->stringType();
+
+    Node empty = d_nm->mkConst(::CVC4::String(""));
+    Node a = d_nm->mkConst(::CVC4::String("A"));
+    Node s = d_nm->mkVar("s", strType);
+    Node x = d_nm->mkVar("x", intType);
+    Node y = d_nm->mkVar("y", intType);
+
+    // (str.substr "A" x x) --> ""
+    Node n = d_nm->mkNode(kind::STRING_SUBSTR, a, x, x);
+    Node res = TheoryStringsRewriter::rewriteSubstr(n);
+    TS_ASSERT_EQUALS(res, empty);
+
+    // (str.substr "A" (+ x 1) x) -> ""
+    n = d_nm->mkNode(kind::STRING_SUBSTR,
+                     a,
+                     d_nm->mkNode(kind::PLUS, x, d_nm->mkConst(Rational(1))),
+                     x);
+    res = TheoryStringsRewriter::rewriteSubstr(n);
+    TS_ASSERT_EQUALS(res, empty);
+
+    // (str.substr "A" (+ x (str.len s2)) x) -> ""
+    n = d_nm->mkNode(
+        kind::STRING_SUBSTR,
+        a,
+        d_nm->mkNode(kind::PLUS, x, d_nm->mkNode(kind::STRING_LENGTH, s)),
+        x);
+    res = TheoryStringsRewriter::rewriteSubstr(n);
+    TS_ASSERT_EQUALS(res, empty);
+
+    // (str.substr "A" x y) -> (str.substr "A" x y)
+    n = d_nm->mkNode(kind::STRING_SUBSTR, a, x, y);
+    res = TheoryStringsRewriter::rewriteSubstr(n);
+    TS_ASSERT_EQUALS(res, n);
+  }
+};