Add some basic rewrites for regular expression intersection (#7629)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 12 Nov 2021 15:02:43 +0000 (09:02 -0600)
committerGitHub <noreply@github.com>
Fri, 12 Nov 2021 15:02:43 +0000 (15:02 +0000)
Adds some basic rewrites for re.inter and re.union involving children of the form (str.to_re c) for constants c.

This is towards addressing bottlenecks for Zelkova benchmarks.

src/theory/strings/rewrites.cpp
src/theory/strings/rewrites.h
src/theory/strings/sequences_rewriter.cpp

index 549d33efcbc4eda70106e303e5f4862f1a79654e..f73f78c24bc9656731038a69cca34c64661a7a30 100644 (file)
@@ -76,6 +76,10 @@ const char* toString(Rewrite r)
     case Rewrite::RE_AND_EMPTY: return "RE_AND_EMPTY";
     case Rewrite::RE_ANDOR_FLATTEN: return "RE_ANDOR_FLATTEN";
     case Rewrite::RE_ANDOR_INC_CONFLICT: return "RE_ANDOR_INC_CONFLICT";
+    case Rewrite::RE_INTER_CONST_CONST_CONFLICT:
+      return "RE_INTER_CONST_CONST_CONFLICT";
+    case Rewrite::RE_INTER_CONST_RE_CONFLICT:
+      return "RE_INTER_CONST_RE_CONFLICT";
     case Rewrite::RE_CHAR_IN_STR_STAR: return "RE_CHAR_IN_STR_STAR";
     case Rewrite::RE_CONCAT: return "RE_CONCAT";
     case Rewrite::RE_CONCAT_FLATTEN: return "RE_CONCAT_FLATTEN";
index 5e63c55c8e76f3a5fad382aeccbdd6c7b8be348e..684ee0e8aacac35de0509807a63e91cfbf648ae5 100644 (file)
@@ -81,6 +81,8 @@ enum class Rewrite : uint32_t
   RE_AND_EMPTY,
   RE_ANDOR_FLATTEN,
   RE_ANDOR_INC_CONFLICT,
+  RE_INTER_CONST_CONST_CONFLICT,
+  RE_INTER_CONST_RE_CONFLICT,
   RE_CHAR_IN_STR_STAR,
   RE_CONCAT,
   RE_CONCAT_FLATTEN,
index 721d875d015f6ff1375f5138e31f137af6cf4c17..fbb40f2123a540dd030832861074257d89401e7f 100644 (file)
@@ -979,11 +979,17 @@ Node SequencesRewriter::rewriteAndOrRegExp(TNode node)
   Assert(nk == REGEXP_UNION || nk == REGEXP_INTER);
   Trace("strings-rewrite-debug")
       << "Strings::rewriteAndOrRegExp start " << node << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
   std::vector<Node> node_vec;
   std::vector<Node> polRegExp[2];
+  // list of constant string regular expressions (str.to_re c)
+  std::vector<Node> constStrRe;
+  // list of all other regular expressions
+  std::vector<Node> otherRe;
   for (const Node& ni : node)
   {
-    if (ni.getKind() == nk)
+    Kind nik = ni.getKind();
+    if (nik == nk)
     {
       for (const Node& nic : ni)
       {
@@ -993,7 +999,7 @@ Node SequencesRewriter::rewriteAndOrRegExp(TNode node)
         }
       }
     }
-    else if (ni.getKind() == REGEXP_NONE)
+    else if (nik == REGEXP_NONE)
     {
       if (nk == REGEXP_INTER)
       {
@@ -1001,7 +1007,7 @@ Node SequencesRewriter::rewriteAndOrRegExp(TNode node)
       }
       // otherwise, can ignore
     }
-    else if (ni.getKind() == REGEXP_STAR && ni[0].getKind() == REGEXP_ALLCHAR)
+    else if (nik == REGEXP_STAR && ni[0].getKind() == REGEXP_ALLCHAR)
     {
       if (nk == REGEXP_UNION)
       {
@@ -1011,13 +1017,103 @@ Node SequencesRewriter::rewriteAndOrRegExp(TNode node)
     }
     else if (std::find(node_vec.begin(), node_vec.end(), ni) == node_vec.end())
     {
+      if (nik == STRING_TO_REGEXP && ni[0].isConst())
+      {
+        if (nk == REGEXP_INTER)
+        {
+          if (!constStrRe.empty())
+          {
+            Assert(constStrRe[0][0] != ni[0]);
+            // (re.inter .. (str.to_re c1) .. (str.to_re c2) ..) ---> re.none
+            // for distinct constant strings c1, c2.
+            Node ret = nm->mkNode(kind::REGEXP_NONE);
+            return returnRewrite(
+                node, ret, Rewrite::RE_INTER_CONST_CONST_CONFLICT);
+          }
+        }
+        else
+        {
+          Assert(nk == REGEXP_UNION);
+        }
+        constStrRe.push_back(ni);
+      }
+      else
+      {
+        otherRe.push_back(ni);
+        uint32_t pindex = nik == REGEXP_COMPLEMENT ? 1 : 0;
+        Node nia = pindex == 1 ? ni[0] : ni;
+        polRegExp[pindex].push_back(nia);
+      }
       node_vec.push_back(ni);
-      uint32_t pindex = ni.getKind() == REGEXP_COMPLEMENT ? 1 : 0;
-      Node nia = pindex == 1 ? ni[0] : ni;
-      polRegExp[pindex].push_back(nia);
     }
   }
-  NodeManager* nm = NodeManager::currentNM();
+  Trace("strings-rewrite-debug")
+      << "Partition constant components " << constStrRe.size() << " / "
+      << otherRe.size() << std::endl;
+  // go back and process constant strings against the others
+  if (!constStrRe.empty())
+  {
+    std::unordered_set<Node> toRemove;
+    for (const Node& c : constStrRe)
+    {
+      Assert(c.getKind() == STRING_TO_REGEXP && c[0].getKind() == CONST_STRING);
+      cvc5::String s = c[0].getConst<String>();
+      for (const Node& r : otherRe)
+      {
+        Trace("strings-rewrite-debug")
+            << "Check " << c << " vs " << r << std::endl;
+        // skip if already removing, or not constant
+        if (!RegExpEntail::isConstRegExp(r)
+            || toRemove.find(r) != toRemove.end())
+        {
+          Trace("strings-rewrite-debug") << "...skip" << std::endl;
+          continue;
+        }
+        // test whether c from (str.to_re c) is in r
+        if (RegExpEntail::testConstStringInRegExp(s, 0, r))
+        {
+          Trace("strings-rewrite-debug") << "...included" << std::endl;
+          if (nk == REGEXP_INTER)
+          {
+            // (re.inter .. (str.to_re c) .. R ..) --->
+            // (re.inter .. (str.to_re c) .. ..) when c in R
+            toRemove.insert(r);
+          }
+          else
+          {
+            // (re.union .. (str.to_re c) .. R ..) --->
+            // (re.union .. .. R ..) when c in R
+            toRemove.insert(c);
+            break;
+          }
+        }
+        else
+        {
+          Trace("strings-rewrite-debug") << "...not included" << std::endl;
+          if (nk == REGEXP_INTER)
+          {
+            // (re.inter .. (str.to_re c) .. R ..) ---> re.none
+            // if c is not a member of R.
+            Node ret = nm->mkNode(kind::REGEXP_NONE);
+            return returnRewrite(
+                node, ret, Rewrite::RE_INTER_CONST_RE_CONFLICT);
+          }
+        }
+      }
+    }
+    if (!toRemove.empty())
+    {
+      std::vector<Node> nodeVecTmp;
+      node_vec.swap(nodeVecTmp);
+      for (const Node& nvt : nodeVecTmp)
+      {
+        if (toRemove.find(nvt) == toRemove.end())
+        {
+          node_vec.push_back(nvt);
+        }
+      }
+    }
+  }
   // use inclusion tests
   for (const Node& negMem : polRegExp[1])
   {