This generalizes eager length bound conflicts to take into account regular expression memberships.
For example:
If `(str.in_re x (re.++ (re.* re.allchar) (str.to_re "abc") (re.* re.allchar))` is asserted, then we know `(str.len x) >= 3`.
If e.g. equivalence class of `x` is merged with `(str.substr y 0 2)`, we get the conflict
`(and (str.in_re x (re.++ (re.* re.allchar) (str.to_re "abc") (re.* re.allchar)) (= x (str.substr y 0 2))`
since `(str.len (str.substr y 0 2)) <= 2`.
This also does some minor refactoring to eager prefix conflicts to make it more analogous to our implementation of length conflicts.
type = "bool"
default = "false"
help = "use extensionality for string disequalities"
+
+[[option]]
+ name = "stringsEagerLenEntRegexp"
+ category = "regular"
+ long = "strings-eager-len-re"
+ type = "bool"
+ default = "false"
+ help = "use regular expressions for eager length conflicts"
typedef expr::Attribute<ArithEntailConstantBoundUpperId, Node>
ArithEntailConstantBoundUpper;
-void ArithEntail::setConstantBoundCache(Node n, Node ret, bool isLower)
+void ArithEntail::setConstantBoundCache(TNode n, Node ret, bool isLower)
{
if (isLower)
{
}
}
-Node ArithEntail::getConstantBoundCache(Node n, bool isLower)
+bool ArithEntail::getConstantBoundCache(TNode n, bool isLower, Node& c)
{
if (isLower)
{
ArithEntailConstantBoundLower acbl;
if (n.hasAttribute(acbl))
{
- return n.getAttribute(acbl);
+ c = n.getAttribute(acbl);
+ return true;
}
}
else
ArithEntailConstantBoundUpper acbu;
if (n.hasAttribute(acbu))
{
- return n.getAttribute(acbu);
+ c = n.getAttribute(acbu);
+ return true;
}
}
- return Node::null();
+ return false;
}
-Node ArithEntail::getConstantBound(Node a, bool isLower)
+Node ArithEntail::getConstantBound(TNode a, bool isLower)
{
Assert(d_rr->rewrite(a) == a);
- Node ret = getConstantBoundCache(a, isLower);
- if (!ret.isNull())
+ Node ret;
+ if (getConstantBoundCache(a, isLower, ret))
{
return ret;
}
return ret;
}
-Node ArithEntail::getConstantBoundLength(Node s, bool isLower)
+Node ArithEntail::getConstantBoundLength(TNode s, bool isLower) const
{
Assert(s.getType().isStringLike());
- Node ret = getConstantBoundCache(s, isLower);
- if (!ret.isNull())
+ Node ret;
+ if (getConstantBoundCache(s, isLower, ret))
{
return ret;
}
NodeManager* nm = NodeManager::currentNM();
if (s.isConst())
{
- ret = nm->mkConst(CONST_RATIONAL, Rational(Word::getLength(s)));
+ size_t len = Word::getLength(s);
+ ret = nm->mkConst(CONST_RATIONAL, Rational(len));
}
else if (s.getKind() == STRING_CONCAT)
{
Assert(b.getKind() == CONST_RATIONAL);
sum = sum + b.getConst<Rational>();
}
- if (success)
+ if (success && (!isLower || sum.sgn() != 0))
{
ret = nm->mkConst(CONST_RATIONAL, sum);
}
}
- else if (isLower)
+ if (ret.isNull() && isLower)
{
ret = d_zero;
}
* if and only if
* check( a, strict ) = true.
*/
- Node getConstantBound(Node a, bool isLower = true);
+ Node getConstantBound(TNode a, bool isLower = true);
/**
- * get constant bound on the length of s.
+ * Get constant bound on the length of s, if it can be determined. This
+ * method will always worst case return 0 as a lower bound.
*/
- Node getConstantBoundLength(Node s, bool isLower = true);
+ Node getConstantBoundLength(TNode s, bool isLower = true) const;
/**
* Given an inequality y1 + ... + yn >= x, removes operands yi s.t. the
* original inequality still holds. Returns true if the original inequality
std::vector<Node>& approx,
bool isOverApprox = false);
/** Set bound cache */
- void setConstantBoundCache(Node n, Node ret, bool isLower);
- /** Get bound cache */
- Node getConstantBoundCache(Node n, bool isLower);
+ static void setConstantBoundCache(TNode n, Node ret, bool isLower);
+ /**
+ * Get bound cache, store in c and return true if the bound for n has been
+ * computed. Used for getConstantBound and getConstantBoundLength.
+ */
+ static bool getConstantBoundCache(TNode n, bool isLower, Node& c);
/** The underlying rewriter */
Rewriter* d_rr;
/** Constant zero */
#include "theory/strings/eager_solver.h"
+#include "options/strings_options.h"
#include "theory/strings/theory_strings_utils.h"
#include "util/rational.h"
namespace theory {
namespace strings {
-EagerSolver::EagerSolver(Env& env,
- SolverState& state,
- TermRegistry& treg,
- ArithEntail& aent)
- : EnvObj(env), d_state(state), d_treg(treg), d_aent(aent)
+EagerSolver::EagerSolver(Env& env, SolverState& state, TermRegistry& treg)
+ : EnvObj(env),
+ d_state(state),
+ d_treg(treg),
+ d_aent(env.getRewriter()),
+ d_rent(env.getRewriter())
{
}
Assert(e1 != nullptr);
Assert(e2 != nullptr);
// check for conflict
- Node conf = checkForMergeConflict(t1, t2, e1, e2);
- if (!conf.isNull())
+ if (checkForMergeConflict(t1, t2, e1, e2))
{
- InferenceId id = t1.getType().isStringLike()
- ? InferenceId::STRINGS_PREFIX_CONFLICT
- : InferenceId::STRINGS_ARITH_BOUND_CONFLICT;
- d_state.setPendingMergeConflict(conf, id);
return;
}
}
-void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
+bool EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
{
Assert(concat.getKind() == STRING_CONCAT
|| concat.getKind() == REGEXP_CONCAT);
EqcInfo* ei = nullptr;
// check each side
- for (unsigned r = 0; r < 2; r++)
+ for (size_t r = 0; r < 2; r++)
{
- unsigned index = r == 0 ? 0 : concat.getNumChildren() - 1;
+ size_t index = r == 0 ? 0 : concat.getNumChildren() - 1;
Node c = utils::getConstantComponent(concat[index]);
if (!c.isNull())
{
Trace("strings-eager-pconf-debug")
<< "New term: " << concat << " for " << t << " with prefix " << c
<< " (" << (r == 1) << ")" << std::endl;
- Node conf = ei->addEndpointConst(t, c, r == 1);
- if (!conf.isNull())
+ if (addEndpointConst(ei, t, c, r == 1))
{
- d_state.setPendingMergeConflict(conf,
- InferenceId::STRINGS_PREFIX_CONFLICT);
- return;
+ return true;
}
}
}
+ return false;
}
-Node EagerSolver::checkForMergeConflict(Node a,
+bool EagerSolver::checkForMergeConflict(Node a,
Node b,
EqcInfo* ea,
EqcInfo* eb)
Node n = i == 0 ? eb->d_firstBound.get() : eb->d_secondBound.get();
if (!n.isNull())
{
- Node conf;
+ bool isConflict;
if (a.getType().isStringLike())
{
- conf = ea->addEndpointConst(n, Node::null(), i == 1);
+ isConflict = addEndpointConst(ea, n, Node::null(), i == 1);
}
else
{
Trace("strings-eager-aconf-debug")
<< "addArithmeticBound " << n << " into " << a << " from " << b
<< std::endl;
- conf = addArithmeticBound(ea, n, i == 0);
+ isConflict = addArithmeticBound(ea, n, i == 0);
}
- if (!conf.isNull())
+ if (isConflict)
{
- return conf;
+ return true;
}
}
}
- return Node::null();
+ return false;
}
void EagerSolver::notifyFact(TNode atom,
{
eq::EqualityEngine* ee = d_state.getEqualityEngine();
Node eqc = ee->getRepresentative(atom[0]);
- addEndpointsToEqcInfo(atom, atom[1], eqc);
+ // add prefix constraints
+ if (addEndpointsToEqcInfo(atom, atom[1], eqc))
+ {
+ // conflict, we are done
+ return;
+ }
+ else if (!options().strings.stringsEagerLenEntRegexp)
+ {
+ // do not infer length constraints if option is disabled
+ return;
+ }
+ // also infer length constraints if the first is a variable
+ if (atom[0].isVar())
+ {
+ EqcInfo* blenEqc = nullptr;
+ for (size_t i = 0; i < 2; i++)
+ {
+ bool isLower = (i == 0);
+ Node b = d_rent.getConstantBoundLengthForRegexp(atom[1], isLower);
+ if (!b.isNull())
+ {
+ if (blenEqc == nullptr)
+ {
+ Node lenTerm =
+ NodeManager::currentNM()->mkNode(STRING_LENGTH, atom[0]);
+ if (!ee->hasTerm(lenTerm))
+ {
+ break;
+ }
+ lenTerm = ee->getRepresentative(lenTerm);
+ blenEqc = d_state.getOrMakeEqcInfo(lenTerm);
+ }
+ if (addArithmeticBound(blenEqc, atom, isLower))
+ {
+ return;
+ }
+ }
+ }
+ }
}
}
}
-Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower)
+bool EagerSolver::addEndpointConst(EqcInfo* e, Node t, Node c, bool isSuf)
+{
+ Assert(e != nullptr);
+ Assert(!t.isNull());
+ Node conf = e->addEndpointConst(t, c, isSuf);
+ if (!conf.isNull())
+ {
+ d_state.setPendingMergeConflict(conf, InferenceId::STRINGS_PREFIX_CONFLICT);
+ return true;
+ }
+ return false;
+}
+
+bool EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower)
{
Assert(e != nullptr);
Assert(!t.isNull());
if (prevbr == br || (br < prevbr) == isLower)
{
// subsumed
- return Node::null();
+ return false;
}
}
Node prevo = isLower ? e->d_secondBound : e->d_firstBound;
if (prevobr != br && (prevobr < br) == isLower)
{
// conflict
- Node ret = EqcInfo::mkMergeConflict(t, prevo);
+ Node ret = EqcInfo::mkMergeConflict(t, prevo, true);
Trace("strings-eager-aconf")
<< "String: eager arithmetic bound conflict: " << ret << std::endl;
- return ret;
+ d_state.setPendingMergeConflict(
+ ret, InferenceId::STRINGS_ARITH_BOUND_CONFLICT);
+ return true;
}
}
if (isLower)
{
e->d_secondBound = t;
}
- return Node::null();
+ return false;
}
-Node EagerSolver::getBoundForLength(Node len, bool isLower)
+Node EagerSolver::getBoundForLength(Node t, bool isLower) const
{
- Assert(len.getKind() == STRING_LENGTH);
+ if (t.getKind() == STRING_IN_REGEXP)
+ {
+ return d_rent.getConstantBoundLengthForRegexp(t[1]);
+ }
+ Assert(t.getKind() == STRING_LENGTH);
// it is prohibitively expensive to convert to original form and rewrite,
// since this may invoke the rewriter on lengths of complex terms. Instead,
// we convert to original term the argument, then call the utility method
// for computing the length of the argument, implicitly under an application
// of length (ArithEntail::getConstantBoundLength).
// convert to original form
- Node olent = SkolemManager::getOriginalForm(len[0]);
+ Node olent = SkolemManager::getOriginalForm(t[0]);
// get the bound
Node c = d_aent.getConstantBoundLength(olent, isLower);
Trace("strings-eager-aconf-debug")
- << "Constant " << (isLower ? "lower" : "upper") << " bound for " << len
+ << "Constant " << (isLower ? "lower" : "upper") << " bound for " << t
<< " is " << c << ", from original form " << olent << std::endl;
return c;
}
#include "smt/env_obj.h"
#include "theory/strings/arith_entail.h"
#include "theory/strings/eqc_info.h"
+#include "theory/strings/regexp_entail.h"
#include "theory/strings/solver_state.h"
#include "theory/strings/term_registry.h"
class EagerSolver : protected EnvObj
{
public:
- EagerSolver(Env& env,
- SolverState& state,
- TermRegistry& treg,
- ArithEntail& aent);
+ EagerSolver(Env& env, SolverState& state, TermRegistry& treg);
~EagerSolver();
/** called when a new equivalence class is created */
void eqNotifyNewClass(TNode t);
* for some eqc that is currently equal to t. Another example is:
* t := (str.in.re z (re.++ r s)), concat := (re.++ r s), eqc
* for some eqc that is currently equal to z.
+ *
+ * Returns true if we are in conflict, that is, a conflict was sent via the
+ * inference manager.
*/
- void addEndpointsToEqcInfo(Node t, Node concat, Node eqc);
+ bool addEndpointsToEqcInfo(Node t, Node concat, Node eqc);
/**
* Check for conflict when merging equivalence classes with the given info,
- * return the node corresponding to the conflict if so.
+ * return true if we are in conflict.
*/
- Node checkForMergeConflict(Node a, Node b, EqcInfo* ea, EqcInfo* eb);
- /** add arithmetic bound */
- Node addArithmeticBound(EqcInfo* ea, Node t, bool isLower);
- /** get bound for length term */
- Node getBoundForLength(Node len, bool isLower);
+ bool checkForMergeConflict(Node a, Node b, EqcInfo* ea, EqcInfo* eb);
+ /** add endpoint constant, return true if in conflict */
+ bool addEndpointConst(EqcInfo* e, Node t, Node c, bool isSuf);
+ /** add arithmetic bound, return true if in conflict */
+ bool addArithmeticBound(EqcInfo* e, Node t, bool isLower);
+ /** get bound for length term or regular expression membership */
+ Node getBoundForLength(Node t, bool isLower) const;
/** Reference to the solver state */
SolverState& d_state;
/** Reference to the term registry */
TermRegistry& d_treg;
/** Arithmetic entailment */
- ArithEntail& d_aent;
+ ArithEntail d_aent;
+ /** Regular expression entailment */
+ RegExpEntail d_rent;
};
} // namespace strings
{
Trace("strings-eager-pconf")
<< "Conflict for " << prevC << ", " << c << std::endl;
- Node ret = mkMergeConflict(t, prev);
+ Node ret = mkMergeConflict(t, prev, false);
Trace("strings-eager-pconf")
<< "String: eager prefix conflict: " << ret << std::endl;
return ret;
return Node::null();
}
-Node EqcInfo::mkMergeConflict(Node t, Node prev)
+Node EqcInfo::mkMergeConflict(Node t, Node prev, bool isArith)
{
+ Trace("strings-eager-debug")
+ << "mkMergeConflict " << t << ", " << prev << std::endl;
+ NodeManager* nm = NodeManager::currentNM();
std::vector<Node> ccs;
Node r[2];
for (unsigned i = 0; i < 2; i++)
if (tp.getKind() == STRING_IN_REGEXP)
{
ccs.push_back(tp);
- r[i] = tp[0];
+ r[i] = isArith ? nm->mkNode(STRING_LENGTH, tp[0]) : tp[0];
}
else
{
ccs.push_back(r[0].eqNode(r[1]));
}
Assert(!ccs.empty());
- return NodeManager::currentNM()->mkAnd(ccs);
+ return nm->mkAnd(ccs);
}
} // namespace strings
* the conflict:
* (and (= x (str.++ "B" y)) (str.in_re x (re.++ (str.to_re "A") R2)))
* for this input.
+ *
+ * @param t The first bound term
+ * @param prev The second bound term
+ * @param isArith Whether this is an arithmetic conflict. This impacts
+ * whether (str.in_re x R) is processed as x or (str.len x).
+ * @return The node corresponding to the conflict.
*/
- static Node mkMergeConflict(Node t, Node prev);
+ static Node mkMergeConflict(Node t, Node prev, bool isArith);
};
} // namespace strings
namespace theory {
namespace strings {
+RegExpEntail::RegExpEntail(Rewriter* r) : d_rewriter(r), d_aent(r)
+{
+ d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0));
+ d_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(1));
+}
+
Node RegExpEntail::simpleRegexpConsume(std::vector<Node>& mchildren,
std::vector<Node>& children,
int dir)
return false;
}
-Node RegExpEntail::getFixedLengthForRegexp(Node n)
+Node RegExpEntail::getFixedLengthForRegexp(TNode n)
{
NodeManager* nm = NodeManager::currentNM();
- if (n.getKind() == STRING_TO_REGEXP)
+ Kind k = n.getKind();
+ if (k == STRING_TO_REGEXP)
{
Node ret = nm->mkNode(STRING_LENGTH, n[0]);
ret = Rewriter::rewrite(ret);
return ret;
}
}
- else if (n.getKind() == REGEXP_ALLCHAR || n.getKind() == REGEXP_RANGE)
+ else if (k == REGEXP_ALLCHAR || k == REGEXP_RANGE)
{
return nm->mkConst(CONST_RATIONAL, Rational(1));
}
- else if (n.getKind() == REGEXP_UNION || n.getKind() == REGEXP_INTER)
+ else if (k == REGEXP_UNION || k == REGEXP_INTER)
{
Node ret;
for (const Node& nc : n)
}
return ret;
}
- else if (n.getKind() == REGEXP_CONCAT)
+ else if (k == REGEXP_CONCAT)
{
NodeBuilder nb(PLUS);
for (const Node& nc : n)
return Node::null();
}
+Node RegExpEntail::getConstantBoundLengthForRegexp(TNode n, bool isLower) const
+{
+ Assert(n.getType().isRegExp());
+ Node ret;
+ if (getConstantBoundCache(n, isLower, ret))
+ {
+ return ret;
+ }
+ Kind k = n.getKind();
+ NodeManager* nm = NodeManager::currentNM();
+ if (k == STRING_TO_REGEXP)
+ {
+ ret = d_aent.getConstantBoundLength(n[0], isLower);
+ }
+ else if (k == REGEXP_ALLCHAR || k == REGEXP_RANGE)
+ {
+ ret = d_one;
+ }
+ else if (k == REGEXP_UNION || k == REGEXP_INTER || k == REGEXP_CONCAT)
+ {
+ bool success = true;
+ bool firstTime = true;
+ Rational rr(0);
+ for (const Node& nc : n)
+ {
+ Node bc = getConstantBoundLengthForRegexp(nc, isLower);
+ if (bc.isNull())
+ {
+ if (k == REGEXP_UNION || (k == REGEXP_CONCAT && !isLower))
+ {
+ // since the bound could not be determined on the component, the
+ // overall bound is undetermined.
+ success = false;
+ break;
+ }
+ else
+ {
+ // if intersection, or we are computing lower bound for concat
+ // and the component cannot be determined, ignore it
+ continue;
+ }
+ }
+ Assert(bc.getKind() == CONST_RATIONAL);
+ Rational r = bc.getConst<Rational>();
+ if (k == REGEXP_CONCAT)
+ {
+ rr += r;
+ }
+ else if (firstTime)
+ {
+ rr = r;
+ }
+ else if ((k == REGEXP_UNION) == isLower)
+ {
+ rr = std::min(r, rr);
+ }
+ else
+ {
+ rr = std::max(r, rr);
+ }
+ firstTime = false;
+ }
+ // if we were successful and didn't ignore all components
+ if (success && !firstTime)
+ {
+ ret = nm->mkConst(CONST_RATIONAL, rr);
+ }
+ }
+ if (ret.isNull() && isLower)
+ {
+ ret = d_zero;
+ }
+ setConstantBoundCache(n, ret, isLower);
+ return ret;
+}
+
bool RegExpEntail::regExpIncludes(Node r1, Node r2)
{
Assert(Rewriter::rewrite(r1) == r1);
return result;
}
+struct RegExpEntailConstantBoundLowerId
+{
+};
+typedef expr::Attribute<RegExpEntailConstantBoundLowerId, Node>
+ RegExpEntailConstantBoundLower;
+
+struct RegExpEntailConstantBoundUpperId
+{
+};
+typedef expr::Attribute<RegExpEntailConstantBoundUpperId, Node>
+ RegExpEntailConstantBoundUpper;
+
+void RegExpEntail::setConstantBoundCache(TNode n, Node ret, bool isLower)
+{
+ if (isLower)
+ {
+ RegExpEntailConstantBoundLower rcbl;
+ n.setAttribute(rcbl, ret);
+ }
+ else
+ {
+ RegExpEntailConstantBoundUpper rcbu;
+ n.setAttribute(rcbu, ret);
+ }
+}
+
+bool RegExpEntail::getConstantBoundCache(TNode n, bool isLower, Node& c)
+{
+ if (isLower)
+ {
+ RegExpEntailConstantBoundLower rcbl;
+ if (n.hasAttribute(rcbl))
+ {
+ c = n.getAttribute(rcbl);
+ return true;
+ }
+ }
+ else
+ {
+ RegExpEntailConstantBoundUpper rcbu;
+ if (n.hasAttribute(rcbu))
+ {
+ c = n.getAttribute(rcbu);
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace strings
} // namespace theory
} // namespace cvc5
#include <vector>
#include "expr/attribute.h"
+#include "theory/strings/arith_entail.h"
#include "theory/strings/rewrites.h"
#include "theory/theory_rewriter.h"
#include "theory/type_enumerator.h"
class RegExpEntail
{
public:
+ RegExpEntail(Rewriter* r);
/** simple regular expression consume
*
* This method is called when we are rewriting a membership of the form
* Given regular expression n, if this method returns a non-null value c, then
* x in n entails len( x ) = c.
*/
- static Node getFixedLengthForRegexp(Node n);
+ static Node getFixedLengthForRegexp(TNode n);
+ /**
+ * Get constant lower or upper bound on the lengths of strings that occur in
+ * regular expression n. Return null if a constant bound cannot be determined.
+ * This method will always worst case return 0 as a lower bound.
+ */
+ Node getConstantBoundLengthForRegexp(TNode n, bool isLower = true) const;
/**
* Returns true if we can show that the regular expression `r1` includes
* the regular expression `r2` (i.e. `r1` matches a superset of sequences
* @return True if the inclusion can be shown, false otherwise
*/
static bool regExpIncludes(Node r1, Node r2);
+
+ private:
+ /** Set bound cache, used for getConstantBoundLengthForRegexp */
+ static void setConstantBoundCache(TNode n, Node ret, bool isLower);
+ /**
+ * Get bound cache, store in c and return true if the bound for n has been
+ * computed. Used for getConstantBoundLengthForRegexp.
+ */
+ static bool getConstantBoundCache(TNode n, bool isLower, Node& c);
+ /** The underlying rewriter */
+ Rewriter* d_rewriter;
+ /** Arithmetic entailment module */
+ ArithEntail d_aent;
+ /** Common constants */
+ Node d_zero;
+ Node d_one;
};
} // namespace strings
d_rewriter(env.getRewriter(),
&d_statistics.d_rewrites,
d_termReg.getAlphabetCardinality()),
- d_eagerSolver(env, d_state, d_termReg, d_rewriter.getArithEntail()),
+ d_eagerSolver(env, d_state, d_termReg),
d_extTheoryCb(),
d_im(env, *this, d_state, d_termReg, d_extTheory, d_statistics),
d_extTheory(env, d_extTheoryCb, d_im),