From: Andrew Reynolds Date: Wed, 25 Jul 2018 17:34:32 +0000 (-0500) Subject: Move reg exp rewrites from prerewrite to postrewrite (#2204) X-Git-Tag: cvc5-1.0.0~4867 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=aae5e18cb1bc8a6774fa4293cc0b5016fab7c46e;p=cvc5.git Move reg exp rewrites from prerewrite to postrewrite (#2204) --- diff --git a/src/theory/strings/theory_strings_rewriter.cpp b/src/theory/strings/theory_strings_rewriter.cpp index 8c589640c..9651fe980 100644 --- a/src/theory/strings/theory_strings_rewriter.cpp +++ b/src/theory/strings/theory_strings_rewriter.cpp @@ -307,8 +307,8 @@ Node TheoryStringsRewriter::rewriteEquality(Node node) Node TheoryStringsRewriter::rewriteConcat(Node node) { Assert(node.getKind() == kind::STRING_CONCAT); - Trace("strings-prerewrite") << "Strings::rewriteConcat start " << node - << std::endl; + Trace("strings-rewrite-debug") + << "Strings::rewriteConcat start " << node << std::endl; Node retNode = node; std::vector node_vec; Node preNode = Node::null(); @@ -380,16 +380,17 @@ Node TheoryStringsRewriter::rewriteConcat(Node node) std::sort(node_vec.begin() + lastIdx, node_vec.end()); retNode = mkConcat( kind::STRING_CONCAT, node_vec ); - Trace("strings-prerewrite") << "Strings::rewriteConcat end " << retNode - << std::endl; + Trace("strings-rewrite-debug") + << "Strings::rewriteConcat end " << retNode << std::endl; return retNode; } -Node TheoryStringsRewriter::prerewriteConcatRegExp( TNode node ) { +Node TheoryStringsRewriter::rewriteConcatRegExp(TNode node) +{ Assert( node.getKind() == kind::REGEXP_CONCAT ); NodeManager* nm = NodeManager::currentNM(); - Trace("strings-prerewrite") - << "Strings::prerewriteConcatRegExp flatten " << node << std::endl; + Trace("strings-rewrite-debug") + << "Strings::rewriteConcatRegExp flatten " << node << std::endl; Node retNode = node; std::vector vec; bool changed = false; @@ -437,8 +438,8 @@ Node TheoryStringsRewriter::prerewriteConcatRegExp( TNode node ) { } return returnRewrite(node, retNode, "re.concat-flatten"); } - Trace("strings-prerewrite") - << "Strings::prerewriteConcatRegExp start " << node << std::endl; + Trace("strings-rewrite-debug") + << "Strings::rewriteConcatRegExp start " << node << std::endl; std::vector cvec; std::vector preReStr; for (unsigned i = 0, size = vec.size(); i <= size; i++) @@ -498,99 +499,200 @@ Node TheoryStringsRewriter::prerewriteConcatRegExp( TNode node ) { return node; } -Node TheoryStringsRewriter::prerewriteOrRegExp(TNode node) { - Assert( node.getKind() == kind::REGEXP_UNION ); - Trace("strings-prerewrite") << "Strings::prerewriteOrRegExp start " << node << std::endl; +Node TheoryStringsRewriter::rewriteStarRegExp(TNode node) +{ + Assert(node.getKind() == REGEXP_STAR); + NodeManager* nm = NodeManager::currentNM(); Node retNode = node; - std::vector node_vec; - bool allflag = false; - for(unsigned i=0; i R* + return returnRewrite(node, node[0], "re-star-nested-star"); + } + else if (node[0].getKind() == STRING_TO_REGEXP + && node[0][0].getKind() == CONST_STRING + && node[0][0].getConst().isEmptyString()) + { + // ("")* ---> "" + return returnRewrite(node, node[0], "re-star-empty-string"); + } + else if (node[0].getKind() == REGEXP_EMPTY) + { + // (empty)* ---> "" + retNode = nm->mkNode(STRING_TO_REGEXP, nm->mkConst(String(""))); + return returnRewrite(node, retNode, "re-star-empty"); + } + else if (node[0].getKind() == REGEXP_UNION) + { + // simplification of unions under star + if (hasEpsilonNode(node[0])) + { + bool changed = false; + std::vector node_vec; + for (const Node& nc : node[0]) + { + if (nc.getKind() == STRING_TO_REGEXP && nc[0].getKind() == CONST_STRING + && nc[0].getConst().isEmptyString()) + { + // can be removed + changed = true; } - } else if(tmpNode.getKind() == kind::REGEXP_EMPTY) { - //nothing - } else if(tmpNode.getKind() == kind::REGEXP_STAR && tmpNode[0].getKind() == kind::REGEXP_SIGMA) { - allflag = true; - retNode = tmpNode; - break; - } else { - if(std::find(node_vec.begin(), node_vec.end(), tmpNode) == node_vec.end()) { - node_vec.push_back( tmpNode ); + else + { + node_vec.push_back(nc); } } - } else if(node[i].getKind() == kind::REGEXP_EMPTY) { - //nothing - } else if(node[i].getKind() == kind::REGEXP_STAR && node[i][0].getKind() == kind::REGEXP_SIGMA) { - allflag = true; - retNode = node[i]; - break; - } else { - if(std::find(node_vec.begin(), node_vec.end(), node[i]) == node_vec.end()) { - node_vec.push_back( node[i] ); + if (changed) + { + retNode = node_vec.size() == 1 ? node_vec[0] + : nm->mkNode(REGEXP_UNION, node_vec); + retNode = nm->mkNode(REGEXP_STAR, retNode); + // simplification of union beneath star based on loop above + // for example, ( "" | "a" )* ---> ("a")* + return returnRewrite(node, retNode, "re-star-union"); } } } - if(!allflag) { - std::vector< Node > nvec; - retNode = node_vec.size() == 0 ? NodeManager::currentNM()->mkNode( kind::REGEXP_EMPTY, nvec ) : - node_vec.size() == 1 ? node_vec[0] : NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, node_vec); - } - Trace("strings-prerewrite") << "Strings::prerewriteOrRegExp end " << retNode << std::endl; - return retNode; + return node; } -Node TheoryStringsRewriter::prerewriteAndRegExp(TNode node) { - Assert( node.getKind() == kind::REGEXP_INTER ); - Trace("strings-prerewrite") << "Strings::prerewriteOrRegExp start " << node << std::endl; - Node retNode = node; +Node TheoryStringsRewriter::rewriteAndOrRegExp(TNode node) +{ + Kind nk = node.getKind(); + Assert(nk == REGEXP_UNION || nk == REGEXP_INTER); + Trace("strings-rewrite-debug") + << "Strings::rewriteAndOrRegExp start " << node << std::endl; std::vector node_vec; - //Node allNode = Node::null(); - for(unsigned i=0; i nvec; - retNode = node_vec.size() == 0 ? - NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, NodeManager::currentNM()->mkNode(kind::REGEXP_SIGMA, nvec)) : - node_vec.size() == 1 ? node_vec[0] : NodeManager::currentNM()->mkNode(kind::REGEXP_INTER, node_vec); + NodeManager* nm = NodeManager::currentNM(); + std::vector nvec; + Node retNode; + if (node_vec.empty()) + { + if (nk == REGEXP_INTER) + { + retNode = nm->mkNode(REGEXP_STAR, nm->mkNode(REGEXP_SIGMA, nvec)); + } + else + { + retNode = nm->mkNode(kind::REGEXP_EMPTY, nvec); + } } - Trace("strings-prerewrite") << "Strings::prerewriteOrRegExp end " << retNode << std::endl; - return retNode; + else + { + retNode = node_vec.size() == 1 ? node_vec[0] : nm->mkNode(nk, node_vec); + } + if (retNode != node) + { + // flattening and removing children, based on loop above + return returnRewrite(node, retNode, "re.andor-flatten"); + } + return node; +} + +Node TheoryStringsRewriter::rewriteLoopRegExp(TNode node) +{ + Assert(node.getKind() == REGEXP_LOOP); + Node retNode = node; + Node r = node[0]; + if (r.getKind() == REGEXP_STAR) + { + return returnRewrite(node, r, "re.loop-star"); + } + TNode n1 = node[1]; + NodeManager* nm = NodeManager::currentNM(); + CVC4::Rational RMAXINT(LONG_MAX); + AlwaysAssert(n1.isConst(), "re.loop contains non-constant integer (1)."); + AlwaysAssert(n1.getConst().sgn() >= 0, + "Negative integer in string REGEXP_LOOP (1)"); + Assert(n1.getConst() <= RMAXINT, + "Exceeded LONG_MAX in string REGEXP_LOOP (1)"); + unsigned l = n1.getConst().getNumerator().toUnsignedInt(); + std::vector vec_nodes; + for (unsigned i = 0; i < l; i++) + { + vec_nodes.push_back(r); + } + if (node.getNumChildren() == 3) + { + TNode n2 = Rewriter::rewrite(node[2]); + Node n = + vec_nodes.size() == 0 + ? nm->mkNode(STRING_TO_REGEXP, nm->mkConst(String(""))) + : vec_nodes.size() == 1 ? r : nm->mkNode(REGEXP_CONCAT, vec_nodes); + AlwaysAssert(n2.isConst(), "re.loop contains non-constant integer (2)."); + AlwaysAssert(n2.getConst().sgn() >= 0, + "Negative integer in string REGEXP_LOOP (2)"); + Assert(n2.getConst() <= RMAXINT, + "Exceeded LONG_MAX in string REGEXP_LOOP (2)"); + unsigned u = n2.getConst().getNumerator().toUnsignedInt(); + if (u <= l) + { + retNode = n; + } + else + { + std::vector vec2; + vec2.push_back(n); + for (unsigned j = l; j < u; j++) + { + vec_nodes.push_back(r); + n = mkConcat(REGEXP_CONCAT, vec_nodes); + vec2.push_back(n); + } + retNode = nm->mkNode(REGEXP_UNION, vec2); + } + } + else + { + Node rest = nm->mkNode(REGEXP_STAR, r); + retNode = vec_nodes.size() == 0 + ? rest + : vec_nodes.size() == 1 + ? nm->mkNode(REGEXP_CONCAT, r, rest) + : nm->mkNode(REGEXP_CONCAT, + nm->mkNode(REGEXP_CONCAT, vec_nodes), + rest); + } + Trace("strings-lp") << "Strings::lp " << node << " => " << retNode + << std::endl; + if (retNode != node) + { + return returnRewrite(node, retNode, "re.loop"); + } + return node; } bool TheoryStringsRewriter::isConstRegExp( TNode t ) { @@ -887,12 +989,17 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) { NodeManager* nm = NodeManager::currentNM(); Node retNode = node; Node orig = retNode; - - if(node.getKind() == kind::STRING_CONCAT) { + Kind nk = node.getKind(); + if (nk == kind::STRING_CONCAT) + { retNode = rewriteConcat(node); - } else if(node.getKind() == kind::EQUAL) { + } + else if (nk == kind::EQUAL) + { retNode = rewriteEquality(node); - } else if(node.getKind() == kind::STRING_LENGTH) { + } + else if (nk == kind::STRING_LENGTH) + { if( node[0].isConst() ){ retNode = NodeManager::currentNM()->mkConst( ::CVC4::Rational( node[0].getConst().size() ) ); }else if( node[0].getKind() == kind::STRING_CONCAT ){ @@ -921,34 +1028,45 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) { retNode = nm->mkNode(STRING_LENGTH, node[0][0]); } } - }else if( node.getKind() == kind::STRING_CHARAT ){ + } + else if (nk == kind::STRING_CHARAT) + { Node one = NodeManager::currentNM()->mkConst( Rational( 1 ) ); retNode = NodeManager::currentNM()->mkNode(kind::STRING_SUBSTR, node[0], node[1], one); - }else if( node.getKind() == kind::STRING_SUBSTR ){ + } + else if (nk == kind::STRING_SUBSTR) + { retNode = rewriteSubstr(node); - }else if( node.getKind() == kind::STRING_STRCTN ){ + } + else if (nk == kind::STRING_STRCTN) + { retNode = rewriteContains( node ); } - else if (node.getKind() == kind::STRING_LT) + else if (nk == kind::STRING_LT) { // eliminate s < t ---> s != t AND s <= t retNode = nm->mkNode(AND, node[0].eqNode(node[1]).negate(), nm->mkNode(STRING_LEQ, node[0], node[1])); } - else if (node.getKind() == kind::STRING_LEQ) + else if (nk == kind::STRING_LEQ) { retNode = rewriteStringLeq(node); - }else if( node.getKind()==kind::STRING_STRIDOF ){ + } + else if (nk == kind::STRING_STRIDOF) + { retNode = rewriteIndexof( node ); - }else if( node.getKind() == kind::STRING_STRREPL ){ + } + else if (nk == kind::STRING_STRREPL) + { retNode = rewriteReplace( node ); } - else if (node.getKind() == kind::STRING_PREFIX - || node.getKind() == kind::STRING_SUFFIX) + else if (nk == kind::STRING_PREFIX || nk == kind::STRING_SUFFIX) { retNode = rewritePrefixSuffix(node); - }else if(node.getKind() == kind::STRING_ITOS) { + } + else if (nk == kind::STRING_ITOS) + { if(node[0].isConst()) { if( node[0].getConst().sgn()==-1 ){ retNode = NodeManager::currentNM()->mkConst( ::CVC4::String("") ); @@ -958,7 +1076,9 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) { retNode = NodeManager::currentNM()->mkConst( ::CVC4::String(stmp) ); } } - }else if(node.getKind() == kind::STRING_STOI) { + } + else if (nk == kind::STRING_STOI) + { if(node[0].isConst()) { CVC4::String s = node[0].getConst(); if(s.isNumber()) { @@ -977,13 +1097,49 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) { } } } - } else if(node.getKind() == kind::STRING_IN_REGEXP) { + } + else if (nk == kind::STRING_IN_REGEXP) + { retNode = rewriteMembership(node); } - else if (node.getKind() == STRING_CODE) + else if (nk == STRING_CODE) { retNode = rewriteStringCode(node); } + else if (nk == REGEXP_CONCAT) + { + retNode = rewriteConcatRegExp(node); + } + else if (nk == REGEXP_UNION || nk == REGEXP_INTER) + { + retNode = rewriteAndOrRegExp(node); + } + else if (nk == REGEXP_STAR) + { + retNode = rewriteStarRegExp(node); + } + else if (nk == REGEXP_PLUS) + { + retNode = + nm->mkNode(REGEXP_CONCAT, node[0], nm->mkNode(REGEXP_STAR, node[0])); + } + else if (nk == REGEXP_OPT) + { + retNode = nm->mkNode(REGEXP_UNION, + nm->mkNode(STRING_TO_REGEXP, nm->mkConst(String(""))), + node[0]); + } + else if (nk == REGEXP_RANGE) + { + if (node[0] == node[1]) + { + retNode = nm->mkNode(STRING_TO_REGEXP, node[0]); + } + } + else if (nk == REGEXP_LOOP) + { + retNode = rewriteLoopRegExp(node); + } Trace("strings-postrewrite") << "Strings::postRewrite returning " << retNode << std::endl; if( orig!=retNode ){ @@ -1002,121 +1158,7 @@ bool TheoryStringsRewriter::hasEpsilonNode(TNode node) { } RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) { - Node retNode = node; - Node orig = retNode; - Trace("strings-prerewrite") << "Strings::preRewrite start " << node << std::endl; - NodeManager* nm = NodeManager::currentNM(); - - if (node.getKind() == kind::REGEXP_CONCAT) - { - retNode = prerewriteConcatRegExp(node); - } else if(node.getKind() == kind::REGEXP_UNION) { - retNode = prerewriteOrRegExp(node); - } else if(node.getKind() == kind::REGEXP_INTER) { - retNode = prerewriteAndRegExp(node); - } - else if(node.getKind() == kind::REGEXP_STAR) { - if(node[0].getKind() == kind::REGEXP_STAR) { - retNode = node[0]; - } else if(node[0].getKind() == kind::STRING_TO_REGEXP && node[0][0].getKind() == kind::CONST_STRING && node[0][0].getConst().isEmptyString()) { - retNode = node[0]; - } else if(node[0].getKind() == kind::REGEXP_EMPTY) { - retNode = NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, NodeManager::currentNM()->mkConst( ::CVC4::String("") ) ); - } else if(node[0].getKind() == kind::REGEXP_UNION) { - Node tmpNode = prerewriteOrRegExp(node[0]); - if(tmpNode.getKind() == kind::REGEXP_UNION) { - if(hasEpsilonNode(node[0])) { - std::vector< Node > node_vec; - for(unsigned int i=0; i().isEmptyString()) { - //return true; - } else { - node_vec.push_back(node[0][i]); - } - } - retNode = node_vec.size()==1 ? node_vec[0] : NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, node_vec); - retNode = NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, retNode); - } - } else if(tmpNode.getKind() == kind::STRING_TO_REGEXP && tmpNode[0].getKind() == kind::CONST_STRING && tmpNode[0].getConst().isEmptyString()) { - retNode = tmpNode; - } else { - retNode = NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, tmpNode); - } - } - } else if(node.getKind() == kind::REGEXP_PLUS) { - retNode = NodeManager::currentNM()->mkNode( kind::REGEXP_CONCAT, node[0], NodeManager::currentNM()->mkNode( kind::REGEXP_STAR, node[0])); - } else if(node.getKind() == kind::REGEXP_OPT) { - retNode = NodeManager::currentNM()->mkNode( kind::REGEXP_UNION, - NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, NodeManager::currentNM()->mkConst( ::CVC4::String("") ) ), - node[0]); - } else if(node.getKind() == kind::REGEXP_RANGE) { - if(node[0] == node[1]) { - retNode = NodeManager::currentNM()->mkNode( kind::STRING_TO_REGEXP, node[0] ); - } - } else if(node.getKind() == kind::REGEXP_LOOP) { - Node r = node[0]; - if(r.getKind() == kind::REGEXP_STAR) { - retNode = r; - } else { - // eager - TNode n1 = Rewriter::rewrite( node[1] ); - // - if(!n1.isConst()) { - throw LogicException("re.loop contains non-constant integer (1)."); - } - CVC4::Rational rz(0); - CVC4::Rational RMAXINT(LONG_MAX); - AlwaysAssert(rz <= n1.getConst(), "Negative integer in string REGEXP_LOOP (1)"); - Assert(n1.getConst() <= RMAXINT, "Exceeded LONG_MAX in string REGEXP_LOOP (1)"); - // - unsigned l = n1.getConst().getNumerator().toUnsignedInt(); - std::vector< Node > vec_nodes; - for(unsigned i=0; imkNode(STRING_TO_REGEXP, nm->mkConst(String(""))) - : vec_nodes.size() == 1 - ? r - : nm->mkNode(REGEXP_CONCAT, vec_nodes); - //Assert(n2.getConst() <= RMAXINT, "Exceeded LONG_MAX in string REGEXP_LOOP (2)"); - unsigned u = n2.getConst().getNumerator().toUnsignedInt(); - if(u <= l) { - retNode = n; - } else { - std::vector< Node > vec2; - vec2.push_back(n); - for(unsigned j=l; jmkNode(REGEXP_UNION, vec2)); - } - } else { - Node rest = nm->mkNode(REGEXP_STAR, r); - retNode = vec_nodes.size() == 0 - ? rest - : vec_nodes.size() == 1 - ? nm->mkNode(REGEXP_CONCAT, r, rest) - : nm->mkNode(REGEXP_CONCAT, - nm->mkNode(REGEXP_CONCAT, vec_nodes), - rest); - } - } - Trace("strings-lp") << "Strings::lp " << node << " => " << retNode << std::endl; - } - - Trace("strings-prerewrite") << "Strings::preRewrite returning " << retNode << std::endl; - if( orig!=retNode ){ - Trace("strings-rewrite-debug") << "Strings: pre-rewrite " << orig << " to " << retNode << std::endl; - } - return RewriteResponse(orig==retNode ? REWRITE_DONE : REWRITE_AGAIN_FULL, retNode); + return RewriteResponse(REWRITE_DONE, node); } Node TheoryStringsRewriter::rewriteSubstr(Node node) diff --git a/src/theory/strings/theory_strings_rewriter.h b/src/theory/strings/theory_strings_rewriter.h index 732d64095..1a3f388ba 100644 --- a/src/theory/strings/theory_strings_rewriter.h +++ b/src/theory/strings/theory_strings_rewriter.h @@ -61,9 +61,35 @@ class TheoryStringsRewriter { static bool isConstRegExp( TNode t ); static bool testConstStringInRegExp( CVC4::String &s, unsigned int index_start, TNode r ); - static Node prerewriteConcatRegExp(TNode node); - static Node prerewriteOrRegExp(TNode node); - static Node prerewriteAndRegExp(TNode node); + /** rewrite regular expression concatenation + * + * This is the entry point for post-rewriting applications of re.++. + * Returns the rewritten form of node. + */ + static Node rewriteConcatRegExp(TNode node); + /** rewrite regular expression star + * + * This is the entry point for post-rewriting applications of re.*. + * Returns the rewritten form of node. + */ + static Node rewriteStarRegExp(TNode node); + /** rewrite regular expression intersection/union + * + * This is the entry point for post-rewriting applications of re.inter and + * re.union. Returns the rewritten form of node. + */ + static Node rewriteAndOrRegExp(TNode node); + /** rewrite regular expression loop + * + * This is the entry point for post-rewriting applications of re.loop. + * Returns the rewritten form of node. + */ + static Node rewriteLoopRegExp(TNode node); + /** rewrite regular expression membership + * + * This is the entry point for post-rewriting applications of str.in.re + * Returns the rewritten form of node. + */ static Node rewriteMembership(TNode node); static bool hasEpsilonNode(TNode node);