Handle RE intersections modulo equality (#3120)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 30 Jul 2019 14:17:00 +0000 (09:17 -0500)
committerGitHub <noreply@github.com>
Tue, 30 Jul 2019 14:17:00 +0000 (09:17 -0500)
src/theory/strings/regexp_operation.cpp
src/theory/strings/regexp_operation.h
src/theory/strings/regexp_solver.cpp
src/theory/strings/regexp_solver.h
src/theory/strings/theory_strings.cpp
test/regress/CMakeLists.txt
test/regress/regress1/strings/re-mod-eq.smt2 [new file with mode: 0644]

index bd693c6c3df3c232b7bbecb590f6c72ea8086f48..f112547948a1b64d5096c31a4c1ccf5af70e872f 100644 (file)
@@ -1493,9 +1493,12 @@ Node RegExpOpr::intersect(Node r1, Node r2, bool &spflag) {
     Node rr1 = removeIntersection(r1);
     Node rr2 = removeIntersection(r2);
     std::map< PairNodes, Node > cache;
+    Trace("regexp-intersect-node") << "Intersect (1): " << rr1 << std::endl;
+    Trace("regexp-intersect-node") << "Intersect (2): " << rr2 << std::endl;
     Trace("regexp-intersect") << "Start INTERSECTION(\n\t" << mkString(r1) << ",\n\t"<< mkString(r2) << ")" << std::endl;
     Node retNode = intersectInternal(rr1, rr2, cache, 1);
     Trace("regexp-intersect") << "End INTERSECTION(\n\t" << mkString(r1) << ",\n\t"<< mkString(r2) << ") =\n\t" << mkString(retNode) << std::endl;
+    Trace("regexp-intersect-node") << "Intersect finished." << std::endl;
     return retNode;
   } else {
     spflag = true;
index b2e3667fc40b753c191c2489f4a044ac390c3c4a..8f9541e91fe86b4f5f3bb1c8c988e799e54d2329 100644 (file)
@@ -67,7 +67,11 @@ class RegExpOpr {
   std::map<Node, std::vector<PairNodes> > d_split_cache;
   void simplifyPRegExp(Node s, Node r, std::vector<Node> &new_nodes);
   void simplifyNRegExp(Node s, Node r, std::vector<Node> &new_nodes);
-  std::string niceChar(Node r);
+  /**
+   * Helper function for mkString, pretty prints constant or variable regular
+   * expression r.
+   */
+  static std::string niceChar(Node r);
   Node mkAllExceptOne(unsigned c);
   bool isPairNodesInSet(std::set<PairNodes> &s, Node n1, Node n2);
 
@@ -86,14 +90,23 @@ class RegExpOpr {
   RegExpOpr();
   ~RegExpOpr();
 
+  /**
+   * Returns true if r is a "constant" regular expression, that is, a set
+   * of regular expression operators whose subterms of the form (str.to.re t)
+   * are such that t is a constant (or rewrites to one).
+   */
   bool checkConstRegExp( Node r );
   void simplify(Node t, std::vector< Node > &new_nodes, bool polarity);
   int delta( Node r, Node &exp );
   int derivativeS( Node r, CVC4::String c, Node &retNode );
   Node derivativeSingle( Node r, CVC4::String c );
+  /**
+   * Returns the regular expression intersection of r1 and r2. If r1 or r2 is
+   * not constant, then this method returns null and sets spflag to true.
+   */
   Node intersect(Node r1, Node r2, bool &spflag);
-
-  std::string mkString( Node r );
+  /** Get the pretty printed version of the regular expression r */
+  static std::string mkString(Node r);
 };
 
 }/* CVC4::theory::strings namespace */
index 49dd1ead6c7a75db50b0f9ddbd1010b131b3b271..c50889e782728dda2d8b0019008653da3a189562 100644 (file)
@@ -19,6 +19,7 @@
 #include <cmath>
 
 #include "options/strings_options.h"
+#include "theory/ext_theory.h"
 #include "theory/strings/theory_strings.h"
 #include "theory/strings/theory_strings_rewriter.h"
 #include "theory/strings/theory_strings_utils.h"
@@ -38,13 +39,8 @@ RegExpSolver::RegExpSolver(TheoryStrings& p,
                            context::UserContext* u)
     : d_parent(p),
       d_im(im),
-      d_regexp_memberships(c),
       d_regexp_ucached(u),
       d_regexp_ccached(c),
-      d_pos_memberships(c),
-      d_neg_memberships(c),
-      d_inter_cache(c),
-      d_inter_index(c),
       d_processed_memberships(c)
 {
   d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String(""));
@@ -54,118 +50,27 @@ RegExpSolver::RegExpSolver(TheoryStrings& p,
   d_false = NodeManager::currentNM()->mkConst(false);
 }
 
-unsigned RegExpSolver::getNumMemberships(Node n, bool isPos)
-{
-  if (isPos)
-  {
-    NodeUIntMap::const_iterator it = d_pos_memberships.find(n);
-    if (it != d_pos_memberships.end())
-    {
-      return (*it).second;
-    }
-  }
-  else
-  {
-    NodeUIntMap::const_iterator it = d_neg_memberships.find(n);
-    if (it != d_neg_memberships.end())
-    {
-      return (*it).second;
-    }
-  }
-  return 0;
-}
-
-Node RegExpSolver::getMembership(Node n, bool isPos, unsigned i)
-{
-  return isPos ? d_pos_memberships_data[n][i] : d_neg_memberships_data[n][i];
-}
-
 Node RegExpSolver::mkAnd(Node c1, Node c2)
 {
   return NodeManager::currentNM()->mkNode(AND, c1, c2);
 }
 
-void RegExpSolver::check()
+void RegExpSolver::check(const std::map<Node, std::vector<Node> >& mems)
 {
   bool addedLemma = false;
   bool changed = false;
   std::vector<Node> processed;
   std::vector<Node> cprocessed;
 
-  Trace("regexp-debug") << "Checking Memberships ... " << std::endl;
-  for (NodeUIntMap::const_iterator itr_xr = d_pos_memberships.begin();
-       itr_xr != d_pos_memberships.end();
-       ++itr_xr)
+  Trace("regexp-process") << "Checking Memberships ... " << std::endl;
+  for (const std::pair<const Node, std::vector<Node> >& mr : mems)
   {
-    bool spflag = false;
-    Node x = (*itr_xr).first;
-    Trace("regexp-debug") << "Checking Memberships for " << x << std::endl;
-    if (d_inter_index.find(x) == d_inter_index.end())
+    Trace("regexp-process")
+        << "Memberships(" << mr.first << ") = " << mr.second << std::endl;
+    if (!checkEqcIntersect(mr.second))
     {
-      d_inter_index[x] = 0;
-    }
-    int cur_inter_idx = d_inter_index[x];
-    unsigned n_pmem = (*itr_xr).second;
-    Assert(getNumMemberships(x, true) == n_pmem);
-    if (cur_inter_idx != (int)n_pmem)
-    {
-      if (n_pmem == 1)
-      {
-        d_inter_cache[x] = getMembership(x, true, 0);
-        d_inter_index[x] = 1;
-        Trace("regexp-debug") << "... only one choice " << std::endl;
-      }
-      else if (n_pmem > 1)
-      {
-        Node r;
-        if (d_inter_cache.find(x) != d_inter_cache.end())
-        {
-          r = d_inter_cache[x];
-        }
-        if (r.isNull())
-        {
-          r = getMembership(x, true, 0);
-          cur_inter_idx = 1;
-        }
-
-        unsigned k_start = cur_inter_idx;
-        Trace("regexp-debug") << "... staring from : " << cur_inter_idx
-                              << ", we have " << n_pmem << std::endl;
-        for (unsigned k = k_start; k < n_pmem; k++)
-        {
-          Node r2 = getMembership(x, true, k);
-          r = d_regexp_opr.intersect(r, r2, spflag);
-          if (spflag)
-          {
-            break;
-          }
-          else if (r == d_emptyRegexp)
-          {
-            std::vector<Node> vec_nodes;
-            for (unsigned kk = 0; kk <= k; kk++)
-            {
-              Node rr = getMembership(x, true, kk);
-              Node n =
-                  NodeManager::currentNM()->mkNode(STRING_IN_REGEXP, x, rr);
-              vec_nodes.push_back(n);
-            }
-            Node conc;
-            d_im.sendInference(vec_nodes, conc, "INTERSECT CONFLICT", true);
-            addedLemma = true;
-            break;
-          }
-          if (d_im.hasConflict())
-          {
-            break;
-          }
-        }
-        // updates
-        if (!d_im.hasConflict() && !spflag)
-        {
-          d_inter_cache[x] = r;
-          d_inter_index[x] = (int)n_pmem;
-        }
-      }
+      // conflict discovered, return
+      return;
     }
   }
 
@@ -174,6 +79,20 @@ void RegExpSolver::check()
       << std::endl;
   if (!addedLemma)
   {
+    // get all memberships
+    std::vector<Node> allMems;
+    for (const std::pair<const Node, std::vector<Node> >& mr : mems)
+    {
+      for (const Node& m : mr.second)
+      {
+        bool polarity = m.getKind() != NOT;
+        if (polarity || !options::stringIgnNegMembership())
+        {
+          allMems.push_back(m);
+        }
+      }
+    }
+
     NodeManager* nm = NodeManager::currentNM();
     // representatives of strings that are the LHS of positive memberships that
     // we unfolded
@@ -181,7 +100,7 @@ void RegExpSolver::check()
     // check positive (e=0), then negative (e=1) memberships
     for (unsigned e = 0; e < 2; e++)
     {
-      for (const Node& assertion : d_regexp_memberships)
+      for (const Node& assertion : allMems)
       {
         // check regular expression membership
         Trace("regexp-debug")
@@ -327,6 +246,90 @@ void RegExpSolver::check()
   }
 }
 
+bool RegExpSolver::checkEqcIntersect(const std::vector<Node>& mems)
+{
+  if (mems.empty())
+  {
+    // nothing to do
+    return true;
+  }
+  // the initial regular expression membership
+  Node mi;
+  NodeManager* nm = NodeManager::currentNM();
+  for (const Node& m : mems)
+  {
+    if (m.getKind() != STRING_IN_REGEXP)
+    {
+      // do not do negative
+      Assert(m.getKind() == NOT && m[0].getKind() == STRING_IN_REGEXP);
+      continue;
+    }
+    if (!d_regexp_opr.checkConstRegExp(m))
+    {
+      // cannot do intersection on RE with variables
+      continue;
+    }
+    if (mi.isNull())
+    {
+      // first regular expression seen
+      mi = m;
+      continue;
+    }
+    bool spflag = false;
+    Node resR = d_regexp_opr.intersect(mi[1], m[1], spflag);
+    // intersection should be computable
+    Assert(!resR.isNull());
+    Assert(!spflag);
+    if (resR == d_emptyRegexp)
+    {
+      // conflict, explain
+      std::vector<Node> vec_nodes;
+      vec_nodes.push_back(mi);
+      vec_nodes.push_back(m);
+      if (mi[0] != m[0])
+      {
+        vec_nodes.push_back(mi[0].eqNode(m[0]));
+      }
+      Node conc;
+      d_im.sendInference(vec_nodes, conc, "INTERSECT CONFLICT", true);
+      // conflict, return
+      return false;
+    }
+    // rewrite to ensure the equality checks below are precise
+    Node mres = Rewriter::rewrite(nm->mkNode(STRING_IN_REGEXP, mi[0], resR));
+    if (mres == mi)
+    {
+      // if R1 = intersect( R1, R2 ), then x in R1 ^ x in R2 is equivalent
+      // to x in R1, hence x in R2 can be marked redundant.
+      d_parent.getExtTheory()->markReduced(m);
+    }
+    else if (mres == m)
+    {
+      // same as above, opposite direction
+      d_parent.getExtTheory()->markReduced(mi);
+    }
+    else
+    {
+      // new conclusion
+      // (x in R ^ y in R2 ^ x = y) => (x in intersect(R1,R2))
+      std::vector<Node> vec_nodes;
+      vec_nodes.push_back(mi);
+      vec_nodes.push_back(m);
+      if (mi[0] != m[0])
+      {
+        vec_nodes.push_back(mi[0].eqNode(m[0]));
+      }
+      d_im.sendInference(vec_nodes, mres, "INTERSECT INFER", true);
+      // both are reduced
+      d_parent.getExtTheory()->markReduced(m);
+      d_parent.getExtTheory()->markReduced(mi);
+      // do not send more than one lemma for this class
+      return true;
+    }
+  }
+  return true;
+}
+
 bool RegExpSolver::checkPDerivative(
     Node x, Node r, Node atom, bool& addedLemma, std::vector<Node>& nf_exp)
 {
@@ -452,83 +455,6 @@ bool RegExpSolver::deriveRegExp(Node x,
   return false;
 }
 
-void RegExpSolver::addMembership(Node assertion)
-{
-  bool polarity = assertion.getKind() != NOT;
-  TNode atom = polarity ? assertion : assertion[0];
-  Node x = atom[0];
-  Node r = atom[1];
-  if (polarity)
-  {
-    unsigned index = 0;
-    NodeUIntMap::const_iterator it = d_pos_memberships.find(x);
-    if (it != d_pos_memberships.end())
-    {
-      index = (*it).second;
-      for (unsigned k = 0; k < index; k++)
-      {
-        if (k < d_pos_memberships_data[x].size())
-        {
-          if (d_pos_memberships_data[x][k] == r)
-          {
-            return;
-          }
-        }
-        else
-        {
-          break;
-        }
-      }
-    }
-    d_pos_memberships[x] = index + 1;
-    if (index < d_pos_memberships_data[x].size())
-    {
-      d_pos_memberships_data[x][index] = r;
-    }
-    else
-    {
-      d_pos_memberships_data[x].push_back(r);
-    }
-  }
-  else if (!options::stringIgnNegMembership())
-  {
-    unsigned index = 0;
-    NodeUIntMap::const_iterator it = d_neg_memberships.find(x);
-    if (it != d_neg_memberships.end())
-    {
-      index = (*it).second;
-      for (unsigned k = 0; k < index; k++)
-      {
-        if (k < d_neg_memberships_data[x].size())
-        {
-          if (d_neg_memberships_data[x][k] == r)
-          {
-            return;
-          }
-        }
-        else
-        {
-          break;
-        }
-      }
-    }
-    d_neg_memberships[x] = index + 1;
-    if (index < d_neg_memberships_data[x].size())
-    {
-      d_neg_memberships_data[x][index] = r;
-    }
-    else
-    {
-      d_neg_memberships_data[x].push_back(r);
-    }
-  }
-  // old
-  if (polarity || !options::stringIgnNegMembership())
-  {
-    d_regexp_memberships.push_back(assertion);
-  }
-}
-
 Node RegExpSolver::getNormalSymRegExp(Node r, std::vector<Node>& nf_exp)
 {
   Node ret = r;
index ec74d98cdcfc3e33771ea6bbd5e15b8604f741ef..f3abb2a1d78e39e046da75a4339961e6b74429eb 100644 (file)
@@ -49,23 +49,33 @@ class RegExpSolver
                context::UserContext* u);
   ~RegExpSolver() {}
 
-  /** add membership
-   *
-   * This informs this class that assertion is asserted in the current context.
-   * We expect that assertion is a (possibly negated) regular expression
-   * membership.
-   */
-  void addMembership(Node assertion);
   /** check
    *
-   * Tells this solver to check whether the regular expressions asserted to it
+   * Tells this solver to check whether the regular expressions in mems
    * are consistent. If they are not, then this class will call the
    * sendInference method of its parent TheoryString object, indicating that
    * it requires a conflict or lemma to be processed.
+   *
+   * The argument mems maps representative string terms r to memberships of the
+   * form (t in R) or ~(t in R), where t = r currently holds in the equality
+   * engine of the theory of strings.
    */
-  void check();
+  void check(const std::map<Node, std::vector<Node>>& mems);
 
  private:
+  /**
+   * Check memberships for equivalence class.
+   * The vector mems is a vector of memberships of the form:
+   *   (~) (x1 in R1 ) ... (~) (xn in Rn)
+   * where x1 = ... = xn in the current context.
+   *
+   * This method may add lemmas or conflicts via the inference manager.
+   *
+   * This method returns false if it discovered a conflict for this set of
+   * assertions, and true otherwise. It discovers a conflict e.g. if mems
+   * contains (xi in Ri) and (xj in Rj) and intersect(xi,xj) is empty.
+   */
+  bool checkEqcIntersect(const std::vector<Node>& mems);
   // Constants
   Node d_emptyString;
   Node d_emptyRegexp;
@@ -85,20 +95,11 @@ class RegExpSolver
   bool deriveRegExp(Node x, Node r, Node atom, std::vector<Node>& ant);
   Node getNormalSymRegExp(Node r, std::vector<Node>& nf_exp);
   // regular expression memberships
-  NodeList d_regexp_memberships;
   NodeSet d_regexp_ucached;
   NodeSet d_regexp_ccached;
-  // stored assertions
-  NodeUIntMap d_pos_memberships;
-  std::map<Node, std::vector<Node> > d_pos_memberships_data;
-  NodeUIntMap d_neg_memberships;
-  std::map<Node, std::vector<Node> > d_neg_memberships_data;
   // semi normal forms for symbolic expression
   std::map<Node, Node> d_nf_regexps;
   std::map<Node, std::vector<Node> > d_nf_regexps_exp;
-  // intersection
-  NodeNodeMap d_inter_cache;
-  NodeIntMap d_inter_index;
   // processed memberships
   NodeSet d_processed_memberships;
   /** regular expression operation module */
index 952d82c21c4f993cc15f92cd7d9fbb0185e93a63..164d227231e3e33e4943a7dfcfedc50a002d7239 100644 (file)
@@ -1056,16 +1056,20 @@ void TheoryStrings::checkMemberships()
 {
   // add the memberships
   std::vector<Node> mems = getExtTheory()->getActive(kind::STRING_IN_REGEXP);
+  // maps representatives to regular expression memberships in that class
+  std::map<Node, std::vector<Node> > assertedMems;
   for (unsigned i = 0; i < mems.size(); i++)
   {
     Node n = mems[i];
+    Assert(n.getKind() == STRING_IN_REGEXP);
     Assert(d_extf_info_tmp.find(n) != d_extf_info_tmp.end());
     if (!d_extf_info_tmp[n].d_const.isNull())
     {
       bool pol = d_extf_info_tmp[n].d_const.getConst<bool>();
       Trace("strings-process-debug")
           << "  add membership : " << n << ", pol = " << pol << std::endl;
-      d_regexp_solver.addMembership(pol ? n : n.negate());
+      Node r = getRepresentative(n[0]);
+      assertedMems[r].push_back(pol ? n : n.negate());
     }
     else
     {
@@ -1073,7 +1077,7 @@ void TheoryStrings::checkMemberships()
           << "  irrelevant (non-asserted) membership : " << n << std::endl;
     }
   }
-  d_regexp_solver.check();
+  d_regexp_solver.check(assertedMems);
 }
 
 TheoryStrings::EqcInfo::EqcInfo(context::Context* c)
index 9874100b3cc02569ac9a758dc1f1f00382b053bf..90641fe2fc525702e677b937cb3f554bbdc1f634 100644 (file)
@@ -1583,6 +1583,7 @@ set(regress_1_tests
   regress1/strings/re-agg-total1.smt2
   regress1/strings/re-agg-total2.smt2
   regress1/strings/re-elim-exact.smt2
+  regress1/strings/re-mod-eq.smt2
   regress1/strings/re-neg-concat-reduct.smt2
   regress1/strings/re-neg-unfold-rev-a.smt2
   regress1/strings/re-unsound-080718.smt2
diff --git a/test/regress/regress1/strings/re-mod-eq.smt2 b/test/regress/regress1/strings/re-mod-eq.smt2
new file mode 100644 (file)
index 0000000..531556e
--- /dev/null
@@ -0,0 +1,13 @@
+(set-info :smt-lib-version 2.5)
+(set-logic QF_SLIA)
+(set-option :strings-exp true)
+(set-info :status unsat)
+(declare-fun x () String)
+(declare-fun y () String)
+(declare-fun z () String)
+(assert (or (= x y)(= x z)))
+(assert (str.in.re x (re.++ (str.to.re "A") (re.* (str.to.re "BAA")))))
+(assert (str.in.re y (re.++ (str.to.re "AB") (re.* (str.to.re "AAB")) (str.to.re "A"))))
+(assert (str.in.re z (re.++ (str.to.re "AB") (re.* (str.to.re "AAB")) (str.to.re "A"))))
+; requires RE solver to reason modulo string equalties
+(check-sat)