Generalize eager length bound conflicts for regular expression memberships (#7633)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 30 Nov 2021 20:30:45 +0000 (14:30 -0600)
committerGitHub <noreply@github.com>
Tue, 30 Nov 2021 20:30:45 +0000 (20:30 +0000)
This generalizes eager length bound conflicts to take into account regular expression memberships.

For example:

If `(str.in_re x (re.++ (re.* re.allchar) (str.to_re "abc") (re.* re.allchar))` is asserted, then we know `(str.len x) >= 3`.

If e.g. equivalence class of `x` is merged with `(str.substr y 0 2)`, we get the conflict

`(and (str.in_re x (re.++ (re.* re.allchar) (str.to_re "abc") (re.* re.allchar)) (= x (str.substr y 0 2))`

since `(str.len (str.substr y 0 2)) <= 2`.

This also does some minor refactoring to eager prefix conflicts to make it more analogous to our implementation of length conflicts.

src/options/strings_options.toml
src/theory/strings/arith_entail.cpp
src/theory/strings/arith_entail.h
src/theory/strings/eager_solver.cpp
src/theory/strings/eager_solver.h
src/theory/strings/eqc_info.cpp
src/theory/strings/eqc_info.h
src/theory/strings/regexp_entail.cpp
src/theory/strings/regexp_entail.h
src/theory/strings/theory_strings.cpp

index caebcec3afe92980d37d68c76f1fc640c4106526..01b92c1b7731f808a47a25f562bd8ade113cf7c4 100644 (file)
@@ -199,3 +199,11 @@ name   = "Strings Theory"
   type       = "bool"
   default    = "false"
   help       = "use extensionality for string disequalities"
+
+[[option]]
+  name       = "stringsEagerLenEntRegexp"
+  category   = "regular"
+  long       = "strings-eager-len-re"
+  type       = "bool"
+  default    = "false"
+  help       = "use regular expressions for eager length conflicts"
index d8a4e85aa85f93c39644182ed470ee6bd159caee..8eaa503541e87546ab2bf79c79eed6fbed58fd59 100644 (file)
@@ -728,7 +728,7 @@ struct ArithEntailConstantBoundUpperId
 typedef expr::Attribute<ArithEntailConstantBoundUpperId, Node>
     ArithEntailConstantBoundUpper;
 
-void ArithEntail::setConstantBoundCache(Node n, Node ret, bool isLower)
+void ArithEntail::setConstantBoundCache(TNode n, Node ret, bool isLower)
 {
   if (isLower)
   {
@@ -742,14 +742,15 @@ void ArithEntail::setConstantBoundCache(Node n, Node ret, bool isLower)
   }
 }
 
-Node ArithEntail::getConstantBoundCache(Node n, bool isLower)
+bool ArithEntail::getConstantBoundCache(TNode n, bool isLower, Node& c)
 {
   if (isLower)
   {
     ArithEntailConstantBoundLower acbl;
     if (n.hasAttribute(acbl))
     {
-      return n.getAttribute(acbl);
+      c = n.getAttribute(acbl);
+      return true;
     }
   }
   else
@@ -757,17 +758,18 @@ Node ArithEntail::getConstantBoundCache(Node n, bool isLower)
     ArithEntailConstantBoundUpper acbu;
     if (n.hasAttribute(acbu))
     {
-      return n.getAttribute(acbu);
+      c = n.getAttribute(acbu);
+      return true;
     }
   }
-  return Node::null();
+  return false;
 }
 
-Node ArithEntail::getConstantBound(Node a, bool isLower)
+Node ArithEntail::getConstantBound(TNode a, bool isLower)
 {
   Assert(d_rr->rewrite(a) == a);
-  Node ret = getConstantBoundCache(a, isLower);
-  if (!ret.isNull())
+  Node ret;
+  if (getConstantBoundCache(a, isLower, ret))
   {
     return ret;
   }
@@ -852,18 +854,19 @@ Node ArithEntail::getConstantBound(Node a, bool isLower)
   return ret;
 }
 
