From: Andrew Reynolds Date: Tue, 2 Jan 2018 17:43:00 +0000 (-0600) Subject: Improve rewriter for string equality (#1427) X-Git-Tag: cvc5-1.0.0~5397 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=ac73ef6098ccdbf59623171bcd4837ddd0afc38f;p=cvc5.git Improve rewriter for string equality (#1427) --- diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index b463a319a..956822303 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -118,47 +118,6 @@ Node ExtendedRewriter::extendedRewrite(Node n) Node new_ret; if (ret.getKind() == kind::EQUAL) { - // string equalities with disequal prefix or suffix - if (ret[0].getType().isString()) - { - std::vector c[2]; - for (unsigned i = 0; i < 2; i++) - { - strings::TheoryStringsRewriter::getConcat(ret[i], c[i]); - } - if (c[0].empty() == c[1].empty()) - { - if (!c[0].empty()) - { - for (unsigned i = 0; i < 2; i++) - { - unsigned index1 = i == 0 ? 0 : c[0].size() - 1; - unsigned index2 = i == 0 ? 0 : c[1].size() - 1; - if (c[0][index1].isConst() && c[1][index2].isConst()) - { - CVC4::String s = c[0][index1].getConst(); - CVC4::String t = c[1][index2].getConst(); - unsigned len_short = s.size() <= t.size() ? s.size() : t.size(); - bool isSameFix = - i == 1 ? s.rstrncmp(t, len_short) : s.strncmp(t, len_short); - if (!isSameFix) - { - Trace("q-ext-rewrite") << "sygus-extr : " << ret - << " rewrites to false due to " - "disequal string prefix/suffix." - << std::endl; - new_ret = d_false; - break; - } - } - } - } - } - else - { - new_ret = d_false; - } - } if (new_ret.isNull()) { // simple ITE pulling diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index e6b8807e9..30a5f0fbc 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -1448,7 +1448,7 @@ void TheoryStrings::checkExtfInference( Node n, Node nr, ExtfInfoTmp& in, int ef if( ( in.d_pol==1 && nr[1].getKind()==kind::STRING_CONCAT ) || ( in.d_pol==-1 && nr[0].getKind()==kind::STRING_CONCAT ) ){ if( d_extf_infer_cache.find( nr )==d_extf_infer_cache.end() ){ d_extf_infer_cache.insert( nr ); - + //one argument does (not) contain each of the components of the other argument int index = in.d_pol==1 ? 1 : 0; std::vector< Node > children; @@ -1458,9 +1458,21 @@ void TheoryStrings::checkExtfInference( Node n, Node nr, ExtfInfoTmp& in, int ef for( unsigned i=0; imkNode( kind::STRING_STRCTN, children ); - //can mark as reduced, since model for n => model for conc - getExtTheory()->markReduced( conc ); - sendInference( in.d_exp, in.d_pol==1 ? conc : conc.negate(), "CTN_Decompose" ); + conc = Rewriter::rewrite(in.d_pol == 1 ? conc : conc.negate()); + // check if it already (does not) hold + if (hasTerm(conc)) + { + if (areEqual(conc, d_false)) + { + // should be a conflict + sendInference(in.d_exp, conc, "CTN_Decompose"); + } + else if (getExtTheory()->hasFunctionKind(conc.getKind())) + { + // can mark as reduced, since model for n => model for conc + getExtTheory()->markReduced(conc); + } + } } } @@ -2978,11 +2990,11 @@ void TheoryStrings::processDeq( Node ni, Node nj ) { return; }else if( !areEqual( firstChar, nconst_k ) ){ //splitting on demand : try to make them disequal - Node eq = firstChar.eqNode( nconst_k ); - sendSplit( firstChar, nconst_k, "S-Split(DEQL-Const)" ); - eq = Rewriter::rewrite( eq ); - d_pending_req_phase[ eq ] = false; - return; + if (sendSplit( + firstChar, nconst_k, "S-Split(DEQL-Const)", false)) + { + return; + } } }else{ Node sk = mkSkolemCached( nconst_k, firstChar, sk_id_dc_spt, "dc_spt", 2 ); @@ -3032,18 +3044,16 @@ void TheoryStrings::processDeq( Node ni, Node nj ) { }else if( areEqual( li, lj ) ){ Assert( !areDisequal( i, j ) ); //splitting on demand : try to make them disequal - Node eq = i.eqNode( j ); - sendSplit( i, j, "S-Split(DEQL)" ); - eq = Rewriter::rewrite( eq ); - d_pending_req_phase[ eq ] = false; - return; + if (sendSplit(i, j, "S-Split(DEQL)", false)) + { + return; + } }else{ //splitting on demand : try to make lengths equal - Node eq = li.eqNode( lj ); - sendSplit( li, lj, "D-Split" ); - eq = Rewriter::rewrite( eq ); - d_pending_req_phase[ eq ] = true; - return; + if (sendSplit(li, lj, "D-Split")) + { + return; + } } } index++; @@ -3361,15 +3371,22 @@ void TheoryStrings::sendInfer( Node eq_exp, Node eq, const char * c ) { d_infer_exp.push_back( eq_exp ); } -void TheoryStrings::sendSplit( Node a, Node b, const char * c, bool preq ) { +bool TheoryStrings::sendSplit(Node a, Node b, const char* c, bool preq) +{ Node eq = a.eqNode( b ); eq = Rewriter::rewrite( eq ); - Node neq = NodeManager::currentNM()->mkNode( kind::NOT, eq ); - Node lemma_or = NodeManager::currentNM()->mkNode( kind::OR, eq, neq ); - Trace("strings-lemma") << "Strings::Lemma " << c << " SPLIT : " << lemma_or << std::endl; - d_lemma_cache.push_back(lemma_or); - d_pending_req_phase[eq] = preq; - ++(d_statistics.d_splits); + if (!eq.isConst()) + { + Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq); + Node lemma_or = NodeManager::currentNM()->mkNode(kind::OR, eq, neq); + Trace("strings-lemma") << "Strings::Lemma " << c << " SPLIT : " << lemma_or + << std::endl; + d_lemma_cache.push_back(lemma_or); + d_pending_req_phase[eq] = preq; + ++(d_statistics.d_splits); + return true; + } + return false; } @@ -3767,8 +3784,10 @@ void TheoryStrings::checkCardinality() { itr2 != cols[i].end(); ++itr2) { if(!areDisequal( *itr1, *itr2 )) { // add split lemma - sendSplit( *itr1, *itr2, "CARD-SP" ); - return; + if (sendSplit(*itr1, *itr2, "CARD-SP")) + { + return; + } } } } diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index 70706bbd4..f07057444 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -409,7 +409,7 @@ protected: void sendInference( std::vector< Node >& exp, Node eq, const char * c, bool asLemma = false ); void sendLemma( Node ant, Node conc, const char * c ); void sendInfer( Node eq_exp, Node eq, const char * c ); - void sendSplit( Node a, Node b, const char * c, bool preq = true ); + bool sendSplit(Node a, Node b, const char* c, bool preq = true); void sendLengthLemma( Node n ); /** mkConcat **/ inline Node mkConcat( Node n1, Node n2 ); diff --git a/src/theory/strings/theory_strings_rewriter.cpp b/src/theory/strings/theory_strings_rewriter.cpp index 5cb58729e..a478667e9 100644 --- a/src/theory/strings/theory_strings_rewriter.cpp +++ b/src/theory/strings/theory_strings_rewriter.cpp @@ -196,11 +196,98 @@ Node TheoryStringsRewriter::simpleRegexpConsume( std::vector< Node >& mchildren, return Node::null(); } +Node TheoryStringsRewriter::rewriteEquality(Node node) +{ + Assert(node.getKind() == kind::EQUAL); + if (node[0] == node[1]) + { + return NodeManager::currentNM()->mkConst(true); + } + else if (node[0].isConst() && node[1].isConst()) + { + return NodeManager::currentNM()->mkConst(false); + } + // ( ~contains( s, t ) V ~contains( t, s ) ) => ( s == t ---> false ) + for (unsigned r = 0; r < 2; r++) + { + Node ctn = NodeManager::currentNM()->mkNode( + kind::STRING_STRCTN, node[r], node[1 - r]); + // must call rewrite contains directly to avoid infinite loop + // we do a fix point since we may rewrite contains terms to simpler + // contains terms. + Node prev; + do + { + prev = ctn; + ctn = rewriteContains(ctn); + } while (prev != ctn && ctn.getKind() == kind::STRING_STRCTN); + if (ctn.isConst()) + { + if (!ctn.getConst()) + { + return returnRewrite(node, ctn, "eq-nctn"); + } + else + { + // definitely contains but not syntactically equal + // We may be able to simplify, e.g. + // str.++( x, "a" ) == "a" ----> x = "" + } + } + } + + std::vector c[2]; + for (unsigned i = 0; i < 2; i++) + { + strings::TheoryStringsRewriter::getConcat(node[i], c[i]); + } + + // check if the prefix, suffix mismatches + // For example, str.++( x, "a", y ) == str.++( x, "bc", z ) ---> false + unsigned minsize = std::min(c[0].size(), c[1].size()); + for (unsigned r = 0; r < 2; r++) + { + for (unsigned i = 0; i < minsize; i++) + { + unsigned index1 = r == 0 ? i : (c[0].size() - 1) - i; + unsigned index2 = r == 0 ? i : (c[1].size() - 1) - i; + if (c[0][index1].isConst() && c[1][index2].isConst()) + { + CVC4::String s = c[0][index1].getConst(); + CVC4::String t = c[1][index2].getConst(); + unsigned len_short = s.size() <= t.size() ? s.size() : t.size(); + bool isSameFix = + r == 1 ? s.rstrncmp(t, len_short) : s.strncmp(t, len_short); + if (!isSameFix) + { + Node ret = NodeManager::currentNM()->mkConst(false); + return returnRewrite(node, ret, "eq-nfix"); + } + } + if (c[0][index1] != c[1][index2]) + { + break; + } + } + } + + // standard ordering + if (node[0] > node[1]) + { + return NodeManager::currentNM()->mkNode(kind::EQUAL, node[1], node[0]); + } + else + { + return node; + } +} + // TODO (#1180) add rewrite // str.++( str.substr( x, n1, n2 ), str.substr( x, n1+n2, n3 ) ) ---> // str.substr( x, n1, n2+n3 ) Node TheoryStringsRewriter::rewriteConcat(Node node) { + Assert(node.getKind() == kind::STRING_CONCAT); Trace("strings-prerewrite") << "Strings::rewriteConcat start " << node << std::endl; Node retNode = node; @@ -1009,15 +1096,7 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) { if(node.getKind() == kind::STRING_CONCAT) { retNode = rewriteConcat(node); } else if(node.getKind() == kind::EQUAL) { - Node leftNode = node[0]; - Node rightNode = node[1]; - if(leftNode == rightNode) { - retNode = NodeManager::currentNM()->mkConst(true); - } else if(leftNode.isConst() && rightNode.isConst()) { - retNode = NodeManager::currentNM()->mkConst(false); - } else if(leftNode > rightNode) { - retNode = NodeManager::currentNM()->mkNode(kind::EQUAL, rightNode, leftNode); - } + retNode = rewriteEquality(node); } else if(node.getKind() == kind::STRING_LENGTH) { if( node[0].isConst() ){ retNode = NodeManager::currentNM()->mkConst( ::CVC4::Rational( node[0].getConst().size() ) ); @@ -1320,6 +1399,7 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) { Node TheoryStringsRewriter::rewriteSubstr(Node node) { + Assert(node.getKind() == kind::STRING_SUBSTR); if (node[0].isConst()) { if (node[0].getConst().size() == 0) @@ -1535,6 +1615,7 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node) } Node TheoryStringsRewriter::rewriteContains( Node node ) { + Assert(node.getKind() == kind::STRING_STRCTN); if( node[0] == node[1] ){ Node ret = NodeManager::currentNM()->mkConst(true); return returnRewrite(node, ret, "ctn-eq"); @@ -1599,7 +1680,88 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) { Node ret = NodeManager::currentNM()->mkConst(false); return returnRewrite(node, ret, "ctn-len-ineq"); } - else if (checkEntailArithEq(len_n1, len_n2)) + + // multi-set reasoning + // For example, contains( str.++( x, "b" ), str.++( "a", x ) ) ---> false + // since the number of a's in the second argument is greater than the number + // of a's in the first argument + std::map num_nconst[2]; + std::map num_const[2]; + for (unsigned j = 0; j < 2; j++) + { + std::vector& ncj = j == 0 ? nc1 : nc2; + for (const Node& cc : ncj) + { + if (cc.isConst()) + { + num_const[j][cc]++; + } + else + { + num_nconst[j][cc]++; + } + } + } + bool ms_success = true; + for (std::pair& nncp : num_nconst[0]) + { + if (nncp.second > num_nconst[1][nncp.first]) + { + ms_success = false; + break; + } + } + if (ms_success) + { + // count the number of constant characters in the first argument + std::map count_const[2]; + std::vector chars; + for (unsigned j = 0; j < 2; j++) + { + for (std::pair& ncp : num_const[j]) + { + Node cn = ncp.first; + Assert(cn.isConst()); + std::vector cc_vec; + const std::vector& cvec = cn.getConst().getVec(); + for (unsigned i = 0, size = cvec.size(); i < size; i++) + { + // make the character + cc_vec.clear(); + cc_vec.insert(cc_vec.end(), cvec.begin() + i, cvec.begin() + i + 1); + Node ch = NodeManager::currentNM()->mkConst(String(cc_vec)); + count_const[j][ch] += ncp.second; + if (std::find(chars.begin(), chars.end(), ch) == chars.end()) + { + chars.push_back(ch); + } + } + } + } + Trace("strings-rewrite-multiset") << "For " << node << " : " << std::endl; + for (const Node& ch : chars) + { + Trace("strings-rewrite-multiset") << " # occurrences of substring "; + Trace("strings-rewrite-multiset") << ch << " in arguments is "; + Trace("strings-rewrite-multiset") << count_const[0][ch] << " / " + << count_const[1][ch] << std::endl; + if (count_const[0][ch] < count_const[1][ch]) + { + Node ret = NodeManager::currentNM()->mkConst(false); + return returnRewrite(node, ret, "ctn-mset-nss"); + } + } + // TODO (#1180): count the number of 2,3,4,.. character substrings + // for example: + // str.contains( str.++( x, "cbabc" ), str.++( "cabbc", x ) ) ---> false + // since the second argument contains more occurrences of "bb". + // note this is orthogonal reasoning to inductive reasoning + // via regular membership reduction in Liang et al CAV 2015. + } + // TODO (#1180): abstract interpretation with multi-set domain + // to show first argument is a strict subset of second argument + + if (checkEntailArithEq(len_n1, len_n2)) { // len( n2 ) = len( n1 ) => contains( n1, n2 ) ---> n1 = n2 Node ret = node[0].eqNode(node[1]); @@ -1654,6 +1816,7 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) { } Node TheoryStringsRewriter::rewriteIndexof( Node node ) { + Assert(node.getKind() == kind::STRING_STRIDOF); std::vector< Node > children; getConcat( node[0], children ); //std::vector< Node > children1; @@ -1759,6 +1922,7 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) { } Node TheoryStringsRewriter::rewriteReplace( Node node ) { + Assert(node.getKind() == kind::STRING_STRREPL); if( node[1]==node[2] ){ return returnRewrite(node, node[0], "rpl-id"); } diff --git a/src/theory/strings/theory_strings_rewriter.h b/src/theory/strings/theory_strings_rewriter.h index 64120eca0..194e9bbe5 100644 --- a/src/theory/strings/theory_strings_rewriter.h +++ b/src/theory/strings/theory_strings_rewriter.h @@ -63,6 +63,12 @@ private: static inline void init() {} static inline void shutdown() {} + /** rewrite equality + * + * This method returns a formula that is equivalent to the equality between + * two strings, given by node. + */ + static Node rewriteEquality(Node node); /** rewrite concat * This is the entry point for post-rewriting terms node of the form * str.++( t1, .., tn )