From: Andrew Reynolds Date: Tue, 30 Nov 2021 20:30:45 +0000 (-0600) Subject: Generalize eager length bound conflicts for regular expression memberships (#7633) X-Git-Tag: cvc5-1.0.0~756 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=23f134fa7309621be513ca3c728b7a3c03473a45;p=cvc5.git Generalize eager length bound conflicts for regular expression memberships (#7633) This generalizes eager length bound conflicts to take into account regular expression memberships. For example: If `(str.in_re x (re.++ (re.* re.allchar) (str.to_re "abc") (re.* re.allchar))` is asserted, then we know `(str.len x) >= 3`. If e.g. equivalence class of `x` is merged with `(str.substr y 0 2)`, we get the conflict `(and (str.in_re x (re.++ (re.* re.allchar) (str.to_re "abc") (re.* re.allchar)) (= x (str.substr y 0 2))` since `(str.len (str.substr y 0 2)) <= 2`. This also does some minor refactoring to eager prefix conflicts to make it more analogous to our implementation of length conflicts. --- diff --git a/src/options/strings_options.toml b/src/options/strings_options.toml index caebcec3a..01b92c1b7 100644 --- a/src/options/strings_options.toml +++ b/src/options/strings_options.toml @@ -199,3 +199,11 @@ name = "Strings Theory" type = "bool" default = "false" help = "use extensionality for string disequalities" + +[[option]] + name = "stringsEagerLenEntRegexp" + category = "regular" + long = "strings-eager-len-re" + type = "bool" + default = "false" + help = "use regular expressions for eager length conflicts" diff --git a/src/theory/strings/arith_entail.cpp b/src/theory/strings/arith_entail.cpp index d8a4e85aa..8eaa50354 100644 --- a/src/theory/strings/arith_entail.cpp +++ b/src/theory/strings/arith_entail.cpp @@ -728,7 +728,7 @@ struct ArithEntailConstantBoundUpperId typedef expr::Attribute ArithEntailConstantBoundUpper; -void ArithEntail::setConstantBoundCache(Node n, Node ret, bool isLower) +void ArithEntail::setConstantBoundCache(TNode n, Node ret, bool isLower) { if (isLower) { @@ -742,14 +742,15 @@ void ArithEntail::setConstantBoundCache(Node n, Node ret, bool isLower) } } -Node ArithEntail::getConstantBoundCache(Node n, bool isLower) +bool ArithEntail::getConstantBoundCache(TNode n, bool isLower, Node& c) { if (isLower) { ArithEntailConstantBoundLower acbl; if (n.hasAttribute(acbl)) { - return n.getAttribute(acbl); + c = n.getAttribute(acbl); + return true; } } else @@ -757,17 +758,18 @@ Node ArithEntail::getConstantBoundCache(Node n, bool isLower) ArithEntailConstantBoundUpper acbu; if (n.hasAttribute(acbu)) { - return n.getAttribute(acbu); + c = n.getAttribute(acbu); + return true; } } - return Node::null(); + return false; } -Node ArithEntail::getConstantBound(Node a, bool isLower) +Node ArithEntail::getConstantBound(TNode a, bool isLower) { Assert(d_rr->rewrite(a) == a); - Node ret = getConstantBoundCache(a, isLower); - if (!ret.isNull()) + Node ret; + if (getConstantBoundCache(a, isLower, ret)) { return ret; } @@ -852,18 +854,19 @@ Node ArithEntail::getConstantBound(Node a, bool isLower) return ret; } -Node ArithEntail::getConstantBoundLength(Node s, bool isLower) +Node ArithEntail::getConstantBoundLength(TNode s, bool isLower) const { Assert(s.getType().isStringLike()); - Node ret = getConstantBoundCache(s, isLower); - if (!ret.isNull()) + Node ret; + if (getConstantBoundCache(s, isLower, ret)) { return ret; } NodeManager* nm = NodeManager::currentNM(); if (s.isConst()) { - ret = nm->mkConst(CONST_RATIONAL, Rational(Word::getLength(s))); + size_t len = Word::getLength(s); + ret = nm->mkConst(CONST_RATIONAL, Rational(len)); } else if (s.getKind() == STRING_CONCAT) { @@ -885,12 +888,12 @@ Node ArithEntail::getConstantBoundLength(Node s, bool isLower) Assert(b.getKind() == CONST_RATIONAL); sum = sum + b.getConst(); } - if (success) + if (success && (!isLower || sum.sgn() != 0)) { ret = nm->mkConst(CONST_RATIONAL, sum); } } - else if (isLower) + if (ret.isNull() && isLower) { ret = d_zero; } diff --git a/src/theory/strings/arith_entail.h b/src/theory/strings/arith_entail.h index 6529a81d1..e35dd3560 100644 --- a/src/theory/strings/arith_entail.h +++ b/src/theory/strings/arith_entail.h @@ -141,12 +141,13 @@ class ArithEntail * if and only if * check( a, strict ) = true. */ - Node getConstantBound(Node a, bool isLower = true); + Node getConstantBound(TNode a, bool isLower = true); /** - * get constant bound on the length of s. + * Get constant bound on the length of s, if it can be determined. This + * method will always worst case return 0 as a lower bound. */ - Node getConstantBoundLength(Node s, bool isLower = true); + Node getConstantBoundLength(TNode s, bool isLower = true) const; /** * Given an inequality y1 + ... + yn >= x, removes operands yi s.t. the * original inequality still holds. Returns true if the original inequality @@ -192,9 +193,12 @@ class ArithEntail std::vector& approx, bool isOverApprox = false); /** Set bound cache */ - void setConstantBoundCache(Node n, Node ret, bool isLower); - /** Get bound cache */ - Node getConstantBoundCache(Node n, bool isLower); + static void setConstantBoundCache(TNode n, Node ret, bool isLower); + /** + * Get bound cache, store in c and return true if the bound for n has been + * computed. Used for getConstantBound and getConstantBoundLength. + */ + static bool getConstantBoundCache(TNode n, bool isLower, Node& c); /** The underlying rewriter */ Rewriter* d_rr; /** Constant zero */ diff --git a/src/theory/strings/eager_solver.cpp b/src/theory/strings/eager_solver.cpp index ce47ff4fc..ac8e815df 100644 --- a/src/theory/strings/eager_solver.cpp +++ b/src/theory/strings/eager_solver.cpp @@ -15,6 +15,7 @@ #include "theory/strings/eager_solver.h" +#include "options/strings_options.h" #include "theory/strings/theory_strings_utils.h" #include "util/rational.h" @@ -24,11 +25,12 @@ namespace cvc5 { namespace theory { namespace strings { -EagerSolver::EagerSolver(Env& env, - SolverState& state, - TermRegistry& treg, - ArithEntail& aent) - : EnvObj(env), d_state(state), d_treg(treg), d_aent(aent) +EagerSolver::EagerSolver(Env& env, SolverState& state, TermRegistry& treg) + : EnvObj(env), + d_state(state), + d_treg(treg), + d_aent(env.getRewriter()), + d_rent(env.getRewriter()) { } @@ -84,26 +86,21 @@ void EagerSolver::eqNotifyMerge(EqcInfo* e1, TNode t1, EqcInfo* e2, TNode t2) Assert(e1 != nullptr); Assert(e2 != nullptr); // check for conflict - Node conf = checkForMergeConflict(t1, t2, e1, e2); - if (!conf.isNull()) + if (checkForMergeConflict(t1, t2, e1, e2)) { - InferenceId id = t1.getType().isStringLike() - ? InferenceId::STRINGS_PREFIX_CONFLICT - : InferenceId::STRINGS_ARITH_BOUND_CONFLICT; - d_state.setPendingMergeConflict(conf, id); return; } } -void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc) +bool EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc) { Assert(concat.getKind() == STRING_CONCAT || concat.getKind() == REGEXP_CONCAT); EqcInfo* ei = nullptr; // check each side - for (unsigned r = 0; r < 2; r++) + for (size_t r = 0; r < 2; r++) { - unsigned index = r == 0 ? 0 : concat.getNumChildren() - 1; + size_t index = r == 0 ? 0 : concat.getNumChildren() - 1; Node c = utils::getConstantComponent(concat[index]); if (!c.isNull()) { @@ -114,18 +111,16 @@ void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc) Trace("strings-eager-pconf-debug") << "New term: " << concat << " for " << t << " with prefix " << c << " (" << (r == 1) << ")" << std::endl; - Node conf = ei->addEndpointConst(t, c, r == 1); - if (!conf.isNull()) + if (addEndpointConst(ei, t, c, r == 1)) { - d_state.setPendingMergeConflict(conf, - InferenceId::STRINGS_PREFIX_CONFLICT); - return; + return true; } } } + return false; } -Node EagerSolver::checkForMergeConflict(Node a, +bool EagerSolver::checkForMergeConflict(Node a, Node b, EqcInfo* ea, EqcInfo* eb) @@ -139,25 +134,25 @@ Node EagerSolver::checkForMergeConflict(Node a, Node n = i == 0 ? eb->d_firstBound.get() : eb->d_secondBound.get(); if (!n.isNull()) { - Node conf; + bool isConflict; if (a.getType().isStringLike()) { - conf = ea->addEndpointConst(n, Node::null(), i == 1); + isConflict = addEndpointConst(ea, n, Node::null(), i == 1); } else { Trace("strings-eager-aconf-debug") << "addArithmeticBound " << n << " into " << a << " from " << b << std::endl; - conf = addArithmeticBound(ea, n, i == 0); + isConflict = addArithmeticBound(ea, n, i == 0); } - if (!conf.isNull()) + if (isConflict) { - return conf; + return true; } } } - return Node::null(); + return false; } void EagerSolver::notifyFact(TNode atom, @@ -171,12 +166,63 @@ void EagerSolver::notifyFact(TNode atom, { eq::EqualityEngine* ee = d_state.getEqualityEngine(); Node eqc = ee->getRepresentative(atom[0]); - addEndpointsToEqcInfo(atom, atom[1], eqc); + // add prefix constraints + if (addEndpointsToEqcInfo(atom, atom[1], eqc)) + { + // conflict, we are done + return; + } + else if (!options().strings.stringsEagerLenEntRegexp) + { + // do not infer length constraints if option is disabled + return; + } + // also infer length constraints if the first is a variable + if (atom[0].isVar()) + { + EqcInfo* blenEqc = nullptr; + for (size_t i = 0; i < 2; i++) + { + bool isLower = (i == 0); + Node b = d_rent.getConstantBoundLengthForRegexp(atom[1], isLower); + if (!b.isNull()) + { + if (blenEqc == nullptr) + { + Node lenTerm = + NodeManager::currentNM()->mkNode(STRING_LENGTH, atom[0]); + if (!ee->hasTerm(lenTerm)) + { + break; + } + lenTerm = ee->getRepresentative(lenTerm); + blenEqc = d_state.getOrMakeEqcInfo(lenTerm); + } + if (addArithmeticBound(blenEqc, atom, isLower)) + { + return; + } + } + } + } } } } -Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower) +bool EagerSolver::addEndpointConst(EqcInfo* e, Node t, Node c, bool isSuf) +{ + Assert(e != nullptr); + Assert(!t.isNull()); + Node conf = e->addEndpointConst(t, c, isSuf); + if (!conf.isNull()) + { + d_state.setPendingMergeConflict(conf, InferenceId::STRINGS_PREFIX_CONFLICT); + return true; + } + return false; +} + +bool EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower) { Assert(e != nullptr); Assert(!t.isNull()); @@ -195,7 +241,7 @@ Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower) if (prevbr == br || (br < prevbr) == isLower) { // subsumed - return Node::null(); + return false; } } Node prevo = isLower ? e->d_secondBound : e->d_firstBound; @@ -210,10 +256,12 @@ Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower) if (prevobr != br && (prevobr < br) == isLower) { // conflict - Node ret = EqcInfo::mkMergeConflict(t, prevo); + Node ret = EqcInfo::mkMergeConflict(t, prevo, true); Trace("strings-eager-aconf") << "String: eager arithmetic bound conflict: " << ret << std::endl; - return ret; + d_state.setPendingMergeConflict( + ret, InferenceId::STRINGS_ARITH_BOUND_CONFLICT); + return true; } } if (isLower) @@ -224,23 +272,27 @@ Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower) { e->d_secondBound = t; } - return Node::null(); + return false; } -Node EagerSolver::getBoundForLength(Node len, bool isLower) +Node EagerSolver::getBoundForLength(Node t, bool isLower) const { - Assert(len.getKind() == STRING_LENGTH); + if (t.getKind() == STRING_IN_REGEXP) + { + return d_rent.getConstantBoundLengthForRegexp(t[1]); + } + Assert(t.getKind() == STRING_LENGTH); // it is prohibitively expensive to convert to original form and rewrite, // since this may invoke the rewriter on lengths of complex terms. Instead, // we convert to original term the argument, then call the utility method // for computing the length of the argument, implicitly under an application // of length (ArithEntail::getConstantBoundLength). // convert to original form - Node olent = SkolemManager::getOriginalForm(len[0]); + Node olent = SkolemManager::getOriginalForm(t[0]); // get the bound Node c = d_aent.getConstantBoundLength(olent, isLower); Trace("strings-eager-aconf-debug") - << "Constant " << (isLower ? "lower" : "upper") << " bound for " << len + << "Constant " << (isLower ? "lower" : "upper") << " bound for " << t << " is " << c << ", from original form " << olent << std::endl; return c; } diff --git a/src/theory/strings/eager_solver.h b/src/theory/strings/eager_solver.h index 4181a15c3..a4898f52a 100644 --- a/src/theory/strings/eager_solver.h +++ b/src/theory/strings/eager_solver.h @@ -24,6 +24,7 @@ #include "smt/env_obj.h" #include "theory/strings/arith_entail.h" #include "theory/strings/eqc_info.h" +#include "theory/strings/regexp_entail.h" #include "theory/strings/solver_state.h" #include "theory/strings/term_registry.h" @@ -38,10 +39,7 @@ namespace strings { class EagerSolver : protected EnvObj { public: - EagerSolver(Env& env, - SolverState& state, - TermRegistry& treg, - ArithEntail& aent); + EagerSolver(Env& env, SolverState& state, TermRegistry& treg); ~EagerSolver(); /** called when a new equivalence class is created */ void eqNotifyNewClass(TNode t); @@ -60,23 +58,30 @@ class EagerSolver : protected EnvObj * for some eqc that is currently equal to t. Another example is: * t := (str.in.re z (re.++ r s)), concat := (re.++ r s), eqc * for some eqc that is currently equal to z. + * + * Returns true if we are in conflict, that is, a conflict was sent via the + * inference manager. */ - void addEndpointsToEqcInfo(Node t, Node concat, Node eqc); + bool addEndpointsToEqcInfo(Node t, Node concat, Node eqc); /** * Check for conflict when merging equivalence classes with the given info, - * return the node corresponding to the conflict if so. + * return true if we are in conflict. */ - Node checkForMergeConflict(Node a, Node b, EqcInfo* ea, EqcInfo* eb); - /** add arithmetic bound */ - Node addArithmeticBound(EqcInfo* ea, Node t, bool isLower); - /** get bound for length term */ - Node getBoundForLength(Node len, bool isLower); + bool checkForMergeConflict(Node a, Node b, EqcInfo* ea, EqcInfo* eb); + /** add endpoint constant, return true if in conflict */ + bool addEndpointConst(EqcInfo* e, Node t, Node c, bool isSuf); + /** add arithmetic bound, return true if in conflict */ + bool addArithmeticBound(EqcInfo* e, Node t, bool isLower); + /** get bound for length term or regular expression membership */ + Node getBoundForLength(Node t, bool isLower) const; /** Reference to the solver state */ SolverState& d_state; /** Reference to the term registry */ TermRegistry& d_treg; /** Arithmetic entailment */ - ArithEntail& d_aent; + ArithEntail d_aent; + /** Regular expression entailment */ + RegExpEntail d_rent; }; } // namespace strings diff --git a/src/theory/strings/eqc_info.cpp b/src/theory/strings/eqc_info.cpp index 5fb5e91c3..71f1f8b55 100644 --- a/src/theory/strings/eqc_info.cpp +++ b/src/theory/strings/eqc_info.cpp @@ -100,7 +100,7 @@ Node EqcInfo::addEndpointConst(Node t, Node c, bool isSuf) { Trace("strings-eager-pconf") << "Conflict for " << prevC << ", " << c << std::endl; - Node ret = mkMergeConflict(t, prev); + Node ret = mkMergeConflict(t, prev, false); Trace("strings-eager-pconf") << "String: eager prefix conflict: " << ret << std::endl; return ret; @@ -117,8 +117,11 @@ Node EqcInfo::addEndpointConst(Node t, Node c, bool isSuf) return Node::null(); } -Node EqcInfo::mkMergeConflict(Node t, Node prev) +Node EqcInfo::mkMergeConflict(Node t, Node prev, bool isArith) { + Trace("strings-eager-debug") + << "mkMergeConflict " << t << ", " << prev << std::endl; + NodeManager* nm = NodeManager::currentNM(); std::vector ccs; Node r[2]; for (unsigned i = 0; i < 2; i++) @@ -127,7 +130,7 @@ Node EqcInfo::mkMergeConflict(Node t, Node prev) if (tp.getKind() == STRING_IN_REGEXP) { ccs.push_back(tp); - r[i] = tp[0]; + r[i] = isArith ? nm->mkNode(STRING_LENGTH, tp[0]) : tp[0]; } else { @@ -139,7 +142,7 @@ Node EqcInfo::mkMergeConflict(Node t, Node prev) ccs.push_back(r[0].eqNode(r[1])); } Assert(!ccs.empty()); - return NodeManager::currentNM()->mkAnd(ccs); + return nm->mkAnd(ccs); } } // namespace strings diff --git a/src/theory/strings/eqc_info.h b/src/theory/strings/eqc_info.h index bfc753989..ff3213a8e 100644 --- a/src/theory/strings/eqc_info.h +++ b/src/theory/strings/eqc_info.h @@ -88,8 +88,14 @@ class EqcInfo * the conflict: * (and (= x (str.++ "B" y)) (str.in_re x (re.++ (str.to_re "A") R2))) * for this input. + * + * @param t The first bound term + * @param prev The second bound term + * @param isArith Whether this is an arithmetic conflict. This impacts + * whether (str.in_re x R) is processed as x or (str.len x). + * @return The node corresponding to the conflict. */ - static Node mkMergeConflict(Node t, Node prev); + static Node mkMergeConflict(Node t, Node prev, bool isArith); }; } // namespace strings diff --git a/src/theory/strings/regexp_entail.cpp b/src/theory/strings/regexp_entail.cpp index 49992f14f..645dc05cd 100644 --- a/src/theory/strings/regexp_entail.cpp +++ b/src/theory/strings/regexp_entail.cpp @@ -28,6 +28,12 @@ namespace cvc5 { namespace theory { namespace strings { +RegExpEntail::RegExpEntail(Rewriter* r) : d_rewriter(r), d_aent(r) +{ + d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0)); + d_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(1)); +} + Node RegExpEntail::simpleRegexpConsume(std::vector& mchildren, std::vector& children, int dir) @@ -647,10 +653,11 @@ bool RegExpEntail::hasEpsilonNode(TNode node) return false; } -Node RegExpEntail::getFixedLengthForRegexp(Node n) +Node RegExpEntail::getFixedLengthForRegexp(TNode n) { NodeManager* nm = NodeManager::currentNM(); - if (n.getKind() == STRING_TO_REGEXP) + Kind k = n.getKind(); + if (k == STRING_TO_REGEXP) { Node ret = nm->mkNode(STRING_LENGTH, n[0]); ret = Rewriter::rewrite(ret); @@ -659,11 +666,11 @@ Node RegExpEntail::getFixedLengthForRegexp(Node n) return ret; } } - else if (n.getKind() == REGEXP_ALLCHAR || n.getKind() == REGEXP_RANGE) + else if (k == REGEXP_ALLCHAR || k == REGEXP_RANGE) { return nm->mkConst(CONST_RATIONAL, Rational(1)); } - else if (n.getKind() == REGEXP_UNION || n.getKind() == REGEXP_INTER) + else if (k == REGEXP_UNION || k == REGEXP_INTER) { Node ret; for (const Node& nc : n) @@ -681,7 +688,7 @@ Node RegExpEntail::getFixedLengthForRegexp(Node n) } return ret; } - else if (n.getKind() == REGEXP_CONCAT) + else if (k == REGEXP_CONCAT) { NodeBuilder nb(PLUS); for (const Node& nc : n) @@ -700,6 +707,82 @@ Node RegExpEntail::getFixedLengthForRegexp(Node n) return Node::null(); } +Node RegExpEntail::getConstantBoundLengthForRegexp(TNode n, bool isLower) const +{ + Assert(n.getType().isRegExp()); + Node ret; + if (getConstantBoundCache(n, isLower, ret)) + { + return ret; + } + Kind k = n.getKind(); + NodeManager* nm = NodeManager::currentNM(); + if (k == STRING_TO_REGEXP) + { + ret = d_aent.getConstantBoundLength(n[0], isLower); + } + else if (k == REGEXP_ALLCHAR || k == REGEXP_RANGE) + { + ret = d_one; + } + else if (k == REGEXP_UNION || k == REGEXP_INTER || k == REGEXP_CONCAT) + { + bool success = true; + bool firstTime = true; + Rational rr(0); + for (const Node& nc : n) + { + Node bc = getConstantBoundLengthForRegexp(nc, isLower); + if (bc.isNull()) + { + if (k == REGEXP_UNION || (k == REGEXP_CONCAT && !isLower)) + { + // since the bound could not be determined on the component, the + // overall bound is undetermined. + success = false; + break; + } + else + { + // if intersection, or we are computing lower bound for concat + // and the component cannot be determined, ignore it + continue; + } + } + Assert(bc.getKind() == CONST_RATIONAL); + Rational r = bc.getConst(); + if (k == REGEXP_CONCAT) + { + rr += r; + } + else if (firstTime) + { + rr = r; + } + else if ((k == REGEXP_UNION) == isLower) + { + rr = std::min(r, rr); + } + else + { + rr = std::max(r, rr); + } + firstTime = false; + } + // if we were successful and didn't ignore all components + if (success && !firstTime) + { + ret = nm->mkConst(CONST_RATIONAL, rr); + } + } + if (ret.isNull() && isLower) + { + ret = d_zero; + } + setConstantBoundCache(n, ret, isLower); + return ret; +} + bool RegExpEntail::regExpIncludes(Node r1, Node r2) { Assert(Rewriter::rewrite(r1) == r1); @@ -803,6 +886,55 @@ bool RegExpEntail::regExpIncludes(Node r1, Node r2) return result; } +struct RegExpEntailConstantBoundLowerId +{ +}; +typedef expr::Attribute + RegExpEntailConstantBoundLower; + +struct RegExpEntailConstantBoundUpperId +{ +}; +typedef expr::Attribute + RegExpEntailConstantBoundUpper; + +void RegExpEntail::setConstantBoundCache(TNode n, Node ret, bool isLower) +{ + if (isLower) + { + RegExpEntailConstantBoundLower rcbl; + n.setAttribute(rcbl, ret); + } + else + { + RegExpEntailConstantBoundUpper rcbu; + n.setAttribute(rcbu, ret); + } +} + +bool RegExpEntail::getConstantBoundCache(TNode n, bool isLower, Node& c) +{ + if (isLower) + { + RegExpEntailConstantBoundLower rcbl; + if (n.hasAttribute(rcbl)) + { + c = n.getAttribute(rcbl); + return true; + } + } + else + { + RegExpEntailConstantBoundUpper rcbu; + if (n.hasAttribute(rcbu)) + { + c = n.getAttribute(rcbu); + return true; + } + } + return false; +} + } // namespace strings } // namespace theory } // namespace cvc5 diff --git a/src/theory/strings/regexp_entail.h b/src/theory/strings/regexp_entail.h index d8bcda4d9..b0511bd53 100644 --- a/src/theory/strings/regexp_entail.h +++ b/src/theory/strings/regexp_entail.h @@ -23,6 +23,7 @@ #include #include "expr/attribute.h" +#include "theory/strings/arith_entail.h" #include "theory/strings/rewrites.h" #include "theory/theory_rewriter.h" #include "theory/type_enumerator.h" @@ -34,6 +35,7 @@ namespace strings { class RegExpEntail { public: + RegExpEntail(Rewriter* r); /** simple regular expression consume * * This method is called when we are rewriting a membership of the form @@ -113,8 +115,14 @@ class RegExpEntail * Given regular expression n, if this method returns a non-null value c, then * x in n entails len( x ) = c. */ - static Node getFixedLengthForRegexp(Node n); + static Node getFixedLengthForRegexp(TNode n); + /** + * Get constant lower or upper bound on the lengths of strings that occur in + * regular expression n. Return null if a constant bound cannot be determined. + * This method will always worst case return 0 as a lower bound. + */ + Node getConstantBoundLengthForRegexp(TNode n, bool isLower = true) const; /** * Returns true if we can show that the regular expression `r1` includes * the regular expression `r2` (i.e. `r1` matches a superset of sequences @@ -129,6 +137,22 @@ class RegExpEntail * @return True if the inclusion can be shown, false otherwise */ static bool regExpIncludes(Node r1, Node r2); + + private: + /** Set bound cache, used for getConstantBoundLengthForRegexp */ + static void setConstantBoundCache(TNode n, Node ret, bool isLower); + /** + * Get bound cache, store in c and return true if the bound for n has been + * computed. Used for getConstantBoundLengthForRegexp. + */ + static bool getConstantBoundCache(TNode n, bool isLower, Node& c); + /** The underlying rewriter */ + Rewriter* d_rewriter; + /** Arithmetic entailment module */ + ArithEntail d_aent; + /** Common constants */ + Node d_zero; + Node d_one; }; } // namespace strings diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 7e79190e6..4306718be 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -59,7 +59,7 @@ TheoryStrings::TheoryStrings(Env& env, OutputChannel& out, Valuation valuation) d_rewriter(env.getRewriter(), &d_statistics.d_rewrites, d_termReg.getAlphabetCardinality()), - d_eagerSolver(env, d_state, d_termReg, d_rewriter.getArithEntail()), + d_eagerSolver(env, d_state, d_termReg), d_extTheoryCb(), d_im(env, *this, d_state, d_termReg, d_extTheory, d_statistics), d_extTheory(env, d_extTheoryCb, d_im),