-Node ArithEntail::getConstantBoundLength(Node s, bool isLower)
+Node ArithEntail::getConstantBoundLength(TNode s, bool isLower) const
 {
   Assert(s.getType().isStringLike());
-  Node ret = getConstantBoundCache(s, isLower);
-  if (!ret.isNull())
+  Node ret;
+  if (getConstantBoundCache(s, isLower, ret))
   {
     return ret;
   }
   NodeManager* nm = NodeManager::currentNM();
   if (s.isConst())
   {
-    ret = nm->mkConst(CONST_RATIONAL, Rational(Word::getLength(s)));
+    size_t len = Word::getLength(s);
+    ret = nm->mkConst(CONST_RATIONAL, Rational(len));
   }
   else if (s.getKind() == STRING_CONCAT)
   {
@@ -885,12 +888,12 @@ Node ArithEntail::getConstantBoundLength(Node s, bool isLower)
       Assert(b.getKind() == CONST_RATIONAL);
       sum = sum + b.getConst<Rational>();
     }
-    if (success)
+    if (success && (!isLower || sum.sgn() != 0))
     {
       ret = nm->mkConst(CONST_RATIONAL, sum);
     }
   }
-  else if (isLower)
+  if (ret.isNull() && isLower)
   {
     ret = d_zero;
   }
index 6529a81d15e9696b76b30c35252a148346a5ed22..e35dd35607aea32a9aab86c2f19f729f598ff61d 100644 (file)
@@ -141,12 +141,13 @@ class ArithEntail
    *     if and only if
    *   check( a, strict ) = true.
    */
-  Node getConstantBound(Node a, bool isLower = true);
+  Node getConstantBound(TNode a, bool isLower = true);
 
   /**
-   * get constant bound on the length of s.
+   * Get constant bound on the length of s, if it can be determined. This
+   * method will always worst case return 0 as a lower bound.
    */
-  Node getConstantBoundLength(Node s, bool isLower = true);
+  Node getConstantBoundLength(TNode s, bool isLower = true) const;
   /**
    * Given an inequality y1 + ... + yn >= x, removes operands yi s.t. the
    * original inequality still holds. Returns true if the original inequality
@@ -192,9 +193,12 @@ class ArithEntail
                               std::vector<Node>& approx,
                               bool isOverApprox = false);
   /** Set bound cache */
-  void setConstantBoundCache(Node n, Node ret, bool isLower);
-  /** Get bound cache */
-  Node getConstantBoundCache(Node n, bool isLower);
+  static void setConstantBoundCache(TNode n, Node ret, bool isLower);
+  /**
+   * Get bound cache, store in c and return true if the bound for n has been
+   * computed. Used for getConstantBound and getConstantBoundLength.
+   */
+  static bool getConstantBoundCache(TNode n, bool isLower, Node& c);
   /** The underlying rewriter */
   Rewriter* d_rr;
   /** Constant zero */
index ce47ff4fc4673f5631e8fc43dfa1b285aab5ebae..ac8e815df64a2d414f0588430787f6fa244b75d7 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "theory/strings/eager_solver.h"
 
+#include "options/strings_options.h"
 #include "theory/strings/theory_strings_utils.h"
 #include "util/rational.h"
 
@@ -24,11 +25,12 @@ namespace cvc5 {
 namespace theory {
 namespace strings {
 
-EagerSolver::EagerSolver(Env& env,
-                         SolverState& state,
-                         TermRegistry& treg,
-                         ArithEntail& aent)
-    : EnvObj(env), d_state(state), d_treg(treg), d_aent(aent)
+EagerSolver::EagerSolver(Env& env, SolverState& state, TermRegistry& treg)
+    : EnvObj(env),
+      d_state(state),
+      d_treg(treg),
+      d_aent(env.getRewriter()),
+      d_rent(env.getRewriter())
 {
 }
 
@@ -84,26 +86,21 @@ void EagerSolver::eqNotifyMerge(EqcInfo* e1, TNode t1, EqcInfo* e2, TNode t2)
   Assert(e1 != nullptr);
   Assert(e2 != nullptr);
   // check for conflict
-  Node conf = checkForMergeConflict(t1, t2, e1, e2);
-  if (!conf.isNull())
+  if (checkForMergeConflict(t1, t2, e1, e2))
   {
-    InferenceId id = t1.getType().isStringLike()
-                         ? InferenceId::STRINGS_PREFIX_CONFLICT
-                         : InferenceId::STRINGS_ARITH_BOUND_CONFLICT;
-    d_state.setPendingMergeConflict(conf, id);
     return;
   }
 }
 
-void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
+bool EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
 {
   Assert(concat.getKind() == STRING_CONCAT
          || concat.getKind() == REGEXP_CONCAT);
   EqcInfo* ei = nullptr;
   // check each side
-  for (unsigned r = 0; r < 2; r++)
+  for (size_t r = 0; r < 2; r++)
   {
-    unsigned index = r == 0 ? 0 : concat.getNumChildren() - 1;
+    size_t index = r == 0 ? 0 : concat.getNumChildren() - 1;
     Node c = utils::getConstantComponent(concat[index]);
     if (!c.isNull())
     {
@@ -114,18 +111,16 @@ void EagerSolver::addEndpointsToEqcInfo(Node t, Node concat, Node eqc)
       Trace("strings-eager-pconf-debug")
           << "New term: " << concat << " for " << t << " with prefix " << c
           << " (" << (r == 1) << ")" << std::endl;
-      Node conf = ei->addEndpointConst(t, c, r == 1);
-      if (!conf.isNull())
+      if (addEndpointConst(ei, t, c, r == 1))
       {
-        d_state.setPendingMergeConflict(conf,
-                                        InferenceId::STRINGS_PREFIX_CONFLICT);
-        return;
+        return true;
       }
     }
   }
+  return false;
 }
 
-Node EagerSolver::checkForMergeConflict(Node a,
+bool EagerSolver::checkForMergeConflict(Node a,
                                         Node b,
                                         EqcInfo* ea,
                                         EqcInfo* eb)
@@ -139,25 +134,25 @@ Node EagerSolver::checkForMergeConflict(Node a,
     Node n = i == 0 ? eb->d_firstBound.get() : eb->d_secondBound.get();
     if (!n.isNull())
     {
-      Node conf;
+      bool isConflict;
       if (a.getType().isStringLike())
       {
-        conf = ea->addEndpointConst(n, Node::null(), i == 1);
+        isConflict = addEndpointConst(ea, n, Node::null(), i == 1);
       }
       else
       {
         Trace("strings-eager-aconf-debug")
             << "addArithmeticBound " << n << " into " << a << " from " << b
             << std::endl;
-        conf = addArithmeticBound(ea, n, i == 0);
+        isConflict = addArithmeticBound(ea, n, i == 0);
       }
-      if (!conf.isNull())
+      if (isConflict)
       {
-        return conf;
+        return true;
       }
     }
   }
-  return Node::null();
+  return false;
 }
 
 void EagerSolver::notifyFact(TNode atom,
@@ -171,12 +166,63 @@ void EagerSolver::notifyFact(TNode atom,
     {
       eq::EqualityEngine* ee = d_state.getEqualityEngine();
       Node eqc = ee->getRepresentative(atom[0]);
-      addEndpointsToEqcInfo(atom, atom[1], eqc);
+      // add prefix constraints
+      if (addEndpointsToEqcInfo(atom, atom[1], eqc))
+      {
+        // conflict, we are done
+        return;
+      }
+      else if (!options().strings.stringsEagerLenEntRegexp)
+      {
+        // do not infer length constraints if option is disabled
+        return;
+      }
+      // also infer length constraints if the first is a variable
+      if (atom[0].isVar())
+      {
+        EqcInfo* blenEqc = nullptr;
+        for (size_t i = 0; i < 2; i++)
+        {
+          bool isLower = (i == 0);
+          Node b = d_rent.getConstantBoundLengthForRegexp(atom[1], isLower);
+          if (!b.isNull())
+          {
+            if (blenEqc == nullptr)
+            {
+              Node lenTerm =
+                  NodeManager::currentNM()->mkNode(STRING_LENGTH, atom[0]);
+              if (!ee->hasTerm(lenTerm))
+              {
+                break;
+              }
+              lenTerm = ee->getRepresentative(lenTerm);
+              blenEqc = d_state.getOrMakeEqcInfo(lenTerm);
+            }
+            if (addArithmeticBound(blenEqc, atom, isLower))
+            {
+              return;
+            }
+          }
+        }
+      }
     }
   }
 }
 
-Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower)
+bool EagerSolver::addEndpointConst(EqcInfo* e, Node t, Node c, bool isSuf)
+{
+  Assert(e != nullptr);
+  Assert(!t.isNull());
+  Node conf = e->addEndpointConst(t, c, isSuf);
+  if (!conf.isNull())
+  {
+    d_state.setPendingMergeConflict(conf, InferenceId::STRINGS_PREFIX_CONFLICT);
+    return true;
+  }
+  return false;
+}
+
+bool EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower)
 {
   Assert(e != nullptr);
   Assert(!t.isNull());
@@ -195,7 +241,7 @@ Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower)
     if (prevbr == br || (br < prevbr) == isLower)
     {
       // subsumed
-      return Node::null();
+      return false;
     }
   }
   Node prevo = isLower ? e->d_secondBound : e->d_firstBound;
