Improve rewrite for string substr (#1337)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 28 Nov 2017 21:53:22 +0000 (15:53 -0600)
committerGitHub <noreply@github.com>
Tue, 28 Nov 2017 21:53:22 +0000 (15:53 -0600)
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h
test/regress/regress0/strings/Makefile.am
test/regress/regress0/strings/substr-rewrites.smt2 [new file with mode: 0644]

index caf143b374d47c0d3fb4daa01bafee1991e5a179..4745817c8fbf2f06db2a9ed9e8ae8c65130f0631 100644 (file)
@@ -196,6 +196,10 @@ Node TheoryStringsRewriter::simpleRegexpConsume( std::vector< Node >& mchildren,
   return Node::null();
 }
 
+// TODO (#1180) rename this to rewriteConcat
+// TODO (#1180) add rewrite
+//  str.++( str.substr( x, n1, n2 ), str.substr( x, n1+n2, n3 ) ) --->
+//  str.substr( x, n1, n2+n3 )
 Node TheoryStringsRewriter::rewriteConcatString( TNode node ) {
   Trace("strings-prerewrite") << "Strings::rewriteConcatString start " << node << std::endl;
   Node retNode = node;
@@ -204,6 +208,7 @@ Node TheoryStringsRewriter::rewriteConcatString( TNode node ) {
   for(unsigned int i=0; i<node.getNumChildren(); ++i) {
     Node tmpNode = node[i];
     if(node[i].getKind() == kind::STRING_CONCAT) {
+      // TODO (#1180) is this necessary?
       tmpNode = rewriteConcatString(node[i]);
       if(tmpNode.getKind() == kind::STRING_CONCAT) {
         unsigned j=0;
@@ -1008,6 +1013,7 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) {
   if(node.getKind() == kind::STRING_CONCAT) {
     retNode = rewriteConcatString(node);
   } else if(node.getKind() == kind::EQUAL) {
+    // TODO (#1180) are these necessary?
     Node leftNode  = node[0];
     if(node[0].getKind() == kind::STRING_CONCAT) {
       leftNode = rewriteConcatString(node[0]);
@@ -1051,6 +1057,7 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) {
       }
     }else if( node[0].getKind()==kind::STRING_STRREPL ){
       if( node[0][1].isConst() && node[0][2].isConst() ){
+        // TODO (#1180) length entailment here
         if( node[0][1].getConst<String>().size()==node[0][2].getConst<String>().size() ){
           retNode = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, node[0][0] );
         }
@@ -1060,74 +1067,7 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) {
     Node one = NodeManager::currentNM()->mkConst( Rational( 1 ) );
     retNode = NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR, node[0], node[1], one);
   }else if( node.getKind() == kind::STRING_SUBSTR ){
-    Node zero = NodeManager::currentNM()->mkConst( ::CVC4::Rational(0) );
-    if( node[2].isConst() && node[2].getConst<Rational>().sgn()<=0 ) {
-      retNode = NodeManager::currentNM()->mkConst( ::CVC4::String("") );
-    }else if( node[1].isConst() ){
-      // TODO (#1180) : use entailment test here
-      if( node[1].getConst<Rational>().sgn()<0 ){
-        //bring forward to start at zero?  don't use this semantics, e.g. does not compose well with error conditions for str.indexof.
-        //retNode = NodeManager::currentNM()->mkNode( kind::STRING_SUBSTR, node[0], zero, NodeManager::currentNM()->mkNode( kind::PLUS, node[1], node[2] ) );
-        retNode = NodeManager::currentNM()->mkConst( ::CVC4::String("") );
-      }else{
-        if( node[2].isConst() ){
-          Assert( node[2].getConst<Rational>().sgn()>=0);
-          CVC4::Rational v1( node[1].getConst<Rational>() );
-          CVC4::Rational v2( node[2].getConst<Rational>() );
-          std::vector< Node > children;
-          getConcat( node[0], children );
-          if( children[0].isConst() ){
-            CVC4::Rational size(children[0].getConst<String>().size());
-            if( v1 >= size ){
-              if( node[0].isConst() ){
-                retNode = NodeManager::currentNM()->mkConst( ::CVC4::String("") );
-              }else{
-                children.erase( children.begin(), children.begin()+1 );
-                retNode = NodeManager::currentNM()->mkNode( kind::STRING_SUBSTR, mkConcat( kind::STRING_CONCAT, children ),
-                                                            NodeManager::currentNM()->mkNode( kind::MINUS, node[1], NodeManager::currentNM()->mkConst( size ) ),
-                                                            node[2] );
-              }
-            }else{
-              //since size is smaller than MAX_INT, v1 is smaller than MAX_INT
-              size_t i = v1.getNumerator().toUnsignedInt();
-              CVC4::Rational sum(v1 + v2);
-              bool full_spl = false;
-              size_t j;
-              if( sum>size ){
-                j = size.getNumerator().toUnsignedInt();
-              }else{
-                //similarly, sum is smaller than MAX_INT
-                j = sum.getNumerator().toUnsignedInt();
-                full_spl = true;
-              }
-              //split the first component of the string
-              Node spl = NodeManager::currentNM()->mkConst( children[0].getConst<String>().substr(i, j-i) );
-              if( node[0].isConst() || full_spl ){
-                retNode = spl;
-              }else{
-                children[0] = spl;
-                retNode = NodeManager::currentNM()->mkNode( kind::STRING_SUBSTR, mkConcat( kind::STRING_CONCAT, children ), zero, node[2] );
-              }
-            }
-          }
-        }else{
-          if( node[1]==zero ){
-            // TODO (#1180) : use entailment test here instead of the special
-            // case
-            if( node[2].getKind() == kind::STRING_LENGTH && node[2][0]==node[0] ){
-              retNode = node[0];
-            }else{
-              //check if the length argument is always at least the length of the string
-              Node cmp = NodeManager::currentNM()->mkNode( kind::GEQ, node[2], NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, node[0] ) );
-              cmp = Rewriter::rewrite( cmp );
-              if( cmp==NodeManager::currentNM()->mkConst(true) ){
-                retNode = node[0];
-              }
-            }
-          }
-        }
-      }
-    }
+    retNode = rewriteSubstr(node);
   }else if( node.getKind() == kind::STRING_STRCTN ){
     retNode = rewriteContains( node );
   }else if( node.getKind()==kind::STRING_STRIDOF ){
@@ -1393,10 +1333,226 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
   return RewriteResponse(orig==retNode ? REWRITE_DONE : REWRITE_AGAIN_FULL, retNode);
 }
 
+Node TheoryStringsRewriter::rewriteSubstr(Node node)
+{
+  if (node[0].isConst())
+  {
+    if (node[0].getConst<String>().size() == 0)
+    {
+      Node ret = node[0];
+      return returnRewrite(node, ret, "ss-emptystr");
+    }
+    // rewriting for constant arguments
+    if (node[1].isConst() && node[2].isConst())
+    {
+      CVC4::String s = node[0].getConst<String>();
+      CVC4::Rational RMAXINT(LONG_MAX);
+      unsigned start;
+      if (node[1].getConst<Rational>() > RMAXINT)
+      {
+        // start beyond the maximum size of strings
+        // thus, it must be beyond the end point of this string
+        Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+        return returnRewrite(node, ret, "ss-const-start-max-oob");
+      }
+      else if (node[1].getConst<Rational>().sgn() < 0)
+      {
+        // start before the beginning of the string
+        Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+        return returnRewrite(node, ret, "ss-const-start-neg");
+      }
+      else
+      {
+        start = node[1].getConst<Rational>().getNumerator().toUnsignedInt();
+        if (start >= s.size())
+        {
+          // start beyond the end of the string
+          Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+          return returnRewrite(node, ret, "ss-const-start-oob");
+        }
+      }
+      if (node[2].getConst<Rational>() > RMAXINT)
+      {
+        // take up to the end of the string
+        Node ret = NodeManager::currentNM()->mkConst(
+            ::CVC4::String(s.suffix(s.size() - start)));
+        return returnRewrite(node, ret, "ss-const-len-max-oob");
+      }
+      else if (node[2].getConst<Rational>().sgn() <= 0)
+      {
+        Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+        return returnRewrite(node, ret, "ss-const-len-non-pos");
+      }
+      else
+      {
+        unsigned len =
+            node[2].getConst<Rational>().getNumerator().toUnsignedInt();
+        if (start + len > s.size())
+        {
+          // take up to the end of the string
+          Node ret = NodeManager::currentNM()->mkConst(
+              ::CVC4::String(s.suffix(s.size() - start)));
+          return returnRewrite(node, ret, "ss-const-end-oob");
+        }
+        else
+        {
+          // compute the substr using the constant string
+          Node ret = NodeManager::currentNM()->mkConst(
+              ::CVC4::String(s.substr(start, len)));
+          return returnRewrite(node, ret, "ss-const-ss");
+        }
+      }
+    }
+  }
+  Node zero = NodeManager::currentNM()->mkConst(CVC4::Rational(0));
+
+  // if entailed non-positive length or negative start point
+  if (checkEntailArith(zero, node[1], true))
+  {
+    Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+    return returnRewrite(node, ret, "ss-start-neg");
+  }
+  else if (checkEntailArith(zero, node[2]))
+  {
+    Node ret = NodeManager::currentNM()->mkConst(::CVC4::String(""));
+    return returnRewrite(node, ret, "ss-len-non-pos");
+  }
+
+  std::vector<Node> n1;
+  getConcat(node[0], n1);
+
+  // definite inclusion
+  if (node[1] == zero)
+  {
+    Node curr = node[2];
+    std::vector<Node> childrenr;
+    if (stripSymbolicLength(n1, childrenr, 1, curr))
+    {
+      if (curr != zero && !n1.empty())
+      {
+        childrenr.push_back(
+            NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR,
+                                             mkConcat(kind::STRING_CONCAT, n1),
+                                             node[1],
+                                             curr));
+      }
+      Node ret = mkConcat(kind::STRING_CONCAT, childrenr);
+      return returnRewrite(node, ret, "ss-len-include");
+    }
+  }
+
+  // symbolic length analysis
+  for (unsigned r = 0; r < 2; r++)
+  {
+    // the amount of characters we can strip
+    Node curr;
+    if (r == 0)
+    {
+      if (node[1] != zero)
+      {
+        // strip up to start point off the start of the string
+        curr = node[1];
+      }
+    }
+    else if (r == 1)
+    {
+      Node tot_len = Rewriter::rewrite(
+          NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[0]));
+      Node end_pt = Rewriter::rewrite(
+          NodeManager::currentNM()->mkNode(kind::PLUS, node[1], node[2]));
+      if (node[2] != tot_len)
+      {
+        if (checkEntailArith(node[2], tot_len))
+        {
+          // end point beyond end point of string, map to tot_len
+          Node ret = NodeManager::currentNM()->mkNode(
+              kind::STRING_SUBSTR, node[0], node[1], tot_len);
+          return returnRewrite(node, ret, "ss-end-pt-norm");
+        }
+        else
+        {
+          // strip up to ( str.len(node[0]) - end_pt ) off the end of the string
+          curr = Rewriter::rewrite(
+              NodeManager::currentNM()->mkNode(kind::MINUS, tot_len, end_pt));
+        }
+      }
+    }
+    if (!curr.isNull())
+    {
+      // strip off components while quantity is entailed positive
+      int dir = r == 0 ? 1 : -1;
+      std::vector<Node> childrenr;
+      if (stripSymbolicLength(n1, childrenr, dir, curr))
+      {
+        if (r == 0)
+        {
+          Node ret = NodeManager::currentNM()->mkNode(
+              kind::STRING_SUBSTR,
+              mkConcat(kind::STRING_CONCAT, n1),
+              curr,
+              node[2]);
+          return returnRewrite(node, ret, "ss-strip-start-pt");
+        }
+        else
+        {
+          Node ret = NodeManager::currentNM()->mkNode(
+              kind::STRING_SUBSTR,
+              mkConcat(kind::STRING_CONCAT, n1),
+              node[1],
+              node[2]);
+          return returnRewrite(node, ret, "ss-strip-end-pt");
+        }
+      }
+    }
+  }
+  // combine substr
+  if (node[0].getKind() == kind::STRING_SUBSTR)
+  {
+    Node start_inner = node[0][1];
+    Node start_outer = node[1];
+    if (checkEntailArith(start_outer) && checkEntailArith(start_inner))
+    {
+      // both are positive
+      // thus, start point is definitely start_inner+start_outer.
+      // We can rewrite if it for certain what the length is
+
+      // the length of a string from the inner substr subtracts the start point
+      // of the outer substr
+      Node len_from_inner = Rewriter::rewrite(NodeManager::currentNM()->mkNode(
+          kind::MINUS, node[0][2], start_outer));
+      Node len_from_outer = node[2];
+      Node new_len;
+      // take quantity that is for sure smaller than the other
+      if (len_from_inner == len_from_outer)
+      {
+        new_len = len_from_inner;
+      }
+      else if (checkEntailArith(len_from_inner, len_from_outer))
+      {
+        new_len = len_from_outer;
+      }
+      else if (checkEntailArith(len_from_outer, len_from_inner))
+      {
+        new_len = len_from_inner;
+      }
+      if (!new_len.isNull())
+      {
+        Node new_start = NodeManager::currentNM()->mkNode(
+            kind::PLUS, start_inner, start_outer);
+        Node ret = NodeManager::currentNM()->mkNode(
+            kind::STRING_SUBSTR, node[0][0], new_start, new_len);
+        return returnRewrite(node, ret, "ss-combine");
+      }
+    }
+  }
+  Trace("strings-rewrite-nf") << "No rewrites for : " << node << std::endl;
+  return node;
+}
 
 Node TheoryStringsRewriter::rewriteContains( Node node ) {
   if( node[0] == node[1] ){
-    return NodeManager::currentNM()->mkConst( true );
+    Node ret = NodeManager::currentNM()->mkConst(true);
+    return returnRewrite(node, ret, "ctn-eq");
   }
   else if (node[0].isConst())
   {
@@ -1404,18 +1560,23 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
     if (node[1].isConst())
     {
       CVC4::String t = node[1].getConst<String>();
-      return NodeManager::currentNM()->mkConst(s.find(t) != std::string::npos);
+      Node ret =
+          NodeManager::currentNM()->mkConst(s.find(t) != std::string::npos);
+      return returnRewrite(node, ret, "ctn-const");
     }else{
       if (s.size() == 0)
       {
-        return NodeManager::currentNM()->mkNode(kind::EQUAL, node[0], node[1]);
+        Node ret =
+            NodeManager::currentNM()->mkNode(kind::EQUAL, node[0], node[1]);
+        return returnRewrite(node, ret, "ctn-emptystr");
       }
       else if (node[1].getKind() == kind::STRING_CONCAT)
       {
         int firstc, lastc;
         if (!canConstantContainConcat(node[0], node[1], firstc, lastc))
         {
-          return NodeManager::currentNM()->mkConst(false);
+          Node ret = NodeManager::currentNM()->mkConst(false);
+          return returnRewrite(node, ret, "ctn-nconst-ctn-concat");
         }
       }
     }
@@ -1430,7 +1591,8 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
   std::vector<Node> nc1re;
   if (componentContains(nc1, nc2, nc1rb, nc1re) != -1)
   {
-    return NodeManager::currentNM()->mkConst(true);
+    Node ret = NodeManager::currentNM()->mkConst(true);
+    return returnRewrite(node, ret, "ctn-component");
   }
 
   // strip endpoints
@@ -1438,11 +1600,10 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
   std::vector<Node> ne;
   if (stripConstantEndpoints(nc1, nc2, nb, ne))
   {
-    return NodeManager::currentNM()->mkNode(
+    Node ret = NodeManager::currentNM()->mkNode(
         kind::STRING_STRCTN, mkConcat(kind::STRING_CONCAT, nc1), node[1]);
+    return returnRewrite(node, ret, "ctn-strip-endpt");
   }
-  Trace("strings-rewrite-debug2") << "No constant endpoints for " << node[0]
-                                  << " " << node[1] << std::endl;
 
   // length entailment
   Node len_n1 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[0]);
@@ -1450,12 +1611,14 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
   if (checkEntailArith(len_n2, len_n1, true))
   {
     // len( n2 ) > len( n1 ) => contains( n1, n2 ) ---> false
-    return NodeManager::currentNM()->mkConst(false);
+    Node ret = NodeManager::currentNM()->mkConst(false);
+    return returnRewrite(node, ret, "ctn-len-ineq");
   }
   else if (checkEntailArithEq(len_n1, len_n2))
   {
     // len( n2 ) = len( n1 ) => contains( n1, n2 ) ---> n1 = n2
-    return node[0].eqNode(node[1]);
+    Node ret = node[0].eqNode(node[1]);
+    return returnRewrite(node, ret, "ctn-len-eq");
   }
 
   // splitting
@@ -1484,7 +1647,7 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
             spl[0].insert(spl[0].end(), nc0.begin(), nc0.begin() + i);
             Assert(i < nc0.size() - 1);
             spl[1].insert(spl[1].end(), nc0.begin() + i + 1, nc0.end());
-            return NodeManager::currentNM()->mkNode(
+            Node ret = NodeManager::currentNM()->mkNode(
                 kind::OR,
                 NodeManager::currentNM()->mkNode(
                     kind::STRING_STRCTN,
@@ -1494,11 +1657,14 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
                     kind::STRING_STRCTN,
                     mkConcat(kind::STRING_CONCAT, spl[1]),
                     node[1]));
+            return returnRewrite(node, ret, "ctn-split");
           }
         }
       }
     }
   }
