Eliminate calls to Rewriter::rewrite from strings entailment checks (#7203)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 24 Sep 2021 05:11:24 +0000 (00:11 -0500)
committerGitHub <noreply@github.com>
Fri, 24 Sep 2021 05:11:24 +0000 (05:11 +0000)
There are a few further circular references that prevent us from not passing Rewriter to the strings TheoryRewriter constructor, this can be cleaned in future PRs.

14 files changed:
src/preprocessing/passes/foreign_theory_rewrite.cpp
src/preprocessing/passes/foreign_theory_rewrite.h
src/theory/quantifiers/extended_rewrite.cpp
src/theory/strings/arith_entail.cpp
src/theory/strings/arith_entail.h
src/theory/strings/sequences_rewriter.cpp
src/theory/strings/sequences_rewriter.h
src/theory/strings/strings_entail.cpp
src/theory/strings/strings_entail.h
src/theory/strings/strings_rewriter.cpp
src/theory/strings/strings_rewriter.h
src/theory/strings/theory_strings.cpp
test/unit/preprocessing/pass_foreign_theory_rewrite_white.cpp
test/unit/theory/sequences_rewriter_white.cpp

index 24edf150948afc92addae9845e1b1fab026cedad..70ad0fea33e224a29f232da676c1ffd0b82e7973 100644 (file)
@@ -20,6 +20,7 @@
 #include "expr/node_traversal.h"
 #include "preprocessing/assertion_pipeline.h"
 #include "preprocessing/preprocessing_pass_context.h"
+#include "smt/env.h"
 #include "theory/rewriter.h"
 #include "theory/strings/arith_entail.h"
 
@@ -28,12 +29,13 @@ namespace preprocessing {
 namespace passes {
 
 using namespace cvc5::theory;
-ForeignTheoryRewrite::ForeignTheoryRewrite(
-    PreprocessingPassContext* preprocContext)
-    : PreprocessingPass(preprocContext, "foreign-theory-rewrite"),
-      d_cache(userContext()){};
 
-Node ForeignTheoryRewrite::simplify(Node n)
+ForeignTheoryRewriter::ForeignTheoryRewriter(Env& env)
+    : EnvObj(env), d_cache(userContext())
+{
+}
+
+Node ForeignTheoryRewriter::simplify(Node n)
 {
   std::vector<Node> toVisit;
   n = rewrite(n);
@@ -87,7 +89,7 @@ Node ForeignTheoryRewrite::simplify(Node n)
   return d_cache[n];
 }
 
-Node ForeignTheoryRewrite::foreignRewrite(Node n)
+Node ForeignTheoryRewriter::foreignRewrite(Node n)
 {
   // n is a rewritten node, and so GT, LT, LEQ
   // should have been eliminated
@@ -102,18 +104,19 @@ Node ForeignTheoryRewrite::foreignRewrite(Node n)
   return n;
 }
 
-Node ForeignTheoryRewrite::rewriteStringsGeq(Node n)
+Node ForeignTheoryRewriter::rewriteStringsGeq(Node n)
 {
+  theory::strings::ArithEntail ae(d_env.getRewriter());
   // check if the node can be simplified to true
-  if (theory::strings::ArithEntail::check(n[0], n[1], false))
+  if (ae.check(n[0], n[1], false))
   {
     return NodeManager::currentNM()->mkConst(true);
   }
   return n;
 }
 
-Node ForeignTheoryRewrite::reconstructNode(Node originalNode,
-                                           std::vector<Node> newChildren)
+Node ForeignTheoryRewriter::reconstructNode(Node originalNode,
+                                            std::vector<Node> newChildren)
 {
   // Nodes with no children are reconstructed to themselves
   if (originalNode.getNumChildren() == 0)
@@ -137,15 +140,22 @@ Node ForeignTheoryRewrite::reconstructNode(Node originalNode,
   return builder.constructNode();
 }
 
+ForeignTheoryRewrite::ForeignTheoryRewrite(
+    PreprocessingPassContext* preprocContext)
+    : PreprocessingPass(preprocContext, "foreign-theory-rewrite"),
+      d_ftr(preprocContext->getEnv())
+{
+}
+
 PreprocessingPassResult ForeignTheoryRewrite::applyInternal(
     AssertionPipeline* assertionsToPreprocess)
 {
-  for (unsigned i = 0; i < assertionsToPreprocess->size(); ++i)
+  for (size_t i = 0, nasserts = assertionsToPreprocess->size(); i < nasserts;
+       ++i)
   {
     assertionsToPreprocess->replace(
-        i, rewrite(simplify((*assertionsToPreprocess)[i])));
+        i, rewrite(d_ftr.simplify((*assertionsToPreprocess)[i])));
   }
-
   return PreprocessingPassResult::NO_CONFLICT;
 }
 
index 4940f326c37b1651df26562b4dac0859a9089c48..81f5282efbaecd180a5c59c26468793a2a80044e 100644 (file)
@@ -23,6 +23,7 @@
 #include "context/cdhashmap.h"
 #include "expr/node.h"
 #include "preprocessing/preprocessing_pass.h"
+#include "smt/env_obj.h"
 
 namespace cvc5 {
 namespace preprocessing {
@@ -30,14 +31,10 @@ namespace passes {
 
 using CDNodeMap = context::CDHashMap<Node, Node>;
 
-class ForeignTheoryRewrite : public PreprocessingPass
+class ForeignTheoryRewriter : protected EnvObj
 {
  public:
-  ForeignTheoryRewrite(PreprocessingPassContext* preprocContext);
-
- protected:
-  PreprocessingPassResult applyInternal(
-      AssertionPipeline* assertionsToPreprocess) override;
+  ForeignTheoryRewriter(Env& env);
   /** the main function that simplifies n.
    * does a traversal on n and call rewriting fucntions.
    */
@@ -45,14 +42,14 @@ class ForeignTheoryRewrite : public PreprocessingPass
   /** A specific simplification function specific for GEQ
    * constraints in strings.
    */
-  static Node rewriteStringsGeq(Node n);
+  Node rewriteStringsGeq(Node n);
   /** invoke rewrite functions for n.
    * based on the structure of n (typically its kind)
    * we invoke rewrites from other theories.
    * For example: when encountering a `>=` node,
    * we invoke rewrites from the theory of strings.
    */
-  static Node foreignRewrite(Node n);
+  Node foreignRewrite(Node n);
   /** construct a node with the same operator as originalNode whose children are
    * processedChildren
    */
@@ -61,6 +58,18 @@ class ForeignTheoryRewrite : public PreprocessingPass
   CDNodeMap d_cache;
 };
 
+class ForeignTheoryRewrite : public PreprocessingPass
+{
+ public:
+  ForeignTheoryRewrite(PreprocessingPassContext* preprocContext);
+
+ protected:
+  PreprocessingPassResult applyInternal(
+      AssertionPipeline* assertionsToPreprocess) override;
+  /** Foreign theory rewriter */
+  ForeignTheoryRewriter d_ftr;
+};
+
 }  // namespace passes
 }  // namespace preprocessing
 }  // namespace cvc5
index 40e28eb78952c9c4bce520e5c8584b396dba4410..f5883c2655c548c41aa5de15366331b20206ecd5 100644 (file)
@@ -1710,7 +1710,8 @@ Node ExtendedRewriter::extendedRewriteStrings(Node ret) const
 
   if (ret.getKind() == EQUAL)
   {
-    new_ret = strings::SequencesRewriter(nullptr).rewriteEqualityExt(ret);
+    strings::SequencesRewriter sr(&d_rew, nullptr);
+    new_ret = sr.rewriteEqualityExt(ret);
   }
 
   return new_ret;
index 6a0eea41a65a6eb16416a3d8e0e75806861ef5ad..d9cbc4c4048fe67b125cb7af47911a4beb5a5380 100644 (file)
@@ -30,14 +30,16 @@ namespace cvc5 {
 namespace theory {
 namespace strings {
 
+ArithEntail::ArithEntail(Rewriter* r) : d_rr(r) {}
+
 bool ArithEntail::checkEq(Node a, Node b)
 {
   if (a == b)
   {
     return true;
   }
-  Node ar = Rewriter::rewrite(a);
-  Node br = Rewriter::rewrite(b);
+  Node ar = d_rr->rewrite(a);
+  Node br = d_rr->rewrite(b);
   return ar == br;
 }
 
@@ -72,7 +74,7 @@ bool ArithEntail::check(Node a, bool strict)
   Node ar = strict ? NodeManager::currentNM()->mkNode(
                 kind::MINUS, a, NodeManager::currentNM()->mkConst(Rational(1)))
                    : a;
-  ar = Rewriter::rewrite(ar);
+  ar = d_rr->rewrite(ar);
 
   if (ar.getAttribute(StrCheckEntailArithComputedAttr()))
   {
@@ -93,7 +95,7 @@ bool ArithEntail::check(Node a, bool strict)
 
 bool ArithEntail::checkApprox(Node ar)
 {
-  Assert(Rewriter::rewrite(ar) == ar);
+  Assert(d_rr->rewrite(ar) == ar);
   NodeManager* nm = NodeManager::currentNM();
   std::map<Node, Node> msum;
   Trace("strings-ent-approx-debug")
@@ -139,7 +141,7 @@ bool ArithEntail::checkApprox(Node ar)
       {
         Node curr = toProcess.back();
         Trace("strings-ent-approx-debug") << "  process " << curr << std::endl;
-        curr = Rewriter::rewrite(curr);
+        curr = d_rr->rewrite(curr);
         toProcess.pop_back();
         if (visited.find(curr) == visited.end())
         {
@@ -195,7 +197,7 @@ bool ArithEntail::checkApprox(Node ar)
   Node aar = aarSum.empty()
                  ? nm->mkConst(Rational(0))
                  : (aarSum.size() == 1 ? aarSum[0] : nm->mkNode(PLUS, aarSum));
-  aar = Rewriter::rewrite(aar);
+  aar = d_rr->rewrite(aar);
   Trace("strings-ent-approx-debug")
       << "...processed fixed sum " << aar << " with " << mApprox.size()
       << " approximated monomials." << std::endl;
@@ -266,8 +268,7 @@ bool ArithEntail::checkApprox(Node ar)
             Node ci = aam.second;
             if (!cr.isNull())
             {
-              ci = ci.isNull() ? cr
-                               : Rewriter::rewrite(nm->mkNode(MULT, ci, cr));
+              ci = ci.isNull() ? cr : d_rr->rewrite(nm->mkNode(MULT, ci, cr));
             }
             Trace("strings-ent-approx-debug") << ci << "*" << ti << " ";
             int ciSgn = ci.isNull() ? 1 : ci.getConst<Rational>().sgn();
@@ -324,7 +325,7 @@ bool ArithEntail::checkApprox(Node ar)
       Node mn = ArithMSum::mkCoeffTerm(msum[v], vapprox);
       aar = nm->mkNode(PLUS, aar, mn);
       // update the msumAar map
-      aar = Rewriter::rewrite(aar);
+      aar = d_rr->rewrite(aar);
       msumAar.clear();
       if (!ArithMSum::getMonomialSum(aar, msumAar))
       {
@@ -557,7 +558,7 @@ void ArithEntail::getArithApproximations(Node a,
 bool ArithEntail::checkWithEqAssumption(Node assumption, Node a, bool strict)
 {
   Assert(assumption.getKind() == kind::EQUAL);
-  Assert(Rewriter::rewrite(assumption) == assumption);
+  Assert(d_rr->rewrite(assumption) == assumption);
   Trace("strings-entail") << "checkWithEqAssumption: " << assumption << " " << a
                           << ", strict=" << strict << std::endl;
 
@@ -633,7 +634,7 @@ bool ArithEntail::checkWithAssumption(Node assumption,
                                       Node b,
                                       bool strict)
 {
-  Assert(Rewriter::rewrite(assumption) == assumption);
+  Assert(d_rr->rewrite(assumption) == assumption);
 
   NodeManager* nm = NodeManager::currentNM();
 
@@ -659,7 +660,7 @@ bool ArithEntail::checkWithAssumption(Node assumption,
 
     Node s = nm->mkBoundVar("slackVal", nm->stringType());
     Node slen = nm->mkNode(kind::STRING_LENGTH, s);
-    assumption = Rewriter::rewrite(
+    assumption = d_rr->rewrite(
         nm->mkNode(kind::EQUAL, x, nm->mkNode(kind::PLUS, y, slen)));
   }
 
@@ -695,7 +696,7 @@ bool ArithEntail::checkWithAssumptions(std::vector<Node> assumptions,
   bool res = false;
   for (const auto& assumption : assumptions)
   {
-    Assert(Rewriter::rewrite(assumption) == assumption);
+    Assert(d_rr->rewrite(assumption) == assumption);
 
     if (checkWithAssumption(assumption, a, b, strict))
     {
@@ -708,7 +709,7 @@ bool ArithEntail::checkWithAssumptions(std::vector<Node> assumptions,
 
 Node ArithEntail::getConstantBound(Node a, bool isLower)
 {
-  Assert(Rewriter::rewrite(a) == a);
+  Assert(d_rr->rewrite(a) == a);
   Node ret;
   if (a.isConst())
   {
@@ -773,7 +774,7 @@ Node ArithEntail::getConstantBound(Node a, bool isLower)
       else
       {
         ret = NodeManager::currentNM()->mkNode(a.getKind(), children);
-        ret = Rewriter::rewrite(ret);
+        ret = d_rr->rewrite(ret);
       }
     }
   }
@@ -791,7 +792,7 @@ Node ArithEntail::getConstantBound(Node a, bool isLower)
 
 bool ArithEntail::checkInternal(Node a)
 {
-  Assert(Rewriter::rewrite(a) == a);
+  Assert(d_rr->rewrite(a) == a);
   // check whether a >= 0
   if (a.isConst())
   {
index 64e76e5b66f6a1aaa6902173d229c11539f62c36..e2b3d0af61d0da2a5788896aea9b5af54fbf2342 100644 (file)
@@ -24,6 +24,9 @@
 
 namespace cvc5 {
 namespace theory {
+
+class Rewriter;
+
 namespace strings {
 
 /**
@@ -34,19 +37,20 @@ namespace strings {
 class ArithEntail
 {
  public:
+  ArithEntail(Rewriter* r);
   /** check arithmetic entailment equal
    * Returns true if it is always the case that a = b.
    */
-  static bool checkEq(Node a, Node b);
+  bool checkEq(Node a, Node b);
   /** check arithmetic entailment
    * Returns true if it is always the case that a >= b,
    * and a>b if strict is true.
    */
-  static bool check(Node a, Node b, bool strict = false);
+  bool check(Node a, Node b, bool strict = false);
   /** check arithmetic entailment
    * Returns true if it is always the case that a >= 0.
    */
-  static bool check(Node a, bool strict = false);
+  bool check(Node a, bool strict = false);
   /** check arithmetic entailment with approximations
    *
    * Returns true if it is always the case that a >= 0. We expect that a is in
@@ -61,7 +65,7 @@ class ArithEntail
    * and thus the entailment len( x ) - len( substr( y, 0, len( x ) ) ) >= 0
    * holds.
    */
-  static bool checkApprox(Node a);
+  bool checkApprox(Node a);
 
   /**
    * Checks whether assumption |= a >= 0 (if strict is false) or
@@ -74,9 +78,7 @@ class ArithEntail
    *
    * Because: x = -(str.len y), so -x >= 0 --> (str.len y) >= 0 --> true
    */
-  static bool checkWithEqAssumption(Node assumption,
-                                    Node a,
-                                    bool strict = false);
+  bool checkWithEqAssumption(Node assumption, Node a, bool strict = false);
 
   /**
    * Checks whether assumption |= a >= b (if strict is false) or
@@ -90,10 +92,10 @@ class ArithEntail
    *
    * Because: x = -(str.len y), so 0 >= x --> 0 >= -(str.len y) --> true
    */
-  static bool checkWithAssumption(Node assumption,
-                                  Node a,
-                                  Node b,
-                                  bool strict = false);
+  bool checkWithAssumption(Node assumption,
+                           Node a,
+                           Node b,
+                           bool strict = false);
 
   /**
    * Checks whether assumptions |= a >= b (if strict is false) or
@@ -108,10 +110,10 @@ class ArithEntail
    *
    * Because: x = -(str.len y), so 0 >= x --> 0 >= -(str.len y) --> true
    */
-  static bool checkWithAssumptions(std::vector<Node> assumptions,
-                                   Node a,
-                                   Node b,
-                                   bool strict = false);
+  bool checkWithAssumptions(std::vector<Node> assumptions,
+                            Node a,
+                            Node b,
+                            bool strict = false);
 
   /** get arithmetic lower bound
    * If this function returns a non-null Node ret,
@@ -126,7 +128,7 @@ class ArithEntail
    *     if and only if
    *   check( a, strict ) = true.
    */
-  static Node getConstantBound(Node a, bool isLower = true);
+  Node getConstantBound(Node a, bool isLower = true);
 
   /**
    * Given an inequality y1 + ... + yn >= x, removes operands yi s.t. the
@@ -144,16 +146,16 @@ class ArithEntail
    * --> returns false because it is not possible to show
    *     str.len(y) >= str.len(x)
    */
-  static bool inferZerosInSumGeq(Node x,
-                                 std::vector<Node>& ys,
-                                 std::vector<Node>& zeroYs);
+  bool inferZerosInSumGeq(Node x,
+                          std::vector<Node>& ys,
+                          std::vector<Node>& zeroYs);
 
  private:
   /** check entail arithmetic internal
    * Returns true if we can show a >= 0 always.
    * a is in rewritten form.
    */
-  static bool checkInternal(Node a);
+  bool checkInternal(Node a);
   /** Get arithmetic approximations
    *
    * This gets the (set of) arithmetic approximations for term a and stores
@@ -169,9 +171,11 @@ class ArithEntail
    * function might be len( substr( x, 0, n ) ) - len( y ), where we don't
    * consider (recursively) the approximations for len( substr( x, 0, n ) ).
    */
-  static void getArithApproximations(Node a,
-                                     std::vector<Node>& approx,
-                                     bool isOverApprox = false);
+  void getArithApproximations(Node a,
+                              std::vector<Node>& approx,
+                              bool isOverApprox = false);
+  /** The underlying rewriter */
+  Rewriter* d_rr;
 };
 
 }  // namespace strings
index 7885c857ea28ccca0b64f0369cd5239f163b3fce..bd8a4d8df42fbd184cc5608cb66fd79dde2df9e6 100644 (file)
@@ -37,11 +37,18 @@ namespace cvc5 {
 namespace theory {
 namespace strings {
 
-SequencesRewriter::SequencesRewriter(HistogramStat<Rewrite>* statistics)
-    : d_statistics(statistics), d_stringsEntail(*this)
+SequencesRewriter::SequencesRewriter(Rewriter* r,
+                                     HistogramStat<Rewrite>* statistics)
+    : d_statistics(statistics),
+      d_arithEntail(r),
+      d_stringsEntail(r, d_arithEntail, *this)
 {
 }
 
+ArithEntail& SequencesRewriter::getArithEntail() { return d_arithEntail; }
+
+StringsEntail& SequencesRewriter::getStringsEntail() { return d_stringsEntail; }
+
 Node SequencesRewriter::rewriteEquality(Node node)
 {
   Assert(node.getKind() == kind::EQUAL);
@@ -216,7 +223,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
   // ------- homogeneous constants
   for (unsigned i = 0; i < 2; i++)
   {
-    Node cn = StringsEntail::checkHomogeneousString(node[i]);
+    Node cn = d_stringsEntail.checkHomogeneousString(node[i]);
     if (!cn.isNull() && !Word::isEmpty(cn))
     {
       Assert(cn.isConst());
@@ -311,7 +318,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
         }
 
         // (= "" (str.replace x y "A")) ---> (and (= x "") (not (= y "")))
-        if (StringsEntail::checkNonEmpty(ne[2]))
+        if (d_stringsEntail.checkNonEmpty(ne[2]))
         {
           Node ret =
               nm->mkNode(AND,
@@ -321,7 +328,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
         }
 
         // (= "" (str.replace x "A" "")) ---> (str.prefix x "A")
-        if (StringsEntail::checkLengthOne(ne[1], true) && ne[2] == empty)
+        if (d_stringsEntail.checkLengthOne(ne[1], true) && ne[2] == empty)
         {
           Node ret = nm->mkNode(STRING_PREFIX, ne[0], ne[1]);
           return returnRewrite(node, ret, Rewrite::STR_EMP_REPL_EMP);
@@ -331,7 +338,8 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
       {
         Node zero = nm->mkConst(Rational(0));
 
-        if (ArithEntail::check(ne[1], false) && ArithEntail::check(ne[2], true))
+        if (d_arithEntail.check(ne[1], false)
+            && d_arithEntail.check(ne[2], true))
         {
           // (= "" (str.substr x 0 m)) ---> (= "" x) if m > 0
           if (ne[1] == zero)
@@ -347,7 +355,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
         }
 
         // (= "" (str.substr "A" 0 z)) ---> (<= z 0)
-        if (StringsEntail::checkNonEmpty(ne[0]) && ne[1] == zero)
+        if (d_stringsEntail.checkNonEmpty(ne[0]) && ne[1] == zero)
         {
           Node ret = nm->mkNode(LEQ, ne[2], zero);
           return returnRewrite(node, ret, Rewrite::STR_EMP_SUBSTR_LEQ_Z);
@@ -365,7 +373,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
       Node x = node[1 - i];
 
       // (= "A" (str.replace "" x y)) ---> (= "" (str.replace "A" y x))
-      if (StringsEntail::checkNonEmpty(x) && repl[0] == empty)
+      if (d_stringsEntail.checkNonEmpty(x) && repl[0] == empty)
       {
         Node ret = nm->mkNode(
             EQUAL, empty, nm->mkNode(STRING_REPLACE, x, repl[2], repl[1]));
@@ -396,7 +404,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
       {
         Node lenY = nm->mkNode(STRING_LENGTH, repl[1]);
         Node lenZ = nm->mkNode(STRING_LENGTH, repl[2]);
-        if (ArithEntail::checkEq(lenY, lenZ))
+        if (d_arithEntail.checkEq(lenY, lenZ))
         {
           Node ret = nm->mkNode(OR,
                                 nm->mkNode(EQUAL, repl[0], repl[1]),
@@ -419,7 +427,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
   {
     if (node[1 - i].getKind() == STRING_CONCAT)
     {
-      new_ret = StringsEntail::inferEqsFromContains(node[i], node[1 - i]);
+      new_ret = d_stringsEntail.inferEqsFromContains(node[i], node[1 - i]);
       if (!new_ret.isNull())
       {
         return returnRewrite(node, new_ret, Rewrite::STR_EQ_CONJ_LEN_ENTAIL);
@@ -456,7 +464,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
           Node lenPfx0 = nm->mkNode(STRING_LENGTH, pfx0);
           Node lenPfx1 = nm->mkNode(STRING_LENGTH, pfx1);
 
-          if (ArithEntail::checkEq(lenPfx0, lenPfx1))
+          if (d_arithEntail.checkEq(lenPfx0, lenPfx1))
           {
             std::vector<Node> sfxv0(v0.begin() + i, v0.end());
             std::vector<Node> sfxv1(v1.begin() + j, v1.end());
@@ -466,7 +474,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
                                       .eqNode(utils::mkConcat(sfxv1, stype)));
             return returnRewrite(node, ret, Rewrite::SPLIT_EQ);
           }
-          else if (ArithEntail::check(lenPfx1, lenPfx0, true))
+          else if (d_arithEntail.check(lenPfx1, lenPfx0, true))
           {
             // The prefix on the right-hand side is strictly longer than the
             // prefix on the left-hand side, so we try to strip the right-hand
@@ -476,7 +484,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
             // (= (str.++ "A" x y) (str.++ x "AB" z)) --->
             //   (and (= (str.++ "A" x) (str.++ x "A")) (= y (str.++ "B" z)))
             std::vector<Node> rpfxv1;
-            if (StringsEntail::stripSymbolicLength(
+            if (d_stringsEntail.stripSymbolicLength(
                     pfxv1, rpfxv1, 1, lenPfx0, true))
             {
               // The rewrite requires the full left-hand prefix length to be
@@ -501,7 +509,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
             // in the inner loop)
             break;
           }
-          else if (ArithEntail::check(lenPfx0, lenPfx1, true))
+          else if (d_arithEntail.check(lenPfx0, lenPfx1, true))
           {
             // The prefix on the left-hand side is strictly longer than the
             // prefix on the right-hand side, so we try to strip the left-hand
@@ -512,7 +520,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
             //   (and (= (str.++ x "A") (str.++ "A" x)) (= (str.++ "B" z) y))
             std::vector<Node> sfxv0 = pfxv0;
             std::vector<Node> rpfxv0;
-            if (StringsEntail::stripSymbolicLength(
+            if (d_stringsEntail.stripSymbolicLength(
                     sfxv0, rpfxv0, 1, lenPfx1, true))
             {
               // The rewrite requires the full right-hand prefix length to be
@@ -698,7 +706,7 @@ Node SequencesRewriter::rewriteConcat(Node node)
   Node lastX;
   for (size_t i = 0, nsize = node_vec.size(); i < nsize; i++)
   {
-    Node s = StringsEntail::getStringOrEmpty(node_vec[i]);
+    Node s = d_stringsEntail.getStringOrEmpty(node_vec[i]);
     bool nextX = false;
     if (s != lastX)
     {
@@ -1711,12 +1719,12 @@ Node SequencesRewriter::rewriteSubstr(Node node)
   Node zero = nm->mkConst(cvc5::Rational(0));
 
   // if entailed non-positive length or negative start point
-  if (ArithEntail::check(zero, node[1], true))
+  if (d_arithEntail.check(zero, node[1], true))
   {
     Node ret = Word::mkEmptyWord(node.getType());
     return returnRewrite(node, ret, Rewrite::SS_START_NEG);
   }
-  else if (ArithEntail::check(zero, node[2]))
+  else if (d_arithEntail.check(zero, node[2]))
   {
     Node ret = Word::mkEmptyWord(node.getType());
     return returnRewrite(node, ret, Rewrite::SS_LEN_NON_POS);
@@ -1742,7 +1750,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
     // over-approximation of the length of (str.substr x a a), which
     // then allows us to reason that the result of the whole term must
     // be empty.
-    if (ArithEntail::check(node[1], node[0][2]))
+    if (d_arithEntail.check(node[1], node[0][2]))
     {
       Node ret = Word::mkEmptyWord(node.getType());
       return returnRewrite(node, ret, Rewrite::SS_START_GEQ_LEN);
@@ -1755,8 +1763,8 @@ Node SequencesRewriter::rewriteSubstr(Node node)
     // if (str.len y) = 1 and (str.len z) = 1
     if (node[1] == zero)
     {
-      if (StringsEntail::checkLengthOne(node[0][1], true)
-          && StringsEntail::checkLengthOne(node[0][2], true))
+      if (d_stringsEntail.checkLengthOne(node[0][1], true)
+          && d_stringsEntail.checkLengthOne(node[0][2], true))
       {
         Node ret = nm->mkNode(
             kind::STRING_REPLACE,
@@ -1777,7 +1785,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
   {
     Node curr = node[2];
     std::vector<Node> childrenr;
-    if (StringsEntail::stripSymbolicLength(n1, childrenr, 1, curr))
+    if (d_stringsEntail.stripSymbolicLength(n1, childrenr, 1, curr))
     {
       if (curr != zero && !n1.empty())
       {
@@ -1809,7 +1817,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       Node end_pt = Rewriter::rewrite(nm->mkNode(kind::PLUS, node[1], node[2]));
       if (node[2] != tot_len)
       {
-        if (ArithEntail::check(node[2], tot_len))
+        if (d_arithEntail.check(node[2], tot_len))
         {
           // end point beyond end point of string, map to tot_len
           Node ret = nm->mkNode(kind::STRING_SUBSTR, node[0], node[1], tot_len);
@@ -1825,7 +1833,8 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       // (str.substr s x y) --> "" if x < len(s) |= 0 >= y
       Node n1_lt_tot_len =
           Rewriter::rewrite(nm->mkNode(kind::LT, node[1], tot_len));
-      if (ArithEntail::checkWithAssumption(n1_lt_tot_len, zero, node[2], false))
+      if (d_arithEntail.checkWithAssumption(
+              n1_lt_tot_len, zero, node[2], false))
       {
         Node ret = Word::mkEmptyWord(node.getType());
         return returnRewrite(node, ret, Rewrite::SS_START_ENTAILS_ZERO_LEN);
@@ -1834,7 +1843,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       // (str.substr s x y) --> "" if 0 < y |= x >= str.len(s)
       Node non_zero_len =
           Rewriter::rewrite(nm->mkNode(kind::LT, zero, node[2]));
-      if (ArithEntail::checkWithAssumption(
+      if (d_arithEntail.checkWithAssumption(
               non_zero_len, node[1], tot_len, false))
       {
         Node ret = Word::mkEmptyWord(node.getType());
@@ -1844,7 +1853,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       // (str.substr s x y) --> "" if x >= 0 |= 0 >= str.len(s)
       Node geq_zero_start =
           Rewriter::rewrite(nm->mkNode(kind::GEQ, node[1], zero));
-      if (ArithEntail::checkWithAssumption(
+      if (d_arithEntail.checkWithAssumption(
               geq_zero_start, zero, tot_len, false))
       {
         Node ret = Word::mkEmptyWord(node.getType());
@@ -1853,7 +1862,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       }
 
       // (str.substr s x x) ---> "" if (str.len s) <= 1
-      if (node[1] == node[2] && StringsEntail::checkLengthOne(node[0]))
+      if (node[1] == node[2] && d_stringsEntail.checkLengthOne(node[0]))
       {
         Node ret = Word::mkEmptyWord(node.getType());
         return returnRewrite(node, ret, Rewrite::SS_LEN_ONE_Z_Z);
@@ -1864,7 +1873,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       // strip off components while quantity is entailed positive
       int dir = r == 0 ? 1 : -1;
       std::vector<Node> childrenr;
-      if (StringsEntail::stripSymbolicLength(n1, childrenr, dir, curr))
+      if (d_stringsEntail.stripSymbolicLength(n1, childrenr, dir, curr))
       {
         if (r == 0)
         {
@@ -1888,7 +1897,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
   {
     Node start_inner = node[0][1];
     Node start_outer = node[1];
-    if (ArithEntail::check(start_outer) && ArithEntail::check(start_inner))
+    if (d_arithEntail.check(start_outer) && d_arithEntail.check(start_inner))
     {
       // both are positive
       // thus, start point is definitely start_inner+start_outer.
@@ -1905,11 +1914,11 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       {
         new_len = len_from_inner;
       }
-      else if (ArithEntail::check(len_from_inner, len_from_outer))
+      else if (d_arithEntail.check(len_from_inner, len_from_outer))
       {
         new_len = len_from_outer;
       }
-      else if (ArithEntail::check(len_from_outer, len_from_inner))
+      else if (d_arithEntail.check(len_from_outer, len_from_inner))
       {
         new_len = len_from_inner;
       }
@@ -1993,7 +2002,7 @@ Node SequencesRewriter::rewriteContains(Node node)
       {
         Node len1 =
             NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[1]);
-        if (ArithEntail::check(len1, true))
+        if (d_arithEntail.check(len1, true))
         {
           // we handle the false case here since the rewrite for equality
           // uses this function, hence we want to conclude false if possible.
@@ -2002,7 +2011,7 @@ Node SequencesRewriter::rewriteContains(Node node)
           return returnRewrite(node, ret, Rewrite::CTN_LHS_EMPTYSTR);
         }
       }
-      else if (StringsEntail::checkLengthOne(t))
+      else if (d_stringsEntail.checkLengthOne(t))
       {
         std::vector<Node> vec = Word::getChars(node[0]);
         Node emp = Word::mkEmptyWord(t.getType());
@@ -2023,7 +2032,7 @@ Node SequencesRewriter::rewriteContains(Node node)
       else if (node[1].getKind() == kind::STRING_CONCAT)
       {
         int firstc, lastc;
-        if (!StringsEntail::canConstantContainConcat(
+        if (!d_stringsEntail.canConstantContainConcat(
                 node[0], node[1], firstc, lastc))
         {
           Node ret = NodeManager::currentNM()->mkConst(false);
@@ -2098,7 +2107,7 @@ Node SequencesRewriter::rewriteContains(Node node)
   // strip endpoints
   std::vector<Node> nb;
   std::vector<Node> ne;
-  if (StringsEntail::stripConstantEndpoints(nc1, nc2, nb, ne))
+  if (d_stringsEntail.stripConstantEndpoints(nc1, nc2, nb, ne))
   {
     Node ret = NodeManager::currentNM()->mkNode(
         kind::STRING_CONTAINS, utils::mkConcat(nc1, stype), node[1]);
@@ -2163,7 +2172,7 @@ Node SequencesRewriter::rewriteContains(Node node)
   // length entailment
   Node len_n1 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[0]);
   Node len_n2 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[1]);
-  if (ArithEntail::check(len_n2, len_n1, true))
+  if (d_arithEntail.check(len_n2, len_n1, true))
   {
     // len( n2 ) > len( n1 ) => contains( n1, n2 ) ---> false
     Node ret = NodeManager::currentNM()->mkConst(false);
@@ -2174,13 +2183,13 @@ Node SequencesRewriter::rewriteContains(Node node)
   //   For example, contains( str.++( x, "b" ), str.++( "a", x ) ) ---> false
   //   since the number of a's in the second argument is greater than the number
   //   of a's in the first argument
-  if (StringsEntail::checkMultisetSubset(node[0], node[1]))
+  if (d_stringsEntail.checkMultisetSubset(node[0], node[1]))
   {
     Node ret = nm->mkConst(false);
     return returnRewrite(node, ret, Rewrite::CTN_MSET_NSS);
   }
 
-  if (ArithEntail::check(len_n2, len_n1, false))
+  if (d_arithEntail.check(len_n2, len_n1, false))
   {
     // len( n2 ) >= len( n1 ) => contains( n1, n2 ) ---> n1 = n2
     Node ret = node[0].eqNode(node[1]);
@@ -2264,7 +2273,7 @@ Node SequencesRewriter::rewriteContains(Node node)
 
       // (str.contains (str.replace x y x) z) ---> (str.contains x z)
       // if (str.len z) <= 1
-      if (StringsEntail::checkLengthOne(node[1]))
+      if (d_stringsEntail.checkLengthOne(node[1]))
       {
         Node ret = nm->mkNode(kind::STRING_CONTAINS, node[0][0], node[1]);
         return returnRewrite(node, ret, Rewrite::CTN_REPL_LEN_ONE_TO_CTN);
@@ -2285,7 +2294,7 @@ Node SequencesRewriter::rewriteContains(Node node)
     // (str.contains (str.replace x y z) w) --->
     //   (str.contains (str.replace x y "") w)
     // if (str.contains z w) ---> false and (str.len w) = 1
-    if (StringsEntail::checkLengthOne(node[1]))
+    if (d_stringsEntail.checkLengthOne(node[1]))
     {
       Node ctn = d_stringsEntail.checkContains(node[0][2], node[1]);
       if (!ctn.isNull() && !ctn.getConst<bool>())
@@ -2385,7 +2394,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
         return returnRewrite(node, zero, Rewrite::IDOF_EQ_CST_START);
       }
     }
-    if (ArithEntail::check(node[2], true))
+    if (d_arithEntail.check(node[2], true))
     {
       // y>0  implies  indexof( x, x, y ) --> -1
       Node negone = nm->mkConst(Rational(-1));
@@ -2408,7 +2417,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
   {
     if (Word::isEmpty(node[1]))
     {
-      if (ArithEntail::check(len0, node[2]) && ArithEntail::check(node[2]))
+      if (d_arithEntail.check(len0, node[2]) && d_arithEntail.check(node[2]))
       {
         // len(x)>=z ^ z >=0 implies indexof( x, "", z ) ---> z
         return returnRewrite(node, node[2], Rewrite::IDOF_EMP_IDOF);
@@ -2416,7 +2425,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
     }
   }
 
-  if (ArithEntail::check(len1, len0m2, true))
+  if (d_arithEntail.check(len1, len0m2, true))
   {
     // len(x)-z < len(y)  implies  indexof( x, y, z ) ----> -1
     Node negone = nm->mkConst(Rational(-1));
@@ -2457,7 +2466,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
         }
 
         // Strip components from the beginning that are guaranteed not to match
-        if (StringsEntail::stripConstantEndpoints(
+        if (d_stringsEntail.stripConstantEndpoints(
                 children0, children1, nb, ne, 1))
         {
           // str.indexof(str.++("AB", x, "C"), "C", 0) --->
@@ -2477,12 +2486,12 @@ Node SequencesRewriter::rewriteIndexof(Node node)
       // (str.indexof t "" n) is not rewritten to something other than -1 when n
       // is beyond the length of t. This is not required for the above rewrites,
       // which only apply when n=0.
-      if (ArithEntail::check(node[2]) && ArithEntail::check(len0, node[2]))
+      if (d_arithEntail.check(node[2]) && d_arithEntail.check(len0, node[2]))
       {
         // strip symbolic length
         Node new_len = node[2];
         std::vector<Node> nr;
-        if (StringsEntail::stripSymbolicLength(children0, nr, 1, new_len))
+        if (d_stringsEntail.stripSymbolicLength(children0, nr, 1, new_len))
         {
           // For example:
           // z>=0 and z>str.len( x1 ) and str.contains( x2, y )-->true
@@ -2509,7 +2518,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
   {
     Node new_len = node[2];
     std::vector<Node> nr;
-    if (StringsEntail::stripSymbolicLength(children0, nr, 1, new_len))
+    if (d_stringsEntail.stripSymbolicLength(children0, nr, 1, new_len))
     {
       // Normalize the string before the start index.
       //
@@ -2535,7 +2544,8 @@ Node SequencesRewriter::rewriteIndexof(Node node)
   {
     std::vector<Node> cb;
     std::vector<Node> ce;
-    if (StringsEntail::stripConstantEndpoints(children0, children1, cb, ce, -1))
+    if (d_stringsEntail.stripConstantEndpoints(
+            children0, children1, cb, ce, -1))
     {
       Node ret = utils::mkConcat(children0, stype);
       ret = nm->mkNode(STRING_INDEXOF, ret, node[1], node[2]);
@@ -2558,7 +2568,7 @@ Node SequencesRewriter::rewriteIndexofRe(Node node)
   Node zero = nm->mkConst(Rational(0));
   Node slen = nm->mkNode(STRING_LENGTH, s);
 
-  if (ArithEntail::check(zero, n, true) || ArithEntail::check(n, slen, true))
+  if (d_arithEntail.check(zero, n, true) || d_arithEntail.check(n, slen, true))
   {
     Node ret = nm->mkConst(Rational(-1));
     return returnRewrite(node, ret, Rewrite::INDEXOF_RE_INVALID_INDEX);
@@ -2589,7 +2599,7 @@ Node SequencesRewriter::rewriteIndexofRe(Node node)
       return returnRewrite(node, ret, Rewrite::INDEXOF_RE_EVAL);
     }
 
-    if (ArithEntail::check(n, zero) && ArithEntail::check(slen, n))
+    if (d_arithEntail.check(n, zero) && d_arithEntail.check(slen, n))
     {
       String emptyStr("");
       if (RegExpEntail::testConstStringInRegExp(emptyStr, 0, r))
@@ -2662,14 +2672,14 @@ Node SequencesRewriter::rewriteReplace(Node node)
     // ( len( y )>=len(x) ) => str.replace( x, y, x ) ---> x
     Node l0 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[0]);
     Node l1 = NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, node[1]);
-    if (ArithEntail::check(l1, l0))
+    if (d_arithEntail.check(l1, l0))
     {
       return returnRewrite(node, node[0], Rewrite::RPL_RPL_LEN_ID);
     }
 
     // (str.replace x y x) ---> (str.replace x (str.++ y1 ... yn) x)
     // if 1 >= (str.len x) and (= y "") ---> (= y1 "") ... (= yn "")
-    if (StringsEntail::checkLengthOne(node[0]))
+    if (d_stringsEntail.checkLengthOne(node[0]))
     {
       Node empty = Word::mkEmptyWord(stype);
       Node rn1 = Rewriter::rewrite(
@@ -2802,7 +2812,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
 
   if (cmp_conr != cmp_con)
   {
-    if (StringsEntail::checkNonEmpty(node[1]))
+    if (d_stringsEntail.checkNonEmpty(node[1]))
     {
       // pull endpoints that can be stripped
       // for example,
@@ -2810,7 +2820,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
       //   str.++( "b", str.replace( x, "a", y ), "b" )
       std::vector<Node> cb;
       std::vector<Node> ce;
-      if (StringsEntail::stripConstantEndpoints(children0, children1, cb, ce))
+      if (d_stringsEntail.stripConstantEndpoints(children0, children1, cb, ce))
       {
         std::vector<Node> cc;
         cc.insert(cc.end(), cb.begin(), cb.end());
@@ -2851,7 +2861,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
     Node len0 = nm->mkNode(kind::STRING_LENGTH, node[0]);
     Node len0_1 = nm->mkNode(kind::PLUS, len0, one);
     // Check len(t) + j > len(x) + 1
-    if (ArithEntail::check(maxLen1, len0_1, true))
+    if (d_arithEntail.check(maxLen1, len0_1, true))
     {
       children1.push_back(nm->mkNode(
           kind::STRING_SUBSTR,
@@ -2901,7 +2911,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
       // (str.len w) >= (str.len z)
       Node wlen = nm->mkNode(kind::STRING_LENGTH, w);
       Node zlen = nm->mkNode(kind::STRING_LENGTH, z);
-      if (ArithEntail::check(wlen, zlen))
+      if (d_arithEntail.check(wlen, zlen))
       {
         // w != z
         Node wEqZ = Rewriter::rewrite(nm->mkNode(kind::EQUAL, w, z));
@@ -3037,7 +3047,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
   //   str.replace( x ++ y ++ x ++ y, "A", z ) -->
   //   str.replace( x ++ y, "A", z ) ++ x ++ y
   // since if "A" occurs in x ++ y ++ x ++ y, then it must occur in x ++ y.
-  if (StringsEntail::checkLengthOne(node[1]))
+  if (d_stringsEntail.checkLengthOne(node[1]))
   {
     Node lastLhs;
     unsigned lastCheckIndex = 0;
@@ -3153,7 +3163,7 @@ Node SequencesRewriter::rewriteReplaceInternal(Node node)
   if (node[0] == node[1])
   {
     // only holds for replaceall if non-empty
-    if (nk == STRING_REPLACE || StringsEntail::checkNonEmpty(node[1]))
+    if (nk == STRING_REPLACE || d_stringsEntail.checkNonEmpty(node[1]))
     {
       return returnRewrite(node, node[2], Rewrite::RPL_REPLACE);
     }
@@ -3379,7 +3389,7 @@ Node SequencesRewriter::rewritePrefixSuffix(Node n)
 
   // Check if we can turn the prefix/suffix into equalities by showing that the
   // prefix/suffix is at least as long as the string
-  Node eqs = StringsEntail::inferEqsFromContains(n[1], n[0]);
+  Node eqs = d_stringsEntail.inferEqsFromContains(n[1], n[0]);
   if (!eqs.isNull())
   {
     return returnRewrite(n, eqs, Rewrite::SUF_PREFIX_TO_EQS);
index 0068c72c1bbfefd4e7f2976286b4d6c5d17342a7..854e3fb81bed85647473d4ba7de32411ac71e498 100644 (file)
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "expr/node.h"
+#include "theory/strings/arith_entail.h"
 #include "theory/strings/rewrites.h"
 #include "theory/strings/sequences_stats.h"
 #include "theory/strings/strings_entail.h"
@@ -33,7 +34,10 @@ namespace strings {
 class SequencesRewriter : public TheoryRewriter
 {
  public:
-  SequencesRewriter(HistogramStat<Rewrite>* statistics);
+  SequencesRewriter(Rewriter* r, HistogramStat<Rewrite>* statistics);
+  /** The underlying entailment utilities */
+  ArithEntail& getArithEntail();
+  StringsEntail& getStringsEntail();
 
  protected:
   /** rewrite regular expression concatenation
@@ -301,7 +305,8 @@ class SequencesRewriter : public TheoryRewriter
   Node postProcessRewrite(Node node, Node ret);
   /** Reference to the rewriter statistics. */
   HistogramStat<Rewrite>* d_statistics;
-
+  /** The arithmetic entailment module */
+  ArithEntail d_arithEntail;
   /** Instance of the entailment checker for strings. */
   StringsEntail d_stringsEntail;
 }; /* class SequencesRewriter */
index 3c7800f8fe99d36d9e7a88537cc8c9afef1e1aa5..3b90338fc82ec64ab275948c3260997fbba8a60b 100644 (file)
@@ -30,7 +30,10 @@ namespace cvc5 {
 namespace theory {
 namespace strings {
 
-StringsEntail::StringsEntail(SequencesRewriter& rewriter) : d_rewriter(rewriter)
+StringsEntail::StringsEntail(Rewriter* r,
+                             ArithEntail& aent,
+                             SequencesRewriter& rewriter)
+    : d_rr(r), d_arithEntail(aent), d_rewriter(rewriter)
 {
 }
 
@@ -61,7 +64,7 @@ bool StringsEntail::canConstantContainConcat(Node c,
         pos = new_pos + Word::getLength(n[i]);
       }
     }
-    else if (n[i].getKind() == STRING_ITOS && ArithEntail::check(n[i][0]))
+    else if (n[i].getKind() == STRING_ITOS && d_arithEntail.check(n[i][0]))
     {
       Assert(c.getType().isString());  // string-only
       const std::vector<unsigned>& tvec = c.getConst<String>().getVec();
@@ -132,24 +135,24 @@ bool StringsEntail::stripSymbolicLength(std::vector<Node>& n1,
     if (n1[sindex_use].isConst())
     {
       // could strip part of a constant
-      Node lowerBound = ArithEntail::getConstantBound(Rewriter::rewrite(curr));
+      Node lowerBound = d_arithEntail.getConstantBound(d_rr->rewrite(curr));
       if (!lowerBound.isNull())
       {
         Assert(lowerBound.isConst());
         Rational lbr = lowerBound.getConst<Rational>();
         if (lbr.sgn() > 0)
         {
-          Assert(ArithEntail::check(curr, true));
+          Assert(d_arithEntail.check(curr, true));
           Node s = n1[sindex_use];
           size_t slen = Word::getLength(s);
           Node ncl = nm->mkConst(cvc5::Rational(slen));
           Node next_s = nm->mkNode(MINUS, lowerBound, ncl);
-          next_s = Rewriter::rewrite(next_s);
+          next_s = d_rr->rewrite(next_s);
           Assert(next_s.isConst());
           // we can remove the entire constant
           if (next_s.getConst<Rational>().sgn() >= 0)
           {
-            curr = Rewriter::rewrite(nm->mkNode(MINUS, curr, ncl));
+            curr = d_rr->rewrite(nm->mkNode(MINUS, curr, ncl));
             success = true;
             sindex++;
           }
@@ -159,7 +162,7 @@ bool StringsEntail::stripSymbolicLength(std::vector<Node>& n1,
             // lower bound minus the length of a concrete string is negative,
             // hence lowerBound cannot be larger than long max
             Assert(lbr < Rational(String::maxSize()));
-            curr = Rewriter::rewrite(nm->mkNode(MINUS, curr, lowerBound));
+            curr = d_rr->rewrite(nm->mkNode(MINUS, curr, lowerBound));
             uint32_t lbsize = lbr.getNumerator().toUnsignedInt();
             Assert(lbsize < slen);
             if (dir == 1)
@@ -176,7 +179,7 @@ bool StringsEntail::stripSymbolicLength(std::vector<Node>& n1,
             }
             ret = true;
           }
-          Assert(ArithEntail::check(curr));
+          Assert(d_arithEntail.check(curr));
         }
         else
         {
@@ -190,8 +193,8 @@ bool StringsEntail::stripSymbolicLength(std::vector<Node>& n1,
           MINUS,
           curr,
           NodeManager::currentNM()->mkNode(STRING_LENGTH, n1[sindex_use]));
-      next_s = Rewriter::rewrite(next_s);
-      if (ArithEntail::check(next_s))
+      next_s = d_rr->rewrite(next_s);
+      if (d_arithEntail.check(next_s))
       {
         success = true;
         curr = next_s;
@@ -251,7 +254,7 @@ int StringsEntail::componentContains(std::vector<Node>& n1,
           }
           else if (!n1re.isNull())
           {
-            n1[i] = Rewriter::rewrite(
+            n1[i] = d_rr->rewrite(
                 NodeManager::currentNM()->mkNode(STRING_CONCAT, n1[i], n1re));
           }
           if (remainderDir != 1)
@@ -265,7 +268,7 @@ int StringsEntail::componentContains(std::vector<Node>& n1,
           }
           else if (!n1rb.isNull())
           {
-            n1[i] = Rewriter::rewrite(
+            n1[i] = d_rr->rewrite(
                 NodeManager::currentNM()->mkNode(STRING_CONCAT, n1rb, n1[i]));
           }
         }
@@ -432,7 +435,7 @@ bool StringsEntail::componentContainsBase(
           {
             // To be a suffix, start + length must be greater than
             // or equal to the length of the string.
-            success = ArithEntail::check(end_pos, len_n2s);
+            success = d_arithEntail.check(end_pos, len_n2s);
           }
           else if (dir == -1)
           {
@@ -449,8 +452,8 @@ bool StringsEntail::componentContainsBase(
             {
               // we can only compute the remainder if start_pos and end_pos
               // are known to be non-negative.
-              if (!ArithEntail::check(start_pos)
-                  || !ArithEntail::check(end_pos))
+              if (!d_arithEntail.check(start_pos)
+                  || !d_arithEntail.check(end_pos))
               {
                 return false;
               }
@@ -679,7 +682,7 @@ Node StringsEntail::checkContains(Node a, Node b, bool fullRewriter)
 
   if (fullRewriter)
   {
-    ctn = Rewriter::rewrite(ctn);
+    ctn = d_rr->rewrite(ctn);
   }
   else
   {
@@ -702,8 +705,8 @@ Node StringsEntail::checkContains(Node a, Node b, bool fullRewriter)
 bool StringsEntail::checkNonEmpty(Node a)
 {
   Node len = NodeManager::currentNM()->mkNode(STRING_LENGTH, a);
-  len = Rewriter::rewrite(len);
-  return ArithEntail::check(len, true);
+  len = d_rr->rewrite(len);
+  return d_arithEntail.check(len, true);
 }
 
 bool StringsEntail::checkLengthOne(Node s, bool strict)
@@ -711,9 +714,9 @@ bool StringsEntail::checkLengthOne(Node s, bool strict)
   NodeManager* nm = NodeManager::currentNM();
   Node one = nm->mkConst(Rational(1));
   Node len = nm->mkNode(STRING_LENGTH, s);
-  len = Rewriter::rewrite(len);
-  return ArithEntail::check(one, len)
-         && (!strict || ArithEntail::check(len, true));
+  len = d_rr->rewrite(len);
+  return d_arithEntail.check(one, len)
+         && (!strict || d_arithEntail.check(len, true));
 }
 
 bool StringsEntail::checkMultisetSubset(Node a, Node b)
@@ -877,7 +880,6 @@ Node StringsEntail::getStringOrEmpty(Node n)
           n = n[2];
           break;
         }
-
         if (checkLengthOne(n[0]) && Word::isEmpty(n[2]))
         {
           // (str.replace "A" x "") --> "A"
@@ -945,7 +947,7 @@ Node StringsEntail::inferEqsFromContains(Node x, Node y)
     // str.len(yn) (where y = y1 ++ ... ++ yn) while keeping the inequality
     // true. The terms that can have length zero without making the inequality
     // false must be all be empty if (str.contains x y) is true.
-    if (!ArithEntail::inferZerosInSumGeq(xLen, yLens, zeroLens))
+    if (!d_arithEntail.inferZerosInSumGeq(xLen, yLens, zeroLens))
     {
       // We could not prove that the inequality holds
       return Node::null();
index 7547bf80957e7917c032946707de3f54804fdfd6..1ff65a5b414846a4599d1a5ce63962f9a625aaf0 100644 (file)
 #include <vector>
 
 #include "expr/node.h"
+#include "theory/strings/arith_entail.h"
 
 namespace cvc5 {
 namespace theory {
+
+class Rewriter;
+
 namespace strings {
 
 class SequencesRewriter;
@@ -36,7 +40,7 @@ class SequencesRewriter;
 class StringsEntail
 {
  public:
-  StringsEntail(SequencesRewriter& rewriter);
+  StringsEntail(Rewriter* r, ArithEntail& aent, SequencesRewriter& rewriter);
 
   /** can constant contain list
    * return true if constant c can contain the list l in order
@@ -64,7 +68,7 @@ class StringsEntail
   /** can constant contain concat
    * same as above but with n = str.++( l ) instead of l
    */
-  static bool canConstantContainConcat(Node c, Node n, int& firstc, int& lastc);
+  bool canConstantContainConcat(Node c, Node n, int& firstc, int& lastc);
 
   /** strip symbolic length
    *
@@ -106,11 +110,11 @@ class StringsEntail
    *    nr is updated to { "abc", y }
    *    curr is updated to str.len(y)+1
    */
-  static bool stripSymbolicLength(std::vector<Node>& n1,
-                                  std::vector<Node>& nr,
-                                  int dir,
-                                  Node& curr,
-                                  bool strict = false);
+  bool stripSymbolicLength(std::vector<Node>& n1,
+                           std::vector<Node>& nr,
+                           int dir,
+                           Node& curr,
+                           bool strict = false);
   /** component contains
    * This function is used when rewriting str.contains( t1, t2 ), where
    * n1 is the vector form of t1
@@ -222,7 +226,7 @@ class StringsEntail
    * Checks whether string a is entailed to be non-empty. Is equivalent to
    * the call checkArithEntail( len( a ), true ).
    */
-  static bool checkNonEmpty(Node a);
+  bool checkNonEmpty(Node a);
 
   /**
    * Checks whether string has at most/exactly length one. Length one strings
@@ -234,7 +238,7 @@ class StringsEntail
    * at most length one
    * @return True if the string has at most/exactly length one, false otherwise
    */
-  static bool checkLengthOne(Node s, bool strict = false);
+  bool checkLengthOne(Node s, bool strict = false);
 
   /**
    * Checks whether it is always true that `a` is a strict subset of `b` in the
@@ -282,7 +286,7 @@ class StringsEntail
    * getStringOrEmpty( (str.substr "ABC" x y) ) --> (str.substr "ABC" x y)
    * because the function could not compute a simpler
    */
-  static Node getStringOrEmpty(Node n);
+  Node getStringOrEmpty(Node n);
 
   /**
    * Infers a conjunction of equalities that correspond to (str.contains x y)
@@ -298,7 +302,7 @@ class StringsEntail
    * y) if the function can infer that str.len(y) >= str.len(x) but cannot
    * infer that any of the yi must be empty.
    */
-  static Node inferEqsFromContains(Node x, Node y);
+  Node inferEqsFromContains(Node x, Node y);
 
  private:
   /** component contains base
@@ -371,6 +375,10 @@ class StringsEntail
   static Node getMultisetApproximation(Node a);
 
  private:
+  /** Pointer to the full rewriter */
+  Rewriter* d_rr;
+  /** The arithmetic entailment module */
+  ArithEntail& d_arithEntail;
   /**
    * Reference to the sequences rewriter that owns this `StringsEntail`
    * instance.
index b455d8a9b9012b02bde2b31f85105b78eddf25a0..9204bfab655d3987ece86d3b13a9c2b5eedd6630 100644 (file)
@@ -27,8 +27,9 @@ namespace cvc5 {
 namespace theory {
 namespace strings {
 
-StringsRewriter::StringsRewriter(HistogramStat<Rewrite>* statistics)
-    : SequencesRewriter(statistics)
+StringsRewriter::StringsRewriter(Rewriter* r,
+                                 HistogramStat<Rewrite>* statistics)
+    : SequencesRewriter(r, statistics)
 {
 }
 
index 70a1cccf0c3207a25cdf8a4df41baef9ccc00be3..65c0b67ab945d1b2345b5addd60410ec0ac85f6a 100644 (file)
@@ -32,7 +32,7 @@ namespace strings {
 class StringsRewriter : public SequencesRewriter
 {
  public:
-  StringsRewriter(HistogramStat<Rewrite>* statistics);
+  StringsRewriter(Rewriter* r, HistogramStat<Rewrite>* statistics);
 
   RewriteResponse postRewrite(TNode node) override;
 
index 1b315447e6502bf80ab135aeb5f85776663111c7..3eac3ca1a5f815e95760250b61821b8a6afaca72 100644 (file)
@@ -60,7 +60,7 @@ TheoryStrings::TheoryStrings(Env& env, OutputChannel& out, Valuation valuation)
       d_extTheoryCb(),
       d_im(env, *this, d_state, d_termReg, d_extTheory, d_statistics, d_pnm),
       d_extTheory(d_extTheoryCb, context(), userContext(), d_im),
-      d_rewriter(&d_statistics.d_rewrites),
+      d_rewriter(env.getRewriter(), &d_statistics.d_rewrites),
       d_bsolver(env, d_state, d_im),
       d_csolver(env, d_state, d_im, d_termReg, d_bsolver),
       d_esolver(env,
index 223cef13b0de1e32e5df326fd46e7f4348fe2cda..c10d8f363f110e2877876026fd8a148471262b80 100644 (file)
@@ -31,19 +31,20 @@ class TestPPWhiteForeignTheoryRewrite : public TestSmt
 
 TEST_F(TestPPWhiteForeignTheoryRewrite, simplify)
 {
+  ForeignTheoryRewriter ftr(d_smtEngine->getEnv());
   std::cout << "len(x) >= 0 is simplified to true" << std::endl;
   Node x = d_nodeManager->mkVar("x", d_nodeManager->stringType());
   Node len_x = d_nodeManager->mkNode(kind::STRING_LENGTH, x);
   Node zero = d_nodeManager->mkConst<Rational>(0);
   Node geq1 = d_nodeManager->mkNode(kind::GEQ, len_x, zero);
   Node tt = d_nodeManager->mkConst<bool>(true);
-  Node simplified1 = ForeignTheoryRewrite::foreignRewrite(geq1);
+  Node simplified1 = ftr.foreignRewrite(geq1);
   ASSERT_EQ(simplified1, tt);
 
   std::cout << "len(x) >= n is not simplified to true" << std::endl;
   Node n = d_nodeManager->mkVar("n", d_nodeManager->integerType());
   Node geq2 = d_nodeManager->mkNode(kind::GEQ, len_x, n);
-  Node simplified2 = ForeignTheoryRewrite::foreignRewrite(geq2);
+  Node simplified2 = ftr.foreignRewrite(geq2);
   ASSERT_NE(simplified2, tt);
 }
 
index b7339942efc63d0cb0fafee27241beead8bdeda8..99454a0147321d47432a81ef202dede359b94069 100644 (file)
@@ -43,9 +43,11 @@ class TestTheoryWhiteSequencesRewriter : public TestSmt
     TestSmt::SetUp();
     Options opts;
     d_rewriter = d_smtEngine->getRewriter();
+    d_seqRewriter.reset(new SequencesRewriter(d_rewriter, nullptr));
   }
 
   Rewriter* d_rewriter;
+  std::unique_ptr<SequencesRewriter> d_seqRewriter;
 
   void inNormalForm(Node t)
   {
@@ -81,6 +83,7 @@ class TestTheoryWhiteSequencesRewriter : public TestSmt
 
 TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_length_one)
 {
+  StringsEntail& se = d_seqRewriter->getStringsEntail();
   TypeNode intType = d_nodeManager->integerType();
   TypeNode strType = d_nodeManager->stringType();
 
@@ -97,28 +100,29 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_length_one)
   Node three = d_nodeManager->mkConst(Rational(3));
   Node i = d_nodeManager->mkVar("i", intType);
 
-  ASSERT_TRUE(StringsEntail::checkLengthOne(a));
-  ASSERT_TRUE(StringsEntail::checkLengthOne(a, true));
+  ASSERT_TRUE(se.checkLengthOne(a));
+  ASSERT_TRUE(se.checkLengthOne(a, true));
 
   Node substr = d_nodeManager->mkNode(kind::STRING_SUBSTR, x, zero, one);
-  ASSERT_TRUE(StringsEntail::checkLengthOne(substr));
-  ASSERT_FALSE(StringsEntail::checkLengthOne(substr, true));
+  ASSERT_TRUE(se.checkLengthOne(substr));
+  ASSERT_FALSE(se.checkLengthOne(substr, true));
 
   substr =
       d_nodeManager->mkNode(kind::STRING_SUBSTR,
                             d_nodeManager->mkNode(kind::STRING_CONCAT, a, x),
                             zero,
                             one);
-  ASSERT_TRUE(StringsEntail::checkLengthOne(substr));
-  ASSERT_TRUE(StringsEntail::checkLengthOne(substr, true));
+  ASSERT_TRUE(se.checkLengthOne(substr));
+  ASSERT_TRUE(se.checkLengthOne(substr, true));
 
   substr = d_nodeManager->mkNode(kind::STRING_SUBSTR, x, zero, two);
-  ASSERT_FALSE(StringsEntail::checkLengthOne(substr));
-  ASSERT_FALSE(StringsEntail::checkLengthOne(substr, true));
+  ASSERT_FALSE(se.checkLengthOne(substr));
+  ASSERT_FALSE(se.checkLengthOne(substr, true));
 }
 
 TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_arith)
 {
+  ArithEntail& ae = d_seqRewriter->getArithEntail();
   TypeNode intType = d_nodeManager->integerType();
   TypeNode strType = d_nodeManager->stringType();
 
@@ -130,14 +134,15 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_arith)
   Node substr_z = d_nodeManager->mkNode(
       kind::STRING_LENGTH,
       d_nodeManager->mkNode(kind::STRING_SUBSTR, z, n, one));
-  ASSERT_TRUE(ArithEntail::check(one, substr_z));
+  ASSERT_TRUE(ae.check(one, substr_z));
 
   // (str.len (str.substr z n 1)) >= 1 ---> false
-  ASSERT_FALSE(ArithEntail::check(substr_z, one));
+  ASSERT_FALSE(ae.check(substr_z, one));
 }
 
 TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption)
 {
+  ArithEntail& ae = d_seqRewriter->getArithEntail();
   TypeNode intType = d_nodeManager->integerType();
   TypeNode strType = d_nodeManager->stringType();
 
@@ -157,19 +162,17 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption)
       d_nodeManager->mkNode(kind::EQUAL, x_plus_slen_y, zero));
 
   // x + (str.len y) = 0 |= 0 >= x --> true
-  ASSERT_TRUE(
-      ArithEntail::checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, false));
+  ASSERT_TRUE(ae.checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, false));
 
   // x + (str.len y) = 0 |= 0 > x --> false
-  ASSERT_FALSE(
-      ArithEntail::checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, true));
+  ASSERT_FALSE(ae.checkWithAssumption(x_plus_slen_y_eq_zero, zero, x, true));
 
   Node x_plus_slen_y_plus_z_eq_zero = d_rewriter->rewrite(d_nodeManager->mkNode(
       kind::EQUAL, d_nodeManager->mkNode(kind::PLUS, x_plus_slen_y, z), zero));
 
   // x + (str.len y) + z = 0 |= 0 > x --> false
-  ASSERT_FALSE(ArithEntail::checkWithAssumption(
-      x_plus_slen_y_plus_z_eq_zero, zero, x, true));
+  ASSERT_FALSE(
+      ae.checkWithAssumption(x_plus_slen_y_plus_z_eq_zero, zero, x, true));
 
   Node x_plus_slen_y_plus_slen_y_eq_zero =
       d_rewriter->rewrite(d_nodeManager->mkNode(
@@ -178,7 +181,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption)
           zero));
 
   // x + (str.len y) + (str.len y) = 0 |= 0 >= x --> true
-  ASSERT_TRUE(ArithEntail::checkWithAssumption(
+  ASSERT_TRUE(ae.checkWithAssumption(
       x_plus_slen_y_plus_slen_y_eq_zero, zero, x, false));
 
   Node five = d_nodeManager->mkConst(Rational(5));
@@ -188,28 +191,24 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption)
       d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, six));
 
   // x + 5 < 6 |= 0 >= x --> true
-  ASSERT_TRUE(
-      ArithEntail::checkWithAssumption(x_plus_five_lt_six, zero, x, false));
+  ASSERT_TRUE(ae.checkWithAssumption(x_plus_five_lt_six, zero, x, false));
 
   // x + 5 < 6 |= 0 > x --> false
-  ASSERT_TRUE(
-      !ArithEntail::checkWithAssumption(x_plus_five_lt_six, zero, x, true));
+  ASSERT_TRUE(!ae.checkWithAssumption(x_plus_five_lt_six, zero, x, true));
 
   Node neg_x = d_nodeManager->mkNode(kind::UMINUS, x);
   Node x_plus_five_lt_five =
       d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, x_plus_five, five));
 
   // x + 5 < 5 |= -x >= 0 --> true
-  ASSERT_TRUE(ArithEntail::checkWithAssumption(
-      x_plus_five_lt_five, neg_x, zero, false));
+  ASSERT_TRUE(ae.checkWithAssumption(x_plus_five_lt_five, neg_x, zero, false));
 
   // x + 5 < 5 |= 0 > x --> true
-  ASSERT_TRUE(
-      ArithEntail::checkWithAssumption(x_plus_five_lt_five, zero, x, false));
+  ASSERT_TRUE(ae.checkWithAssumption(x_plus_five_lt_five, zero, x, false));
 
   // 0 < x |= x >= (str.len (int.to.str x))
   Node assm = d_rewriter->rewrite(d_nodeManager->mkNode(kind::LT, zero, x));
-  ASSERT_TRUE(ArithEntail::checkWithAssumption(
+  ASSERT_TRUE(ae.checkWithAssumption(
       assm,
       x,
       d_nodeManager->mkNode(kind::STRING_LENGTH,
@@ -219,6 +218,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, check_entail_with_with_assumption)
 
 TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr)
 {
+  StringsRewriter sr(d_rewriter, nullptr);
   TypeNode intType = d_nodeManager->integerType();
   TypeNode strType = d_nodeManager->stringType();
 
@@ -239,7 +239,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr)
 
   // (str.substr "A" x x) --> ""
   Node n = d_nodeManager->mkNode(kind::STRING_SUBSTR, a, x, x);
-  Node res = StringsRewriter(nullptr).rewriteSubstr(n);
+  Node res = sr.rewriteSubstr(n);
   ASSERT_EQ(res, empty);
 
   // (str.substr "A" (+ x 1) x) -> ""
@@ -248,7 +248,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr)
       a,
       d_nodeManager->mkNode(kind::PLUS, x, d_nodeManager->mkConst(Rational(1))),
       x);
-  res = StringsRewriter(nullptr).rewriteSubstr(n);
+  res = sr.rewriteSubstr(n);
   ASSERT_EQ(res, empty);
 
   // (str.substr "A" (+ x (str.len s2)) x) -> ""
@@ -258,12 +258,12 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr)
       d_nodeManager->mkNode(
           kind::PLUS, x, d_nodeManager->mkNode(kind::STRING_LENGTH, s)),
       x);
-  res = StringsRewriter(nullptr).rewriteSubstr(n);
+  res = sr.rewriteSubstr(n);
   ASSERT_EQ(res, empty);
 
   // (str.substr "A" x y) -> (str.substr "A" x y)
   n = d_nodeManager->mkNode(kind::STRING_SUBSTR, a, x, y);
-  res = StringsRewriter(nullptr).rewriteSubstr(n);
+  res = sr.rewriteSubstr(n);
   ASSERT_EQ(res, n);
 
   // (str.substr "ABCD" (+ x 3) x) -> ""
@@ -271,13 +271,13 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_substr)
                             abcd,
                             d_nodeManager->mkNode(kind::PLUS, x, three),
                             x);
-  res = StringsRewriter(nullptr).rewriteSubstr(n);
+  res = sr.rewriteSubstr(n);
   ASSERT_EQ(res, empty);
 
   // (str.substr "ABCD" (+ x 2) x) -> (str.substr "ABCD" (+ x 2) x)
   n = d_nodeManager->mkNode(
       kind::STRING_SUBSTR, abcd, d_nodeManager->mkNode(kind::PLUS, x, two), x);
-  res = StringsRewriter(nullptr).rewriteSubstr(n);
+  res = sr.rewriteSubstr(n);
   ASSERT_EQ(res, n);
 
   // (str.substr (str.substr s x x) x x) -> ""
@@ -1303,6 +1303,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_contains)
 
 TEST_F(TestTheoryWhiteSequencesRewriter, infer_eqs_from_contains)
 {
+  StringsEntail& se = d_seqRewriter->getStringsEntail();
   TypeNode strType = d_nodeManager->stringType();
 
   Node empty = d_nodeManager->mkConst(::cvc5::String(""));
@@ -1319,30 +1320,30 @@ TEST_F(TestTheoryWhiteSequencesRewriter, infer_eqs_from_contains)
       d_nodeManager->mkNode(kind::AND,
                             d_nodeManager->mkNode(kind::EQUAL, empty, x),
                             d_nodeManager->mkNode(kind::EQUAL, empty, y));
-  sameNormalForm(StringsEntail::inferEqsFromContains(empty, xy), empty_x_y);
+  sameNormalForm(se.inferEqsFromContains(empty, xy), empty_x_y);
 
   // inferEqsFromContains(x, (str.++ x y)) returns false
   Node bxya = d_nodeManager->mkNode(kind::STRING_CONCAT, {b, y, x, a});
-  sameNormalForm(StringsEntail::inferEqsFromContains(x, bxya), f);
+  sameNormalForm(se.inferEqsFromContains(x, bxya), f);
 
   // inferEqsFromContains(x, y) returns null
-  Node n = StringsEntail::inferEqsFromContains(x, y);
+  Node n = se.inferEqsFromContains(x, y);
   ASSERT_TRUE(n.isNull());
 
   // inferEqsFromContains(x, x) returns something equivalent to (= x x)
   Node eq_x_x = d_nodeManager->mkNode(kind::EQUAL, x, x);
-  sameNormalForm(StringsEntail::inferEqsFromContains(x, x), eq_x_x);
+  sameNormalForm(se.inferEqsFromContains(x, x), eq_x_x);
 
   // inferEqsFromContains((str.replace x "B" "A"), x) returns something
   // equivalent to (= (str.replace x "B" "A") x)
   Node repl = d_nodeManager->mkNode(kind::STRING_REPLACE, x, b, a);
   Node eq_repl_x = d_nodeManager->mkNode(kind::EQUAL, repl, x);
-  sameNormalForm(StringsEntail::inferEqsFromContains(repl, x), eq_repl_x);
+  sameNormalForm(se.inferEqsFromContains(repl, x), eq_repl_x);
 
   // inferEqsFromContains(x, (str.replace x "B" "A")) returns something
   // equivalent to (= (str.replace x "B" "A") x)
   Node eq_x_repl = d_nodeManager->mkNode(kind::EQUAL, x, repl);
-  sameNormalForm(StringsEntail::inferEqsFromContains(x, repl), eq_x_repl);
+  sameNormalForm(se.inferEqsFromContains(x, repl), eq_x_repl);
 }
 
 TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_prefix_suffix)
@@ -1672,6 +1673,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, rewrite_equality_ext)
 
 TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints)
 {
+  StringsEntail& se = d_seqRewriter->getStringsEntail();
   TypeNode intType = d_nodeManager->integerType();
   TypeNode strType = d_nodeManager->stringType();
 
@@ -1693,7 +1695,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints)
     std::vector<Node> n2 = {a};
     std::vector<Node> nb;
     std::vector<Node> ne;
-    bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 0);
+    bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 0);
     ASSERT_FALSE(res);
   }
 
@@ -1704,7 +1706,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints)
     std::vector<Node> n2 = {a, d_nodeManager->mkNode(kind::STRING_ITOS, n)};
     std::vector<Node> nb;
     std::vector<Node> ne;
-    bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 0);
+    bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 0);
     ASSERT_FALSE(res);
   }
 
@@ -1719,7 +1721,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints)
     std::vector<Node> ne;
     std::vector<Node> n1r = {cd};
     std::vector<Node> nbr = {ab};
-    bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 1);
+    bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 1);
     ASSERT_TRUE(res);
     ASSERT_EQ(n1, n1r);
     ASSERT_EQ(nb, nbr);
@@ -1736,7 +1738,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints)
     std::vector<Node> ne;
     std::vector<Node> n1r = {c, x};
     std::vector<Node> nbr = {ab};
-    bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, 1);
+    bool res = se.stripConstantEndpoints(n1, n2, nb, ne, 1);
     ASSERT_TRUE(res);
     ASSERT_EQ(n1, n1r);
     ASSERT_EQ(nb, nbr);
@@ -1753,7 +1755,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints)
     std::vector<Node> ne;
     std::vector<Node> n1r = {a};
     std::vector<Node> ner = {bc};
-    bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, -1);
+    bool res = se.stripConstantEndpoints(n1, n2, nb, ne, -1);
     ASSERT_TRUE(res);
     ASSERT_EQ(n1, n1r);
     ASSERT_EQ(ne, ner);
@@ -1770,7 +1772,7 @@ TEST_F(TestTheoryWhiteSequencesRewriter, strip_constant_endpoints)
     std::vector<Node> ne;
     std::vector<Node> n1r = {x, a};
     std::vector<Node> ner = {bc};
-    bool res = StringsEntail::stripConstantEndpoints(n1, n2, nb, ne, -1);
+    bool res = se.stripConstantEndpoints(n1, n2, nb, ne, -1);
     ASSERT_TRUE(res);
     ASSERT_EQ(n1, n1r);
     ASSERT_EQ(ne, ner);