@@ -210,10 +256,12 @@ Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower)
     if (prevobr != br && (prevobr < br) == isLower)
     {
       // conflict
-      Node ret = EqcInfo::mkMergeConflict(t, prevo);
+      Node ret = EqcInfo::mkMergeConflict(t, prevo, true);
       Trace("strings-eager-aconf")
           << "String: eager arithmetic bound conflict: " << ret << std::endl;
-      return ret;
+      d_state.setPendingMergeConflict(
+          ret, InferenceId::STRINGS_ARITH_BOUND_CONFLICT);
+      return true;
     }
   }
   if (isLower)
@@ -224,23 +272,27 @@ Node EagerSolver::addArithmeticBound(EqcInfo* e, Node t, bool isLower)
   {
     e->d_secondBound = t;
   }
-  return Node::null();
+  return false;
 }
 
-Node EagerSolver::getBoundForLength(Node len, bool isLower)
+Node EagerSolver::getBoundForLength(Node t, bool isLower) const
 {
-  Assert(len.getKind() == STRING_LENGTH);
+  if (t.getKind() == STRING_IN_REGEXP)
+  {
+    return d_rent.getConstantBoundLengthForRegexp(t[1]);
+  }
+  Assert(t.getKind() == STRING_LENGTH);
   // it is prohibitively expensive to convert to original form and rewrite,
   // since this may invoke the rewriter on lengths of complex terms. Instead,
   // we convert to original term the argument, then call the utility method
   // for computing the length of the argument, implicitly under an application
   // of length (ArithEntail::getConstantBoundLength).
   // convert to original form
-  Node olent = SkolemManager::getOriginalForm(len[0]);
+  Node olent = SkolemManager::getOriginalForm(t[0]);
   // get the bound
   Node c = d_aent.getConstantBoundLength(olent, isLower);
   Trace("strings-eager-aconf-debug")
-      << "Constant " << (isLower ? "lower" : "upper") << " bound for " << len
+      << "Constant " << (isLower ? "lower" : "upper") << " bound for " << t
       << " is " << c << ", from original form " << olent << std::endl;
   return c;
 }
index 4181a15c366fee70b2052aa4b5bde04687249c2d..a4898f52ad5b3aeacb0c36bc6ac5713a053b1e2b 100644 (file)
@@ -24,6 +24,7 @@
 #include "smt/env_obj.h"
 #include "theory/strings/arith_entail.h"
 #include "theory/strings/eqc_info.h"
+#include "theory/strings/regexp_entail.h"
 #include "theory/strings/solver_state.h"
 #include "theory/strings/term_registry.h"
 
@@ -38,10 +39,7 @@ namespace strings {
 class EagerSolver : protected EnvObj
 {
  public:
-  EagerSolver(Env& env,
-              SolverState& state,
-              TermRegistry& treg,
-              ArithEntail& aent);
+  EagerSolver(Env& env, SolverState& state, TermRegistry& treg);
   ~EagerSolver();
   /** called when a new equivalence class is created */
   void eqNotifyNewClass(TNode t);
@@ -60,23 +58,30 @@ class EagerSolver : protected EnvObj
    * for some eqc that is currently equal to t. Another example is:
    *   t := (str.in.re z (re.++ r s)), concat := (re.++ r s), eqc
    * for some eqc that is currently equal to z.
+   *
+   * Returns true if we are in conflict, that is, a conflict was sent via the
+   * inference manager.
    */
-  void addEndpointsToEqcInfo(Node t, Node concat, Node eqc);
+  bool addEndpointsToEqcInfo(Node t, Node concat, Node eqc);
   /**
    * Check for conflict when merging equivalence classes with the given info,
-   * return the node corresponding to the conflict if so.
+   * return true if we are in conflict.
    */
-  Node checkForMergeConflict(Node a, Node b, EqcInfo* ea, EqcInfo* eb);
-  /** add arithmetic bound */
-  Node addArithmeticBound(EqcInfo* ea, Node t, bool isLower);
-  /** get bound for length term */
-  Node getBoundForLength(Node len, bool isLower);
+  bool checkForMergeConflict(Node a, Node b, EqcInfo* ea, EqcInfo* eb);
+  /** add endpoint constant, return true if in conflict */
+  bool addEndpointConst(EqcInfo* e, Node t, Node c, bool isSuf);
+  /** add arithmetic bound, return true if in conflict */
+  bool addArithmeticBound(EqcInfo* e, Node t, bool isLower);
+  /** get bound for length term or regular expression membership */
+  Node getBoundForLength(Node t, bool isLower) const;
   /** Reference to the solver state */
   SolverState& d_state;
   /** Reference to the term registry */
   TermRegistry& d_treg;
   /** Arithmetic entailment */
-  ArithEntail& d_aent;
+  ArithEntail d_aent;
+  /** Regular expression entailment */
+  RegExpEntail d_rent;
 };
 
 }  // namespace strings
