Strings: Strengthen multiset reasoning (#2817)
authorAndres Noetzli <andres.noetzli@gmail.com>
Wed, 23 Jan 2019 02:47:08 +0000 (18:47 -0800)
committerGitHub <noreply@github.com>
Wed, 23 Jan 2019 02:47:08 +0000 (18:47 -0800)
This commit introduces three helper methods for performing multiset
reasoning: an entailment check whether a term is always a strict subset
of another term in the multiset domain (`checkEntailMultisetSubset()`),
a check whether a string term is always homogeneous
(`checkEntailHomogeneousString()`), and an overapproximation for the
multiset domain (`getMultisetApproximation()`). It also adds unit tests
related to multiset reasoning.

src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h
test/unit/theory/theory_strings_rewriter_white.h

index e5740325035cd2ca0c4f16941f3716c59af84382..8e5e22d38423503c2b3407482f8367a1b0f634cc 100644 (file)
@@ -385,91 +385,77 @@ Node TheoryStringsRewriter::rewriteStrEqualityExt(Node node)
   // ------- homogeneous constants
   for (unsigned i = 0; i < 2; i++)
   {
-    if (node[i].isConst())
-    {
-      bool isHomogeneous = true;
-      unsigned hchar = 0;
-      String lhss = node[i].getConst<String>();
-      std::vector<unsigned> vec = lhss.getVec();
-      if (vec.size() >= 1)
-      {
-        hchar = vec[0];
-        for (unsigned j = 1, size = vec.size(); j < size; j++)
+    Node cn = checkEntailHomogeneousString(node[i]);
+    if (!cn.isNull() && cn.getConst<String>().size() > 0)
+    {
+      Assert(cn.isConst());
+      Assert(cn.getConst<String>().size() == 1);
+      unsigned hchar = cn.getConst<String>().front();
+
+      // The operands of the concat on each side of the equality without
+      // constant strings
+      std::vector<Node> trimmed[2];
+      // Counts the number of `hchar`s on each side
+      size_t numHChars[2] = {0, 0};
+      for (size_t j = 0; j < 2; j++)
+      {
+        // Sort the operands of the concats on both sides of the equality
+        // (since both sides may only contain one char, the order does not
+        // matter)
+        std::sort(c[j].begin(), c[j].end());
+        for (const Node& cc : c[j])
         {
-          if (vec[j] != hchar)
+          if (cc.getKind() == CONST_STRING)
           {
-            isHomogeneous = false;
-            break;
-          }
-        }
-      }
-      if (isHomogeneous)
-      {
-        std::sort(c[1 - i].begin(), c[1 - i].end());
-        std::vector<Node> trimmed;
-        unsigned rmChar = 0;
-        for (unsigned j = 0, size = c[1 - i].size(); j < size; j++)
-        {
-          if (c[1 - i][j].isConst())
-          {
-            // process the constant : either we have a conflict, or we
-            // drop an equal number of constants on the LHS
-            std::vector<unsigned> vecj =
-                c[1 - i][j].getConst<String>().getVec();
-            for (unsigned k = 0, sizev = vecj.size(); k < sizev; k++)
+            // Count the number of `hchar`s in the string constant and make
+            // sure that all chars are `hchar`s
+            std::vector<unsigned> veccc = cc.getConst<String>().getVec();
+            for (size_t k = 0, size = veccc.size(); k < size; k++)
             {
-              bool conflict = false;
-              if (vec.empty())
-              {
-                // e.g. "" = x ++ "A" ---> false
-                conflict = true;
-              }
-              else if (vecj[k] != hchar)
-              {
-                // e.g. "AA" = x ++ "B" ---> false
-                conflict = true;
-              }
-              else
+              if (veccc[k] != hchar)
               {
-                rmChar++;
-                if (rmChar > lhss.size())
-                {
-                  // e.g. "AA" = x ++ "AAA" ---> false
-                  conflict = true;
-                }
-              }
-              if (conflict)
-              {
-                // The three conflict cases should mostly should be taken
-                // care of by multiset reasoning in the strings rewriter,
-                // but we recognize this conflict just in case.
+                // This conflict case should mostly should be taken care of by
+                // multiset reasoning in the strings rewriter, but we recognize
+                // this conflict just in case.
                 new_ret = nm->mkConst(false);
-                return returnRewrite(node, new_ret, "string-eq-const-conflict");
+                return returnRewrite(
+                    node, new_ret, "string-eq-const-conflict-non-homog");
               }
+              numHChars[j]++;
             }
           }
           else
           {
-            trimmed.push_back(c[1 - i][j]);
+            trimmed[j].push_back(cc);
           }
         }
-        Node lhs = node[i];
-        if (rmChar > 0)
-        {
-          Assert(lhss.size() >= rmChar);
-          // we trimmed
-          lhs = nm->mkConst(lhss.substr(0, lhss.size() - rmChar));
-        }
-        Node ss = mkConcat(STRING_CONCAT, trimmed);
-        if (lhs != node[i] || ss != node[1 - i])
+      }
+
+      // We have to remove the same number of `hchar`s from both sides, so the
+      // side with less `hchar`s determines how many we can remove
+      size_t trimmedConst = std::min(numHChars[0], numHChars[1]);
+      for (size_t j = 0; j < 2; j++)
+      {
+        size_t diff = numHChars[j] - trimmedConst;
+        if (diff != 0)
         {
-          // e.g.
-          //  "AA" = y ++ x ---> "AA" = x ++ y if x < y
-          //  "AAA" = y ++ "A" ++ z ---> "AA" = y ++ z
-          new_ret = lhs.eqNode(ss);
-          node = returnRewrite(node, new_ret, "str-eq-homog-const");
+          // Add a constant string to the side with more `hchar`s to restore
+          // the difference in number of `hchar`s
+          std::vector<unsigned> vec(diff, hchar);
+          trimmed[j].push_back(nm->mkConst(String(vec)));
         }
       }
+
+      Node lhs = mkConcat(STRING_CONCAT, trimmed[i]);
+      Node ss = mkConcat(STRING_CONCAT, trimmed[1 - i]);
+      if (lhs != node[i] || ss != node[1 - i])
+      {
+        // e.g.
+        //  "AA" = y ++ x ---> "AA" = x ++ y if x < y
+        //  "AAA" = y ++ "A" ++ z ---> "AA" = y ++ z
+        new_ret = lhs.eqNode(ss);
+        node = returnRewrite(node, new_ret, "str-eq-homog-const");
+      }
     }
   }
 
@@ -2025,84 +2011,12 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
   //   For example, contains( str.++( x, "b" ), str.++( "a", x ) ) ---> false
   //   since the number of a's in the second argument is greater than the number
   //   of a's in the first argument
-  std::map<Node, unsigned> num_nconst[2];
-  std::map<Node, unsigned> num_const[2];
-  for (unsigned j = 0; j < 2; j++)
+  if (checkEntailMultisetSubset(node[0], node[1]))
   {
-    std::vector<Node>& ncj = j == 0 ? nc1 : nc2;
-    for (const Node& cc : ncj)
-    {
-      if (cc.isConst())
-      {
-        num_const[j][cc]++;
-      }
-      else
-      {
-        num_nconst[j][cc]++;
-      }
-    }
-  }
-  bool ms_success = true;
-  for (std::pair<const Node, unsigned>& nncp : num_nconst[0])
-  {
-    if (nncp.second > num_nconst[1][nncp.first])
-    {
-      ms_success = false;
-      break;
-    }
-  }
-  if (ms_success)
-  {
-    // count the number of constant characters in the first argument
-    std::map<Node, unsigned> count_const[2];
-    std::vector<Node> chars;
-    for (unsigned j = 0; j < 2; j++)
-    {
-      for (std::pair<const Node, unsigned>& ncp : num_const[j])
-      {
-        Node cn = ncp.first;
-        Assert(cn.isConst());
-        std::vector<unsigned> cc_vec;
-        const std::vector<unsigned>& cvec = cn.getConst<String>().getVec();
-        for (unsigned i = 0, size = cvec.size(); i < size; i++)
-        {
-          // make the character
-          cc_vec.clear();
-          cc_vec.insert(cc_vec.end(), cvec.begin() + i, cvec.begin() + i + 1);
-          Node ch = NodeManager::currentNM()->mkConst(String(cc_vec));
-          count_const[j][ch] += ncp.second;
-          if (std::find(chars.begin(), chars.end(), ch) == chars.end())
-          {
-            chars.push_back(ch);
-          }
-        }
-      }
-    }
-    Trace("strings-rewrite-multiset") << "For " << node << " : " << std::endl;
-    for (const Node& ch : chars)
-    {
-      Trace("strings-rewrite-multiset") << "  # occurrences of substring ";
-      Trace("strings-rewrite-multiset") << ch << " in arguments is ";
-      Trace("strings-rewrite-multiset") << count_const[0][ch] << " / "
-                                        << count_const[1][ch] << std::endl;
-      if (count_const[0][ch] < count_const[1][ch])
-      {
-        Node ret = NodeManager::currentNM()->mkConst(false);
-        return returnRewrite(node, ret, "ctn-mset-nss");
-      }
-    }
-
-    // TODO (#1180): count the number of 2,3,4,.. character substrings
-    // for example:
-    // str.contains( str.++( x, "cbabc" ), str.++( "cabbc", x ) ) ---> false
-    // since the second argument contains more occurrences of "bb".
-    // note this is orthogonal reasoning to inductive reasoning
-    // via regular membership reduction in Liang et al CAV 2015.
+    Node ret = nm->mkConst(false);
+    return returnRewrite(node, ret, "ctn-mset-nss");
   }
 
-  // TODO (#1180): abstract interpretation with multi-set domain
-  // to show first argument is a strict subset of second argument
-
   if (checkEntailArith(len_n2, len_n1, false))
   {
     // len( n2 ) >= len( n1 ) => contains( n1, n2 ) ---> n1 = n2
@@ -4488,6 +4402,162 @@ void TheoryStringsRewriter::getArithApproximations(Node a,
   Trace("strings-ent-approx-debug") << "Return " << approx.size() << std::endl;
 }
 
+bool TheoryStringsRewriter::checkEntailMultisetSubset(Node a, Node b)
+{
+  NodeManager* nm = NodeManager::currentNM();
+
+  std::vector<Node> avec;
+  getConcat(getMultisetApproximation(a), avec);
+  std::vector<Node> bvec;
+  getConcat(b, bvec);
+
+  std::map<Node, unsigned> num_nconst[2];
+  std::map<Node, unsigned> num_const[2];
+  for (unsigned j = 0; j < 2; j++)
+  {
+    std::vector<Node>& jvec = j == 0 ? avec : bvec;
+    for (const Node& cc : jvec)
+    {
+      if (cc.isConst())
+      {
+        num_const[j][cc]++;
+      }
+      else
+      {
+        num_nconst[j][cc]++;
+      }
+    }
+  }
+  bool ms_success = true;
+  for (std::pair<const Node, unsigned>& nncp : num_nconst[0])
+  {
+    if (nncp.second > num_nconst[1][nncp.first])
+    {
+      ms_success = false;
+      break;
+    }
+  }
+  if (ms_success)
+  {
+    // count the number of constant characters in the first argument
+    std::map<Node, unsigned> count_const[2];
+    std::vector<Node> chars;
+    for (unsigned j = 0; j < 2; j++)
+    {
+      for (std::pair<const Node, unsigned>& ncp : num_const[j])
+      {
+        Node cn = ncp.first;
+        Assert(cn.isConst());
+        std::vector<unsigned> cc_vec;
+        const std::vector<unsigned>& cvec = cn.getConst<String>().getVec();
+        for (unsigned i = 0, size = cvec.size(); i < size; i++)
+        {
+          // make the character
+          cc_vec.clear();
+          cc_vec.insert(cc_vec.end(), cvec.begin() + i, cvec.begin() + i + 1);
+          Node ch = nm->mkConst(String(cc_vec));
+          count_const[j][ch] += ncp.second;
+          if (std::find(chars.begin(), chars.end(), ch) == chars.end())
+          {
+            chars.push_back(ch);
+          }
+        }
+      }
+    }
+    Trace("strings-entail-ms-ss")
+        << "For " << a << " and " << b << " : " << std::endl;
+    for (const Node& ch : chars)
+    {
+      Trace("strings-entail-ms-ss") << "  # occurrences of substring ";
+      Trace("strings-entail-ms-ss") << ch << " in arguments is ";
+      Trace("strings-entail-ms-ss")
+          << count_const[0][ch] << " / " << count_const[1][ch] << std::endl;
+      if (count_const[0][ch] < count_const[1][ch])
+      {
+        return true;
+      }
+    }
+
+    // TODO (#1180): count the number of 2,3,4,.. character substrings
+    // for example:
+    // str.contains( str.++( x, "cbabc" ), str.++( "cabbc", x ) ) ---> false
+    // since the second argument contains more occurrences of "bb".
+    // note this is orthogonal reasoning to inductive reasoning
+    // via regular membership reduction in Liang et al CAV 2015.
+  }
+  return false;
+}
+
+Node TheoryStringsRewriter::checkEntailHomogeneousString(Node a)
+{
+  NodeManager* nm = NodeManager::currentNM();
+
+  std::vector<Node> avec;
+  getConcat(getMultisetApproximation(a), avec);
+
+  bool cValid = false;
+  unsigned c = 0;
+  for (const Node& ac : avec)
+  {
+    if (ac.getKind() == CONST_STRING)
+    {
+      std::vector<unsigned> acv = ac.getConst<String>().getVec();
+      for (unsigned cc : acv)
+      {
+        if (!cValid)
+        {
+          cValid = true;
+          c = cc;
+        }
+        else if (c != cc)
+        {
+          // Found a different character
+          return Node::null();
+        }
+      }
+    }
+    else
+    {
+      // Could produce a different character
+      return Node::null();
+    }
+  }
+
+  if (!cValid)
+  {
+    return nm->mkConst(String(""));
+  }
+
+  std::vector<unsigned> cv = {c};
+  return nm->mkConst(String(cv));
+}
+
+Node TheoryStringsRewriter::getMultisetApproximation(Node a)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  if (a.getKind() == STRING_SUBSTR)
+  {
+    return a[0];
+  }
+  else if (a.getKind() == STRING_STRREPL)
+  {
+    return getMultisetApproximation(nm->mkNode(STRING_CONCAT, a[0], a[2]));
+  }
+  else if (a.getKind() == STRING_CONCAT)
+  {
+    NodeBuilder<> nb(STRING_CONCAT);
+    for (const Node& ac : a)
+    {
+      nb << getMultisetApproximation(ac);
+    }
+    return nb.constructNode();
+  }
+  else
+  {
+    return a;
+  }
+}
+
 bool TheoryStringsRewriter::checkEntailArithWithEqAssumption(Node assumption,
                                                              Node a,
                                                              bool strict)
index e4b76036d29e5e1b9eedc2c42e22b5ff7bff51c1..8b0072f52671383db6f688c625df1b58a162930c 100644 (file)
@@ -540,6 +540,52 @@ class TheoryStringsRewriter {
                                      std::vector<Node>& approx,
                                      bool isOverApprox = false);
 
+  /**
+   * Checks whether it is always true that `a` is a strict subset of `b` in the
+   * multiset domain.
+   *
+   * Examples:
+   *
+   * a = (str.++ "A" x), b = (str.++ "A" x "B") ---> true
+   * a = (str.++ "A" x), b = (str.++ "B" x "AA") ---> true
+   * a = (str.++ "A" x), b = (str.++ "B" y "AA") ---> false
+   *
+   * @param a The term for which it should be checked if it is a strict subset
+   * of `b` in the multiset domain
+   * @param b The term for which it should be checked if it is a strict
+   * superset of `a` in the multiset domain
+   * @return True if it is always the case that `a` is a strict subset of `b`,
+   * false otherwise.
+   */
+  static bool checkEntailMultisetSubset(Node a, Node b);
+
+  /**
+   * Returns a character `c` if it is always the case that str.in.re(a, c*),
+   * i.e. if all possible values of `a` only consist of `c` characters, and the
+   * null node otherwise. If `a` is the empty string, the function returns an
+   * empty string.
+   *
+   * @param a The node to check for homogeneity
+   * @return If `a` is homogeneous, the only character that it may contain, the
+   * empty string if `a` is empty, and the null node otherwise
+   */
+  static Node checkEntailHomogeneousString(Node a);
+
+  /**
+   * Simplifies a given node `a` s.t. the result is a concatenation of string
+   * terms that can be interpreted as a multiset and which contains all
+   * multisets that `a` could form.
+   *
+   * Examples:
+   *
+   * (str.substr "AA" 0 n) ---> "AA"
+   * (str.replace "AAA" x "BB") ---> (str.++ "AAA" "BB")
+   *
+   * @param a The node to simplify
+   * @return A concatenation that can be interpreted as a multiset
+   */
+  static Node getMultisetApproximation(Node a);
+
   /**
    * Checks whether assumption |= a >= 0 (if strict is false) or
    * assumption |= a > 0 (if strict is true), where assumption is an equality
index 8139f1c2e01028a831c2f82e1baa177d651df271..59d36d9e83cc01090d2a39f1f26962616150c905 100644 (file)
@@ -866,6 +866,14 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
                        d_nm->mkNode(kind::STRING_STRREPL, x, a, empty),
                        a);
     sameNormalForm(lhs, rhs);
+
+    {
+      // (str.contains (str.++ x "A") (str.++ "B" x)) ---> false
+      Node ctn = d_nm->mkNode(kind::STRING_STRCTN,
+                              d_nm->mkNode(kind::STRING_CONCAT, x, a),
+                              d_nm->mkNode(kind::STRING_CONCAT, b, x));
+      sameNormalForm(ctn, f);
+    }
   }
 
   void testInferEqsFromContains()
@@ -929,7 +937,7 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
 
     // Same normal form for:
     //
-    // (str.prefix x (str.++ x y))
+    // (str.prefix (str.++ x y) x)
     //
     // (= y "")
     Node p_xy = d_nm->mkNode(kind::STRING_PREFIX, xy, x);
@@ -938,16 +946,14 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
 
     // Same normal form for:
     //
-    // (str.suffix x (str.++ x x))
+    // (str.suffix (str.++ x x) x)
     //
     // (= x "")
     Node p_xx = d_nm->mkNode(kind::STRING_SUFFIX, xx, x);
     Node empty_x = d_nm->mkNode(kind::EQUAL, x, empty);
     sameNormalForm(p_xx, empty_x);
 
-    // (str.suffix x (str.++ x x "A")) --> false
-    //
-    // (= x "")
+    // (str.suffix x (str.++ x x "A")) ---> false
     Node p_xxa = d_nm->mkNode(kind::STRING_SUFFIX, xxa, x);
     sameNormalForm(p_xxa, f);
   }
@@ -959,6 +965,7 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
 
     Node empty = d_nm->mkConst(::CVC4::String(""));
     Node a = d_nm->mkConst(::CVC4::String("A"));
+    Node aaa = d_nm->mkConst(::CVC4::String("AAA"));
     Node b = d_nm->mkConst(::CVC4::String("B"));
     Node x = d_nm->mkVar("x", strType);
     Node y = d_nm->mkVar("y", strType);
@@ -1075,16 +1082,88 @@ class TheoryStringsRewriterWhite : public CxxTest::TestSuite
     Node eq_x = d_nm->mkNode(kind::EQUAL, x, empty);
     sameNormalForm(eq_repl, eq_x);
 
-    // Same normal form for:
-    //
-    // (= (str.replace y "A" "B") "B")
-    //
-    // (= (str.replace y "B" "A") "A")
-    Node lhs = d_nm->mkNode(
-        kind::EQUAL, d_nm->mkNode(kind::STRING_STRREPL, x, a, b), b);
-    Node rhs = d_nm->mkNode(
-        kind::EQUAL, d_nm->mkNode(kind::STRING_STRREPL, x, b, a), a);
-    sameNormalForm(lhs, rhs);
+    {
+      // Same normal form for:
+      //
+      // (= (str.replace y "A" "B") "B")
+      //
+      // (= (str.replace y "B" "A") "A")
+      Node lhs = d_nm->mkNode(
+          kind::EQUAL, d_nm->mkNode(kind::STRING_STRREPL, x, a, b), b);
+      Node rhs = d_nm->mkNode(
+          kind::EQUAL, d_nm->mkNode(kind::STRING_STRREPL, x, b, a), a);
+      sameNormalForm(lhs, rhs);
+    }
+
+    {
+      // Same normal form for:
+      //
+      // (= (str.++ x "A" y) (str.++ "A" "A" (str.substr "AAA" 0 n)))
+      //
+      // (= (str.++ y x) (str.++ (str.substr "AAA" 0 n) "A"))
+      Node lhs = d_nm->mkNode(
+          kind::EQUAL,
+          d_nm->mkNode(kind::STRING_CONCAT, x, a, y),
+          d_nm->mkNode(kind::STRING_CONCAT,
+                       a,
+                       a,
+                       d_nm->mkNode(kind::STRING_SUBSTR, aaa, zero, n)));
+      Node rhs = d_nm->mkNode(
+          kind::EQUAL,
+          d_nm->mkNode(kind::STRING_CONCAT, x, y),
+          d_nm->mkNode(kind::STRING_CONCAT,
+                       d_nm->mkNode(kind::STRING_SUBSTR, aaa, zero, n),
+                       a));
+      sameNormalForm(lhs, rhs);
+    }
+
+    {
+      // Same normal form for:
+      //
+      // (= (str.++ "A" x) "A")
+      //
+      // (= x "")
+      Node lhs =
+          d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::STRING_CONCAT, a, x), a);
+      Node rhs = d_nm->mkNode(kind::EQUAL, x, empty);
+      sameNormalForm(lhs, rhs);
+    }
+
+    {
+      // (= (str.++ x "A") "") ---> false
+      Node eq = d_nm->mkNode(
+          kind::EQUAL, d_nm->mkNode(kind::STRING_CONCAT, x, a), empty);
+      sameNormalForm(eq, f);
+    }
+
+    {
+      // (= (str.++ x "B") "AAA") ---> false
+      Node eq = d_nm->mkNode(
+          kind::EQUAL, d_nm->mkNode(kind::STRING_CONCAT, x, b), aaa);
+      sameNormalForm(eq, f);
+    }
+
+    {
+      // (= (str.++ x "AAA") "A") ---> false
+      Node eq = d_nm->mkNode(
+          kind::EQUAL, d_nm->mkNode(kind::STRING_CONCAT, x, aaa), a);
+      sameNormalForm(eq, f);
+    }
+
+    {
+      // (= (str.++ "AAA" (str.substr "A" 0 n)) (str.++ x "B")) ---> false
+      Node eq = d_nm->mkNode(
+          kind::EQUAL,
+          d_nm->mkNode(
+              kind::STRING_CONCAT,
+              aaa,
+              d_nm->mkNode(kind::STRING_CONCAT,
+                           a,
+                           a,
+                           d_nm->mkNode(kind::STRING_SUBSTR, x, zero, n))),
+          d_nm->mkNode(kind::STRING_CONCAT, x, b));
+      sameNormalForm(eq, f);
+    }
   }
 
  private: