Refactor strings equality rewriting (#2513)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 25 Sep 2018 02:28:31 +0000 (21:28 -0500)
committerAndres Noetzli <andres.noetzli@gmail.com>
Tue, 25 Sep 2018 02:28:31 +0000 (19:28 -0700)
This moves the extended rewrites for string equality to the main strings rewriter as a function rewriteEqualityExt, and makes this function called on every equality that is generated (from non-equalities) by our rewriter.

src/theory/quantifiers/extended_rewrite.cpp
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h

index df82e075037623153916a0d8767a53ee25c2ecb4..e64e1b7b23de7e64e67e22c2859c90475ac573e9 100644 (file)
@@ -1671,208 +1671,6 @@ Node ExtendedRewriter::extendedRewriteStrings(Node ret)
   Node new_ret;
   Trace("q-ext-rewrite-debug")
       << "Extended rewrite strings : " << ret << std::endl;
-  NodeManager* nm = NodeManager::currentNM();
-  if (ret.getKind() == EQUAL)
-  {
-    if (ret[0].getType().isString())
-    {
-      std::vector<Node> c[2];
-      for (unsigned i = 0; i < 2; i++)
-      {
-        strings::TheoryStringsRewriter::getConcat(ret[i], c[i]);
-      }
-
-      // ------- equality unification
-      bool changed = false;
-      for (unsigned i = 0; i < 2; i++)
-      {
-        while (!c[0].empty() && !c[1].empty() && c[0].back() == c[1].back())
-        {
-          c[0].pop_back();
-          c[1].pop_back();
-          changed = true;
-        }
-        // splice constants
-        if (!c[0].empty() && !c[1].empty() && c[0].back().isConst()
-            && c[1].back().isConst())
-        {
-          String cs[2];
-          for (unsigned j = 0; j < 2; j++)
-          {
-            cs[j] = c[j].back().getConst<String>();
-          }
-          unsigned larger = cs[0].size() > cs[1].size() ? 0 : 1;
-          unsigned smallerSize = cs[1 - larger].size();
-          if (cs[1 - larger]
-              == (i == 0 ? cs[larger].suffix(smallerSize)
-                         : cs[larger].prefix(smallerSize)))
-          {
-            unsigned sizeDiff = cs[larger].size() - smallerSize;
-            c[larger][c[larger].size() - 1] =
-                nm->mkConst(i == 0 ? cs[larger].prefix(sizeDiff)
-                                   : cs[larger].suffix(sizeDiff));
-            c[1 - larger].pop_back();
-            changed = true;
-          }
-        }
-        for (unsigned j = 0; j < 2; j++)
-        {
-          std::reverse(c[j].begin(), c[j].end());
-        }
-      }
-      if (changed)
-      {
-        // e.g. x++y = x++z ---> y = z, "AB" ++ x = "A" ++ y --> "B" ++ x = y
-        Node s1 = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT, c[0]);
-        Node s2 = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT, c[1]);
-        new_ret = s1.eqNode(s2);
-        debugExtendedRewrite(ret, new_ret, "string-eq-unify");
-        return new_ret;
-      }
-
-      // ------- using the contains rewriter to reduce equalities
-      Node tcontains[2];
-      bool tcontainsOneTrue = false;
-      unsigned tcontainsTrueIndex = 0;
-      for (unsigned i = 0; i < 2; i++)
-      {
-        Node tc = nm->mkNode(STRING_STRCTN, ret[i], ret[1 - i]);
-        tcontains[i] = Rewriter::rewrite(tc);
-        if (tcontains[i].isConst())
-        {
-          if (tcontains[i].getConst<bool>())
-          {
-            tcontainsOneTrue = true;
-            tcontainsTrueIndex = i;
-          }
-          else
-          {
-            new_ret = tcontains[i];
-            // if str.contains( x, y ) ---> false  then   x = y ---> false
-            // Notice we may not catch this in the rewriter for strings
-            // equality, since it only calls the specific rewriter for
-            // contains and not the full rewriter.
-            debugExtendedRewrite(ret, new_ret, "eq-contains-one-false");
-            return new_ret;
-          }
-        }
-      }
-      if (tcontainsOneTrue)
-      {
-        // if str.contains( x, y ) ---> true
-        // then x = y ---> contains( y, x )
-        new_ret = tcontains[1 - tcontainsTrueIndex];
-        debugExtendedRewrite(ret, new_ret, "eq-contains-one-true");
-        return new_ret;
-      }
-      else if (tcontains[0] == tcontains[1] && tcontains[0] != ret)
-      {
-        // if str.contains( x, y ) ---> t and str.contains( y, x ) ---> t,
-        // then x = y ---> t
-        new_ret = tcontains[0];
-        debugExtendedRewrite(ret, new_ret, "eq-dual-contains-eq");
-        return new_ret;
-      }
-
-      // ------- homogeneous constants
-      if (d_aggr)
-      {
-        for (unsigned i = 0; i < 2; i++)
-        {
-          if (ret[i].isConst())
-          {
-            bool isHomogeneous = true;
-            unsigned hchar = 0;
-            String lhss = ret[i].getConst<String>();
-            std::vector<unsigned> vec = lhss.getVec();
-            if (vec.size() > 1)
-            {
-              hchar = vec[0];
-              for (unsigned j = 1, size = vec.size(); j < size; j++)
-              {
-                if (vec[j] != hchar)
-                {
-                  isHomogeneous = false;
-                  break;
-                }
-              }
-            }
-            if (isHomogeneous)
-            {
-              std::sort(c[1 - i].begin(), c[1 - i].end());
-              std::vector<Node> trimmed;
-              unsigned rmChar = 0;
-              for (unsigned j = 0, size = c[1 - i].size(); j < size; j++)
-              {
-                if (c[1 - i][j].isConst())
-                {
-                  // process the constant : either we have a conflict, or we
-                  // drop an equal number of constants on the LHS
-                  std::vector<unsigned> vecj =
-                      c[1 - i][j].getConst<String>().getVec();
-                  for (unsigned k = 0, sizev = vecj.size(); k < sizev; k++)
-                  {
-                    bool conflict = false;
-                    if (vec.empty())
-                    {
-                      // e.g. "" = x ++ "A" ---> false
-                      conflict = true;
-                    }
-                    else if (vecj[k] != hchar)
-                    {
-                      // e.g. "AA" = x ++ "B" ---> false
-                      conflict = true;
-                    }
-                    else
-                    {
-                      rmChar++;
-                      if (rmChar > lhss.size())
-                      {
-                        // e.g. "AA" = x ++ "AAA" ---> false
-                        conflict = true;
-                      }
-                    }
-                    if (conflict)
-                    {
-                      // The three conflict cases should mostly should be taken
-                      // care of by multiset reasoning in the strings rewriter,
-                      // but we recognize this conflict just in case.
-                      new_ret = nm->mkConst(false);
-                      debugExtendedRewrite(
-                          ret, new_ret, "string-eq-const-conflict");
-                      return new_ret;
-                    }
-                  }
-                }
-                else
-                {
-                  trimmed.push_back(c[1 - i][j]);
-                }
-              }
-              Node lhs = ret[i];
-              if (rmChar > 0)
-              {
-                Assert(lhss.size() >= rmChar);
-                // we trimmed
-                lhs = nm->mkConst(lhss.substr(0, lhss.size() - rmChar));
-              }
-              Node ss = strings::TheoryStringsRewriter::mkConcat(STRING_CONCAT,
-                                                                 trimmed);
-              if (lhs != ret[i] || ss != ret[1 - i])
-              {
-                // e.g.
-                //  "AA" = y ++ x ---> "AA" = x ++ y if x < y
-                //  "AAA" = y ++ "A" ++ z ---> "AA" = y ++ z
-                new_ret = lhs.eqNode(ss);
-                debugExtendedRewrite(ret, new_ret, "string-eq-homog-const");
-                return new_ret;
-              }
-            }
-          }
-        }
-      }
-    }
-  }
 
   return new_ret;
 }
index 48b288ea3e43058e95d0eec8e70d82ccb3ef0ec9..f8bbeecf52ab883d70971acf61a6bce874379972 100644 (file)
@@ -315,10 +315,179 @@ Node TheoryStringsRewriter::rewriteEquality(Node node)
   {
     return NodeManager::currentNM()->mkNode(kind::EQUAL, node[1], node[0]);
   }
-  else
+  return node;
+}
+
+Node TheoryStringsRewriter::rewriteEqualityExt(Node node)
+{
+  Assert(node.getKind() == EQUAL);
+  if (!node[0].getType().isString())
   {
     return node;
   }
+  NodeManager* nm = NodeManager::currentNM();
+  std::vector<Node> c[2];
+  Node new_ret;
+  for (unsigned i = 0; i < 2; i++)
+  {
+    getConcat(node[i], c[i]);
+  }
+  // ------- equality unification
+  bool changed = false;
+  for (unsigned i = 0; i < 2; i++)
+  {
+    while (!c[0].empty() && !c[1].empty() && c[0].back() == c[1].back())
+    {
+      c[0].pop_back();
+      c[1].pop_back();
+      changed = true;
+    }
+    // splice constants
+    if (!c[0].empty() && !c[1].empty() && c[0].back().isConst()
+        && c[1].back().isConst())
+    {
+      String cs[2];
+      for (unsigned j = 0; j < 2; j++)
+      {
+        cs[j] = c[j].back().getConst<String>();
+      }
+      unsigned larger = cs[0].size() > cs[1].size() ? 0 : 1;
+      unsigned smallerSize = cs[1 - larger].size();
+      if (cs[1 - larger]
+          == (i == 0 ? cs[larger].suffix(smallerSize)
+                     : cs[larger].prefix(smallerSize)))
+      {
+        unsigned sizeDiff = cs[larger].size() - smallerSize;
+        c[larger][c[larger].size() - 1] = nm->mkConst(
+            i == 0 ? cs[larger].prefix(sizeDiff) : cs[larger].suffix(sizeDiff));
+        c[1 - larger].pop_back();
+        changed = true;
+      }
+    }
+    for (unsigned j = 0; j < 2; j++)
+    {
+      std::reverse(c[j].begin(), c[j].end());
+    }
+  }
+  if (changed)
+  {
+    // e.g. x++y = x++z ---> y = z, "AB" ++ x = "A" ++ y --> "B" ++ x = y
+    Node s1 = mkConcat(STRING_CONCAT, c[0]);
+    Node s2 = mkConcat(STRING_CONCAT, c[1]);
+    new_ret = s1.eqNode(s2);
+    node = returnRewrite(node, new_ret, "str-eq-unify");
+  }
+
+  // ------- homogeneous constants
+  for (unsigned i = 0; i < 2; i++)
+  {
+    if (node[i].isConst())
+    {
+      bool isHomogeneous = true;
+      unsigned hchar = 0;
+      String lhss = node[i].getConst<String>();
+      std::vector<unsigned> vec = lhss.getVec();
+      if (vec.size() > 1)
+      {
+        hchar = vec[0];
+        for (unsigned j = 1, size = vec.size(); j < size; j++)
+        {
+          if (vec[j] != hchar)
+          {
+            isHomogeneous = false;
+            break;
+          }
+        }
+      }
+      if (isHomogeneous)
+      {
+        std::sort(c[1 - i].begin(), c[1 - i].end());
+        std::vector<Node> trimmed;
+        unsigned rmChar = 0;
+        for (unsigned j = 0, size = c[1 - i].size(); j < size; j++)
+        {
+          if (c[1 - i][j].isConst())
+          {
+            // process the constant : either we have a conflict, or we
+            // drop an equal number of constants on the LHS
+            std::vector<unsigned> vecj =
+                c[1 - i][j].getConst<String>().getVec();
+            for (unsigned k = 0, sizev = vecj.size(); k < sizev; k++)
+            {
+              bool conflict = false;
+              if (vec.empty())
+              {
+                // e.g. "" = x ++ "A" ---> false
+                conflict = true;
+              }
+              else if (vecj[k] != hchar)
+              {
+                // e.g. "AA" = x ++ "B" ---> false
+                conflict = true;
+              }
+              else
+              {
+                rmChar++;
+                if (rmChar > lhss.size())
+                {
+                  // e.g. "AA" = x ++ "AAA" ---> false
+                  conflict = true;
+                }
+              }
+              if (conflict)
+              {
+                // The three conflict cases should mostly should be taken
+                // care of by multiset reasoning in the strings rewriter,
+                // but we recognize this conflict just in case.
+                new_ret = nm->mkConst(false);
+                return returnRewrite(node, new_ret, "string-eq-const-conflict");
+              }
+            }
+          }
+          else
+          {
+            trimmed.push_back(c[1 - i][j]);
+          }
+        }
+        Node lhs = node[i];
+        if (rmChar > 0)
+        {
+          Assert(lhss.size() >= rmChar);
+          // we trimmed
+          lhs = nm->mkConst(lhss.substr(0, lhss.size() - rmChar));
+        }
+        Node ss = mkConcat(STRING_CONCAT, trimmed);
+        if (lhs != node[i] || ss != node[1 - i])
+        {
+          // e.g.
+          //  "AA" = y ++ x ---> "AA" = x ++ y if x < y
+          //  "AAA" = y ++ "A" ++ z ---> "AA" = y ++ z
+          new_ret = lhs.eqNode(ss);
+          node = returnRewrite(node, new_ret, "str-eq-homog-const");
+        }
+      }
+    }
+  }
+
+  Assert(node.getKind() == EQUAL);
+
+  // Try to rewrite (= x y) into a conjunction of equalities based on length
+  // entailment.
+  //
+  // (<= (str.len x) (str.++ y1 ... yn)) AND (= x (str.++ y1 ... yn)) --->
+  //  (and (= x (str.++ y1' ... ym')) (= y1'' "") ... (= yk'' ""))
+  //
+  // where yi' and yi'' correspond to some yj and
+  //   (<= (str.len x) (str.++ y1' ... ym'))
+  for (unsigned i = 0; i < 2; i++)
+  {
+    new_ret = inferEqsFromContains(node[i], node[1 - i]);
+    if (!new_ret.isNull())
+    {
+      return returnRewrite(node, new_ret, "str-eq-conj-len-entail");
+    }
+  }
+  return node;
 }
 
 // TODO (#1180) add rewrite