index 5fb5e91c342b43d03879d8406bc97f894ffdfc22..71f1f8b550e25f3c9fd238828c0a0bbb25385202 100644 (file)
@@ -100,7 +100,7 @@ Node EqcInfo::addEndpointConst(Node t, Node c, bool isSuf)
     {
       Trace("strings-eager-pconf")
           << "Conflict for " << prevC << ", " << c << std::endl;
-      Node ret = mkMergeConflict(t, prev);
+      Node ret = mkMergeConflict(t, prev, false);
       Trace("strings-eager-pconf")
           << "String: eager prefix conflict: " << ret << std::endl;
       return ret;
@@ -117,8 +117,11 @@ Node EqcInfo::addEndpointConst(Node t, Node c, bool isSuf)
   return Node::null();
 }
 
-Node EqcInfo::mkMergeConflict(Node t, Node prev)
+Node EqcInfo::mkMergeConflict(Node t, Node prev, bool isArith)
 {
+  Trace("strings-eager-debug")
+      << "mkMergeConflict " << t << ", " << prev << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
   std::vector<Node> ccs;
   Node r[2];
   for (unsigned i = 0; i < 2; i++)
@@ -127,7 +130,7 @@ Node EqcInfo::mkMergeConflict(Node t, Node prev)
     if (tp.getKind() == STRING_IN_REGEXP)
     {
       ccs.push_back(tp);
-      r[i] = tp[0];
+      r[i] = isArith ? nm->mkNode(STRING_LENGTH, tp[0]) : tp[0];
     }
     else
     {
@@ -139,7 +142,7 @@ Node EqcInfo::mkMergeConflict(Node t, Node prev)
     ccs.push_back(r[0].eqNode(r[1]));
   }
   Assert(!ccs.empty());
-  return NodeManager::currentNM()->mkAnd(ccs);
+  return nm->mkAnd(ccs);
 }
 
 }  // namespace strings
index bfc753989b3b04a32e4278916d7b299ff6f4b016..ff3213a8e54dcab94610db9dabdc18ec1d78f858 100644 (file)
@@ -88,8 +88,14 @@ class EqcInfo
    * the conflict:
    *   (and (= x (str.++ "B" y)) (str.in_re x (re.++ (str.to_re "A") R2)))
    * for this input.
+   *
+   * @param t The first bound term
+   * @param prev The second bound term
+   * @param isArith Whether this is an arithmetic conflict. This impacts
+   * whether (str.in_re x R) is processed as x or (str.len x).
+   * @return The node corresponding to the conflict.
    */
-  static Node mkMergeConflict(Node t, Node prev);
+  static Node mkMergeConflict(Node t, Node prev, bool isArith);
 };
 
 }  // namespace strings
index 49992f14f8ae0f0cae01acfb1e31b86d82993571..645dc05cd36ce7dabbddc050519fbc50d2fa15b6 100644 (file)
@@ -28,6 +28,12 @@ namespace cvc5 {
 namespace theory {
 namespace strings {
 
+RegExpEntail::RegExpEntail(Rewriter* r) : d_rewriter(r), d_aent(r)
+{
+  d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0));
+  d_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(1));
+}
+
 Node RegExpEntail::simpleRegexpConsume(std::vector<Node>& mchildren,
                                        std::vector<Node>& children,
                                        int dir)
@@ -647,10 +653,11 @@ bool RegExpEntail::hasEpsilonNode(TNode node)
   return false;
 }
 
-Node RegExpEntail::getFixedLengthForRegexp(Node n)
+Node RegExpEntail::getFixedLengthForRegexp(TNode n)
 {
   NodeManager* nm = NodeManager::currentNM();
-  if (n.getKind() == STRING_TO_REGEXP)
+  Kind k = n.getKind();
+  if (k == STRING_TO_REGEXP)
   {
     Node ret = nm->mkNode(STRING_LENGTH, n[0]);
     ret = Rewriter::rewrite(ret);
@@ -659,11 +666,11 @@ Node RegExpEntail::getFixedLengthForRegexp(Node n)
       return ret;
     }
   }
-  else if (n.getKind() == REGEXP_ALLCHAR || n.getKind() == REGEXP_RANGE)
+  else if (k == REGEXP_ALLCHAR || k == REGEXP_RANGE)
   {
     return nm->mkConst(CONST_RATIONAL, Rational(1));
   }
-  else if (n.getKind() == REGEXP_UNION || n.getKind() == REGEXP_INTER)
+  else if (k == REGEXP_UNION || k == REGEXP_INTER)
   {
     Node ret;
     for (const Node& nc : n)
@@ -681,7 +688,7 @@ Node RegExpEntail::getFixedLengthForRegexp(Node n)
     }
     return ret;
   }
-  else if (n.getKind() == REGEXP_CONCAT)
+  else if (k == REGEXP_CONCAT)
   {
     NodeBuilder nb(PLUS);
     for (const Node& nc : n)
@@ -700,6 +707,82 @@ Node RegExpEntail::getFixedLengthForRegexp(Node n)
   return Node::null();
 }
 
+Node RegExpEntail::getConstantBoundLengthForRegexp(TNode n, bool isLower) const
+{
+  Assert(n.getType().isRegExp());
+  Node ret;
+  if (getConstantBoundCache(n, isLower, ret))
+  {
+    return ret;
+  }
+  Kind k = n.getKind();
+  NodeManager* nm = NodeManager::currentNM();
+  if (k == STRING_TO_REGEXP)
+  {
+    ret = d_aent.getConstantBoundLength(n[0], isLower);
+  }
+  else if (k == REGEXP_ALLCHAR || k == REGEXP_RANGE)
+  {
+    ret = d_one;
+  }
+  else if (k == REGEXP_UNION || k == REGEXP_INTER || k == REGEXP_CONCAT)
+  {
+    bool success = true;
+    bool firstTime = true;
+    Rational rr(0);
+    for (const Node& nc : n)
+    {
+      Node bc = getConstantBoundLengthForRegexp(nc, isLower);
+      if (bc.isNull())
+      {
+        if (k == REGEXP_UNION || (k == REGEXP_CONCAT && !isLower))
+        {
+          // since the bound could not be determined on the component, the
+          // overall bound is undetermined.
+          success = false;
+          break;
+        }
+        else
+        {
+          // if intersection, or we are computing lower bound for concat
+          // and the component cannot be determined, ignore it
+          continue;
+        }
+      }
+      Assert(bc.getKind() == CONST_RATIONAL);
+      Rational r = bc.getConst<Rational>();
+      if (k == REGEXP_CONCAT)
+      {
+        rr += r;
+      }
+      else if (firstTime)
+      {
+        rr = r;
+      }
+      else if ((k == REGEXP_UNION) == isLower)
+      {
+        rr = std::min(r, rr);
+      }
+      else
+      {
+        rr = std::max(r, rr);
+      }
+      firstTime = false;
+    }
+    // if we were successful and didn't ignore all components
+    if (success && !firstTime)
+    {
+      ret = nm->mkConst(CONST_RATIONAL, rr);
+    }
+  }
+  if (ret.isNull() && isLower)
+  {
+    ret = d_zero;
+  }
+  setConstantBoundCache(n, ret, isLower);
+  return ret;
+}
+
 bool RegExpEntail::regExpIncludes(Node r1, Node r2)
 {
   Assert(Rewriter::rewrite(r1) == r1);
@@ -803,6 +886,55 @@ bool RegExpEntail::regExpIncludes(Node r1, Node r2)
   return result;
 }
 