+
+  Trace("strings-rewrite-nf") << "No rewrites for : " << node << std::endl;
   return node;
 }
 
@@ -1789,6 +1955,120 @@ Node TheoryStringsRewriter::collectConstantStringAt( std::vector< Node >& vec, u
   }
 }
 
+bool TheoryStringsRewriter::stripSymbolicLength(std::vector<Node>& n1,
+                                                std::vector<Node>& nr,
+                                                int dir,
+                                                Node& curr)
+{
+  Assert(dir == 1 || dir == -1);
+  Assert(nr.empty());
+  Node zero = NodeManager::currentNM()->mkConst(CVC4::Rational(0));
+  bool ret = false;
+  bool success;
+  unsigned sindex = 0;
+  do
+  {
+    Assert(!curr.isNull());
+    success = false;
+    if (curr != zero && sindex < n1.size())
+    {
+      unsigned sindex_use = dir == 1 ? sindex : ((n1.size() - 1) - sindex);
+      if (n1[sindex_use].isConst())
+      {
+        // could strip part of a constant
+        Node lowerBound = getConstantArithBound(curr);
+        if (!lowerBound.isNull())
+        {
+          Assert(lowerBound.isConst());
+          Rational lbr = lowerBound.getConst<Rational>();
+          if (lbr.sgn() > 0)
+          {
+            Assert(checkEntailArith(curr, true));
+            CVC4::String s = n1[sindex_use].getConst<String>();
+            Node ncl =
+                NodeManager::currentNM()->mkConst(CVC4::Rational(s.size()));
+            Node next_s =
+                NodeManager::currentNM()->mkNode(kind::MINUS, lowerBound, ncl);
+            next_s = Rewriter::rewrite(next_s);
+            Assert(next_s.isConst());
+            // we can remove the entire constant
+            if (next_s.getConst<Rational>().sgn() >= 0)
+            {
+              curr = Rewriter::rewrite(
+                  NodeManager::currentNM()->mkNode(kind::MINUS, curr, ncl));
+              success = true;
+              sindex++;
+            }
+            else
+            {
+              // we can remove part of the constant
+              // lower bound minus the length of a concrete string is negative,
+              // hence lowerBound cannot be larger than long max
+              Assert(lbr < Rational(LONG_MAX));
+              curr = Rewriter::rewrite(NodeManager::currentNM()->mkNode(
+                  kind::MINUS, curr, lowerBound));
+              unsigned lbsize = lbr.getNumerator().toUnsignedInt();
+              Assert(lbsize < s.size());
+              if (dir == 1)
+              {
+                // strip partially from the front
+                nr.push_back(
+                    NodeManager::currentNM()->mkConst(s.prefix(lbsize)));
+                n1[sindex_use] = NodeManager::currentNM()->mkConst(
+                    s.suffix(s.size() - lbsize));
+              }
+              else
+              {
+                // strip partially from the back
+                nr.push_back(
+                    NodeManager::currentNM()->mkConst(s.suffix(lbsize)));
+                n1[sindex_use] = NodeManager::currentNM()->mkConst(
+                    s.prefix(s.size() - lbsize));
+              }
+              ret = true;
+            }
+            Assert(checkEntailArith(curr));
+          }
+          else
+          {
+            // we cannot remove the constant
+          }
+        }
+      }
+      else
+      {
+        Node next_s = NodeManager::currentNM()->mkNode(
+            kind::MINUS,
+            curr,
+            NodeManager::currentNM()->mkNode(kind::STRING_LENGTH,
+                                             n1[sindex_use]));
+        next_s = Rewriter::rewrite(next_s);
+        if (checkEntailArith(next_s))
+        {
+          success = true;
+          curr = next_s;
+          sindex++;
+        }
+      }
+    }
+  } while (success);
+  if (sindex > 0)
+  {
+    if (dir == 1)
+    {
+      nr.insert(nr.begin(), n1.begin(), n1.begin() + sindex);
+      n1.erase(n1.begin(), n1.begin() + sindex);
+    }
+    else
+    {
+      nr.insert(nr.end(), n1.end() - sindex, n1.end());
+      n1.erase(n1.end() - sindex, n1.end());
+    }
+    ret = true;
+  }
+  return ret;
+}
+
 int TheoryStringsRewriter::componentContains(std::vector<Node>& n1,
                                              std::vector<Node>& n2,
                                              std::vector<Node>& nb,
@@ -2262,10 +2542,118 @@ bool TheoryStringsRewriter::checkEntailArith(Node a, bool strict)
       return true;
     }
     // TODO (#1180) : abstract interpretation goes here