@@ -1710,22 +1879,11 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
   // TODO (#1180): abstract interpretation with multi-set domain
   // to show first argument is a strict subset of second argument
 
-  // Try to rewrite (str.contains x y) into an equality or a conjunction of
-  // equalities:
-  //
-  // (str.contains x y) ---> (= x y) if (<= (str.len x) (str.len y))
-  //
-  // or more generally:
-  //
-  // (str.contains x (str.++ y1 ... yn)) --->
-  //  (and (= x (str.++ y1' ... ym')) (= y1'' "") ... (= yk'' ""))
-  //
-  // where yi' and yi'' correspond to some yj and
-  // (<= (str.len x) (str.++ y1' ... ym'))
-  Node eqs = inferEqsFromContains(node[0], node[1]);
-  if (!eqs.isNull())
+  if (checkEntailArith(len_n2, len_n1, false))
   {
-    return returnRewrite(node, eqs, "ctn-to-eqs");
+    // len( n2 ) >= len( n1 ) => contains( n1, n2 ) ---> n1 = n2
+    Node ret = node[0].eqNode(node[1]);
+    return returnRewrite(node, ret, "ctn-len-ineq-nstrict");
   }
 
   // splitting
@@ -2574,14 +2732,7 @@ Node TheoryStringsRewriter::rewritePrefixSuffix(Node n)
   // general reduction to equality + substr
   Node retNode = n[0].eqNode(
       NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR, n[1], val, lens));
-  // add length constraint if it cannot be shown by simple entailment check
-  if (!checkEntailArith(lent, lens))
-  {
-    retNode = NodeManager::currentNM()->mkNode(
-        kind::AND,
-        retNode,
-        NodeManager::currentNM()->mkNode(kind::GEQ, lent, lens));
-  }
+
   return retNode;
 }
 
@@ -3913,5 +4064,35 @@ Node TheoryStringsRewriter::returnRewrite(Node node, Node ret, const char* c)
 {
   Trace("strings-rewrite") << "Rewrite " << node << " to " << ret << " by " << c
                            << "." << std::endl;
+  // standard post-processing
+  // We rewrite (string) equalities immediately here. This allows us to forego
+  // the standard invariant on equality rewrites (that s=t must rewrite to one
+  // of { s=t, t=s, true, false } ).
+  Kind retk = ret.getKind();
+  if (retk == OR || retk == AND)
+  {
+    std::vector<Node> children;
+    bool childChanged = false;
+    for (const Node& cret : ret)
+    {
+      Node creter = cret;
+      if (cret.getKind() == EQUAL)
+      {
+        creter = rewriteEqualityExt(cret);
+      }
+      childChanged = childChanged || cret != creter;
+      children.push_back(creter);
+    }
+    if (childChanged)
+    {
+      ret = NodeManager::currentNM()->mkNode(retk, children);
+    }
+  }
+  else if (retk == EQUAL && node.getKind() != EQUAL)
+  {
+    Trace("strings-rewrite")
+        << "Apply extended equality rewrite on " << ret << std::endl;
+    ret = rewriteEqualityExt(ret);
+  }
   return ret;
 }
