Convert more uses of string-specific functions (#4158)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 31 Mar 2020 15:07:29 +0000 (10:07 -0500)
committerGitHub <noreply@github.com>
Tue, 31 Mar 2020 15:07:29 +0000 (10:07 -0500)
Towards theory of sequences.

src/theory/strings/sequences_rewriter.cpp
src/theory/strings/sequences_rewriter.h
src/theory/strings/theory_strings_preprocess.cpp
src/theory/strings/theory_strings_preprocess.h
src/theory/strings/word.cpp
src/theory/strings/word.h

index d7ee459c72a55272af5cf5c6efa04447fd5747f8..be1e1345920a3a659c743ca83e551f97cc2c3a28 100644 (file)
@@ -479,13 +479,12 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
     {
       Assert(cn.isConst());
       Assert(Word::getLength(cn) == 1);
-      unsigned hchar = cn.getConst<String>().front();
 
       // The operands of the concat on each side of the equality without
       // constant strings
       std::vector<Node> trimmed[2];
-      // Counts the number of `hchar`s on each side
-      size_t numHChars[2] = {0, 0};
+      // Counts the number of `cn`s on each side
+      size_t numCns[2] = {0, 0};
       for (size_t j = 0; j < 2; j++)
       {
         // Sort the operands of the concats on both sides of the equality
@@ -496,12 +495,12 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
         {
           if (cc.isConst())
           {
-            // Count the number of `hchar`s in the string constant and make
-            // sure that all chars are `hchar`s
-            std::vector<unsigned> veccc = cc.getConst<String>().getVec();
-            for (size_t k = 0, size = veccc.size(); k < size; k++)
+            // Count the number of `cn`s in the string constant and make
+            // sure that all chars are `cn`s
+            std::vector<Node> veccc = Word::getChars(cc);
+            for (const Node& cv : veccc)
             {
-              if (veccc[k] != hchar)
+              if (cv != cn)
               {
                 // This conflict case should mostly should be taken care of by
                 // multiset reasoning in the strings rewriter, but we recognize
@@ -510,7 +509,7 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
                 return returnRewrite(
                     node, new_ret, Rewrite::STR_EQ_CONST_NHOMOG);
               }
-              numHChars[j]++;
+              numCns[j]++;
             }
           }
           else
@@ -520,18 +519,18 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
         }
       }
 
-      // We have to remove the same number of `hchar`s from both sides, so the
-      // side with less `hchar`s determines how many we can remove
-      size_t trimmedConst = std::min(numHChars[0], numHChars[1]);
+      // We have to remove the same number of `cn`s from both sides, so the
+      // side with less `cn`s determines how many we can remove
+      size_t trimmedConst = std::min(numCns[0], numCns[1]);
       for (size_t j = 0; j < 2; j++)
       {
-        size_t diff = numHChars[j] - trimmedConst;
+        size_t diff = numCns[j] - trimmedConst;
         if (diff != 0)
         {
-          // Add a constant string to the side with more `hchar`s to restore
-          // the difference in number of `hchar`s
-          std::vector<unsigned> vec(diff, hchar);
-          trimmed[j].push_back(nm->mkConst(String(vec)));
+          // Add a constant string to the side with more `cn`s to restore
+          // the difference in number of `cn`s
+          std::vector<Node> vec(diff, cn);
+          trimmed[j].push_back(Word::mkWord(vec));
         }
       }
 
@@ -2332,7 +2331,6 @@ Node SequencesRewriter::rewriteContains(Node node)
   }
   if (node[0].isConst())
   {
-    CVC4::String s = node[0].getConst<String>();
     if (node[1].isConst())
     {
       Node ret = nm->mkConst(Word::find(node[0], node[1]) != std::string::npos);
@@ -2356,14 +2354,13 @@ Node SequencesRewriter::rewriteContains(Node node)
       }
       else if (checkEntailLengthOne(t))
       {
-        const std::vector<unsigned>& vec = s.getVec();
-
+        std::vector<Node> vec = Word::getChars(node[0]);
+        Node emp = Word::mkEmptyWord(t.getType());
         NodeBuilder<> nb(OR);
-        nb << nm->mkConst(String("")).eqNode(t);
-        for (unsigned c : vec)
+        nb << emp.eqNode(t);
+        for (const Node& c : vec)
         {
-          std::vector<unsigned> sv = {c};
-          nb << nm->mkConst(String(sv)).eqNode(t);
+          nb << c.eqNode(t);
         }
 
         // str.contains("ABCabc", t) --->
@@ -4232,7 +4229,7 @@ bool SequencesRewriter::stripConstantEndpoints(std::vector<Node>& n1,
   return changed;
 }
 
-Node SequencesRewriter::canonicalStrForSymbolicLength(Node len)
+Node SequencesRewriter::canonicalStrForSymbolicLength(Node len, TypeNode stype)
 {
   NodeManager* nm = NodeManager::currentNM();
 
@@ -4243,7 +4240,15 @@ Node SequencesRewriter::canonicalStrForSymbolicLength(Node len)
     Rational ratLen = len.getConst<Rational>();
     Assert(ratLen.getDenominator() == 1);
     Integer intLen = ratLen.getNumerator();
-    res = nm->mkConst(String(std::string(intLen.getUnsignedInt(), 'A')));
+    uint32_t u = intLen.getUnsignedInt();
+    if (stype.isString())
+    {
+      res = nm->mkConst(String(std::string(u, 'A')));
+    }
+    else
+    {
+      Unimplemented() << "canonicalStrForSymbolicLength for non-string";
+    }
   }
   else if (len.getKind() == kind::PLUS)
   {
@@ -4251,7 +4256,7 @@ Node SequencesRewriter::canonicalStrForSymbolicLength(Node len)
     NodeBuilder<> concatBuilder(kind::STRING_CONCAT);
     for (const auto& n : len)
     {
-      Node sn = canonicalStrForSymbolicLength(n);
+      Node sn = canonicalStrForSymbolicLength(n, stype);
       if (sn.isNull())
       {
         return Node::null();
@@ -4270,7 +4275,7 @@ Node SequencesRewriter::canonicalStrForSymbolicLength(Node len)
     Assert(ratReps.getDenominator() == 1);
     Integer intReps = ratReps.getNumerator();
 
-    Node nRep = canonicalStrForSymbolicLength(len[1]);
+    Node nRep = canonicalStrForSymbolicLength(len[1], stype);
     std::vector<Node> nRepChildren;
     utils::getConcat(nRep, nRepChildren);
     NodeBuilder<> concatBuilder(kind::STRING_CONCAT);
@@ -4292,7 +4297,7 @@ Node SequencesRewriter::lengthPreserveRewrite(Node n)
 {
   NodeManager* nm = NodeManager::currentNM();
   Node len = Rewriter::rewrite(nm->mkNode(kind::STRING_LENGTH, n));
-  Node res = canonicalStrForSymbolicLength(len);
+  Node res = canonicalStrForSymbolicLength(len, n.getType());
   return res.isNull() ? n : res;
 }
 
@@ -4869,8 +4874,6 @@ void SequencesRewriter::getArithApproximations(Node a,
 
 bool SequencesRewriter::checkEntailMultisetSubset(Node a, Node b)
 {
-  NodeManager* nm = NodeManager::currentNM();
-
   std::vector<Node> avec;
   utils::getConcat(getMultisetApproximation(a), avec);
   std::vector<Node> bvec;
@@ -4913,14 +4916,9 @@ bool SequencesRewriter::checkEntailMultisetSubset(Node a, Node b)
       {
         Node cn = ncp.first;
         Assert(cn.isConst());
-        std::vector<unsigned> cc_vec;
-        const std::vector<unsigned>& cvec = cn.getConst<String>().getVec();
-        for (unsigned i = 0, size = cvec.size(); i < size; i++)
+        std::vector<Node> cnChars = Word::getChars(cn);
+        for (const Node& ch : cnChars)
         {
-          // make the character
-          cc_vec.clear();
-          cc_vec.insert(cc_vec.end(), cvec.begin() + i, cvec.begin() + i + 1);
-          Node ch = nm->mkConst(String(cc_vec));
           count_const[j][ch] += ncp.second;
           if (std::find(chars.begin(), chars.end(), ch) == chars.end())
           {
@@ -4955,19 +4953,17 @@ bool SequencesRewriter::checkEntailMultisetSubset(Node a, Node b)
 
 Node SequencesRewriter::checkEntailHomogeneousString(Node a)
 {
-  NodeManager* nm = NodeManager::currentNM();
-
   std::vector<Node> avec;
   utils::getConcat(getMultisetApproximation(a), avec);
 
   bool cValid = false;
-  unsigned c = 0;
+  Node c;
   for (const Node& ac : avec)
   {
     if (ac.isConst())
     {
-      std::vector<unsigned> acv = ac.getConst<String>().getVec();
-      for (unsigned cc : acv)
+      std::vector<Node> acv = Word::getChars(ac);
+      for (const Node& cc : acv)
       {
         if (!cValid)
         {
@@ -4990,11 +4986,10 @@ Node SequencesRewriter::checkEntailHomogeneousString(Node a)
 
   if (!cValid)
   {
-    return nm->mkConst(String(""));
+    return Word::mkEmptyWord(a.getType());
   }
 
-  std::vector<unsigned> cv = {c};
-  return nm->mkConst(String(cv));
+  return c;
 }
 
 Node SequencesRewriter::getMultisetApproximation(Node a)
index afdd2c0e1c63d1f209ee5df4817463caf8b2f71e..0e5cd5705bc4b66b1257975e8fbaf9bf4b9be709 100644 (file)
@@ -512,11 +512,12 @@ class SequencesRewriter : public TheoryRewriter
                                      int dir = 0);
 
   /**
-   * Given a symbolic length n, returns the canonical string for that length.
-   * For example if n is constant, this function returns a string consisting of
-   * "A" repeated n times. Returns the null node if no such string exists.
+   * Given a symbolic length n, returns the canonical string (of type stype)
+   * for that length. For example if n is constant, this function returns a
+   * string consisting of "A" repeated n times. Returns the null node if no such
+   * string exists.
    */
-  static Node canonicalStrForSymbolicLength(Node n);
+  static Node canonicalStrForSymbolicLength(Node n, TypeNode stype);
 
   /** length preserving rewrite
    *
index b35c4a921466fad7d09486ab50c1d17ec237d6e4..7777b9bd7cd774ee6c91cc9062d85915e2c75637 100644 (file)
@@ -23,6 +23,7 @@
 #include "proof/proof_manager.h"
 #include "smt/logic_exception.h"
 #include "theory/strings/sequences_rewriter.h"
+#include "theory/strings/word.h"
 
 using namespace CVC4;
 using namespace CVC4::kind;
@@ -40,7 +41,6 @@ StringsPreprocess::StringsPreprocess(SkolemCache* sc,
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
   d_one = NodeManager::currentNM()->mkConst(Rational(1));
   d_neg_one = NodeManager::currentNM()->mkConst(Rational(-1));
-  d_empty_str = NodeManager::currentNM()->mkConst(String(""));
 }
 
 StringsPreprocess::~StringsPreprocess(){
@@ -70,11 +70,13 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) {
     Node c3 = nm->mkNode(GT, m, d_zero);
     Node cond = nm->mkNode(AND, c1, c2, c3);
 
-    Node sk1 = n == d_zero ? d_empty_str
+    Node emp = Word::mkEmptyWord(t.getType());
+
+    Node sk1 = n == d_zero ? emp
                            : d_sc->mkSkolemCached(
                                  s, n, SkolemCache::SK_PREFIX, "sspre");
     Node sk2 = SequencesRewriter::checkEntailArith(t12, lt0)
-                   ? d_empty_str
+                   ? emp
                    : d_sc->mkSkolemCached(
                          s, t12, SkolemCache::SK_SUFFIX_REM, "sssufr");
     Node b11 = s.eqNode(nm->mkNode(STRING_CONCAT, sk1, skt, sk2));
@@ -91,7 +93,7 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) {
     Node b14 = nm->mkNode(LEQ, nm->mkNode(STRING_LENGTH, skt), m);
 
     Node b1 = nm->mkNode(AND, b11, b12, b13, b14);
-    Node b2 = skt.eqNode(d_empty_str);
+    Node b2 = skt.eqNode(emp);
     Node lemma = nm->mkNode(ITE, cond, b1, b2);
 
     // assert:
@@ -151,7 +153,8 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) {
     Node cc1 = skk.eqNode(negone);
 
     // y = ""
-    Node cond2 = y.eqNode(d_empty_str);
+    Node emp = Word::mkEmptyWord(x.getType());
+    Node cond2 = y.eqNode(emp);
     // skk = n
     Node cc2 = skk.eqNode(t[2]);
 
@@ -239,8 +242,8 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) {
 
     Node nonneg = nm->mkNode(GEQ, n, d_zero);
 
-    lem = nm->mkNode(
-        ITE, nonneg, nm->mkNode(AND, conc), itost.eqNode(d_empty_str));
+    Node emp = Word::mkEmptyWord(t.getType());
+    lem = nm->mkNode(ITE, nonneg, nm->mkNode(AND, conc), itost.eqNode(emp));
     new_nodes.push_back(lem);
     // assert:
     // IF n>=0
@@ -277,7 +280,8 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) {
     Node lem = stoit.eqNode(d_neg_one);
     conc1.push_back(lem);
 
-    Node sEmpty = s.eqNode(d_empty_str);
+    Node emp = Word::mkEmptyWord(s.getType());
+    Node sEmpty = s.eqNode(emp);
     Node k = nm->mkSkolem("k", nm->integerType());
     Node kc1 = nm->mkNode(GEQ, k, d_zero);
     Node kc2 = nm->mkNode(LT, k, lens);
@@ -478,8 +482,9 @@ Node StringsPreprocess::simplify( Node t, std::vector< Node > &new_nodes ) {
     // the index to begin searching in x for y after the i^th occurrence of y in
     // x, and Us( i ) is the result of processing the remainder after processing
     // the i^th occurrence of y in x.
-    Node assert = nm->mkNode(
-        ITE, y.eqNode(d_empty_str), rpaw.eqNode(x), nm->mkNode(AND, lem));
+    Node emp = Word::mkEmptyWord(t.getType());
+    Node assert =
+        nm->mkNode(ITE, y.eqNode(emp), rpaw.eqNode(x), nm->mkNode(AND, lem));
     new_nodes.push_back(assert);
 
     // Thus, replaceall( x, y, z ) = rpaw
index 155b9014c03e7bcf1bf6e442b82adf83106a9b97..fb6404aa64196f7f04fe618973e7f7d48ae519b8 100644 (file)
@@ -72,7 +72,6 @@ class StringsPreprocess {
   Node d_zero;
   Node d_one;
   Node d_neg_one;
-  Node d_empty_str;
   /** pointer to the skolem cache used by this class */
   SkolemCache* d_sc;
   /** Reference to the statistics for the theory of strings/sequences. */
index 0faeffd999959580456c7f020ebb3981fd197dae..b42cf3160e0ab060c18883e1805c36aefbc8a6ac 100644 (file)
@@ -76,6 +76,29 @@ size_t Word::getLength(TNode x)
   return 0;
 }
 
+std::vector<Node> Word::getChars(TNode x)
+{
+  Kind k = x.getKind();
+  if (k == CONST_STRING)
+  {
+    std::vector<Node> ret;
+    NodeManager* nm = NodeManager::currentNM();
+    std::vector<unsigned> ccVec;
+    const std::vector<unsigned>& cvec = x.getConst<String>().getVec();
+    for (unsigned chVal : cvec)
+    {
+      ccVec.clear();
+      ccVec.push_back(chVal);
+      Node ch = nm->mkConst(String(ccVec));
+      ret.push_back(ch);
+    }
+    return ret;
+  }
+  Unimplemented();
+  std::vector<Node> ret;
+  return ret;
+}
+
 bool Word::isEmpty(TNode x) { return getLength(x) == 0; }
 
 bool Word::strncmp(TNode x, TNode y, std::size_t n)
index 7b813a0b2ebff281f9250cc6ed3abd9b8b20c3d8..8e6e7876e6cb745b9eb8b23fd688859afc1d1b7f 100644 (file)
@@ -42,6 +42,13 @@ class Word
   /** Return the length of word x */
   static size_t getLength(TNode x);
 
+  /** Get characters
+   *
+   * Given word x, this returns the vector of words of length one whose
+   * concatenation is equivalent to x.
+   */
+  static std::vector<Node> getChars(TNode x);
+
   /** Return true if x is empty */
   static bool isEmpty(TNode x);