+
+    // over approximation O/U
+
+    // O( x + y ) -> O( x ) + O( y )
+    // O( c * x ) -> O( x ) if c > 0, U( x ) if c < 0
+    // O( len( x ) ) -> len( x )
+    // O( len( int.to.str( x ) ) ) -> len( int.to.str( x ) )
+    // O( len( str.substr( x, n1, n2 ) ) ) -> O( n2 ) | O( len( x ) )
+    // O( len( str.replace( x, y, z ) ) ) ->
+    //   O( len( x ) ) + O( len( z ) ) - U( len( y ) )
+    // O( indexof( x, y, n ) ) -> O( len( x ) ) - U( len( y ) )
+    // O( str.to.int( x ) ) -> str.to.int( x )
+
+    // U( x + y ) -> U( x ) + U( y )
+    // U( c * x ) -> U( x ) if c > 0, O( x ) if c < 0
+    // U( len( x ) ) -> len( x )
+    // U( len( int.to.str( x ) ) ) -> 1
+    // U( len( str.substr( x, n1, n2 ) ) ) ->
+    //   min( U( len( x ) ) - O( n1 ), U( n2 ) )
+    // U( len( str.replace( x, y, z ) ) ) ->
+    //   U( len( x ) ) + U( len( z ) ) - O( len( y ) ) | 0
+    // U( indexof( x, y, n ) ) -> -1    ?
+    // U( str.to.int( x ) ) -> -1
+
     return false;
   }
 }
 