index 70c573d9e77ccea575d2efbc86f78d6b2e33725f..c0aa913606abcde8688f35de07c5dd5f0d528d47 100644 (file)
@@ -98,12 +98,16 @@ class TheoryStringsRewriter {
    * a is in rewritten form.
    */
   static bool checkEntailArithInternal(Node a);
-  /** return rewrite
+  /**
    * Called when node rewrites to ret.
-   * The string c indicates the justification
-   * for the rewrite, which is printed by this
-   * function for debugging.
-   * This function returns ret.
+   *
+   * The string c indicates the justification for the rewrite, which is printed
+   * by this function for debugging.
+   *
+   * If node is not an equality and ret is an equality, this method applies
+   * an additional rewrite step (rewriteEqualityExt) that performs
+   * additional rewrites on ret, after which we return the result of this call.
+   * Otherwise, this method simply returns ret.
    */
   static Node returnRewrite(Node node, Node ret, const char* c);
 
@@ -118,9 +122,18 @@ class TheoryStringsRewriter {
   /** rewrite equality
    *
    * This method returns a formula that is equivalent to the equality between
-   * two strings, given by node.
+   * two strings s = t, given by node. The result of rewrite is one of
+   * { s = t, t = s, true, false }.
    */
   static Node rewriteEquality(Node node);
+  /** rewrite equality extended
+   *
+   * This method returns a formula that is equivalent to the equality between
+   * two strings s = t, given by node. Specifically, this function performs
+   * rewrites whose conclusion is not necessarily one of
+   * { s = t, t = s, true, false }.
+   */
+  static Node rewriteEqualityExt(Node node);
   /** rewrite concat
   * This is the entry point for post-rewriting terms node of the form
   *   str.++( t1, .., tn )