From: Andrew Reynolds Date: Fri, 20 Mar 2020 23:09:27 +0000 (-0500) Subject: Generalize mkConcat for types (#4123) X-Git-Tag: cvc5-1.0.0~3461 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=0e62e42f739e467f61f5c3d10e7b1c7356db6406;p=cvc5.git Generalize mkConcat for types (#4123) Towards theory of sequences. The utility function mkConcat needs a type to know what to construct in the empty string case. --- diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index 187c765d1..231c81bbf 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -818,6 +818,7 @@ Node QuantifiersRewriter::getVarElimLitString(Node lit, { if (lit[i].getKind() == STRING_CONCAT) { + TypeNode stype = lit[i].getType(); for (unsigned j = 0, nchildren = lit[i].getNumChildren(); j < nchildren; j++) { @@ -827,8 +828,8 @@ Node QuantifiersRewriter::getVarElimLitString(Node lit, Node slv = lit[1 - i]; std::vector preL(lit[i].begin(), lit[i].begin() + j); std::vector postL(lit[i].begin() + j + 1, lit[i].end()); - Node tpre = strings::utils::mkConcat(STRING_CONCAT, preL); - Node tpost = strings::utils::mkConcat(STRING_CONCAT, postL); + Node tpre = strings::utils::mkConcat(preL, stype); + Node tpost = strings::utils::mkConcat(postL, stype); Node slvL = nm->mkNode(STRING_LENGTH, slv); Node tpreL = nm->mkNode(STRING_LENGTH, tpre); Node tpostL = nm->mkNode(STRING_LENGTH, tpost); diff --git a/src/theory/strings/base_solver.cpp b/src/theory/strings/base_solver.cpp index 6958d2528..128893cf0 100644 --- a/src/theory/strings/base_solver.cpp +++ b/src/theory/strings/base_solver.cpp @@ -35,6 +35,7 @@ BaseSolver::BaseSolver(context::Context* c, d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String("")); d_false = NodeManager::currentNM()->mkConst(false); d_cardSize = utils::getAlphabetCardinality(); + d_type = NodeManager::currentNM()->stringType(); } BaseSolver::~BaseSolver() {} @@ -253,7 +254,7 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti, if (!n.isNull()) { // construct the constant - Node c = utils::mkNConcat(vecc); + Node c = utils::mkNConcat(vecc, d_type); if (!d_state.areEqual(n, c)) { if (Trace.isOn("strings-debug")) diff --git a/src/theory/strings/base_solver.h b/src/theory/strings/base_solver.h index 3681b49a4..bf223bc0a 100644 --- a/src/theory/strings/base_solver.h +++ b/src/theory/strings/base_solver.h @@ -191,6 +191,8 @@ class BaseSolver std::map d_termIndex; /** the cardinality of the alphabet */ uint32_t d_cardSize; + /** The string-like type for this base solver */ + TypeNode d_type; }; /* class BaseSolver */ } // namespace strings diff --git a/src/theory/strings/core_solver.cpp b/src/theory/strings/core_solver.cpp index 2a95b41ba..f250647af 100644 --- a/src/theory/strings/core_solver.cpp +++ b/src/theory/strings/core_solver.cpp @@ -40,6 +40,7 @@ d_nf_pairs(c) d_emptyString = Word::mkEmptyWord(CONST_STRING); d_true = NodeManager::currentNM()->mkConst( true ); d_false = NodeManager::currentNM()->mkConst( false ); + d_type = NodeManager::currentNM()->stringType(); } CoreSolver::~CoreSolver() { @@ -559,7 +560,7 @@ void CoreSolver::checkNormalFormsEq() return; } NormalForm& nfe = getNormalForm(eqc); - Node nf_term = utils::mkNConcat(nfe.d_nf); + Node nf_term = utils::mkNConcat(nfe.d_nf, d_type); std::map::iterator itn = nf_to_eqc.find(nf_term); if (itn != nf_to_eqc.end()) { @@ -682,7 +683,7 @@ Node CoreSolver::getNormalString(Node x, std::vector& nf_exp) if (it != d_normal_form.end()) { NormalForm& nf = it->second; - Node ret = utils::mkNConcat(nf.d_nf); + Node ret = utils::mkNConcat(nf.d_nf, d_type); nf_exp.insert(nf_exp.end(), nf.d_exp.begin(), nf.d_exp.end()); d_im.addToExplanation(x, nf.d_base, nf_exp); Trace("strings-debug") @@ -700,7 +701,7 @@ Node CoreSolver::getNormalString(Node x, std::vector& nf_exp) Node nc = getNormalString(x[i], nf_exp); vec_nodes.push_back(nc); } - return utils::mkNConcat(vec_nodes); + return utils::mkNConcat(vec_nodes, d_type); } } return x; @@ -1117,7 +1118,7 @@ void CoreSolver::processSimpleNEq(NormalForm& nfi, eqnc.push_back(nfkv[i]); } } - eqn[r] = utils::mkNConcat(eqnc); + eqn[r] = utils::mkNConcat(eqnc, d_type); } if (!d_state.areEqual(eqn[0], eqn[1])) { @@ -1540,15 +1541,15 @@ CoreSolver::ProcessLoopResult CoreSolver::processLoop(NormalForm& nfi, Trace("strings-loop") << " ... (X)= " << vecoi[index] << std::endl; Trace("strings-loop") << " ... T(Y.Z)= "; std::vector vec_t(veci.begin() + index, veci.begin() + loop_index); - Node t_yz = utils::mkNConcat(vec_t); + Node t_yz = utils::mkNConcat(vec_t, d_type); Trace("strings-loop") << " (" << t_yz << ")" << std::endl; Trace("strings-loop") << " ... S(Z.Y)= "; std::vector vec_s(vecoi.begin() + index + 1, vecoi.end()); - Node s_zy = utils::mkNConcat(vec_s); + Node s_zy = utils::mkNConcat(vec_s, d_type); Trace("strings-loop") << s_zy << std::endl; Trace("strings-loop") << " ... R= "; std::vector vec_r(veci.begin() + loop_index + 1, veci.end()); - Node r = utils::mkNConcat(vec_r); + Node r = utils::mkNConcat(vec_r, d_type); Trace("strings-loop") << r << std::endl; if (s_zy.isConst() && r.isConst() && r != d_emptyString) @@ -1640,7 +1641,7 @@ CoreSolver::ProcessLoopResult CoreSolver::processLoop(NormalForm& nfi, v2.insert(v2.begin(), y); v2.insert(v2.begin(), z); restr = utils::mkNConcat(z, y); - cc = Rewriter::rewrite(s_zy.eqNode(utils::mkNConcat(v2))); + cc = Rewriter::rewrite(s_zy.eqNode(utils::mkNConcat(v2, d_type))); } else { @@ -1690,7 +1691,7 @@ CoreSolver::ProcessLoopResult CoreSolver::processLoop(NormalForm& nfi, // s1 * ... * sk = z * y * r vec_r.insert(vec_r.begin(), sk_y); vec_r.insert(vec_r.begin(), sk_z); - Node conc2 = s_zy.eqNode(utils::mkNConcat(vec_r)); + Node conc2 = s_zy.eqNode(utils::mkNConcat(vec_r, d_type)); Node conc3 = vecoi[index].eqNode(utils::mkNConcat(sk_y, sk_w)); Node restr = r == d_emptyString ? s_zy : utils::mkNConcat(sk_z, sk_y); str_in_re = @@ -2147,7 +2148,7 @@ void CoreSolver::checkLengthsEqc() { // now, check if length normalization has occurred if (ei->d_normalizedLength.get().isNull()) { - Node nf = utils::mkNConcat(nfi.d_nf); + Node nf = utils::mkNConcat(nfi.d_nf, d_type); if (Trace.isOn("strings-process-debug")) { Trace("strings-process-debug") diff --git a/src/theory/strings/core_solver.h b/src/theory/strings/core_solver.h index d18f109b2..c549fa886 100644 --- a/src/theory/strings/core_solver.h +++ b/src/theory/strings/core_solver.h @@ -368,6 +368,8 @@ class CoreSolver * the argument number of the t1 ... tn they were generated from. */ std::map > d_flat_form_index; + /** The string-like type for this solver */ + TypeNode d_type; }; /* class CoreSolver */ } // namespace strings diff --git a/src/theory/strings/normal_form.cpp b/src/theory/strings/normal_form.cpp index ac28be245..7a2323d89 100644 --- a/src/theory/strings/normal_form.cpp +++ b/src/theory/strings/normal_form.cpp @@ -150,7 +150,7 @@ Node NormalForm::collectConstantStringAt(size_t& index) { std::reverse(c.begin(), c.end()); } - Node cc = Rewriter::rewrite(utils::mkConcat(STRING_CONCAT, c)); + Node cc = Rewriter::rewrite(utils::mkConcat(c, c[0].getType())); Assert(cc.isConst()); return cc; } diff --git a/src/theory/strings/regexp_elim.cpp b/src/theory/strings/regexp_elim.cpp index 976efad3c..259588789 100644 --- a/src/theory/strings/regexp_elim.cpp +++ b/src/theory/strings/regexp_elim.cpp @@ -30,6 +30,7 @@ RegExpElimination::RegExpElimination() d_zero = NodeManager::currentNM()->mkConst(Rational(0)); d_one = NodeManager::currentNM()->mkConst(Rational(1)); d_neg_one = NodeManager::currentNM()->mkConst(Rational(-1)); + d_regExpType = NodeManager::currentNM()->regExpType(); } Node RegExpElimination::eliminate(Node atom) @@ -382,7 +383,7 @@ Node RegExpElimination::eliminateConcat(Node atom) Assert(rexpElimChildren.size() + sConstraints.size() == nchildren); Node ss = nm->mkNode(STRING_SUBSTR, x, sStartIndex, sLength); Assert(!rexpElimChildren.empty()); - Node regElim = utils::mkConcat(REGEXP_CONCAT, rexpElimChildren); + Node regElim = utils::mkConcat(rexpElimChildren, d_regExpType); sConstraints.push_back(nm->mkNode(STRING_IN_REGEXP, ss, regElim)); Node ret = nm->mkNode(AND, sConstraints); // e.g. @@ -422,7 +423,7 @@ Node RegExpElimination::eliminateConcat(Node atom) { std::vector rprefix; rprefix.insert(rprefix.end(), children.begin(), children.begin() + i); - Node rpn = utils::mkConcat(REGEXP_CONCAT, rprefix); + Node rpn = utils::mkConcat(rprefix, d_regExpType); Node substrPrefix = nm->mkNode( STRING_IN_REGEXP, nm->mkNode(STRING_SUBSTR, x, d_zero, k), rpn); echildren.push_back(substrPrefix); @@ -431,7 +432,7 @@ Node RegExpElimination::eliminateConcat(Node atom) { std::vector rsuffix; rsuffix.insert(rsuffix.end(), children.begin() + i + 1, children.end()); - Node rps = utils::mkConcat(REGEXP_CONCAT, rsuffix); + Node rps = utils::mkConcat(rsuffix, d_regExpType); Node ks = nm->mkNode(PLUS, k, lens); Node substrSuffix = nm->mkNode( STRING_IN_REGEXP, diff --git a/src/theory/strings/regexp_elim.h b/src/theory/strings/regexp_elim.h index dbd4102b6..61ce8a920 100644 --- a/src/theory/strings/regexp_elim.h +++ b/src/theory/strings/regexp_elim.h @@ -47,6 +47,8 @@ class RegExpElimination Node d_zero; Node d_one; Node d_neg_one; + /** The type of regular expressions */ + TypeNode d_regExpType; /** return elimination * * This method is called when atom is rewritten to atomElim, and returns diff --git a/src/theory/strings/regexp_solver.cpp b/src/theory/strings/regexp_solver.cpp index cd66c0ebf..30f9c4a73 100644 --- a/src/theory/strings/regexp_solver.cpp +++ b/src/theory/strings/regexp_solver.cpp @@ -621,7 +621,7 @@ bool RegExpSolver::deriveRegExp(Node x, { vec_nodes.push_back(x[i]); } - Node left = utils::mkConcat(STRING_CONCAT, vec_nodes); + Node left = utils::mkConcat(vec_nodes, x.getType()); left = Rewriter::rewrite(left); conc = NodeManager::currentNM()->mkNode(STRING_IN_REGEXP, left, dc); } diff --git a/src/theory/strings/sequences_rewriter.cpp b/src/theory/strings/sequences_rewriter.cpp index f4a1cd411..200d7a734 100644 --- a/src/theory/strings/sequences_rewriter.cpp +++ b/src/theory/strings/sequences_rewriter.cpp @@ -410,6 +410,7 @@ Node SequencesRewriter::rewriteEqualityExt(Node node) Node SequencesRewriter::rewriteStrEqualityExt(Node node) { Assert(node.getKind() == EQUAL && node[0].getType().isString()); + TypeNode stype = node[0].getType(); NodeManager* nm = NodeManager::currentNM(); std::vector c[2]; @@ -461,8 +462,8 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) if (changed) { // e.g. x++y = x++z ---> y = z, "AB" ++ x = "A" ++ y --> "B" ++ x = y - Node s1 = utils::mkConcat(STRING_CONCAT, c[0]); - Node s2 = utils::mkConcat(STRING_CONCAT, c[1]); + Node s1 = utils::mkConcat(c[0], stype); + Node s2 = utils::mkConcat(c[1], stype); new_ret = s1.eqNode(s2); node = returnRewrite(node, new_ret, "str-eq-unify"); } @@ -531,8 +532,8 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) } } - Node lhs = utils::mkConcat(STRING_CONCAT, trimmed[i]); - Node ss = utils::mkConcat(STRING_CONCAT, trimmed[1 - i]); + Node lhs = utils::mkConcat(trimmed[i], stype); + Node ss = utils::mkConcat(trimmed[1 - i], stype); if (lhs != node[i] || ss != node[1 - i]) { // e.g. @@ -696,13 +697,13 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) for (size_t i = 0, size0 = v0.size(); i <= size0; i++) { std::vector pfxv0(v0.begin(), v0.begin() + i); - Node pfx0 = utils::mkConcat(STRING_CONCAT, pfxv0); + Node pfx0 = utils::mkConcat(pfxv0, stype); for (size_t j = startRhs, size1 = v1.size(); j <= size1; j++) { if (!(i == 0 && j == 0) && !(i == v0.size() && j == v1.size())) { std::vector pfxv1(v1.begin(), v1.begin() + j); - Node pfx1 = utils::mkConcat(STRING_CONCAT, pfxv1); + Node pfx1 = utils::mkConcat(pfxv1, stype); Node lenPfx0 = nm->mkNode(STRING_LENGTH, pfx0); Node lenPfx1 = nm->mkNode(STRING_LENGTH, pfx1); @@ -710,11 +711,10 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) { std::vector sfxv0(v0.begin() + i, v0.end()); std::vector sfxv1(v1.begin() + j, v1.end()); - Node ret = - nm->mkNode(kind::AND, - pfx0.eqNode(pfx1), - utils::mkConcat(STRING_CONCAT, sfxv0) - .eqNode(utils::mkConcat(STRING_CONCAT, sfxv1))); + Node ret = nm->mkNode(kind::AND, + pfx0.eqNode(pfx1), + utils::mkConcat(sfxv0, stype) + .eqNode(utils::mkConcat(sfxv1, stype))); return returnRewrite(node, ret, "split-eq"); } else if (checkEntailArith(lenPfx1, lenPfx0, true)) @@ -731,11 +731,10 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) { std::vector sfxv0(v0.begin() + i, v0.end()); pfxv1.insert(pfxv1.end(), v1.begin() + j, v1.end()); - Node ret = nm->mkNode( - kind::AND, - pfx0.eqNode(utils::mkConcat(STRING_CONCAT, rpfxv1)), - utils::mkConcat(STRING_CONCAT, sfxv0) - .eqNode(utils::mkConcat(STRING_CONCAT, pfxv1))); + Node ret = nm->mkNode(kind::AND, + pfx0.eqNode(utils::mkConcat(rpfxv1, stype)), + utils::mkConcat(sfxv0, stype) + .eqNode(utils::mkConcat(pfxv1, stype))); return returnRewrite(node, ret, "split-eq-strip-r"); } @@ -759,11 +758,10 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) { pfxv0.insert(pfxv0.end(), v0.begin() + i, v0.end()); std::vector sfxv1(v1.begin() + j, v1.end()); - Node ret = nm->mkNode( - kind::AND, - utils::mkConcat(STRING_CONCAT, rpfxv0).eqNode(pfx1), - utils::mkConcat(STRING_CONCAT, pfxv0) - .eqNode(utils::mkConcat(STRING_CONCAT, sfxv1))); + Node ret = nm->mkNode(kind::AND, + utils::mkConcat(rpfxv0, stype).eqNode(pfx1), + utils::mkConcat(pfxv0, stype) + .eqNode(utils::mkConcat(sfxv1, stype))); return returnRewrite(node, ret, "split-eq-strip-l"); } @@ -891,7 +889,8 @@ Node SequencesRewriter::rewriteConcat(Node node) } std::sort(node_vec.begin() + lastIdx, node_vec.end()); - retNode = utils::mkConcat(STRING_CONCAT, node_vec); + TypeNode tn = node.getType(); + retNode = utils::mkConcat(node_vec, tn); Trace("strings-rewrite-debug") << "Strings::rewriteConcat end " << retNode << std::endl; return retNode; @@ -907,6 +906,20 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node) std::vector vec; bool changed = false; Node emptyRe; + + // get the string type that are members of this regular expression + TypeNode rtype = node.getType(); + TypeNode stype; + if (rtype.isRegExp()) + { + // standard regular expressions are for strings + stype = nm->stringType(); + } + else + { + Unimplemented(); + } + for (const Node& c : node) { if (c.getKind() == REGEXP_CONCAT) @@ -978,8 +991,7 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node) { Assert(!lastAllStar); // this groups consecutive strings a++b ---> ab - Node acc = nm->mkNode(STRING_TO_REGEXP, - utils::mkConcat(STRING_CONCAT, preReStr)); + Node acc = nm->mkNode(STRING_TO_REGEXP, utils::mkConcat(preReStr, stype)); cvec.push_back(acc); preReStr.clear(); } @@ -1022,11 +1034,11 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node) } } Assert(!cvec.empty()); - retNode = utils::mkConcat(REGEXP_CONCAT, cvec); + retNode = utils::mkConcat(cvec, rtype); if (retNode != node) { - // handles all cases where consecutive re constants are combined or dropped - // as described in the loop above. + // handles all cases where consecutive re constants are combined or + // dropped as described in the loop above. return returnRewrite(node, retNode, "re.concat"); } @@ -1043,7 +1055,7 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node) } if (changed) { - retNode = utils::mkConcat(REGEXP_CONCAT, cvec); + retNode = utils::mkConcat(cvec, rtype); return returnRewrite(node, retNode, "re.concat.opt"); } return node; @@ -1215,10 +1227,11 @@ Node SequencesRewriter::rewriteLoopRegExp(TNode node) { std::vector vec2; vec2.push_back(n); + TypeNode rtype = nm->regExpType(); for (unsigned j = l; j < u; j++) { vec_nodes.push_back(r); - n = utils::mkConcat(REGEXP_CONCAT, vec_nodes); + n = utils::mkConcat(vec_nodes, rtype); vec2.push_back(n); } retNode = nm->mkNode(REGEXP_UNION, vec2); @@ -1518,9 +1531,12 @@ Node SequencesRewriter::rewriteMembership(TNode node) Node x = node[0]; Node r = node[1]; - if (r.getKind() == kind::REGEXP_EMPTY) + TypeNode stype = x.getType(); + TypeNode rtype = r.getType(); + + if(r.getKind() == kind::REGEXP_EMPTY) { - retNode = NodeManager::currentNM()->mkConst(false); + retNode = NodeManager::currentNM()->mkConst( false ); } else if (x.isConst() && isConstRegExp(r)) { @@ -1708,7 +1724,7 @@ Node SequencesRewriter::rewriteMembership(TNode node) else { retNode = nm->mkNode(STRING_IN_REGEXP, - utils::mkConcat(STRING_CONCAT, mchildren), + utils::mkConcat(mchildren, stype), r); success = true; } @@ -1743,7 +1759,7 @@ Node SequencesRewriter::rewriteMembership(TNode node) // Given a membership (str.++ x1 ... xn) in (re.++ r1 ... rm), // above, we strip components to construct an equivalent membership: // (str.++ xi .. xj) in (re.++ rk ... rl). - Node xn = utils::mkConcat(STRING_CONCAT, mchildren); + Node xn = utils::mkConcat(mchildren, stype); Node emptyStr = nm->mkConst(String("")); if (children.empty()) { @@ -1756,7 +1772,7 @@ Node SequencesRewriter::rewriteMembership(TNode node) { // otherwise, construct the updated regular expression retNode = nm->mkNode( - STRING_IN_REGEXP, xn, utils::mkConcat(REGEXP_CONCAT, children)); + STRING_IN_REGEXP, xn, utils::mkConcat(children, rtype)); } Trace("regexp-ext-rewrite") << "Regexp : rewrite : " << node << " -> " << retNode << std::endl; @@ -2115,6 +2131,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) std::vector n1; utils::getConcat(node[0], n1); + TypeNode stype = node.getType(); // definite inclusion if (node[1] == zero) @@ -2125,12 +2142,10 @@ Node SequencesRewriter::rewriteSubstr(Node node) { if (curr != zero && !n1.empty()) { - childrenr.push_back(nm->mkNode(kind::STRING_SUBSTR, - utils::mkConcat(STRING_CONCAT, n1), - node[1], - curr)); + childrenr.push_back(nm->mkNode( + kind::STRING_SUBSTR, utils::mkConcat(n1, stype), node[1], curr)); } - Node ret = utils::mkConcat(STRING_CONCAT, childrenr); + Node ret = utils::mkConcat(childrenr, stype); return returnRewrite(node, ret, "ss-len-include"); } } @@ -2211,16 +2226,14 @@ Node SequencesRewriter::rewriteSubstr(Node node) { if (r == 0) { - Node ret = nm->mkNode(kind::STRING_SUBSTR, - utils::mkConcat(STRING_CONCAT, n1), - curr, - node[2]); + Node ret = nm->mkNode( + kind::STRING_SUBSTR, utils::mkConcat(n1, stype), curr, node[2]); return returnRewrite(node, ret, "ss-strip-start-pt"); } else { Node ret = nm->mkNode(kind::STRING_SUBSTR, - utils::mkConcat(STRING_CONCAT, n1), + utils::mkConcat(n1, stype), node[1], node[2]); return returnRewrite(node, ret, "ss-strip-end-pt"); @@ -2395,6 +2408,7 @@ Node SequencesRewriter::rewriteContains(Node node) Node ret = NodeManager::currentNM()->mkConst(true); return returnRewrite(node, ret, "ctn-component"); } + TypeNode stype = node[0].getType(); // strip endpoints std::vector nb; @@ -2402,7 +2416,7 @@ Node SequencesRewriter::rewriteContains(Node node) if (stripConstantEndpoints(nc1, nc2, nb, ne)) { Node ret = NodeManager::currentNM()->mkNode( - kind::STRING_STRCTN, utils::mkConcat(STRING_CONCAT, nc1), node[1]); + kind::STRING_STRCTN, utils::mkConcat(nc1, stype), node[1]); return returnRewrite(node, ret, "ctn-strip-endpt"); } @@ -2517,14 +2531,12 @@ Node SequencesRewriter::rewriteContains(Node node) spl[1].insert(spl[1].end(), nc0.begin() + i + 1, nc0.end()); Node ret = NodeManager::currentNM()->mkNode( kind::OR, - NodeManager::currentNM()->mkNode( - kind::STRING_STRCTN, - utils::mkConcat(STRING_CONCAT, spl[0]), - node[1]), - NodeManager::currentNM()->mkNode( - kind::STRING_STRCTN, - utils::mkConcat(STRING_CONCAT, spl[1]), - node[1])); + NodeManager::currentNM()->mkNode(kind::STRING_STRCTN, + utils::mkConcat(spl[0], stype), + node[1]), + NodeManager::currentNM()->mkNode(kind::STRING_STRCTN, + utils::mkConcat(spl[1], stype), + node[1])); return returnRewrite(node, ret, "ctn-split"); } } @@ -2643,6 +2655,9 @@ Node SequencesRewriter::rewriteIndexof(Node node) return returnRewrite(node, negone, "idof-neg"); } + // the string type + TypeNode stype = node[0].getType(); + // evaluation and simple cases std::vector children0; utils::getConcat(node[0], children0); @@ -2751,7 +2766,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) { // For example: // str.indexof(str.++(x,y,z),y,0) ---> str.indexof(str.++(x,y),y,0) - Node nn = utils::mkConcat(STRING_CONCAT, children0); + Node nn = utils::mkConcat(children0, stype); Node ret = nm->mkNode(kind::STRING_STRIDOF, nn, node[1], node[2]); return returnRewrite(node, ret, "idof-def-ctn"); } @@ -2761,14 +2776,13 @@ Node SequencesRewriter::rewriteIndexof(Node node) { // str.indexof(str.++("AB", x, "C"), "C", 0) ---> // 2 + str.indexof(str.++(x, "C"), "C", 0) - Node ret = - nm->mkNode(kind::PLUS, - nm->mkNode(kind::STRING_LENGTH, - utils::mkConcat(STRING_CONCAT, nb)), - nm->mkNode(kind::STRING_STRIDOF, - utils::mkConcat(STRING_CONCAT, children0), - node[1], - node[2])); + Node ret = nm->mkNode( + kind::PLUS, + nm->mkNode(kind::STRING_LENGTH, utils::mkConcat(nb, stype)), + nm->mkNode(kind::STRING_STRIDOF, + utils::mkConcat(children0, stype), + node[1], + node[2])); return returnRewrite(node, ret, "idof-strip-cnst-endpts"); } } @@ -2783,7 +2797,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) // implies // str.indexof( str.++( x1, x2 ), y, z ) ---> // str.len( x1 ) + str.indexof( x2, y, z-str.len(x1) ) - Node nn = utils::mkConcat(STRING_CONCAT, children0); + Node nn = utils::mkConcat(children0, stype); Node ret = nm->mkNode(kind::PLUS, nm->mkNode(kind::MINUS, node[2], new_len), @@ -2809,7 +2823,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) // For example: // str.indexof(str.++("ABCD", x), y, 3) ---> // str.indexof(str.++("AAAD", x), y, 3) - Node nodeNr = utils::mkConcat(STRING_CONCAT, nr); + Node nodeNr = utils::mkConcat(nr, stype); Node normNr = lengthPreserveRewrite(nodeNr); if (normNr != nodeNr) { @@ -2817,7 +2831,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) utils::getConcat(normNr, normNrChildren); std::vector children(normNrChildren); children.insert(children.end(), children0.begin(), children0.end()); - Node nn = utils::mkConcat(STRING_CONCAT, children); + Node nn = utils::mkConcat(children, stype); Node res = nm->mkNode(kind::STRING_STRIDOF, nn, node[1], node[2]); return returnRewrite(node, res, "idof-norm-prefix"); } @@ -2830,7 +2844,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) std::vector ce; if (stripConstantEndpoints(children0, children1, cb, ce, -1)) { - Node ret = utils::mkConcat(STRING_CONCAT, children0); + Node ret = utils::mkConcat(children0, stype); ret = nm->mkNode(STRING_STRIDOF, ret, node[1], node[2]); // For example: // str.indexof( str.++( x, "A" ), "B", 0 ) ---> str.indexof( x, "B", 0 ) @@ -2852,6 +2866,8 @@ Node SequencesRewriter::rewriteReplace(Node node) Node ret = nm->mkNode(STRING_CONCAT, node[2], node[0]); return returnRewrite(node, ret, "rpl-rpl-empty"); } + // the string type + TypeNode stype = node.getType(); std::vector children0; utils::getConcat(node[0], children0); @@ -2883,7 +2899,7 @@ Node SequencesRewriter::rewriteReplace(Node node) children.push_back(s3); } children.insert(children.end(), children0.begin() + 1, children0.end()); - Node ret = utils::mkConcat(STRING_CONCAT, children); + Node ret = utils::mkConcat(children, stype); return returnRewrite(node, ret, "rpl-const-find"); } } @@ -2921,7 +2937,7 @@ Node SequencesRewriter::rewriteReplace(Node node) if (allEmptyEqs) { - Node nn1 = utils::mkConcat(STRING_CONCAT, emptyNodes); + Node nn1 = utils::mkConcat(emptyNodes, stype); if (node[1] != nn1) { Node ret = nm->mkNode(STRING_STRREPL, node[0], nn1, node[2]); @@ -2957,7 +2973,7 @@ Node SequencesRewriter::rewriteReplace(Node node) std::vector cres; cres.push_back(node[2]); cres.insert(cres.end(), ce.begin(), ce.end()); - Node ret = utils::mkConcat(STRING_CONCAT, cres); + Node ret = utils::mkConcat(cres, stype); return returnRewrite(node, ret, "rpl-cctn-rpl"); } else if (!ce.empty()) @@ -2970,11 +2986,11 @@ Node SequencesRewriter::rewriteReplace(Node node) std::vector scc; scc.push_back(NodeManager::currentNM()->mkNode( kind::STRING_STRREPL, - utils::mkConcat(STRING_CONCAT, children0), + utils::mkConcat(children0, stype), node[1], node[2])); scc.insert(scc.end(), ce.begin(), ce.end()); - Node ret = utils::mkConcat(STRING_CONCAT, scc); + Node ret = utils::mkConcat(scc, stype); return returnRewrite(node, ret, "rpl-cctn"); } } @@ -3022,7 +3038,7 @@ Node SequencesRewriter::rewriteReplace(Node node) if (node[0] == empty && allEmptyEqs) { std::vector emptyNodesList(emptyNodes.begin(), emptyNodes.end()); - Node nn1 = utils::mkConcat(STRING_CONCAT, emptyNodesList); + Node nn1 = utils::mkConcat(emptyNodesList, stype); if (nn1 != node[1] || nn2 != node[2]) { Node res = nm->mkNode(kind::STRING_STRREPL, node[0], nn1, nn2); @@ -3052,13 +3068,13 @@ Node SequencesRewriter::rewriteReplace(Node node) { std::vector cc; cc.insert(cc.end(), cb.begin(), cb.end()); - cc.push_back(NodeManager::currentNM()->mkNode( - kind::STRING_STRREPL, - utils::mkConcat(STRING_CONCAT, children0), - node[1], - node[2])); + cc.push_back( + NodeManager::currentNM()->mkNode(kind::STRING_STRREPL, + utils::mkConcat(children0, stype), + node[1], + node[2])); cc.insert(cc.end(), ce.begin(), ce.end()); - Node ret = utils::mkConcat(STRING_CONCAT, cc); + Node ret = utils::mkConcat(cc, stype); return returnRewrite(node, ret, "rpl-pull-endpt"); } } @@ -3080,8 +3096,8 @@ Node SequencesRewriter::rewriteReplace(Node node) children1.pop_back(); // Length of the non-substr components in the second argument - Node partLen1 = nm->mkNode(kind::STRING_LENGTH, - utils::mkConcat(STRING_CONCAT, children1)); + Node partLen1 = + nm->mkNode(kind::STRING_LENGTH, utils::mkConcat(children1, stype)); Node maxLen1 = nm->mkNode(kind::PLUS, partLen1, lastChild1[2]); Node zero = nm->mkConst(Rational(0)); @@ -3099,7 +3115,7 @@ Node SequencesRewriter::rewriteReplace(Node node) kind::PLUS, len0, one, nm->mkNode(kind::UMINUS, partLen1)))); Node res = nm->mkNode(kind::STRING_STRREPL, node[0], - utils::mkConcat(STRING_CONCAT, children1), + utils::mkConcat(children1, stype), node[2]); return returnRewrite(node, res, "repl-subst-idx"); } @@ -3285,7 +3301,7 @@ Node SequencesRewriter::rewriteReplace(Node node) std::vector checkLhs; checkLhs.insert( checkLhs.end(), children0.begin(), children0.begin() + checkIndex); - Node lhs = utils::mkConcat(STRING_CONCAT, checkLhs); + Node lhs = utils::mkConcat(checkLhs, stype); Node rhs = children0[checkIndex]; Node ctn = checkEntailContains(lhs, rhs); if (!ctn.isNull() && ctn.getConst()) @@ -3302,7 +3318,7 @@ Node SequencesRewriter::rewriteReplace(Node node) { std::vector remc(children0.begin() + lastCheckIndex, children0.end()); - Node rem = utils::mkConcat(STRING_CONCAT, remc); + Node rem = utils::mkConcat(remc, stype); Node ret = nm->mkNode(STRING_CONCAT, nm->mkNode(STRING_STRREPL, lastLhs, node[1], node[2]), @@ -3331,6 +3347,8 @@ Node SequencesRewriter::rewriteReplaceAll(Node node) { Assert(node.getKind() == STRING_STRREPLALL); + TypeNode stype = node.getType(); + if (node[0].isConst() && node[1].isConst()) { std::vector children; @@ -3362,7 +3380,7 @@ Node SequencesRewriter::rewriteReplaceAll(Node node) } } while (curr != std::string::npos && curr < sizeS); // constant evaluation - Node res = utils::mkConcat(STRING_CONCAT, children); + Node res = utils::mkConcat(children, stype); return returnRewrite(node, res, "replall-const"); } @@ -5411,6 +5429,8 @@ Node SequencesRewriter::inferEqsFromContains(Node x, Node y) { NodeManager* nm = NodeManager::currentNM(); Node emp = nm->mkConst(String("")); + Assert(x.getType() == y.getType()); + TypeNode stype = x.getType(); Node xLen = nm->mkNode(STRING_LENGTH, x); std::vector yLens; @@ -5481,7 +5501,7 @@ Node SequencesRewriter::inferEqsFromContains(Node x, Node y) // (= x (str.++ y1' ... ym')) if (!cs.empty()) { - nb << nm->mkNode(EQUAL, x, utils::mkConcat(STRING_CONCAT, cs)); + nb << nm->mkNode(EQUAL, x, utils::mkConcat(cs, stype)); } // (= y1'' "") ... (= yk'' "") for (const Node& zeroLen : zeroLens) diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index e6e0f8557..a26669fbf 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -519,7 +519,7 @@ bool TheoryStrings::collectModelInfoType( Assert(r.isConst() || processed.find(r) != processed.end()); nc.push_back(r.isConst() ? r : processed[r]); } - Node cc = utils::mkNConcat(nc); + Node cc = utils::mkNConcat(nc, tn); Assert(cc.isConst()); Trace("strings-model") << "*** Determined constant " << cc << " for " << nodes[i] << std::endl; processed[nodes[i]] = cc; @@ -1115,7 +1115,7 @@ void TheoryStrings::checkRegisterTermsNormalForms() Node lt = ei ? ei->d_lengthTerm : Node::null(); if (lt.isNull()) { - Node c = utils::mkNConcat(nfi.d_nf); + Node c = utils::mkNConcat(nfi.d_nf, eqc.getType()); registerTerm(c, 3); } } diff --git a/src/theory/strings/theory_strings_utils.cpp b/src/theory/strings/theory_strings_utils.cpp index a325108e4..5d27b8e2b 100644 --- a/src/theory/strings/theory_strings_utils.cpp +++ b/src/theory/strings/theory_strings_utils.cpp @@ -18,6 +18,7 @@ #include "options/strings_options.h" #include "theory/rewriter.h" +#include "theory/strings/word.h" using namespace CVC4::kind; @@ -115,12 +116,20 @@ void getConcat(Node n, std::vector& c) } } -Node mkConcat(Kind k, const std::vector& c) +Node mkConcat(const std::vector& c, TypeNode tn) { - Assert(!c.empty() || k == STRING_CONCAT); - NodeManager* nm = NodeManager::currentNM(); - return c.size() > 1 ? nm->mkNode(k, c) - : (c.size() == 1 ? c[0] : nm->mkConst(String(""))); + Assert(tn.isStringLike() || tn.isRegExp()); + if (c.empty()) + { + Assert(tn.isStringLike()); + return Word::mkEmptyWord(tn); + } + else if (c.size() == 1) + { + return c[0]; + } + Kind k = tn.isStringLike() ? STRING_CONCAT : REGEXP_CONCAT; + return NodeManager::currentNM()->mkNode(k, c); } Node mkNConcat(Node n1, Node n2) @@ -135,9 +144,9 @@ Node mkNConcat(Node n1, Node n2, Node n3) NodeManager::currentNM()->mkNode(STRING_CONCAT, n1, n2, n3)); } -Node mkNConcat(const std::vector& c) +Node mkNConcat(const std::vector& c, TypeNode tn) { - return Rewriter::rewrite(mkConcat(STRING_CONCAT, c)); + return Rewriter::rewrite(mkConcat(c, tn)); } Node mkNLength(Node t) @@ -147,12 +156,11 @@ Node mkNLength(Node t) Node getConstantComponent(Node t) { - Kind tk = t.getKind(); - if (tk == STRING_TO_REGEXP) + if (t.getKind() == STRING_TO_REGEXP) { return t[0].isConst() ? t[0] : Node::null(); } - return tk == CONST_STRING ? t : Node::null(); + return t.isConst() ? t : Node::null(); } Node getConstantEndpoint(Node e, bool isSuf) diff --git a/src/theory/strings/theory_strings_utils.h b/src/theory/strings/theory_strings_utils.h index 51fe8cfc7..5f18d3936 100644 --- a/src/theory/strings/theory_strings_utils.h +++ b/src/theory/strings/theory_strings_utils.h @@ -56,10 +56,10 @@ void flattenOp(Kind k, Node n, std::vector& conj); void getConcat(Node n, std::vector& c); /** - * Make the concatentation from vector c - * The kind k is either STRING_CONCAT or REGEXP_CONCAT. + * Make the concatentation from vector c of (string-like or regular + * expression) type tn. */ -Node mkConcat(Kind k, const std::vector& c); +Node mkConcat(const std::vector& c, TypeNode tn); /** * Returns the rewritten form of the string concatenation of n1 and n2. @@ -72,9 +72,10 @@ Node mkNConcat(Node n1, Node n2); Node mkNConcat(Node n1, Node n2, Node n3); /** - * Returns the rewritten form of the string concatenation of nodes in c. + * Returns the rewritten form of the concatentation from vector c of + * (string-like) type tn. */ -Node mkNConcat(const std::vector& c); +Node mkNConcat(const std::vector& c, TypeNode tn); /** * Returns the rewritten form of the length of string term t.