+Node TheoryStringsRewriter::getConstantArithBound(Node a, bool isLower)
+{
+  Assert(Rewriter::rewrite(a) == a);
+  Node ret;
+  if (a.isConst())
+  {
+    ret = a;
+  }
+  else if (a.getKind() == kind::STRING_LENGTH)
+  {
+    if (isLower)
+    {
+      ret = NodeManager::currentNM()->mkConst(Rational(0));
+    }
+  }
+  else if (a.getKind() == kind::PLUS || a.getKind() == kind::MULT)
+  {
+    std::vector<Node> children;
+    bool success = true;
+    for (unsigned i = 0; i < a.getNumChildren(); i++)
+    {
+      Node ac = getConstantArithBound(a[i], isLower);
+      if (ac.isNull())
+      {
+        ret = ac;
+        success = false;
+        break;
+      }
+      else
+      {
+        if (ac.getConst<Rational>().sgn() == 0)
+        {
+          if (a.getKind() == kind::MULT)
+          {
+            ret = ac;
+            success = false;
+            break;
+          }
+        }
+        else
+        {
+          if (a.getKind() == kind::MULT)
+          {
+            if ((ac.getConst<Rational>().sgn() > 0) != isLower)
+            {
+              ret = Node::null();
+              success = false;
+              break;
+            }
+          }
+          children.push_back(ac);
+        }
+      }
+    }
+    if (success)
+    {
+      if (children.empty())
+      {
+        ret = NodeManager::currentNM()->mkConst(Rational(0));
+      }
+      else if (children.size() == 1)
+      {
+        ret = children[0];
+      }
+      else
+      {
+        ret = NodeManager::currentNM()->mkNode(a.getKind(), children);
+        ret = Rewriter::rewrite(ret);
+      }
+    }
+  }
+  Trace("strings-rewrite-cbound")
+      << "Constant " << (isLower ? "lower" : "upper") << " bound for " << a
+      << " is " << ret << std::endl;
+  Assert(ret.isNull() || ret.isConst());
+  Assert(!isLower
+         || (ret.isNull() || ret.getConst<Rational>().sgn() < 0)
+                != checkEntailArith(a, false));
+  Assert(!isLower
+         || (ret.isNull() || ret.getConst<Rational>().sgn() <= 0)
+                != checkEntailArith(a, true));
+  return ret;
+}
+
 bool TheoryStringsRewriter::checkEntailArithInternal(Node a)
 {
   Assert(Rewriter::rewrite(a) == a);
@@ -2294,3 +2682,10 @@ bool TheoryStringsRewriter::checkEntailArithInternal(Node a)
 
   return false;
 }
+
+Node TheoryStringsRewriter::returnRewrite(Node node, Node ret, const char* c)
+{
+  Trace("strings-rewrite") << "Rewrite " << node << " to " << ret << " by " << c
+                           << "." << std::endl;
+  return ret;
+}
index 593458843d22dc8356b5a3e5ca4e960a62ce516d..b7712bcef68c45c567d61bd53d35ea6f989c98f6 100644 (file)
@@ -50,6 +50,14 @@ private:
    * a is in rewritten form.
    */
   static bool checkEntailArithInternal(Node a);
+  /** return rewrite
+   * Called when node rewrites to ret.
+   * The string c indicates the justification
+   * for the rewrite, which is printed by this
+   * function for debugging.
+   * This function returns ret.
+   */
+  static Node returnRewrite(Node node, Node ret, const char* c);
 
  public:
   static RewriteResponse postRewrite(TNode node);
@@ -57,18 +65,24 @@ private:
 
   static inline void init() {}
   static inline void shutdown() {}
+  /** rewrite substr
+  * This is the entry point for post-rewriting terms node of the form
+  *   str.substr( s, i1, i2 )
+  * Returns the rewritten form of node.
+  */
+  static Node rewriteSubstr(Node node);
   /** rewrite contains
-  * This is the entry point for post-rewriting terms n of the form 
+  * This is the entry point for post-rewriting terms node of the form
   *   str.contains( t, s )
-  * Returns the rewritten form of n.
+  * Returns the rewritten form of node.
   *
   * For details on some of the basic rewrites done in this function, see Figure
-  * 7 of Reynolds et al "Scaling Up DPLL(T) String Solvers Using 
+  * 7 of Reynolds et al "Scaling Up DPLL(T) String Solvers Using
   * Context-Dependent Rewriting", CAV 2017.
   */
