Improve rewriter for string replace (#1416)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 2 Dec 2017 02:03:03 +0000 (20:03 -0600)
committerGitHub <noreply@github.com>
Sat, 2 Dec 2017 02:03:03 +0000 (20:03 -0600)
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h

index 4745817c8fbf2f06db2a9ed9e8ae8c65130f0631..aab4196ccb51edfd3d757b4795aaa4c2cda82fad 100644 (file)
@@ -196,20 +196,19 @@ Node TheoryStringsRewriter::simpleRegexpConsume( std::vector< Node >& mchildren,
   return Node::null();
 }
 
-// TODO (#1180) rename this to rewriteConcat
 // TODO (#1180) add rewrite
 //  str.++( str.substr( x, n1, n2 ), str.substr( x, n1+n2, n3 ) ) --->
 //  str.substr( x, n1, n2+n3 )
-Node TheoryStringsRewriter::rewriteConcatString( TNode node ) {
-  Trace("strings-prerewrite") << "Strings::rewriteConcatString start " << node << std::endl;
+Node TheoryStringsRewriter::rewriteConcat(Node node)
+{
+  Trace("strings-prerewrite") << "Strings::rewriteConcat start " << node
+                              << std::endl;
   Node retNode = node;
   std::vector<Node> node_vec;
   Node preNode = Node::null();
   for(unsigned int i=0; i<node.getNumChildren(); ++i) {
     Node tmpNode = node[i];
     if(node[i].getKind() == kind::STRING_CONCAT) {
-      // TODO (#1180) is this necessary?
-      tmpNode = rewriteConcatString(node[i]);
       if(tmpNode.getKind() == kind::STRING_CONCAT) {
         unsigned j=0;
         if(!preNode.isNull()) {
@@ -249,7 +248,8 @@ Node TheoryStringsRewriter::rewriteConcatString( TNode node ) {
     node_vec.push_back( preNode );
   }
   retNode = mkConcat( kind::STRING_CONCAT, node_vec );
-  Trace("strings-prerewrite") << "Strings::rewriteConcatString end " << retNode << std::endl;
+  Trace("strings-prerewrite") << "Strings::rewriteConcat end " << retNode
+                              << std::endl;
   return retNode;
 }
 
@@ -270,7 +270,7 @@ void TheoryStringsRewriter::shrinkConVec(std::vector<Node> &vec) {
       vec.erase(vec.begin() + i);
     } else if(vec[i].getKind()==kind::STRING_TO_REGEXP && i<vec.size()-1 && vec[i+1].getKind()==kind::STRING_TO_REGEXP) {
       Node tmp = NodeManager::currentNM()->mkNode(kind::STRING_CONCAT, vec[i][0], vec[i+1][0]);
-      tmp = rewriteConcatString(tmp);
+      tmp = rewriteConcat(tmp);
       vec[i] = NodeManager::currentNM()->mkNode(kind::STRING_TO_REGEXP, tmp);
       vec.erase(vec.begin() + i + 1);
     } else {
@@ -568,8 +568,8 @@ Node TheoryStringsRewriter::prerewriteConcatRegExp( TNode node ) {
         unsigned j=0;
         if(!preNode.isNull()) {
           if(tmpNode[0].getKind() == kind::STRING_TO_REGEXP) {
-            preNode = rewriteConcatString(
-              NodeManager::currentNM()->mkNode( kind::STRING_CONCAT, preNode, tmpNode[0][0] ) );
+            preNode = rewriteConcat(NodeManager::currentNM()->mkNode(
+                kind::STRING_CONCAT, preNode, tmpNode[0][0]));
             node_vec.push_back( NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, preNode ) );
             preNode = Node::null();
           } else {
@@ -589,8 +589,8 @@ Node TheoryStringsRewriter::prerewriteConcatRegExp( TNode node ) {
       if(preNode.isNull()) {
         preNode = tmpNode[0];
       } else {
-        preNode = rewriteConcatString(
-        NodeManager::currentNM()->mkNode( kind::STRING_CONCAT, preNode, tmpNode[0] ) );
+        preNode = rewriteConcat(NodeManager::currentNM()->mkNode(
+            kind::STRING_CONCAT, preNode, tmpNode[0]));
       }
     } else if( tmpNode.getKind() == kind::REGEXP_EMPTY ) {
       emptyflag = true;
@@ -899,10 +899,6 @@ Node TheoryStringsRewriter::rewriteMembership(TNode node) {
   Node x = node[0];
   Node r = node[1];//applyAX(node[1]);
 
-  if(node[0].getKind() == kind::STRING_CONCAT) {
-    x = rewriteConcatString(node[0]);
-  }
-
   if(r.getKind() == kind::REGEXP_EMPTY) {
     retNode = NodeManager::currentNM()->mkConst( false );
   } else if(x.getKind()==kind::CONST_STRING && isConstRegExp(r)) {
@@ -1011,32 +1007,22 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) {
   Node orig = retNode;
 
   if(node.getKind() == kind::STRING_CONCAT) {
-    retNode = rewriteConcatString(node);
+    retNode = rewriteConcat(node);
   } else if(node.getKind() == kind::EQUAL) {
-    // TODO (#1180) are these necessary?
     Node leftNode  = node[0];
-    if(node[0].getKind() == kind::STRING_CONCAT) {
-      leftNode = rewriteConcatString(node[0]);
-    }
     Node rightNode = node[1];
-    if(node[1].getKind() == kind::STRING_CONCAT) {
-      rightNode = rewriteConcatString(node[1]);
-    }
-
     if(leftNode == rightNode) {
       retNode = NodeManager::currentNM()->mkConst(true);
     } else if(leftNode.isConst() && rightNode.isConst()) {
       retNode = NodeManager::currentNM()->mkConst(false);
     } else if(leftNode > rightNode) {
       retNode = NodeManager::currentNM()->mkNode(kind::EQUAL, rightNode, leftNode);
-    } else if( leftNode != node[0] || rightNode != node[1]) {
-      retNode = NodeManager::currentNM()->mkNode(kind::EQUAL, leftNode, rightNode);
     }
   } else if(node.getKind() == kind::STRING_LENGTH) {
     if( node[0].isConst() ){
       retNode = NodeManager::currentNM()->mkConst( ::CVC4::Rational( node[0].getConst<String>().size() ) );
     }else if( node[0].getKind() == kind::STRING_CONCAT ){
-      Node tmpNode = rewriteConcatString(node[0]);
+      Node tmpNode = node[0];
       if(tmpNode.isConst()) {
         retNode = NodeManager::currentNM()->mkConst( ::CVC4::Rational( tmpNode.getConst<String>().size() ) );
       //} else if(tmpNode.getKind() == kind::STRING_SUBSTR) {
@@ -1171,9 +1157,8 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
   Node orig = retNode;
   Trace("strings-prerewrite") << "Strings::preRewrite start " << node << std::endl;
 
-  if(node.getKind() == kind::STRING_CONCAT) {
-    retNode = rewriteConcatString(node);
-  }else if(node.getKind() == kind::REGEXP_CONCAT) {
+  if (node.getKind() == kind::REGEXP_CONCAT)
+  {
     retNode = prerewriteConcatRegExp(node);
   } else if(node.getKind() == kind::REGEXP_UNION) {
     retNode = prerewriteOrRegExp(node);
@@ -1775,55 +1760,126 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
 
 Node TheoryStringsRewriter::rewriteReplace( Node node ) {
   if( node[1]==node[2] ){
-    return node[0];
-  }else{
-    // TODO (#1180) : try str.contains( node[0], node[1] ) ---> false
-    if( node[1].isConst() ){
-      if( node[1].getConst<String>().isEmptyString() ){
-        return node[0];
+    return returnRewrite(node, node[0], "rpl-id");
+  }
+  else if (node[0] == node[1])
+  {
+    return returnRewrite(node, node[2], "rpl-replace");
+  }
+  else if (node[1].isConst())
+  {
+    if (node[1].getConst<String>().isEmptyString())
+    {
+      return returnRewrite(node, node[0], "rpl-empty");
+    }
+    else if (node[0].isConst())
+    {
+      CVC4::String s = node[0].getConst<String>();
+      CVC4::String t = node[1].getConst<String>();
+      std::size_t p = s.find(t);
+      if (p == std::string::npos)
+      {
+        return returnRewrite(node, node[0], "rpl-const-nfind");
       }
-      std::vector< Node > children;
-      getConcat( node[0], children );
-      if( children[0].isConst() ){
-        CVC4::String s = children[0].getConst<String>();
-        CVC4::String t = node[1].getConst<String>();
-        std::size_t p = s.find(t);
-        if( p != std::string::npos ) {
-          Node retNode;
-          if( node[2].isConst() ){
-            CVC4::String r = node[2].getConst<String>();
-            CVC4::String ret = s.replace(t, r);
-            retNode = NodeManager::currentNM()->mkConst( ::CVC4::String(ret) );
-          } else {
-            CVC4::String s1 = s.substr(0, (int)p);
-            CVC4::String s3 = s.substr((int)p + (int)t.size());
-            Node ns1 = NodeManager::currentNM()->mkConst( ::CVC4::String(s1) );
-            Node ns3 = NodeManager::currentNM()->mkConst( ::CVC4::String(s3) );
-            retNode = NodeManager::currentNM()->mkNode( kind::STRING_CONCAT, ns1, node[2], ns3 );
-          }
-          if( children.size()>1 ){
-            children[0] = retNode;
-            return mkConcat( kind::STRING_CONCAT, children );
-          }else{
-            return retNode;
-          }
-        }else{
-          //could not find replacement string
-          if( node[0].isConst() ){
-            return node[0];
-          }else{
-            //check for overlap, if none, we can remove the prefix
-            if( s.overlap(t)==0 ){
-              std::vector< Node > spl;
-              spl.insert( spl.end(), children.begin()+1, children.end() );
-              return NodeManager::currentNM()->mkNode( kind::STRING_CONCAT, children[0],
-                          NodeManager::currentNM()->mkNode( kind::STRING_STRREPL, mkConcat( kind::STRING_CONCAT, spl ), node[1], node[2] ) );
-            }
-          }
+      else
+      {
+        CVC4::String s1 = s.substr(0, (int)p);
+        CVC4::String s3 = s.substr((int)p + (int)t.size());
+        Node ns1 = NodeManager::currentNM()->mkConst(::CVC4::String(s1));
+        Node ns3 = NodeManager::currentNM()->mkConst(::CVC4::String(s3));
+        Node ret = NodeManager::currentNM()->mkNode(
+            kind::STRING_CONCAT, ns1, node[2], ns3);
+        return returnRewrite(node, ret, "rpl-const-find");
+      }
+    }
+  }
+
+  std::vector<Node> children0;
+  getConcat(node[0], children0);
+  std::vector<Node> children1;
+  getConcat(node[1], children1);
+
+  // check if contains definitely does (or does not) hold
+  Node cmp_con =
+      NodeManager::currentNM()->mkNode(kind::STRING_STRCTN, node[0], node[1]);
+  Node cmp_conr = Rewriter::rewrite(cmp_con);
+  if (cmp_conr.isConst())
+  {
+    if (cmp_conr.getConst<bool>())
+    {
+      // component-wise containment
+      std::vector<Node> cb;
+      std::vector<Node> ce;
+      int cc = componentContains(children0, children1, cb, ce, true, 1);
+      if (cc != -1)
+      {
+        if (cc == 0 && children0[0] == children1[0])
+        {
+          // definitely a prefix, can do the replace
+          // for example,
+          //   str.replace( str.++( x, "ab" ), str.++( x, "a" ), y )  --->
+          //   str.++( y, "b" )
+          std::vector<Node> cres;
+          cres.push_back(node[2]);
+          cres.insert(cres.end(), ce.begin(), ce.end());
+          Node ret = mkConcat(kind::STRING_CONCAT, cres);
+          return returnRewrite(node, ret, "rpl-cctn-rpl");
+        }
+        else if (!ce.empty())
+        {
+          // we can pull remainder past first definite containment
+          // for example,
+          //   str.replace( str.++( x, "ab" ), "a", y ) --->
+          //   str.++( str.replace( str.++( x, "a" ), "a", y ), "b" )
+          std::vector<Node> cc;
+          cc.push_back(NodeManager::currentNM()->mkNode(
+              kind::STRING_STRREPL,
+              mkConcat(kind::STRING_CONCAT, children0),
+              node[1],
+              node[2]));
+          cc.insert(cc.end(), ce.begin(), ce.end());
+          Node ret = mkConcat(kind::STRING_CONCAT, cc);
+          return returnRewrite(node, ret, "rpl-cctn");
         }
       }
     }
+    else
+    {
+      // ~contains( t, s ) => ( replace( t, s, r ) ----> t )
+      return returnRewrite(node, node[0], "rpl-nctn");
+    }
   }
+
+  if (cmp_conr != cmp_con)
+  {
+    // pull endpoints that can be stripped
+    // for example,
+    //   str.replace( str.++( "b", x, "b" ), "a", y ) --->
+    //   str.++( "b", str.replace( x, "a", y ), "b" )
+    std::vector<Node> cb;
+    std::vector<Node> ce;
+    if (stripConstantEndpoints(children0, children1, cb, ce))
+    {
+      std::vector<Node> cc;
+      cc.insert(cc.end(), cb.begin(), cb.end());
+      cc.push_back(NodeManager::currentNM()->mkNode(
+          kind::STRING_STRREPL,
+          mkConcat(kind::STRING_CONCAT, children0),
+          node[1],
+          node[2]));
+      cc.insert(cc.end(), ce.begin(), ce.end());
+      Node ret = mkConcat(kind::STRING_CONCAT, cc);
+      return returnRewrite(node, ret, "rpl-pull-endpt");
+    }
+  }
+
+  // TODO (#1180) incorporate these?
+  // contains( t, s ) =>
+  //   replace( replace( x, t, s ), s, r ) ----> replace( x, t, r )
+  // contains( t, s ) =>
+  //   contains( replace( t, s, r ), r ) ----> true
+
+  Trace("strings-rewrite-nf") << "No rewrites for : " << node << std::endl;
   return node;
 }
 
index b7712bcef68c45c567d61bd53d35ea6f989c98f6..64120eca04d20a2327f69824fde872c4b45b909f 100644 (file)
@@ -33,8 +33,6 @@ private:
   static bool isConstRegExp( TNode t );
   static bool testConstStringInRegExp( CVC4::String &s, unsigned int index_start, TNode r );
 
-  static Node rewriteConcatString(TNode node);
-
   static void mergeInto(std::vector<Node> &t, const std::vector<Node> &s);
   static void shrinkConVec(std::vector<Node> &vec);
   static Node applyAX( TNode node );
@@ -65,6 +63,12 @@ private:
 
   static inline void init() {}
   static inline void shutdown() {}
+  /** rewrite concat
+  * This is the entry point for post-rewriting terms node of the form
+  *   str.++( t1, .., tn )
+  * Returns the rewritten form of node.
+  */
+  static Node rewriteConcat(Node node);
   /** rewrite substr
   * This is the entry point for post-rewriting terms node of the form
   *   str.substr( s, i1, i2 )
@@ -82,7 +86,12 @@ private:
   */
   static Node rewriteContains(Node node);
   static Node rewriteIndexof(Node node);
-  static Node rewriteReplace(Node node);
+  /** rewrite replace
+  * This is the entry point for post-rewriting terms n of the form
+  *   str.replace( s, t, r )
+  * Returns the rewritten form of n.
+  */
+  static Node rewriteReplace(Node n);
 
   /** gets the "vector form" of term n, adds it to c.
   * For example: