Improve rewriter for regular expression concatenation (#2196)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 23 Jul 2018 23:02:36 +0000 (18:02 -0500)
committerGitHub <noreply@github.com>
Mon, 23 Jul 2018 23:02:36 +0000 (18:02 -0500)
src/theory/strings/theory_strings_rewriter.cpp

index cd7c6eeb4ac926f6a2abd0146704aa3980e38578..8c589640ca5ddb288c0daa59b0d6f54e2280f37f 100644 (file)
@@ -387,77 +387,115 @@ Node TheoryStringsRewriter::rewriteConcat(Node node)
 
 Node TheoryStringsRewriter::prerewriteConcatRegExp( TNode node ) {
   Assert( node.getKind() == kind::REGEXP_CONCAT );
-  Trace("strings-prerewrite") << "Strings::prerewriteConcatRegExp start " << node << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
+  Trace("strings-prerewrite")
+      << "Strings::prerewriteConcatRegExp flatten " << node << std::endl;
   Node retNode = node;
-  std::vector<Node> node_vec;
-  Node preNode = Node::null();
-  bool emptyflag = false;
-  for(unsigned int i=0; i<node.getNumChildren(); ++i) {
-    Trace("strings-prerewrite") << "Strings::prerewriteConcatRegExp preNode: " << preNode << std::endl;
-    Node tmpNode = node[i];
-    if(tmpNode.getKind() == kind::REGEXP_CONCAT) {
-      tmpNode = prerewriteConcatRegExp(node[i]);
-      if(tmpNode.getKind() == kind::REGEXP_CONCAT) {
-        unsigned j=0;
-        if(!preNode.isNull()) {
-          if(tmpNode[0].getKind() == kind::STRING_TO_REGEXP) {
-            preNode = rewriteConcat(NodeManager::currentNM()->mkNode(
-                kind::STRING_CONCAT, preNode, tmpNode[0][0]));
-            node_vec.push_back( NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, preNode ) );
-            preNode = Node::null();
-          } else {
-            node_vec.push_back( NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, preNode ) );
-            preNode = Node::null();
-            node_vec.push_back( tmpNode[0] );
-          }
-          ++j;
-        }
-        for(; j<tmpNode.getNumChildren() - 1; ++j) {
-          node_vec.push_back( tmpNode[j] );
-        }
-        tmpNode = tmpNode[j];
+  std::vector<Node> vec;
+  bool changed = false;
+  Node emptyRe;
+  for (const Node& c : node)
+  {
+    if (c.getKind() == REGEXP_CONCAT)
+    {
+      changed = true;
+      for (const Node& cc : c)
+      {
+        vec.push_back(cc);
       }
     }
-    if( tmpNode.getKind() == kind::STRING_TO_REGEXP ) {
-      if(preNode.isNull()) {
-        preNode = tmpNode[0];
-      } else {
-        preNode = rewriteConcat(NodeManager::currentNM()->mkNode(
-            kind::STRING_CONCAT, preNode, tmpNode[0]));
-      }
-    } else if( tmpNode.getKind() == kind::REGEXP_EMPTY ) {
-      emptyflag = true;
-      break;
-    } else {
-      if(!preNode.isNull()) {
-        if(preNode.getKind() == kind::CONST_STRING && preNode.getConst<String>().isEmptyString() ) {
-          preNode = Node::null();
-        } else {
-          node_vec.push_back( NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, preNode ) );
-          preNode = Node::null();
+    else if (c.getKind() == STRING_TO_REGEXP && c[0].isConst()
+             && c[0].getConst<String>().isEmptyString())
+    {
+      changed = true;
+      emptyRe = c;
+    }
+    else if (c.getKind() == REGEXP_EMPTY)
+    {
+      // re.++( ..., empty, ... ) ---> empty
+      std::vector<Node> nvec;
+      return nm->mkNode(REGEXP_EMPTY, nvec);
+    }
+    else
+    {
+      vec.push_back(c);
+    }
+  }
+  if (changed)
+  {
+    // flatten
+    // this handles nested re.++ and elimination or str.to.re(""), e.g.:
+    // re.++( re.++( R1, R2 ), str.to.re(""), R3 ) ---> re.++( R1, R2, R3 )
+    if (vec.empty())
+    {
+      Assert(!emptyRe.isNull());
+      retNode = emptyRe;
+    }
+    else
+    {
+      retNode = vec.size() == 1 ? vec[0] : nm->mkNode(REGEXP_CONCAT, vec);
+    }
+    return returnRewrite(node, retNode, "re.concat-flatten");
+  }
+  Trace("strings-prerewrite")
+      << "Strings::prerewriteConcatRegExp start " << node << std::endl;
+  std::vector<Node> cvec;
+  std::vector<Node> preReStr;
+  for (unsigned i = 0, size = vec.size(); i <= size; i++)
+  {
+    Node curr;
+    if (i < size)
+    {
+      curr = vec[i];
+      Assert(curr.getKind() != REGEXP_CONCAT);
+      if (!cvec.empty() && preReStr.empty())
+      {
+        Node cvecLast = cvec.back();
+        if (cvecLast.getKind() == REGEXP_STAR && cvecLast[0] == curr)
+        {
+          // by convention, flip the order (a*)++a ---> a++(a*)
+          cvec[cvec.size() - 1] = curr;
+          cvec.push_back(cvecLast);
+          curr = Node::null();
         }
       }
-      node_vec.push_back( tmpNode );
     }
-  }
-  if(emptyflag) {
-    std::vector< Node > nvec;
-    retNode = NodeManager::currentNM()->mkNode( kind::REGEXP_EMPTY, nvec );
-  } else {
-    if(!preNode.isNull()) {
-      bool bflag = (preNode.getKind() == kind::CONST_STRING && preNode.getConst<String>().isEmptyString() );
-      if(node_vec.empty() || !bflag ) {
-        node_vec.push_back( NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, preNode ) );
+    // update preReStr
+    if (!curr.isNull() && curr.getKind() == STRING_TO_REGEXP)
+    {
+      preReStr.push_back(curr[0]);
+      curr = Node::null();
+    }
+    else if (!preReStr.empty())
+    {
+      // this groups consecutive strings a++b ---> ab
+      Node acc =
+          nm->mkNode(STRING_TO_REGEXP, mkConcat(STRING_CONCAT, preReStr));
+      cvec.push_back(acc);
+      preReStr.clear();
+    }
+    if (!curr.isNull() && curr.getKind() == REGEXP_STAR)
+    {
+      // we can group stars (a*)++(a*) ---> a*
+      if (!cvec.empty() && cvec.back() == curr)
+      {
+        curr = Node::null();
       }
     }
-    if(node_vec.size() > 1) {
-      retNode = NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, node_vec);
-    } else {
-      retNode = node_vec[0];
+    if (!curr.isNull())
+    {
+      cvec.push_back(curr);
     }
   }
-  Trace("strings-prerewrite") << "Strings::prerewriteConcatRegExp end " << retNode << std::endl;
-  return retNode;
+  Assert(!cvec.empty());
+  retNode = mkConcat(REGEXP_CONCAT, cvec);
+  if (retNode != node)
+  {
+    // handles all cases where consecutive re constants are combined, and cases
+    // where arguments are swapped, as described in the loop above.
+    return returnRewrite(node, retNode, "re.concat");
+  }
+  return node;
 }
 
 Node TheoryStringsRewriter::prerewriteOrRegExp(TNode node) {
@@ -728,6 +766,7 @@ bool TheoryStringsRewriter::testConstStringInRegExp( CVC4::String &s, unsigned i
 }
 
 Node TheoryStringsRewriter::rewriteMembership(TNode node) {
+  NodeManager* nm = NodeManager::currentNM();
   Node retNode = node;
   Node x = node[0];
   Node r = node[1];
@@ -739,10 +778,11 @@ Node TheoryStringsRewriter::rewriteMembership(TNode node) {
     CVC4::String s = x.getConst<String>();
     retNode = NodeManager::currentNM()->mkConst( testConstStringInRegExp( s, 0, r ) );
   } else if(r.getKind() == kind::REGEXP_SIGMA) {
-    Node one = NodeManager::currentNM()->mkConst( ::CVC4::Rational(1) );
-    retNode = one.eqNode(NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, x));
+    Node one = nm->mkConst(Rational(1));
+    retNode = one.eqNode(nm->mkNode(STRING_LENGTH, x));
   } else if( r.getKind() == kind::REGEXP_STAR ) {
-    if( r[0].getKind() == kind::REGEXP_SIGMA ){
+    if (r[0].getKind() == kind::REGEXP_SIGMA)
+    {
       retNode = NodeManager::currentNM()->mkConst( true );
     }
   }else if( r.getKind() == kind::REGEXP_CONCAT ){
@@ -774,6 +814,14 @@ Node TheoryStringsRewriter::rewriteMembership(TNode node) {
     retNode = NodeManager::currentNM()->mkNode( r.getKind()==kind::REGEXP_INTER ? kind::AND : kind::OR, mvec );
   }else if(r.getKind() == kind::STRING_TO_REGEXP) {
     retNode = x.eqNode(r[0]);
+  }
+  else if (r.getKind() == REGEXP_RANGE)
+  {
+    // x in re.range( char_i, char_j ) ---> i <= str.code(x) <= j
+    Node xcode = nm->mkNode(STRING_CODE, x);
+    retNode = nm->mkNode(AND,
+                         nm->mkNode(LEQ, nm->mkNode(STRING_CODE, r[0]), xcode),
+                         nm->mkNode(LEQ, xcode, nm->mkNode(STRING_CODE, r[1])));
   }else if(x != node[0] || r != node[1]) {
     retNode = NodeManager::currentNM()->mkNode( kind::STRING_IN_REGEXP, x, r );
   }
@@ -851,16 +899,11 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) {
       Node tmpNode = node[0];
       if(tmpNode.isConst()) {
         retNode = NodeManager::currentNM()->mkConst( ::CVC4::Rational( tmpNode.getConst<String>().size() ) );
-      //} else if(tmpNode.getKind() == kind::STRING_SUBSTR) {
-        //retNode = tmpNode[2];
       }else if( tmpNode.getKind()==kind::STRING_CONCAT ){
-        // it has to be string concat
         std::vector<Node> node_vec;
         for(unsigned int i=0; i<tmpNode.getNumChildren(); ++i) {
           if(tmpNode[i].isConst()) {
             node_vec.push_back( NodeManager::currentNM()->mkConst( ::CVC4::Rational( tmpNode[i].getConst<String>().size() ) ) );
-          //} else if(tmpNode[i].getKind() == kind::STRING_SUBSTR) {
-          //  node_vec.push_back( tmpNode[i][2] );
           } else {
             node_vec.push_back( NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, tmpNode[i]) );
           }
@@ -962,6 +1005,7 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
   Node retNode = node;
   Node orig = retNode;
   Trace("strings-prerewrite") << "Strings::preRewrite start " << node << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
 
   if (node.getKind() == kind::REGEXP_CONCAT)
   {
@@ -1009,18 +1053,6 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
     if(node[0] == node[1]) {
       retNode = NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, node[0] );
     }
-    /*std::vector< Node > vec_nodes;
-    unsigned char c = node[0].getConst<String>().getFirstChar();
-    unsigned char end = node[1].getConst<String>().getFirstChar();
-    for(; c<=end; ++c) {
-      Node n = NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, NodeManager::currentNM()->mkConst( ::CVC4::String( c ) ) );
-      vec_nodes.push_back( n );
-    }
-    if(vec_nodes.size() == 1) {
-      retNode = vec_nodes[0];
-    } else {
-      retNode = NodeManager::currentNM()->mkNode( kind::REGEXP_UNION, vec_nodes );
-    }*/
   } else if(node.getKind() == kind::REGEXP_LOOP) {
     Node r = node[0];
     if(r.getKind() == kind::REGEXP_STAR) {
@@ -1047,8 +1079,11 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
         //if(!n2.isConst()) {
         //  throw LogicException("re.loop contains non-constant integer (2).");
         //}
-        Node n = vec_nodes.size()==0 ? NodeManager::currentNM()->mkNode(kind::STRING_TO_REGEXP, NodeManager::currentNM()->mkConst(CVC4::String("")))
-          : vec_nodes.size()==1 ? r : prerewriteConcatRegExp(NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, vec_nodes));
+        Node n = vec_nodes.size() == 0
+                     ? nm->mkNode(STRING_TO_REGEXP, nm->mkConst(String("")))
+                     : vec_nodes.size() == 1
+                           ? r
+                           : nm->mkNode(REGEXP_CONCAT, vec_nodes);
         //Assert(n2.getConst<Rational>() <= RMAXINT, "Exceeded LONG_MAX in string REGEXP_LOOP (2)");
         unsigned u = n2.getConst<Rational>().getNumerator().toUnsignedInt();
         if(u <= l) {
@@ -1058,17 +1093,20 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
           vec2.push_back(n);
           for(unsigned j=l; j<u; j++) {
             vec_nodes.push_back(r);
-            n = vec_nodes.size()==1? r : prerewriteConcatRegExp(NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, vec_nodes));
+            n = mkConcat(REGEXP_CONCAT, vec_nodes);
             vec2.push_back(n);
           }
-          retNode = prerewriteOrRegExp(NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, vec2));
+          retNode = prerewriteOrRegExp(nm->mkNode(REGEXP_UNION, vec2));
         }
       } else {
-        Node rest = NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, r);
-        retNode = vec_nodes.size()==0? rest : prerewriteConcatRegExp( vec_nodes.size()==1?
-                 NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, r, rest)
-                :NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT,
-                  NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, vec_nodes), rest) );
+        Node rest = nm->mkNode(REGEXP_STAR, r);
+        retNode = vec_nodes.size() == 0
+                      ? rest
+                      : vec_nodes.size() == 1
+                            ? nm->mkNode(REGEXP_CONCAT, r, rest)
+                            : nm->mkNode(REGEXP_CONCAT,
+                                         nm->mkNode(REGEXP_CONCAT, vec_nodes),
+                                         rest);
       }
     }
     Trace("strings-lp") << "Strings::lp " << node << " => " << retNode << std::endl;