-  static Node rewriteContains( Node n );
-  static Node rewriteIndexof( Node n );
-  static Node rewriteReplace( Node n );
+  static Node rewriteContains(Node node);
+  static Node rewriteIndexof(Node node);
+  static Node rewriteReplace(Node node);
 
   /** gets the "vector form" of term n, adds it to c.
   * For example:
@@ -108,6 +122,49 @@ private:
   static Node getNextConstantAt( std::vector< Node >& vec, unsigned& start_index, unsigned& end_index, bool isRev );
   static Node collectConstantStringAt( std::vector< Node >& vec, unsigned& end_index, bool isRev );
 
+  /** strip symbolic length
+   *
+   * This function strips off components of n1 whose length is less than
+   * or equal to argument curr, and stores them in nr. The direction
+   * dir determines whether the components are removed from the start
+   * or end of n1.
+   *
+   * In detail, this function updates n1 to n1' such that:
+   *   If dir=1,
+   *     n1 = str.++( nr, n1' )
+   *   If dir=-1
+   *     n1 = str.++( n1', nr )
+   * It updates curr to curr' such that:
+   *   curr' = curr - str.len( str.++( nr ) ), and
+   *   curr' >= 0
+   * where the latter fact is determined by checkArithEntail.
+   *
+   * This function returns true if n1 is modified.
+   *
+   * For example:
+   *
+   *  stripSymbolicLength( { x, "abc", y }, {}, 1, str.len(x)+1 )
+   *    returns true
+   *    n1 is updated to { "bc", y }
+   *    nr is updated to { x, "a" }
+   *    curr is updated to 0   *
+   *
+   * stripSymbolicLength( { x, "abc", y }, {}, 1, str.len(x)-1 )
+   *    returns false
+   *
+   *  stripSymbolicLength( { y, "abc", x }, {}, 1, str.len(x)+1 )
+   *    returns false
+   *
+   *  stripSymbolicLength( { x, "abc", y }, {}, -1, 2*str.len(y)+4 )
+   *    returns true
+   *    n1 is updated to { x }
+   *    nr is updated to { "abc", y }
+   *    curr is updated to str.len(y)+1
+   */
+  static bool stripSymbolicLength(std::vector<Node>& n1,
+                                  std::vector<Node>& nr,
+                                  int dir,
+                                  Node& curr);
   /** component contains
   * This function is used when rewriting str.contains( t1, t2 ), where
   * n1 is the vector form of t1
@@ -240,6 +297,20 @@ private:
    * Returns true if it is always the case that a >= 0.
    */
   static bool checkEntailArith(Node a, bool strict = false);
+  /** get arithmetic lower bound
+   * If this function returns a non-null Node ret,
+   * then ret is a rational constant and
+   * we know that n >= ret always if isLower is true,
+   * or n <= ret if isLower is false.
+   *
+   * Notice the following invariant.
+   * If getConstantArithBound(a, true) = ret where ret is non-null, then for
+   * strict = { true, false } :
+   *   ret >= strict ? 1 : 0
+   *     if and only if
+   *   checkEntailArith( a, strict ) = true.
+   */
+  static Node getConstantArithBound(Node a, bool isLower = true);
 };/* class TheoryStringsRewriter */
 
 }/* CVC4::theory::strings namespace */
index c4fb8dd94d34838ee04de5732f0d65f9ccff37c8..99fd2b6307bec4ee192afb8de14cd43622af8dac 100644 (file)
@@ -92,6 +92,7 @@ TESTS = \
   issue1105.smt2 \
   issue1189.smt2 \
   rewrites-v2.smt2 \
+  substr-rewrites.smt2 \
   norn-ab.smt2 \
   type002.smt2
 
diff --git a/test/regress/regress0/strings/substr-rewrites.smt2 b/test/regress/regress0/strings/substr-rewrites.smt2
new file mode 100644 (file)
index 0000000..c4f19b7
--- /dev/null
@@ -0,0 +1,21 @@
+; COMMAND-LINE: --strings-exp
+; EXPECT: unsat
+(set-logic SLIA)
+(set-info :status unsat)
+
+(declare-fun x () String)
+(declare-fun y () String)
+(declare-fun z () String)
+
+; these should all rewrite to false
+(assert (or 
+
+(not (= (str.substr (str.++ x "abc" y) 0 (+ (str.len x) 1)) (str.++ x "a")))
+(not (= (str.substr (str.++ x "abc" y) (+ (str.len x) 1) (+ (* 2 (str.len y)) 7)) (str.++ "bc" y)))
+(not (= (str.substr (str.++ x y) 0 (+ (str.len z) (* 2 (str.len x)) (* 2 (str.len y)))) (str.substr (str.++ x y) 0 (+ (str.len x) (str.len y)))))
+(not (= (str.substr x (+ (str.len x) 1) 5) (str.substr y (- (- 1) (str.len z)) 5)))
+(not (= (str.substr "abc" 100000000000000000000000000000000000000000000000000000000000000000000000000 5) ""))
+
+))
+
+(check-sat)