+struct RegExpEntailConstantBoundLowerId
+{
+};
+typedef expr::Attribute<RegExpEntailConstantBoundLowerId, Node>
+    RegExpEntailConstantBoundLower;
+
+struct RegExpEntailConstantBoundUpperId
+{
+};
+typedef expr::Attribute<RegExpEntailConstantBoundUpperId, Node>
+    RegExpEntailConstantBoundUpper;
+
+void RegExpEntail::setConstantBoundCache(TNode n, Node ret, bool isLower)
+{
+  if (isLower)
+  {
+    RegExpEntailConstantBoundLower rcbl;
+    n.setAttribute(rcbl, ret);
+  }
+  else
+  {
+    RegExpEntailConstantBoundUpper rcbu;
+    n.setAttribute(rcbu, ret);
+  }
+}
+
+bool RegExpEntail::getConstantBoundCache(TNode n, bool isLower, Node& c)
+{
+  if (isLower)
+  {
+    RegExpEntailConstantBoundLower rcbl;
+    if (n.hasAttribute(rcbl))
+    {
+      c = n.getAttribute(rcbl);
+      return true;
+    }
+  }
+  else
+  {
+    RegExpEntailConstantBoundUpper rcbu;
+    if (n.hasAttribute(rcbu))
+    {
+      c = n.getAttribute(rcbu);
+      return true;
+    }
+  }
+  return false;
+}
+
 }  // namespace strings
 }  // namespace theory
 }  // namespace cvc5
index d8bcda4d97cfc32492a544fa2ab89fbaafcf1081..b0511bd53ee9dcef877115e31ccc830f952d1861 100644 (file)
@@ -23,6 +23,7 @@
 #include <vector>
 
 #include "expr/attribute.h"
+#include "theory/strings/arith_entail.h"
 #include "theory/strings/rewrites.h"
 #include "theory/theory_rewriter.h"
 #include "theory/type_enumerator.h"
@@ -34,6 +35,7 @@ namespace strings {
 class RegExpEntail
 {
  public:
+  RegExpEntail(Rewriter* r);
   /** simple regular expression consume
    *
    * This method is called when we are rewriting a membership of the form
@@ -113,8 +115,14 @@ class RegExpEntail
    * Given regular expression n, if this method returns a non-null value c, then
    * x in n entails len( x ) = c.
    */
-  static Node getFixedLengthForRegexp(Node n);
+  static Node getFixedLengthForRegexp(TNode n);
 
+  /**
+   * Get constant lower or upper bound on the lengths of strings that occur in
+   * regular expression n. Return null if a constant bound cannot be determined.
+   * This method will always worst case return 0 as a lower bound.
+   */
+  Node getConstantBoundLengthForRegexp(TNode n, bool isLower = true) const;
   /**
    * Returns true if we can show that the regular expression `r1` includes
    * the regular expression `r2` (i.e. `r1` matches a superset of sequences
@@ -129,6 +137,22 @@ class RegExpEntail
    * @return True if the inclusion can be shown, false otherwise
    */
   static bool regExpIncludes(Node r1, Node r2);
+
+ private:
+  /** Set bound cache, used for getConstantBoundLengthForRegexp */
+  static void setConstantBoundCache(TNode n, Node ret, bool isLower);
+  /**
+   * Get bound cache, store in c and return true if the bound for n has been
+   * computed. Used for getConstantBoundLengthForRegexp.
+   */
+  static bool getConstantBoundCache(TNode n, bool isLower, Node& c);
+  /** The underlying rewriter */
+  Rewriter* d_rewriter;
+  /** Arithmetic entailment module */
+  ArithEntail d_aent;
+  /** Common constants */
+  Node d_zero;
+  Node d_one;
 };
 
 }  // namespace strings
index 7e79190e6ea3da9501efe4949666c3a9e417039c..4306718be98437c910ea551b137c2e7630ec398d 100644 (file)
@@ -59,7 +59,7 @@ TheoryStrings::TheoryStrings(Env& env, OutputChannel& out, Valuation valuation)
       d_rewriter(env.getRewriter(),
                  &d_statistics.d_rewrites,
                  d_termReg.getAlphabetCardinality()),
-      d_eagerSolver(env, d_state, d_termReg, d_rewriter.getArithEntail()),
+      d_eagerSolver(env, d_state, d_termReg),
       d_extTheoryCb(),
       d_im(env, *this, d_state, d_termReg, d_extTheory, d_statistics),
       d_extTheory(env, d_extTheoryCb, d_im),