From 7179be03b049d3046140316c4c5987efbdbd09b8 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 24 Sep 2021 00:11:24 -0500 Subject: [PATCH] Eliminate calls to Rewriter::rewrite from strings entailment checks (#7203) There are a few further circular references that prevent us from not passing Rewriter to the strings TheoryRewriter constructor, this can be cleaned in future PRs. --- .../passes/foreign_theory_rewrite.cpp | 36 +++-- .../passes/foreign_theory_rewrite.h | 25 ++-- src/theory/quantifiers/extended_rewrite.cpp | 3 +- src/theory/strings/arith_entail.cpp | 33 ++--- src/theory/strings/arith_entail.h | 50 +++---- src/theory/strings/sequences_rewriter.cpp | 128 ++++++++++-------- src/theory/strings/sequences_rewriter.h | 9 +- src/theory/strings/strings_entail.cpp | 48 +++---- src/theory/strings/strings_entail.h | 30 ++-- src/theory/strings/strings_rewriter.cpp | 5 +- src/theory/strings/strings_rewriter.h | 2 +- src/theory/strings/theory_strings.cpp | 2 +- .../pass_foreign_theory_rewrite_white.cpp | 5 +- test/unit/theory/sequences_rewriter_white.cpp | 90 ++++++------ 14 files changed, 260 insertions(+), 206 deletions(-) diff --git a/src/preprocessing/passes/foreign_theory_rewrite.cpp b/src/preprocessing/passes/foreign_theory_rewrite.cpp index 24edf1509..70ad0fea3 100644 --- a/src/preprocessing/passes/foreign_theory_rewrite.cpp +++ b/src/preprocessing/passes/foreign_theory_rewrite.cpp @@ -20,6 +20,7 @@ #include "expr/node_traversal.h" #include "preprocessing/assertion_pipeline.h" #include "preprocessing/preprocessing_pass_context.h" +#include "smt/env.h" #include "theory/rewriter.h" #include "theory/strings/arith_entail.h" @@ -28,12 +29,13 @@ namespace preprocessing { namespace passes { using namespace cvc5::theory; -ForeignTheoryRewrite::ForeignTheoryRewrite( - PreprocessingPassContext* preprocContext) - : PreprocessingPass(preprocContext, "foreign-theory-rewrite"), - d_cache(userContext()){}; -Node ForeignTheoryRewrite::simplify(Node n) +ForeignTheoryRewriter::ForeignTheoryRewriter(Env& env) + : EnvObj(env), d_cache(userContext()) +{ +} + +Node ForeignTheoryRewriter::simplify(Node n) { std::vector toVisit; n = rewrite(n); @@ -87,7 +89,7 @@ Node ForeignTheoryRewrite::simplify(Node n) return d_cache[n]; } -Node ForeignTheoryRewrite::foreignRewrite(Node n) +Node ForeignTheoryRewriter::foreignRewrite(Node n) { // n is a rewritten node, and so GT, LT, LEQ // should have been eliminated @@ -102,18 +104,19 @@ Node ForeignTheoryRewrite::foreignRewrite(Node n) return n; } -Node ForeignTheoryRewrite::rewriteStringsGeq(Node n) +Node ForeignTheoryRewriter::rewriteStringsGeq(Node n) { + theory::strings::ArithEntail ae(d_env.getRewriter()); // check if the node can be simplified to true - if (theory::strings::ArithEntail::check(n[0], n[1], false)) + if (ae.check(n[0], n[1], false)) { return NodeManager::currentNM()->mkConst(true); } return n; } -Node ForeignTheoryRewrite::reconstructNode(Node originalNode, - std::vector newChildren) +Node ForeignTheoryRewriter::reconstructNode(Node originalNode, + std::vector newChildren) { // Nodes with no children are reconstructed to themselves if (originalNode.getNumChildren() == 0) @@ -137,15 +140,22 @@ Node ForeignTheoryRewrite::reconstructNode(Node originalNode, return builder.constructNode(); } +ForeignTheoryRewrite::ForeignTheoryRewrite( + PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "foreign-theory-rewrite"), + d_ftr(preprocContext->getEnv()) +{ +} + PreprocessingPassResult ForeignTheoryRewrite::applyInternal( AssertionPipeline* assertionsToPreprocess) { - for (unsigned i = 0; i < assertionsToPreprocess->size(); ++i) + for (size_t i = 0, nasserts = assertionsToPreprocess->size(); i < nasserts; + ++i) { assertionsToPreprocess->replace( - i, rewrite(simplify((*assertionsToPreprocess)[i]))); + i, rewrite(d_ftr.simplify((*assertionsToPreprocess)[i]))); } - return PreprocessingPassResult::NO_CONFLICT; } diff --git a/src/preprocessing/passes/foreign_theory_rewrite.h b/src/preprocessing/passes/foreign_theory_rewrite.h index 4940f326c..81f5282ef 100644 --- a/src/preprocessing/passes/foreign_theory_rewrite.h +++ b/src/preprocessing/passes/foreign_theory_rewrite.h @@ -23,6 +23,7 @@ #include "context/cdhashmap.h" #include "expr/node.h" #include "preprocessing/preprocessing_pass.h" +#include "smt/env_obj.h" namespace cvc5 { namespace preprocessing { @@ -30,14 +31,10 @@ namespace passes { using CDNodeMap = context::CDHashMap; -class ForeignTheoryRewrite : public PreprocessingPass +class ForeignTheoryRewriter : protected EnvObj { public: - ForeignTheoryRewrite(PreprocessingPassContext* preprocContext); - - protected: - PreprocessingPassResult applyInternal( - AssertionPipeline* assertionsToPreprocess) override; + ForeignTheoryRewriter(Env& env); /** the main function that simplifies n. * does a traversal on n and call rewriting fucntions. */ @@ -45,14 +42,14 @@ class ForeignTheoryRewrite : public PreprocessingPass /** A specific simplification function specific for GEQ * constraints in strings. */ - static Node rewriteStringsGeq(Node n); + Node rewriteStringsGeq(Node n); /** invoke rewrite functions for n. * based on the structure of n (typically its kind) * we invoke rewrites from other theories. * For example: when encountering a `>=` node, * we invoke rewrites from the theory of strings. */ - static Node foreignRewrite(Node n); + Node foreignRewrite(Node n); /** construct a node with the same operator as originalNode whose children are * processedChildren */ @@ -61,6 +58,18 @@ class ForeignTheoryRewrite : public PreprocessingPass CDNodeMap d_cache; }; +class ForeignTheoryRewrite : public PreprocessingPass +{ + public: + ForeignTheoryRewrite(PreprocessingPassContext* preprocContext); + + protected: + PreprocessingPassResult applyInternal( + AssertionPipeline* assertionsToPreprocess) override; + /** Foreign theory rewriter */ + ForeignTheoryRewriter d_ftr; +}; + } // namespace passes } // namespace preprocessing } // namespace cvc5 diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 40e28eb78..f5883c265 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -1710,7 +1710,8 @@ Node ExtendedRewriter::extendedRewriteStrings(Node ret) const if (ret.getKind() == EQUAL) { - new_ret = strings::SequencesRewriter(nullptr).rewriteEqualityExt(ret); + strings::SequencesRewriter sr(&d_rew, nullptr); + new_ret = sr.rewriteEqualityExt(ret); } return new_ret; diff --git a/src/theory/strings/arith_entail.cpp b/src/theory/strings/arith_entail.cpp index 6a0eea41a..d9cbc4c40 100644 --- a/src/theory/strings/arith_entail.cpp +++ b/src/theory/strings/arith_entail.cpp @@ -30,14 +30,16 @@ namespace cvc5 { namespace theory { namespace strings { +ArithEntail::ArithEntail(Rewriter* r) : d_rr(r) {} + bool ArithEntail::checkEq(Node a, Node b) { if (a == b) { return true; } - Node ar = Rewriter::rewrite(a); - Node br = Rewriter::rewrite(b); + Node ar = d_rr->rewrite(a); + Node br = d_rr->rewrite(b); return ar == br; } @@ -72,7 +74,7 @@ bool ArithEntail::check(Node a, bool strict) Node ar = strict ? NodeManager::currentNM()->mkNode( kind::MINUS, a, NodeManager::currentNM()->mkConst(Rational(1))) : a; - ar = Rewriter::rewrite(ar); + ar = d_rr->rewrite(ar); if (ar.getAttribute(StrCheckEntailArithComputedAttr())) { @@ -93,7 +95,7 @@ bool ArithEntail::check(Node a, bool strict) bool ArithEntail::checkApprox(Node ar) { - Assert(Rewriter::rewrite(ar) == ar); + Assert(d_rr->rewrite(ar) == ar); NodeManager* nm = NodeManager::currentNM(); std::map msum; Trace("strings-ent-approx-debug") @@ -139,7 +141,7 @@ bool ArithEntail::checkApprox(Node ar) { Node curr = toProcess.back(); Trace("strings-ent-approx-debug") << " process " << curr << std::endl; - curr = Rewriter::rewrite(curr); + curr = d_rr->rewrite(curr); toProcess.pop_back(); if (visited.find(curr) == visited.end()) { @@ -195,7 +197,7 @@ bool ArithEntail::checkApprox(Node ar) Node aar = aarSum.empty() ? nm->mkConst(Rational(0)) : (aarSum.size() == 1 ? aarSum[0] : nm->mkNode(PLUS, aarSum)); - aar = Rewriter::rewrite(aar); + aar = d_rr->rewrite(aar); Trace("strings-ent-approx-debug") << "...processed fixed sum " << aar << " with " << mApprox.size() << " approximated monomials." << std::endl; @@ -266,8 +268,7 @@ bool ArithEntail::checkApprox(Node ar) Node ci = aam.second; if (!cr.isNull()) { - ci = ci.isNull() ? cr - : Rewriter::rewrite(nm->mkNode(MULT, ci, cr)); + ci = ci.isNull() ? cr : d_rr->rewrite(nm->mkNode(MULT, ci, cr)); } Trace("strings-ent-approx-debug") << ci << "*" << ti << " "; int ciSgn = ci.isNull() ? 1 : ci.getConst().sgn(); @@ -324,7 +325,7 @@ bool ArithEntail::checkApprox(Node ar) Node mn = ArithMSum::mkCoeffTerm(msum[v], vapprox); aar = nm->mkNode(PLUS, aar, mn); // update the msumAar map - aar = Rewriter::rewrite(aar); + aar = d_rr->rewrite(aar); msumAar.clear(); if (!ArithMSum::getMonomialSum(aar, msumAar)) { @@ -557,7 +558,7 @@ void ArithEntail::getArithApproximations(Node a, bool ArithEntail::checkWithEqAssumption(Node assumption, Node a, bool strict) { Assert(assumption.getKind() == kind::EQUAL); - Assert(Rewriter::rewrite(assumption) == assumption); + Assert(d_rr->rewrite(assumption) == assumption); Trace("strings-entail") << "checkWithEqAssumption: " << assumption << " " << a << ", strict=" << strict << std::endl; @@ -633,7 +634,7 @@ bool ArithEntail::checkWithAssumption(Node assumption, Node b, bool strict) { - Assert(Rewriter::rewrite(assumption) == assumption); + Assert(d_rr->rewrite(assumption) == assumption); NodeManager* nm = NodeManager::currentNM(); @@ -659,7 +660,7 @@ bool ArithEntail::checkWithAssumption(Node assumption, Node s = nm->mkBoundVar("slackVal", nm->stringType()); Node slen = nm->mkNode(kind::STRING_LENGTH, s); - assumption = Rewriter::rewrite( + assumption = d_rr->rewrite( nm->mkNode(kind::EQUAL, x, nm->mkNode(kind::PLUS, y, slen))); } @@ -695,7 +696,7 @@ bool ArithEntail::checkWithAssumptions(std::vector assumptions, bool res = false; for (const auto& assumption : assumptions) { - Assert(Rewriter::rewrite(assumption) == assumption); + Assert(d_rr->rewrite(assumption) == assumption); if (checkWithAssumption(assumption, a, b, strict)) { @@ -708,7 +709,7 @@ bool ArithEntail::checkWithAssumptions(std::vector assumptions, Node ArithEntail::getConstantBound(Node a, bool isLower) { - Assert(Rewriter::rewrite(a) == a); + Assert(d_rr->rewrite(a) == a); Node ret; if (a.isConst()) { @@ -773,7 +774,7 @@ Node ArithEntail::getConstantBound(Node a, bool isLower) else { ret = NodeManager::currentNM()->mkNode(a.getKind(), children); - ret = Rewriter::rewrite(ret); + ret = d_rr->rewrite(ret); } } } @@ -791,7 +792,7 @@ Node ArithEntail::getConstantBound(Node a, bool isLower) bool ArithEntail::checkInternal(Node a) { - Assert(Rewriter::rewrite(a) == a); + Assert(d_rr->rewrite(a) == a); // check whether a >= 0 if (a.isConst()) { diff --git a/src/theory/strings/arith_entail.h b/src/theory/strings/arith_entail.h index 64e76e5b6..e2b3d0af6 100644 --- a/src/theory/strings/arith_entail.h +++ b/src/theory/strings/arith_entail.h @@ -24,6 +24,9 @@ namespace cvc5 { namespace theory { + +class Rewriter; + namespace strings { /** @@ -34,19 +37,20 @@ namespace strings { class ArithEntail { public: + ArithEntail(Rewriter* r); /** check arithmetic entailment equal * Returns true if it is always the case that a = b. */ - static bool checkEq(Node a, Node b); + bool checkEq(Node a, Node b); /** check arithmetic entailment * Returns true if it is always the case that a >= b, * and a>b if strict is true. */ - static bool check(Node a, Node b, bool strict = false); + bool check(Node a, Node b, bool strict = false); /** check arithmetic entailment * Returns true if it is always the case that a >= 0. */ - static bool check(Node a, bool strict = false); + bool check(Node a, bool strict = false); /** check arithmetic entailment with approximations * * Returns true if it is always the case that a >= 0. We expect that a is in @@ -61,7 +65,7 @@ class ArithEntail * and thus the entailment len( x ) - len( substr( y, 0, len( x ) ) ) >= 0 * holds. */ - static bool checkApprox(Node a); + bool checkApprox(Node a); /** * Checks whether assumption |= a >= 0 (if strict is false) or @@ -74,9 +78,7 @@ class ArithEntail * * Because: x = -(str.len y), so -x >= 0 --> (str.len y) >= 0 --> true */ - static bool checkWithEqAssumption(Node assumption, - Node a, - bool strict = false); + bool checkWithEqAssumption(Node assumption, Node a, bool strict = false); /** * Checks whether assumption |= a >= b (if strict is false) or @@ -90,10 +92,10 @@ class ArithEntail * * Because: x = -(str.len y), so 0 >= x --> 0 >= -(str.len y) --> true */ - static bool checkWithAssumption(Node assumption, - Node a, - Node b, - bool strict = false); + bool checkWithAssumption(Node assumption, + Node a, + Node b, + bool strict = false); /** * Checks whether assumptions |= a >= b (if strict is false) or @@ -108,10 +110,10 @@ class ArithEntail * * Because: x = -(str.len y), so 0 >= x --> 0 >= -(str.len y) --> true */ - static bool checkWithAssumptions(std::vector assumptions, - Node a, - Node b, - bool strict = false); + bool checkWithAssumptions(std::vector assumptions, + Node a, + Node b, + bool strict = false); /** get arithmetic lower bound * If this function returns a non-null Node ret, @@ -126,7 +128,7 @@ class ArithEntail * if and only if * check( a, strict ) = true. */ - static Node getConstantBound(Node a, bool isLower = true); + Node getConstantBound(Node a, bool isLower = true); /** * Given an inequality y1 + ... + yn >= x, removes operands yi s.t. the @@ -144,16 +146,16 @@ class ArithEntail * --> returns false because it is not possible to show * str.len(y) >= str.len(x) */ - static bool inferZerosInSumGeq(Node x, - std::vector& ys, - std::vector& zeroYs); + bool inferZerosInSumGeq(Node x, + std::vector& ys, + std::vector& zeroYs); private: /** check entail arithmetic internal * Returns true if we can show a >= 0 always. * a is in rewritten form. */ - static bool checkInternal(Node a); + bool checkInternal(Node a); /** Get arithmetic approximations * * This gets the (set of) arithmetic approximations for term a and stores @@ -169,9 +171,11 @@ class ArithEntail * function might be len( substr( x, 0, n ) ) - len( y ), where we don't * consider (recursively) the approximations for len( substr( x, 0, n ) ). */ - static void getArithApproximations(Node a, - std::vector& approx, - bool isOverApprox = false); + void getArithApproximations(Node a, + std::vector& approx, + bool isOverApprox = false); + /** The underlying rewriter */ + Rewriter* d_rr; }; } // namespace strings diff --git a/src/theory/strings/sequences_rewriter.cpp b/src/theory/strings/sequences_rewriter.cpp index 7885c857e..bd8a4d8df 100644 --- a/src/theory/strings/sequences_rewriter.cpp +++ b/src/theory/strings/sequences_rewriter.cpp @@ -37,11 +37,18 @@ namespace cvc5 { namespace theory { namespace strings { -SequencesRewriter::SequencesRewriter(HistogramStat* statistics) - : d_statistics(statistics), d_stringsEntail(*this) +SequencesRewriter::SequencesRewriter(Rewriter* r, + HistogramStat* statistics) + : d_statistics(statistics), + d_arithEntail(r), + d_stringsEntail(r, d_arithEntail, *this) { } +ArithEntail& SequencesRewriter::getArithEntail() { return d_arithEntail; } + +StringsEntail& SequencesRewriter::getStringsEntail() { return d_stringsEntail; } + Node SequencesRewriter::rewriteEquality(Node node) { Assert(node.getKind() == kind::EQUAL); @@ -216,7 +223,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) // ------- homogeneous constants for (unsigned i = 0; i < 2; i++) { - Node cn = StringsEntail::checkHomogeneousString(node[i]); + Node cn = d_stringsEntail.checkHomogeneousString(node[i]); if (!cn.isNull() && !Word::isEmpty(cn)) { Assert(cn.isConst()); @@ -311,7 +318,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) } // (= "" (str.replace x y "A")) ---> (and (= x "") (not (= y ""))) - if (StringsEntail::checkNonEmpty(ne[2])) + if (d_stringsEntail.checkNonEmpty(ne[2])) { Node ret = nm->mkNode(AND, @@ -321,7 +328,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) } // (= "" (str.replace x "A" "")) ---> (str.prefix x "A") - if (StringsEntail::checkLengthOne(ne[1], true) && ne[2] == empty) + if (d_stringsEntail.checkLengthOne(ne[1], true) && ne[2] == empty) { Node ret = nm->mkNode(STRING_PREFIX, ne[0], ne[1]); return returnRewrite(node, ret, Rewrite::STR_EMP_REPL_EMP); @@ -331,7 +338,8 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) { Node zero = nm->mkConst(Rational(0)); - if (ArithEntail::check(ne[1], false) && ArithEntail::check(ne[2], true)) + if (d_arithEntail.check(ne[1], false) + && d_arithEntail.check(ne[2], true)) { // (= "" (str.substr x 0 m)) ---> (= "" x) if m > 0 if (ne[1] == zero) @@ -347,7 +355,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) } // (= "" (str.substr "A" 0 z)) ---> (<= z 0) - if (StringsEntail::checkNonEmpty(ne[0]) && ne[1] == zero) + if (d_stringsEntail.checkNonEmpty(ne[0]) && ne[1] == zero) { Node ret = nm->mkNode(LEQ, ne[2], zero); return returnRewrite(node, ret, Rewrite::STR_EMP_SUBSTR_LEQ_Z); @@ -365,7 +373,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) Node x = node[1 - i]; // (= "A" (str.replace "" x y)) ---> (= "" (str.replace "A" y x)) - if (StringsEntail::checkNonEmpty(x) && repl[0] == empty) + if (d_stringsEntail.checkNonEmpty(x) && repl[0] == empty) { Node ret = nm->mkNode( EQUAL, empty, nm->mkNode(STRING_REPLACE, x, repl[2], repl[1])); @@ -396,7 +404,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) { Node lenY = nm->mkNode(STRING_LENGTH, repl[1]); Node lenZ = nm->mkNode(STRING_LENGTH, repl[2]); - if (ArithEntail::checkEq(lenY, lenZ)) + if (d_arithEntail.checkEq(lenY, lenZ)) { Node ret = nm->mkNode(OR, nm->mkNode(EQUAL, repl[0], repl[1]), @@ -419,7 +427,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) { if (node[1 - i].getKind() == STRING_CONCAT) { - new_ret = StringsEntail::inferEqsFromContains(node[i], node[1 - i]); + new_ret = d_stringsEntail.inferEqsFromContains(node[i], node[1 - i]); if (!new_ret.isNull()) { return returnRewrite(node, new_ret, Rewrite::STR_EQ_CONJ_LEN_ENTAIL); @@ -456,7 +464,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) Node lenPfx0 = nm->mkNode(STRING_LENGTH, pfx0); Node lenPfx1 = nm->mkNode(STRING_LENGTH, pfx1); - if (ArithEntail::checkEq(lenPfx0, lenPfx1)) + if (d_arithEntail.checkEq(lenPfx0, lenPfx1)) { std::vector sfxv0(v0.begin() + i, v0.end()); std::vector sfxv1(v1.begin() + j, v1.end()); @@ -466,7 +474,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) .eqNode(utils::mkConcat(sfxv1, stype))); return returnRewrite(node, ret, Rewrite::SPLIT_EQ); } - else if (ArithEntail::check(lenPfx1, lenPfx0, true)) + else if (d_arithEntail.check(lenPfx1, lenPfx0, true)) { // The prefix on the right-hand side is strictly longer than the // prefix on the left-hand side, so we try to strip the right-hand @@ -476,7 +484,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) // (= (str.++ "A" x y) (str.++ x "AB" z)) ---> // (and (= (str.++ "A" x) (str.++ x "A")) (= y (str.++ "B" z))) std::vector rpfxv1; - if (StringsEntail::stripSymbolicLength( + if (d_stringsEntail.stripSymbolicLength( pfxv1, rpfxv1, 1, lenPfx0, true)) { // The rewrite requires the full left-hand prefix length to be @@ -501,7 +509,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) // in the inner loop) break; } - else if (ArithEntail::check(lenPfx0, lenPfx1, true)) + else if (d_arithEntail.check(lenPfx0, lenPfx1, true)) { // The prefix on the left-hand side is strictly longer than the // prefix on the right-hand side, so we try to strip the left-hand @@ -512,7 +520,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node) // (and (= (str.++ x "A") (str.++ "A" x)) (= (str.++ "B" z) y)) std::vector sfxv0 = pfxv0; std::vector rpfxv0; - if (StringsEntail::stripSymbolicLength( + if (d_stringsEntail.stripSymbolicLength( sfxv0, rpfxv0, 1, lenPfx1, true)) { // The rewrite requires the full right-hand prefix length to be @@ -698,7 +706,7 @@ Node SequencesRewriter::rewriteConcat(Node node) Node lastX; for (size_t i = 0, nsize = node_vec.size(); i < nsize; i++) { - Node s = StringsEntail::getStringOrEmpty(node_vec[i]); + Node s = d_stringsEntail.getStringOrEmpty(node_vec[i]); bool nextX = false; if (s != lastX) { @@ -1711,12 +1719,12 @@ Node SequencesRewriter::rewriteSubstr(Node node) Node zero = nm->mkConst(cvc5::Rational(0)); // if entailed non-positive length or negative start point - if (ArithEntail::check(zero, node[1], true)) + if (d_arithEntail.check(zero, node[1], true)) { Node ret = Word::mkEmptyWord(node.getType()); return returnRewrite(node, ret, Rewrite::SS_START_NEG); } - else if (ArithEntail::check(zero, node[2])) + else if (d_arithEntail.check(zero, node[2])) { Node ret = Word::mkEmptyWord(node.getType()); return returnRewrite(node, ret, Rewrite::SS_LEN_NON_POS); @@ -1742,7 +1750,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) // over-approximation of the length of (str.substr x a a), which // then allows us to reason that the result of the whole term must // be empty. - if (ArithEntail::check(node[1], node[0][2])) + if (d_arithEntail.check(node[1], node[0][2])) { Node ret = Word::mkEmptyWord(node.getType()); return returnRewrite(node, ret, Rewrite::SS_START_GEQ_LEN); @@ -1755,8 +1763,8 @@ Node SequencesRewriter::rewriteSubstr(Node node) // if (str.len y) = 1 and (str.len z) = 1 if (node[1] == zero) { - if (StringsEntail::checkLengthOne(node[0][1], true) - && StringsEntail::checkLengthOne(node[0][2], true)) + if (d_stringsEntail.checkLengthOne(node[0][1], true) + && d_stringsEntail.checkLengthOne(node[0][2], true)) { Node ret = nm->mkNode( kind::STRING_REPLACE, @@ -1777,7 +1785,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) { Node curr = node[2]; std::vector childrenr; - if (StringsEntail::stripSymbolicLength(n1, childrenr, 1, curr)) + if (d_stringsEntail.stripSymbolicLength(n1, childrenr, 1, curr)) { if (curr != zero && !n1.empty()) { @@ -1809,7 +1817,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) Node end_pt = Rewriter::rewrite(nm->mkNode(kind::PLUS, node[1], node[2])); if (node[2] != tot_len) { - if (ArithEntail::check(node[2], tot_len)) + if (d_arithEntail.check(node[2], tot_len)) { // end point beyond end point of string, map to tot_len Node ret = nm->mkNode(kind::STRING_SUBSTR, node[0], node[1], tot_len); @@ -1825,7 +1833,8 @@ Node SequencesRewriter::rewriteSubstr(Node node) // (str.substr s x y) --> "" if x < len(s) |= 0 >= y Node n1_lt_tot_len = Rewriter::rewrite(nm->mkNode(kind::LT, node[1], tot_len)); - if (ArithEntail::checkWithAssumption(n1_lt_tot_len, zero, node[2], false)) + if (d_arithEntail.checkWithAssumption( + n1_lt_tot_len, zero, node[2], false)) { Node ret = Word::mkEmptyWord(node.getType()); return returnRewrite(node, ret, Rewrite::SS_START_ENTAILS_ZERO_LEN); @@ -1834,7 +1843,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) // (str.substr s x y) --> "" if 0 < y |= x >= str.len(s) Node non_zero_len = Rewriter::rewrite(nm->mkNode(kind::LT, zero, node[2])); - if (ArithEntail::checkWithAssumption( + if (d_arithEntail.checkWithAssumption( non_zero_len, node[1], tot_len, false)) { Node ret = Word::mkEmptyWord(node.getType()); @@ -1844,7 +1853,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) // (str.substr s x y) --> "" if x >= 0 |= 0 >= str.len(s) Node geq_zero_start = Rewriter::rewrite(nm->mkNode(kind::GEQ, node[1], zero)); - if (ArithEntail::checkWithAssumption( + if (d_arithEntail.checkWithAssumption( geq_zero_start, zero, tot_len, false)) { Node ret = Word::mkEmptyWord(node.getType()); @@ -1853,7 +1862,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) } // (str.substr s x x) ---> "" if (str.len s) <= 1 - if (node[1] == node[2] && StringsEntail::checkLengthOne(node[0])) + if (node[1] == node[2] && d_stringsEntail.checkLengthOne(node[0])) { Node ret = Word::mkEmptyWord(node.getType()); return returnRewrite(node, ret, Rewrite::SS_LEN_ONE_Z_Z); @@ -1864,7 +1873,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) // strip off components while quantity is entailed positive int dir = r == 0 ? 1 : -1; std::vector childrenr; - if (StringsEntail::stripSymbolicLength(n1, childrenr, dir, curr)) + if (d_stringsEntail.stripSymbolicLength(n1, childrenr, dir, curr)) { if (r == 0) { @@ -1888,7 +1897,7 @@ Node SequencesRewriter::rewriteSubstr(Node node) { Node start_inner = node[0][1]; Node start_outer = node[1]; - if (ArithEntail::check(start_outer) && ArithEntail::check(start_inner)) + if (d_arithEntail.check(start_outer) && d_arithEntail.check(start_inner)) { // both are positive // thus, start point is definitely start_inner+start_outer. @@ -1905,11 +1914,11 @@ Node SequencesRewriter::rewriteSubstr(Node node) { new_len = len_from_inner; } - else if (ArithEntail::check(len_from_inner, len_from_outer)) + else if (d_arithEntail.check(len_from_inner, len_from_outer)) { new_len = len_from_outer; } - else if (ArithEntail::check(len_from_outer, len_from_inner)) + else if (d_arithEntail.check(len_from_outer, len_from_inner)) { new_len = len_from_inner; } @@ -1993,7 +2002,7 @@ Node SequencesRewriter::rewriteContains(Node node) { Node len1 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[1]); - if (ArithEntail::check(len1, true)) + if (d_arithEntail.check(len1, true)) { // we handle the false case here since the rewrite for equality // uses this function, hence we want to conclude false if possible. @@ -2002,7 +2011,7 @@ Node SequencesRewriter::rewriteContains(Node node) return returnRewrite(node, ret, Rewrite::CTN_LHS_EMPTYSTR); } } - else if (StringsEntail::checkLengthOne(t)) + else if (d_stringsEntail.checkLengthOne(t)) { std::vector vec = Word::getChars(node[0]); Node emp = Word::mkEmptyWord(t.getType()); @@ -2023,7 +2032,7 @@ Node SequencesRewriter::rewriteContains(Node node) else if (node[1].getKind() == kind::STRING_CONCAT) { int firstc, lastc; - if (!StringsEntail::canConstantContainConcat( + if (!d_stringsEntail.canConstantContainConcat( node[0], node[1], firstc, lastc)) { Node ret = NodeManager::currentNM()->mkConst(false); @@ -2098,7 +2107,7 @@ Node SequencesRewriter::rewriteContains(Node node) // strip endpoints std::vector nb; std::vector ne; - if (StringsEntail::stripConstantEndpoints(nc1, nc2, nb, ne)) + if (d_stringsEntail.stripConstantEndpoints(nc1, nc2, nb, ne)) { Node ret = NodeManager::currentNM()->mkNode( kind::STRING_CONTAINS, utils::mkConcat(nc1, stype), node[1]); @@ -2163,7 +2172,7 @@ Node SequencesRewriter::rewriteContains(Node node) // length entailment Node len_n1 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[0]); Node len_n2 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[1]); - if (ArithEntail::check(len_n2, len_n1, true)) + if (d_arithEntail.check(len_n2, len_n1, true)) { // len( n2 ) > len( n1 ) => contains( n1, n2 ) ---> false Node ret = NodeManager::currentNM()->mkConst(false); @@ -2174,13 +2183,13 @@ Node SequencesRewriter::rewriteContains(Node node) // For example, contains( str.++( x, "b" ), str.++( "a", x ) ) ---> false // since the number of a's in the second argument is greater than the number // of a's in the first argument - if (StringsEntail::checkMultisetSubset(node[0], node[1])) + if (d_stringsEntail.checkMultisetSubset(node[0], node[1])) { Node ret = nm->mkConst(false); return returnRewrite(node, ret, Rewrite::CTN_MSET_NSS); } - if (ArithEntail::check(len_n2, len_n1, false)) + if (d_arithEntail.check(len_n2, len_n1, false)) { // len( n2 ) >= len( n1 ) => contains( n1, n2 ) ---> n1 = n2 Node ret = node[0].eqNode(node[1]); @@ -2264,7 +2273,7 @@ Node SequencesRewriter::rewriteContains(Node node) // (str.contains (str.replace x y x) z) ---> (str.contains x z) // if (str.len z) <= 1 - if (StringsEntail::checkLengthOne(node[1])) + if (d_stringsEntail.checkLengthOne(node[1])) { Node ret = nm->mkNode(kind::STRING_CONTAINS, node[0][0], node[1]); return returnRewrite(node, ret, Rewrite::CTN_REPL_LEN_ONE_TO_CTN); @@ -2285,7 +2294,7 @@ Node SequencesRewriter::rewriteContains(Node node) // (str.contains (str.replace x y z) w) ---> // (str.contains (str.replace x y "") w) // if (str.contains z w) ---> false and (str.len w) = 1 - if (StringsEntail::checkLengthOne(node[1])) + if (d_stringsEntail.checkLengthOne(node[1])) { Node ctn = d_stringsEntail.checkContains(node[0][2], node[1]); if (!ctn.isNull() && !ctn.getConst()) @@ -2385,7 +2394,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) return returnRewrite(node, zero, Rewrite::IDOF_EQ_CST_START); } } - if (ArithEntail::check(node[2], true)) + if (d_arithEntail.check(node[2], true)) { // y>0 implies indexof( x, x, y ) --> -1 Node negone = nm->mkConst(Rational(-1)); @@ -2408,7 +2417,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) { if (Word::isEmpty(node[1])) { - if (ArithEntail::check(len0, node[2]) && ArithEntail::check(node[2])) + if (d_arithEntail.check(len0, node[2]) && d_arithEntail.check(node[2])) { // len(x)>=z ^ z >=0 implies indexof( x, "", z ) ---> z return returnRewrite(node, node[2], Rewrite::IDOF_EMP_IDOF); @@ -2416,7 +2425,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) } } - if (ArithEntail::check(len1, len0m2, true)) + if (d_arithEntail.check(len1, len0m2, true)) { // len(x)-z < len(y) implies indexof( x, y, z ) ----> -1 Node negone = nm->mkConst(Rational(-1)); @@ -2457,7 +2466,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) } // Strip components from the beginning that are guaranteed not to match - if (StringsEntail::stripConstantEndpoints( + if (d_stringsEntail.stripConstantEndpoints( children0, children1, nb, ne, 1)) { // str.indexof(str.++("AB", x, "C"), "C", 0) ---> @@ -2477,12 +2486,12 @@ Node SequencesRewriter::rewriteIndexof(Node node) // (str.indexof t "" n) is not rewritten to something other than -1 when n // is beyond the length of t. This is not required for the above rewrites, // which only apply when n=0. - if (ArithEntail::check(node[2]) && ArithEntail::check(len0, node[2])) + if (d_arithEntail.check(node[2]) && d_arithEntail.check(len0, node[2])) { // strip symbolic length Node new_len = node[2]; std::vector nr; - if (StringsEntail::stripSymbolicLength(children0, nr, 1, new_len)) + if (d_stringsEntail.stripSymbolicLength(children0, nr, 1, new_len)) { // For example: // z>=0 and z>str.len( x1 ) and str.contains( x2, y )-->true @@ -2509,7 +2518,7 @@ Node SequencesRewriter::rewriteIndexof(Node node) { Node new_len = node[2]; std::vector nr; - if (StringsEntail::stripSymbolicLength(children0, nr, 1, new_len)) + if (d_stringsEntail.stripSymbolicLength(children0, nr, 1, new_len)) { // Normalize the string before the start index. // @@ -2535,7 +2544,8 @@ Node SequencesRewriter::rewriteIndexof(Node node) { std::vector cb; std::vector ce; - if (StringsEntail::stripConstantEndpoints(children0, children1, cb, ce, -1)) + if (d_stringsEntail.stripConstantEndpoints( + children0, children1, cb, ce, -1)) { Node ret = utils::mkConcat(children0, stype); ret = nm->mkNode(STRING_INDEXOF, ret, node[1], node[2]); @@ -2558,7 +2568,7 @@ Node SequencesRewriter::rewriteIndexofRe(Node node) Node zero = nm->mkConst(Rational(0)); Node slen = nm->mkNode(STRING_LENGTH, s); - if (ArithEntail::check(zero, n, true) || ArithEntail::check(n, slen, true)) + if (d_arithEntail.check(zero, n, true) || d_arithEntail.check(n, slen, true)) { Node ret = nm->mkConst(Rational(-1)); return returnRewrite(node, ret, Rewrite::INDEXOF_RE_INVALID_INDEX); @@ -2589,7 +2599,7 @@ Node SequencesRewriter::rewriteIndexofRe(Node node) return returnRewrite(node, ret, Rewrite::INDEXOF_RE_EVAL); } - if (ArithEntail::check(n, zero) && ArithEntail::check(slen, n)) + if (d_arithEntail.check(n, zero) && d_arithEntail.check(slen, n)) { String emptyStr(""); if (RegExpEntail::testConstStringInRegExp(emptyStr, 0, r)) @@ -2662,14 +2672,14 @@ Node SequencesRewriter::rewriteReplace(Node node) // ( len( y )>=len(x) ) => str.replace( x, y, x ) ---> x Node l0 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[0]); Node l1 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[1]); - if (ArithEntail::check(l1, l0)) + if (d_arithEntail.check(l1, l0)) { return returnRewrite(node, node[0], Rewrite::RPL_RPL_LEN_ID); } // (str.replace x y x) ---> (str.replace x (str.++ y1 ... yn) x) // if 1 >= (str.len x) and (= y "") ---> (= y1 "") ... (= yn "") - if (StringsEntail::checkLengthOne(node[0])) + if (d_stringsEntail.checkLengthOne(node[0])) { Node empty = Word::mkEmptyWord(stype); Node rn1 = Rewriter::rewrite( @@ -2802,7 +2812,7 @@ Node SequencesRewriter::rewriteReplace(Node node) if (cmp_conr != cmp_con) { - if (StringsEntail::checkNonEmpty(node[1])) + if (d_stringsEntail.checkNonEmpty(node[1])) { // pull endpoints that can be stripped // for example, @@ -2810,7 +2820,7 @@ Node SequencesRewriter::rewriteReplace(Node node) // str.++( "b", str.replace( x, "a", y ), "b" ) std::vector cb; std::vector ce; - if (StringsEntail::stripConstantEndpoints(children0, children1, cb, ce)) + if (d_stringsEntail.stripConstantEndpoints(children0, children1, cb, ce)) { std::vector cc; cc.insert(cc.end(), cb.begin(), cb.end()); @@ -2851,7 +2861,7 @@ Node SequencesRewriter::rewriteReplace(Node node) Node len0 = nm->mkNode(kind::STRING_LENGTH, node[0]); Node len0_1 = nm->mkNode(kind::PLUS, len0, one); // Check len(t) + j > len(x) + 1 - if (ArithEntail::check(maxLen1, len0_1, true)) + if (d_arithEntail.check(maxLen1, len0_1, true)) { children1.push_back(nm->mkNode( kind::STRING_SUBSTR, @@ -2901,7 +2911,7 @@ Node SequencesRewriter::rewriteReplace(Node node) // (str.len w) >= (str.len z) Node wlen = nm->mkNode(kind::STRING_LENGTH, w); Node zlen = nm->mkNode(kind::STRING_LENGTH, z); - if (ArithEntail::check(wlen, zlen)) + if (d_arithEntail.check(wlen, zlen)) { // w != z Node wEqZ = Rewriter::rewrite(nm->mkNode(kind::EQUAL, w, z)); @@ -3037,7 +3047,7 @@ Node SequencesRewriter::rewriteReplace(Node node) // str.replace( x ++ y ++ x ++ y, "A", z ) --> // str.replace( x ++ y, "A", z ) ++ x ++ y // since if "A" occurs in x ++ y ++ x ++ y, then it must occur in x ++ y. - if (StringsEntail::checkLengthOne(node[1])) + if (d_stringsEntail.checkLengthOne(node[1])) { Node lastLhs; unsigned lastCheckIndex = 0; @@ -3153,7 +3163,7 @@ Node SequencesRewriter::rewriteReplaceInternal(Node node) if (node[0] == node[1]) { // only holds for replaceall if non-empty - if (nk == STRING_REPLACE || StringsEntail::checkNonEmpty(node[1])) + if (nk == STRING_REPLACE || d_stringsEntail.checkNonEmpty(node[1])) { return returnRewrite(node, node[2], Rewrite::RPL_REPLACE); } @@ -3379,7 +3389,7 @@ Node SequencesRewriter::rewritePrefixSuffix(Node n) // Check if we can turn the prefix/suffix into equalities by showing that the // prefix/suffix is at least as long as the string - Node eqs = StringsEntail::inferEqsFromContains(n[1], n[0]); + Node eqs = d_stringsEntail.inferEqsFromContains(n[1], n[0]); if (!eqs.isNull()) { return returnRewrite(n, eqs, Rewrite::SUF_PREFIX_TO_EQS); diff --git a/src/theory/strings/sequences_rewriter.h b/src/theory/strings/sequences_rewriter.h index 0068c72c1..854e3fb81 100644 --- a/src/theory/strings/sequences_rewriter.h +++ b/src/theory/strings/sequences_rewriter.h @@ -21,6 +21,7 @@ #include #include "expr/node.h" +#include "theory/strings/arith_entail.h" #include "theory/strings/rewrites.h" #include "theory/strings/sequences_stats.h" #include "theory/strings/strings_entail.h" @@ -33,7 +34,10 @@ namespace strings { class SequencesRewriter : public TheoryRewriter { public: - SequencesRewriter(HistogramStat* statistics); + SequencesRewriter(Rewriter* r, HistogramStat* statistics); + /** The underlying entailment utilities */ + ArithEntail& getArithEntail(); + StringsEntail& getStringsEntail(); protected: /** rewrite regular expression concatenation @@ -301,7 +305,8 @@ class SequencesRewriter : public TheoryRewriter Node postProcessRewrite(Node node, Node ret); /** Reference to the rewriter statistics. */ HistogramStat* d_statistics; - + /** The arithmetic entailment module */ + ArithEntail d_arithEntail; /** Instance of the entailment checker for strings. */ StringsEntail d_stringsEntail; }; /* class SequencesRewriter */ diff --git a/src/theory/strings/strings_entail.cpp b/src/theory/strings/strings_entail.cpp index 3c7800f8f..3b90338fc 100644 --- a/src/theory/strings/strings_entail.cpp +++ b/src/theory/strings/strings_entail.cpp @@ -30,7 +30,10 @@ namespace cvc5 { namespace theory { namespace strings { -StringsEntail::StringsEntail(SequencesRewriter& rewriter) : d_rewriter(rewriter) +StringsEntail::StringsEntail(Rewriter* r, + ArithEntail& aent, + SequencesRewriter& rewriter) + : d_rr(r), d_arithEntail(aent), d_rewriter(rewriter) { } @@ -61,7 +64,7 @@ bool StringsEntail::canConstantContainConcat(Node c, pos = new_pos + Word::getLength(n[i]); } } - else if (n[i].getKind() == STRING_ITOS && ArithEntail::check(n[i][0])) + else if (n[i].getKind() == STRING_ITOS && d_arithEntail.check(n[i][0])) { Assert(c.getType().isString()); // string-only const std::vector& tvec = c.getConst().getVec(); @@ -132,24 +135,24 @@ bool StringsEntail::stripSymbolicLength(std::vector& n1, if (n1[sindex_use].isConst()) { // could strip part of a constant - Node lowerBound = ArithEntail::getConstantBound(Rewriter::rewrite(curr)); + Node lowerBound = d_arithEntail.getConstantBound(d_rr->rewrite(curr)); if (!lowerBound.isNull()) { Assert(lowerBound.isConst()); Rational lbr = lowerBound.getConst(); if (lbr.sgn() > 0) { - Assert(ArithEntail::check(curr, true)); + Assert(d_arithEntail.check(curr, true)); Node s = n1[sindex_use]; size_t slen = Word::getLength(s); Node ncl = nm->mkConst(cvc5::Rational(slen)); Node next_s = nm->mkNode(MINUS, lowerBound, ncl); - next_s = Rewriter::rewrite(next_s); + next_s = d_rr->rewrite(next_s); Assert(next_s.isConst()); // we can remove the entire constant if (next_s.getConst().sgn() >= 0) { - curr = Rewriter::rewrite(nm->mkNode(MINUS, curr, ncl)); + curr = d_rr->rewrite(nm->mkNode(MINUS, curr, ncl)); success = true; sindex++; } @@ -159,7 +162,7 @@ bool StringsEntail::stripSymbolicLength(std::vector& n1, // lower bound minus the length of a concrete string is negative, // hence lowerBound cannot be larger than long max Assert(lbr < Rational(String::maxSize())); - curr = Rewriter::rewrite(nm->mkNode(MINUS, curr, lowerBound)); + curr = d_rr->rewrite(nm->mkNode(MINUS, curr, lowerBound)); uint32_t lbsize = lbr.getNumerator().toUnsignedInt(); Assert(lbsize < slen); if (dir == 1) @@ -176,7 +179,7 @@ bool StringsEntail::stripSymbolicLength(std::vector& n1, } ret = true; } - Assert(ArithEntail::check(curr)); + Assert(d_arithEntail.check(curr)); } else { @@ -190,8 +193,8 @@ bool StringsEntail::stripSymbolicLength(std::vector& n1, MINUS, curr, NodeManager::currentNM()->mkNode(STRING_LENGTH, n1[sindex_use])); - next_s = Rewriter::rewrite(next_s); - if (ArithEntail::check(next_s)) + next_s = d_rr->rewrite(next_s); + if (d_arithEntail.check(next_s)) { success = true; curr = next_s; @@ -251,7 +254,7 @@ int StringsEntail::componentContains(std::vector& n1, } else if (!n1re.isNull()) { - n1[i] = Rewriter::rewrite( + n1[i] = d_rr->rewrite( NodeManager::currentNM()->mkNode(STRING_CONCAT, n1[i], n1re)); } if (remainderDir != 1) @@ -265,7 +268,7 @@ int StringsEntail::componentContains(std::vector& n1, } else if (!n1rb.isNull()) { - n1[i] = Rewriter::rewrite( + n1[i] = d_rr->rewrite( NodeManager::currentNM()->mkNode(STRING_CONCAT, n1rb, n1[i])); } } @@ -432,7 +435,7 @@ bool StringsEntail::componentContainsBase( { // To be a suffix, start + length must be greater than // or equal to the length of the string. - success = ArithEntail::check(end_pos, len_n2s); + success = d_arithEntail.check(end_pos, len_n2s); } else if (dir == -1) { @@ -449,8 +452,8 @@ bool StringsEntail::componentContainsBase( { // we can only compute the remainder if start_pos and end_pos // are known to be non-negative. - if (!ArithEntail::check(start_pos) - || !ArithEntail::check(end_pos)) + if (!d_arithEntail.check(start_pos) + || !d_arithEntail.check(end_pos)) { return false; } @@ -679,7 +682,7 @@ Node StringsEntail::checkContains(Node a, Node b, bool fullRewriter) if (fullRewriter) { - ctn = Rewriter::rewrite(ctn); + ctn = d_rr->rewrite(ctn); } else { @@ -702,8 +705,8 @@ Node StringsEntail::checkContains(Node a, Node b, bool fullRewriter) bool StringsEntail::checkNonEmpty(Node a) { Node len = NodeManager::currentNM()->mkNode(STRING_LENGTH, a); - len = Rewriter::rewrite(len); - return ArithEntail::check(len, true); + len = d_rr->rewrite(len); + return d_arithEntail.check(len, true); } bool StringsEntail::checkLengthOne(Node s, bool strict) @@ -711,9 +714,9 @@ bool StringsEntail::checkLengthOne(Node s, bool strict) NodeManager* nm = NodeManager::currentNM(); Node one = nm->mkConst(Rational(1)); Node len = nm->mkNode(STRING_LENGTH, s); - len = Rewriter::rewrite(len); - return ArithEntail::check(one, len) - && (!strict || ArithEntail::check(len, true)); + len = d_rr->rewrite(len); + return d_arithEntail.check(one, len) + && (!strict || d_arithEntail.check(len, true)); } bool StringsEntail::checkMultisetSubset(Node a, Node b) @@ -877,7 +880,6 @@ Node StringsEntail::getStringOrEmpty(Node n) n = n[2]; break; } - if (checkLengthOne(n[0]) && Word::isEmpty(n[2])) { // (str.replace "A" x "") --> "A" @@ -945,7 +947,7 @@ Node StringsEntail::inferEqsFromContains(Node x, Node y) // str.len(yn) (where y = y1 ++ ... ++ yn) while keeping the inequality // true. The terms that can have length zero without making the inequality // false must be all be empty if (str.contains x y) is true. - if (!ArithEntail::inferZerosInSumGeq(xLen, yLens, zeroLens)) + if (!d_arithEntail.inferZerosInSumGeq(xLen, yLens, zeroLens)) { // We could not prove that the inequality holds return Node::null(); diff --git a/src/theory/strings/strings_entail.h b/src/theory/strings/strings_entail.h index 7547bf809..1ff65a5b4 100644 --- a/src/theory/strings/strings_entail.h +++ b/src/theory/strings/strings_entail.h @@ -21,9 +21,13 @@ #include #include "expr/node.h" +#include "theory/strings/arith_entail.h" namespace cvc5 { namespace theory { + +class Rewriter; + namespace strings { class SequencesRewriter; @@ -36,7 +40,7 @@ class SequencesRewriter; class StringsEntail { public: - StringsEntail(SequencesRewriter& rewriter); + StringsEntail(Rewriter* r, ArithEntail& aent, SequencesRewriter& rewriter); /** can constant contain list * return true if constant c can contain the list l in order @@ -64,7 +68,7 @@ class StringsEntail /** can constant contain concat * same as above but with n = str.++( l ) instead of l */ - static bool canConstantContainConcat(Node c, Node n, int& firstc, int& lastc); + bool canConstantContainConcat(Node c, Node n, int& firstc, int& lastc); /** strip symbolic length * @@ -106,11 +110,11 @@ class StringsEntail * nr is updated to { "abc", y } * curr is updated to str.len(y)+1 */ - static bool stripSymbolicLength(std::vector& n1, - std::vector& nr, - int dir, - Node& curr, - bool strict = false); + bool stripSymbolicLength(std::vector& n1, + std::vector& nr, + int dir, + Node& curr, + bool strict = false); /** component contains * This function is used when rewriting str.contains( t1, t2 ), where * n1 is the vector form of t1 @@ -222,7 +226,7 @@ class StringsEntail * Checks whether string a is entailed to be non-empty. Is equivalent to * the call checkArithEntail( len( a ), true ). */ - static bool checkNonEmpty(Node a); + bool checkNonEmpty(Node a); /** * Checks whether string has at most/exactly length one. Length one strings @@ -234,7 +238,7 @@ class StringsEntail * at most length one * @return True if the string has at most/exactly length one, false otherwise */ - static bool checkLengthOne(Node s, bool strict = false); + bool checkLengthOne(Node s, bool strict = false); /** * Checks whether it is always true that `a` is a strict subset of `b` in the @@ -282,7 +286,7 @@ class StringsEntail * getStringOrEmpty( (str.substr "ABC" x y) ) --> (str.substr "ABC" x y) * because the function could not compute a simpler */ - static Node getStringOrEmpty(Node n); + Node getStringOrEmpty(Node n); /** * Infers a conjunction of equalities that correspond to (str.contains x y) @@ -298,7 +302,7 @@ class StringsEntail * y) if the function can infer that str.len(y) >= str.len(x) but cannot * infer that any of the yi must be empty. */ - static Node inferEqsFromContains(Node x, Node y); + Node inferEqsFromContains(Node x, Node y); private: /** component contains base @@ -371,6 +375,10 @@ class StringsEntail static Node getMultisetApproximation(Node a); private: + /** Pointer to the full rewriter */ + Rewriter* d_rr; + /** The arithmetic entailment module */ + ArithEntail& d_arithEntail; /** * Reference to the sequences rewriter that owns this `StringsEntail` * instance. diff --git a/src/theory/strings/strings_rewriter.cpp b/src/theory/strings/strings_rewriter.cpp index b455d8a9b..9204bfab6 100644 --- a/src/theory/strings/strings_rewriter.cpp +++ b/src/theory/strings/strings_rewriter.cpp @@ -27,8 +27,9 @@ namespace cvc5 { namespace theory { namespace strings { -StringsRewriter::StringsRewriter(HistogramStat* statistics) - : SequencesRewriter(statistics) +StringsRewriter::StringsRewriter(Rewriter* r, + HistogramStat* statistics) + : SequencesRewriter(r, statistics) { } diff --git a/src/theory/strings/strings_rewriter.h b/src/theory/strings/strings_rewriter.h index 70a1cccf0..65c0b67ab 100644 --- a/src/theory/strings/strings_rewriter.h +++ b/src/theory/strings/strings_rewriter.h @@ -32,7 +32,7 @@ namespace strings { class StringsRewriter : public SequencesRewriter { public: - StringsRewriter(HistogramStat* statistics); + StringsRewriter(Rewriter* r, HistogramStat* statistics); RewriteResponse postRewrite(TNode node) override; diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 1b315447e..3eac3ca1a 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -60,7 +60,7 @@ TheoryStrings::TheoryStrings(Env& env, OutputChannel& out, Valuation valuation) d_extTheoryCb(), d_im(env, *this, d_state, d_termReg, d_extTheory, d_statistics, d_pnm), d_extTheory(d_extTheoryCb, context(), userContext(), d_im), - d_rewriter(&d_statistics.d_rewrites), + d_rewriter(env.getRewriter(), &d_statistics.d_rewrites), d_bsolver(env, d_state, d_im), d_csolver(env, d_state, d_im, d_termReg, d_bsolver), d_esolver(env, diff --git a/test/unit/preprocessing/pass_foreign_theory_rewrite_white.cpp b/test/unit/preprocessing/pass_foreign_theory_rewrite_white.cpp index 223cef13b..c10d8f363 100644 --- a/test/unit/preprocessing/pass_foreign_theory_rewrite_white.cpp +++ b/test/unit/preprocessing/pass_foreign_theory_rewrite_white.cpp @@ -31,19 +31,20 @@ class TestPPWhiteForeignTheoryRewrite : public TestSmt TEST_F(TestPPWhiteForeignTheoryRewrite, simplify) { + ForeignTheoryRewriter ftr(d_smtEngine->getEnv()); std::cout << "len(x) >= 0 is simplified to true" << std::endl; Node x = d_nodeManager->mkVar("x", d_nodeManager->stringType()); Node len_x = d_nodeManager->mkNode(kind::STRING_LENGTH, x); Node zero = d_nodeManager->mkConst(0); Node geq1 = d_nodeManager->mkNode(kind::GEQ, len_x, zero); Node tt = d_nodeManager->mkConst(true); - Node simplified1 = ForeignTheoryRewrite::foreignRewrite(geq1); + Node simplified1 = ftr.foreignRewrite(geq1); ASSERT_EQ(simplified1, tt); std::cout << "len(x) >= n is not simplified to true" << std::endl; Node n = d_nodeManager->mkVar("n", d_nodeManager->integerType()); Node geq2 = d_nodeManager->mkNode(kind::GEQ, len_x, n); - Node simplified2 = ForeignTheoryRewrite::foreignRewrite(geq2); + Node simplified2 = ftr.foreignRewrite(geq2); ASSERT_NE(simplified2, tt); } diff --git a/test/unit/theory/sequences_rewriter_white.cpp b/test/unit/theory/sequences_rewriter_white.cpp index b7339942e..99454a014 100644 --- a/test/unit/theory/sequences_rewriter_white.cpp +++ b/test/unit/theory/sequences_rewriter_white.cpp @@ -43,9 +43,11 @@ class TestTheoryWhiteSequencesRewriter : public TestSmt TestSmt::SetUp(); Options opts; d_rewriter = d_smtEngine->getRewriter(); + d_seqRewriter.reset(new SequencesRewriter(d_rewriter, nullptr)); } Rewriter* d_rewriter; + std::unique_ptr d_seqRewriter; void inNormalForm(Node t) { @@ -81,6 +83,7 @@ class TestTheoryWhiteSequencesRewriter : public TestSmt TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_length_one) { + StringsEntail& se = d_seqRewriter->getStringsEntail(); TypeNode intType = d_nodeManager->integerType(); TypeNode strType = d_nodeManager->stringType(); @@ -97,28 +100,29 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_length_one) Node three = d_nodeManager->mkConst(Rational(3)); Node i = d_nodeManager->mkVar("i", intType); - ASSERT_TRUE(StringsEntail::checkLengthOne(a)); - ASSERT_TRUE(StringsEntail::checkLengthOne(a, true)); + ASSERT_TRUE(se.checkLengthOne(a)); + ASSERT_TRUE(se.checkLengthOne(a, true)); Node substr = d_nodeManager->mkNode(kind::STRING_SUBSTR, x, zero, one); - ASSERT_TRUE(StringsEntail::checkLengthOne(substr)); - ASSERT_FALSE(StringsEntail::checkLengthOne(substr, true)); + ASSERT_TRUE(se.checkLengthOne(substr)); + ASSERT_FALSE(se.checkLengthOne(substr, true)); substr = d_nodeManager->mkNode(kind::STRING_SUBSTR, d_nodeManager->mkNode(kind::STRING_CONCAT, a, x), zero, one); - ASSERT_TRUE(StringsEntail::checkLengthOne(substr)); - ASSERT_TRUE(StringsEntail::checkLengthOne(substr, true)); + ASSERT_TRUE(se.checkLengthOne(substr)); + ASSERT_TRUE(se.checkLengthOne(substr, true)); substr = d_nodeManager->mkNode(kind::STRING_SUBSTR, x, zero, two); - ASSERT_FALSE(StringsEntail::checkLengthOne(substr)); - ASSERT_FALSE(StringsEntail::checkLengthOne(substr, true)); + ASSERT_FALSE(se.checkLengthOne(substr)); + ASSERT_FALSE(se.checkLengthOne(substr, true)); } TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_arith) { + ArithEntail& ae = d_seqRewriter->getArithEntail(); TypeNode intType = d_nodeManager->integerType(); TypeNode strType = d_nodeManager->stringType(); @@ -130,14 +134,15 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_arith) Node substr_z = d_nodeManager->mkNode( kind::STRING_LENGTH, d_nodeManager->mkNode(kind::STRING_SUBSTR, z, n, one)); - ASSERT_TRUE(ArithEntail::check(one, substr_z)); + ASSERT_TRUE(ae.check(one, substr_z)); // (str.len (str.substr z n 1)) >= 1 ---> false - ASSERT_FALSE(ArithEntail::check(substr_z, one)); + ASSERT_FALSE(ae.check(substr_z, one)); } TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) { + ArithEntail& ae = d_seqRewriter->getArithEntail(); TypeNode intType = d_nodeManager->integerType(); TypeNode strType = d_nodeManager->stringType(); @@ -157,19 +162,17 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) d_nodeManager->mkNode(kind::EQUAL, x_plus_slen_y, zero)); // x + (str.len y) = 0 |= 0 >= x --> true - ASSERT_TRUE( - ArithEntail::checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, false)); + ASSERT_TRUE(ae.checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, false)); // x + (str.len y) = 0 |= 0 > x --> false - ASSERT_FALSE( - ArithEntail::checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, true)); + ASSERT_FALSE(ae.checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, true)); Node x_plus_slen_y_plus_z_eq_zero = d_rewriter->rewrite(d_nodeManager->mkNode( kind::EQUAL, d_nodeManager->mkNode(kind::PLUS, x_plus_slen_y, z), zero)); // x + (str.len y) + z = 0 |= 0 > x --> false - ASSERT_FALSE(ArithEntail::checkWithAssumption( - x_plus_slen_y_plus_z_eq_zero, zero, x, true)); + ASSERT_FALSE( + ae.checkWithAssumption(x_plus_slen_y_plus_z_eq_zero, zero, x, true)); Node x_plus_slen_y_plus_slen_y_eq_zero = d_rewriter->rewrite(d_nodeManager->mkNode( @@ -178,7 +181,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) zero)); // x + (str.len y) + (str.len y) = 0 |= 0 >= x --> true - ASSERT_TRUE(ArithEntail::checkWithAssumption( + ASSERT_TRUE(ae.checkWithAssumption( x_plus_slen_y_plus_slen_y_eq_zero, zero, x, false)); Node five = d_nodeManager->mkConst(Rational(5)); @@ -188,28 +191,24 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, six)); // x + 5 < 6 |= 0 >= x --> true - ASSERT_TRUE( - ArithEntail::checkWithAssumption(x_plus_five_lt_six, zero, x, false)); + ASSERT_TRUE(ae.checkWithAssumption(x_plus_five_lt_six, zero, x, false)); // x + 5 < 6 |= 0 > x --> false - ASSERT_TRUE( - !ArithEntail::checkWithAssumption(x_plus_five_lt_six, zero, x, true)); + ASSERT_TRUE(!ae.checkWithAssumption(x_plus_five_lt_six, zero, x, true)); Node neg_x = d_nodeManager->mkNode(kind::UMINUS, x); Node x_plus_five_lt_five = d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, five)); // x + 5 < 5 |= -x >= 0 --> true - ASSERT_TRUE(ArithEntail::checkWithAssumption( - x_plus_five_lt_five, neg_x, zero, false)); + ASSERT_TRUE(ae.checkWithAssumption(x_plus_five_lt_five, neg_x, zero, false)); // x + 5 < 5 |= 0 > x --> true - ASSERT_TRUE( - ArithEntail::checkWithAssumption(x_plus_five_lt_five, zero, x, false)); + ASSERT_TRUE(ae.checkWithAssumption(x_plus_five_lt_five, zero, x, false)); // 0 < x |= x >= (str.len (int.to.str x)) Node assm = d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, zero, x)); - ASSERT_TRUE(ArithEntail::checkWithAssumption( + ASSERT_TRUE(ae.checkWithAssumption( assm, x, d_nodeManager->mkNode(kind::STRING_LENGTH, @@ -219,6 +218,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption) TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr) { + StringsRewriter sr(d_rewriter, nullptr); TypeNode intType = d_nodeManager->integerType(); TypeNode strType = d_nodeManager->stringType(); @@ -239,7 +239,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr) // (str.substr "A" x x) --> "" Node n = d_nodeManager->mkNode(kind::STRING_SUBSTR, a, x, x); - Node res = StringsRewriter(nullptr).rewriteSubstr(n); + Node res = sr.rewriteSubstr(n); ASSERT_EQ(res, empty); // (str.substr "A" (+ x 1) x) -> "" @@ -248,7 +248,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr) a, d_nodeManager->mkNode(kind::PLUS, x, d_nodeManager->mkConst(Rational(1))), x); - res = StringsRewriter(nullptr).rewriteSubstr(n); + res = sr.rewriteSubstr(n); ASSERT_EQ(res, empty); // (str.substr "A" (+ x (str.len s2)) x) -> "" @@ -258,12 +258,12 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr) d_nodeManager->mkNode( kind::PLUS, x, d_nodeManager->mkNode(kind::STRING_LENGTH, s)), x); - res = StringsRewriter(nullptr).rewriteSubstr(n); + res = sr.rewriteSubstr(n); ASSERT_EQ(res, empty); // (str.substr "A" x y) -> (str.substr "A" x y) n = d_nodeManager->mkNode(kind::STRING_SUBSTR, a, x, y); - res = StringsRewriter(nullptr).rewriteSubstr(n); + res = sr.rewriteSubstr(n); ASSERT_EQ(res, n); // (str.substr "ABCD" (+ x 3) x) -> "" @@ -271,13 +271,13 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr) abcd, d_nodeManager->mkNode(kind::PLUS, x, three), x); - res = StringsRewriter(nullptr).rewriteSubstr(n); + res = sr.rewriteSubstr(n); ASSERT_EQ(res, empty); // (str.substr "ABCD" (+ x 2) x) -> (str.substr "ABCD" (+ x 2) x) n = d_nodeManager->mkNode( kind::STRING_SUBSTR, abcd, d_nodeManager->mkNode(kind::PLUS, x, two), x); - res = StringsRewriter(nullptr).rewriteSubstr(n); + res = sr.rewriteSubstr(n); ASSERT_EQ(res, n); // (str.substr (str.substr s x x) x x) -> "" @@ -1303,6 +1303,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_contains) TEST_F(TestTheoryWhiteSequencesRewriter, infer_eqs_from_contains) { + StringsEntail& se = d_seqRewriter->getStringsEntail(); TypeNode strType = d_nodeManager->stringType(); Node empty = d_nodeManager->mkConst(::cvc5::String("")); @@ -1319,30 +1320,30 @@ TEST_F(TestTheoryWhiteSequencesRewriter, infer_eqs_from_contains) d_nodeManager->mkNode(kind::AND, d_nodeManager->mkNode(kind::EQUAL, empty, x), d_nodeManager->mkNode(kind::EQUAL, empty, y)); - sameNormalForm(StringsEntail::inferEqsFromContains(empty, xy), empty_x_y); + sameNormalForm(se.inferEqsFromContains(empty, xy), empty_x_y); // inferEqsFromContains(x, (str.++ x y)) returns false Node bxya = d_nodeManager->mkNode(kind::STRING_CONCAT, {b, y, x, a}); - sameNormalForm(StringsEntail::inferEqsFromContains(x, bxya), f); + sameNormalForm(se.inferEqsFromContains(x, bxya), f); // inferEqsFromContains(x, y) returns null - Node n = StringsEntail::inferEqsFromContains(x, y); + Node n = se.inferEqsFromContains(x, y); ASSERT_TRUE(n.isNull()); // inferEqsFromContains(x, x) returns something equivalent to (= x x) Node eq_x_x = d_nodeManager->mkNode(kind::EQUAL, x, x); - sameNormalForm(StringsEntail::inferEqsFromContains(x, x), eq_x_x); + sameNormalForm(se.inferEqsFromContains(x, x), eq_x_x); // inferEqsFromContains((str.replace x "B" "A"), x) returns something // equivalent to (= (str.replace x "B" "A") x) Node repl = d_nodeManager->mkNode(kind::STRING_REPLACE, x, b, a); Node eq_repl_x = d_nodeManager->mkNode(kind::EQUAL, repl, x); - sameNormalForm(StringsEntail::inferEqsFromContains(repl, x), eq_repl_x); + sameNormalForm(se.inferEqsFromContains(repl, x), eq_repl_x); // inferEqsFromContains(x, (str.replace x "B" "A")) returns something // equivalent to (= (str.replace x "B" "A") x) Node eq_x_repl = d_nodeManager->mkNode(kind::EQUAL, x, repl); - sameNormalForm(StringsEntail::inferEqsFromContains(x, repl), eq_x_repl); + sameNormalForm(se.inferEqsFromContains(x, repl), eq_x_repl); } TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_prefix_suffix) @@ -1672,6 +1673,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_equality_ext) TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints) { + StringsEntail& se = d_seqRewriter->getStringsEntail(); TypeNode intType = d_nodeManager->integerType(); TypeNode strType = d_nodeManager->stringType(); @@ -1693,7 +1695,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints) std::vector n2 = {a}; std::vector nb; std::vector ne; - bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 0); + bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 0); ASSERT_FALSE(res); } @@ -1704,7 +1706,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints) std::vector n2 = {a, d_nodeManager->mkNode(kind::STRING_ITOS, n)}; std::vector nb; std::vector ne; - bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 0); + bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 0); ASSERT_FALSE(res); } @@ -1719,7 +1721,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints) std::vector ne; std::vector n1r = {cd}; std::vector nbr = {ab}; - bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 1); + bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 1); ASSERT_TRUE(res); ASSERT_EQ(n1, n1r); ASSERT_EQ(nb, nbr); @@ -1736,7 +1738,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints) std::vector ne; std::vector n1r = {c, x}; std::vector nbr = {ab}; - bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 1); + bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 1); ASSERT_TRUE(res); ASSERT_EQ(n1, n1r); ASSERT_EQ(nb, nbr); @@ -1753,7 +1755,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints) std::vector ne; std::vector n1r = {a}; std::vector ner = {bc}; - bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, -1); + bool res = se.stripConstantEndpoints(n1, n2, nb, ne, -1); ASSERT_TRUE(res); ASSERT_EQ(n1, n1r); ASSERT_EQ(ne, ner); @@ -1770,7 +1772,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints) std::vector ne; std::vector n1r = {x, a}; std::vector ner = {bc}; - bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, -1); + bool res = se.stripConstantEndpoints(n1, n2, nb, ne, -1); ASSERT_TRUE(res); ASSERT_EQ(n1, n1r); ASSERT_EQ(ne, ner); -- 2.30.2