Generalize mkConcat for types (#4123)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 20 Mar 2020 23:09:27 +0000 (18:09 -0500)
committerGitHub <noreply@github.com>
Fri, 20 Mar 2020 23:09:27 +0000 (18:09 -0500)
Towards theory of sequences.

The utility function mkConcat needs a type to know what to construct in the empty string case.

13 files changed:
src/theory/quantifiers/quantifiers_rewriter.cpp
src/theory/strings/base_solver.cpp
src/theory/strings/base_solver.h
src/theory/strings/core_solver.cpp
src/theory/strings/core_solver.h
src/theory/strings/normal_form.cpp
src/theory/strings/regexp_elim.cpp
src/theory/strings/regexp_elim.h
src/theory/strings/regexp_solver.cpp
src/theory/strings/sequences_rewriter.cpp
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings_utils.cpp
src/theory/strings/theory_strings_utils.h

index 187c765d10defcd7e76e6cf9beeb40277754dea6..231c81bbf2b03dc73dd7d057c021df80df32904d 100644 (file)
@@ -818,6 +818,7 @@ Node QuantifiersRewriter::getVarElimLitString(Node lit,
   {
     if (lit[i].getKind() == STRING_CONCAT)
     {
+      TypeNode stype = lit[i].getType();
       for (unsigned j = 0, nchildren = lit[i].getNumChildren(); j < nchildren;
            j++)
       {
@@ -827,8 +828,8 @@ Node QuantifiersRewriter::getVarElimLitString(Node lit,
           Node slv = lit[1 - i];
           std::vector<Node> preL(lit[i].begin(), lit[i].begin() + j);
           std::vector<Node> postL(lit[i].begin() + j + 1, lit[i].end());
-          Node tpre = strings::utils::mkConcat(STRING_CONCAT, preL);
-          Node tpost = strings::utils::mkConcat(STRING_CONCAT, postL);
+          Node tpre = strings::utils::mkConcat(preL, stype);
+          Node tpost = strings::utils::mkConcat(postL, stype);
           Node slvL = nm->mkNode(STRING_LENGTH, slv);
           Node tpreL = nm->mkNode(STRING_LENGTH, tpre);
           Node tpostL = nm->mkNode(STRING_LENGTH, tpost);
index 6958d2528c756e9b7e5ba452a6f4e08226e46b63..128893cf0d0055a5b4cb477b224a2a95f2dc67db 100644 (file)
@@ -35,6 +35,7 @@ BaseSolver::BaseSolver(context::Context* c,
   d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String(""));
   d_false = NodeManager::currentNM()->mkConst(false);
   d_cardSize = utils::getAlphabetCardinality();
+  d_type = NodeManager::currentNM()->stringType();
 }
 
 BaseSolver::~BaseSolver() {}
@@ -253,7 +254,7 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
   if (!n.isNull())
   {
     // construct the constant
-    Node c = utils::mkNConcat(vecc);
+    Node c = utils::mkNConcat(vecc, d_type);
     if (!d_state.areEqual(n, c))
     {
       if (Trace.isOn("strings-debug"))
index 3681b49a4a1d7e5310efd3c86bcd068416dc947e..bf223bc0ae495373a9ccf8e914a502c229864d19 100644 (file)
@@ -191,6 +191,8 @@ class BaseSolver
   std::map<Kind, TermIndex> d_termIndex;
   /** the cardinality of the alphabet */
   uint32_t d_cardSize;
+  /** The string-like type for this base solver */
+  TypeNode d_type;
 }; /* class BaseSolver */
 
 }  // namespace strings
index 2a95b41baab6ced62990d44e592d92e7c9724213..f250647af7818c00bbfe82ae4572aa5730c637aa 100644 (file)
@@ -40,6 +40,7 @@ d_nf_pairs(c)
   d_emptyString = Word::mkEmptyWord(CONST_STRING);
   d_true = NodeManager::currentNM()->mkConst( true );
   d_false = NodeManager::currentNM()->mkConst( false );
+  d_type = NodeManager::currentNM()->stringType();
 }
 
 CoreSolver::~CoreSolver() {
@@ -559,7 +560,7 @@ void CoreSolver::checkNormalFormsEq()
       return;
     }
     NormalForm& nfe = getNormalForm(eqc);
-    Node nf_term = utils::mkNConcat(nfe.d_nf);
+    Node nf_term = utils::mkNConcat(nfe.d_nf, d_type);
     std::map<Node, Node>::iterator itn = nf_to_eqc.find(nf_term);
     if (itn != nf_to_eqc.end())
     {
@@ -682,7 +683,7 @@ Node CoreSolver::getNormalString(Node x, std::vector<Node>& nf_exp)
     if (it != d_normal_form.end())
     {
       NormalForm& nf = it->second;
-      Node ret = utils::mkNConcat(nf.d_nf);
+      Node ret = utils::mkNConcat(nf.d_nf, d_type);
       nf_exp.insert(nf_exp.end(), nf.d_exp.begin(), nf.d_exp.end());
       d_im.addToExplanation(x, nf.d_base, nf_exp);
       Trace("strings-debug")
@@ -700,7 +701,7 @@ Node CoreSolver::getNormalString(Node x, std::vector<Node>& nf_exp)
         Node nc = getNormalString(x[i], nf_exp);
         vec_nodes.push_back(nc);
       }
-      return utils::mkNConcat(vec_nodes);
+      return utils::mkNConcat(vec_nodes, d_type);
     }
   }
   return x;
@@ -1117,7 +1118,7 @@ void CoreSolver::processSimpleNEq(NormalForm& nfi,
             eqnc.push_back(nfkv[i]);
           }
         }
-        eqn[r] = utils::mkNConcat(eqnc);
+        eqn[r] = utils::mkNConcat(eqnc, d_type);
       }
       if (!d_state.areEqual(eqn[0], eqn[1]))
       {
@@ -1540,15 +1541,15 @@ CoreSolver::ProcessLoopResult CoreSolver::processLoop(NormalForm& nfi,
   Trace("strings-loop") << " ... (X)= " << vecoi[index] << std::endl;
   Trace("strings-loop") << " ... T(Y.Z)= ";
   std::vector<Node> vec_t(veci.begin() + index, veci.begin() + loop_index);
-  Node t_yz = utils::mkNConcat(vec_t);
+  Node t_yz = utils::mkNConcat(vec_t, d_type);
   Trace("strings-loop") << " (" << t_yz << ")" << std::endl;
   Trace("strings-loop") << " ... S(Z.Y)= ";
   std::vector<Node> vec_s(vecoi.begin() + index + 1, vecoi.end());
-  Node s_zy = utils::mkNConcat(vec_s);
+  Node s_zy = utils::mkNConcat(vec_s, d_type);
   Trace("strings-loop") << s_zy << std::endl;
   Trace("strings-loop") << " ... R= ";
   std::vector<Node> vec_r(veci.begin() + loop_index + 1, veci.end());
-  Node r = utils::mkNConcat(vec_r);
+  Node r = utils::mkNConcat(vec_r, d_type);
   Trace("strings-loop") << r << std::endl;
 
   if (s_zy.isConst() && r.isConst() && r != d_emptyString)
@@ -1640,7 +1641,7 @@ CoreSolver::ProcessLoopResult CoreSolver::processLoop(NormalForm& nfi,
         v2.insert(v2.begin(), y);
         v2.insert(v2.begin(), z);
         restr = utils::mkNConcat(z, y);
-        cc = Rewriter::rewrite(s_zy.eqNode(utils::mkNConcat(v2)));
+        cc = Rewriter::rewrite(s_zy.eqNode(utils::mkNConcat(v2, d_type)));
       }
       else
       {
@@ -1690,7 +1691,7 @@ CoreSolver::ProcessLoopResult CoreSolver::processLoop(NormalForm& nfi,
     // s1 * ... * sk = z * y * r
     vec_r.insert(vec_r.begin(), sk_y);
     vec_r.insert(vec_r.begin(), sk_z);
-    Node conc2 = s_zy.eqNode(utils::mkNConcat(vec_r));
+    Node conc2 = s_zy.eqNode(utils::mkNConcat(vec_r, d_type));
     Node conc3 = vecoi[index].eqNode(utils::mkNConcat(sk_y, sk_w));
     Node restr = r == d_emptyString ? s_zy : utils::mkNConcat(sk_z, sk_y);
     str_in_re =
@@ -2147,7 +2148,7 @@ void CoreSolver::checkLengthsEqc() {
     // now, check if length normalization has occurred
     if (ei->d_normalizedLength.get().isNull())
     {
-      Node nf = utils::mkNConcat(nfi.d_nf);
+      Node nf = utils::mkNConcat(nfi.d_nf, d_type);
       if (Trace.isOn("strings-process-debug"))
       {
         Trace("strings-process-debug")
index d18f109b2d7f65683712412d21dc2686ede8b52c..c549fa88617e38309be189202de3e0c1efcd537e 100644 (file)
@@ -368,6 +368,8 @@ class CoreSolver
    * the argument number of the t1 ... tn they were generated from.
    */
   std::map<Node, std::vector<int> > d_flat_form_index;
+  /** The string-like type for this solver */
+  TypeNode d_type;
 }; /* class CoreSolver */
 
 }  // namespace strings
index ac28be2453bc4a20243538022b9738458f2a1bb6..7a2323d891696578dbb1dc97cd4e0bc8e7884051 100644 (file)
@@ -150,7 +150,7 @@ Node NormalForm::collectConstantStringAt(size_t& index)
     {
       std::reverse(c.begin(), c.end());
     }
-    Node cc = Rewriter::rewrite(utils::mkConcat(STRING_CONCAT, c));
+    Node cc = Rewriter::rewrite(utils::mkConcat(c, c[0].getType()));
     Assert(cc.isConst());
     return cc;
   }
index 976efad3c028ef5c8a86fa10aaa7b98895b02cac..259588789f69cc64c03031d664a1f7b033f0d0df 100644 (file)
@@ -30,6 +30,7 @@ RegExpElimination::RegExpElimination()
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
   d_one = NodeManager::currentNM()->mkConst(Rational(1));
   d_neg_one = NodeManager::currentNM()->mkConst(Rational(-1));
+  d_regExpType = NodeManager::currentNM()->regExpType();
 }
 
 Node RegExpElimination::eliminate(Node atom)
@@ -382,7 +383,7 @@ Node RegExpElimination::eliminateConcat(Node atom)
     Assert(rexpElimChildren.size() + sConstraints.size() == nchildren);
     Node ss = nm->mkNode(STRING_SUBSTR, x, sStartIndex, sLength);
     Assert(!rexpElimChildren.empty());
-    Node regElim = utils::mkConcat(REGEXP_CONCAT, rexpElimChildren);
+    Node regElim = utils::mkConcat(rexpElimChildren, d_regExpType);
     sConstraints.push_back(nm->mkNode(STRING_IN_REGEXP, ss, regElim));
     Node ret = nm->mkNode(AND, sConstraints);
     // e.g.
@@ -422,7 +423,7 @@ Node RegExpElimination::eliminateConcat(Node atom)
       {
         std::vector<Node> rprefix;
         rprefix.insert(rprefix.end(), children.begin(), children.begin() + i);
-        Node rpn = utils::mkConcat(REGEXP_CONCAT, rprefix);
+        Node rpn = utils::mkConcat(rprefix, d_regExpType);
         Node substrPrefix = nm->mkNode(
             STRING_IN_REGEXP, nm->mkNode(STRING_SUBSTR, x, d_zero, k), rpn);
         echildren.push_back(substrPrefix);
@@ -431,7 +432,7 @@ Node RegExpElimination::eliminateConcat(Node atom)
       {
         std::vector<Node> rsuffix;
         rsuffix.insert(rsuffix.end(), children.begin() + i + 1, children.end());
-        Node rps = utils::mkConcat(REGEXP_CONCAT, rsuffix);
+        Node rps = utils::mkConcat(rsuffix, d_regExpType);
         Node ks = nm->mkNode(PLUS, k, lens);
         Node substrSuffix = nm->mkNode(
             STRING_IN_REGEXP,
index dbd4102b663b18c211ebc1b64032b20227d650df..61ce8a920eb58594732400537d999f15d9409402 100644 (file)
@@ -47,6 +47,8 @@ class RegExpElimination
   Node d_zero;
   Node d_one;
   Node d_neg_one;
+  /** The type of regular expressions */
+  TypeNode d_regExpType;
   /** return elimination
    *
    * This method is called when atom is rewritten to atomElim, and returns
index cd66c0ebfc9d0d3aa30cd4e46438f3cf01cdf07e..30f9c4a7303d9c72001df0f91a873cceb52c845b 100644 (file)
@@ -621,7 +621,7 @@ bool RegExpSolver::deriveRegExp(Node x,
         {
           vec_nodes.push_back(x[i]);
         }
-        Node left = utils::mkConcat(STRING_CONCAT, vec_nodes);
+        Node left = utils::mkConcat(vec_nodes, x.getType());
         left = Rewriter::rewrite(left);
         conc = NodeManager::currentNM()->mkNode(STRING_IN_REGEXP, left, dc);
       }
index f4a1cd411e70fa316bb6df297934635cbd293828..200d7a7342d56d3431ca513bd8d5c9094300ed52 100644 (file)
@@ -410,6 +410,7 @@ Node SequencesRewriter::rewriteEqualityExt(Node node)
 Node SequencesRewriter::rewriteStrEqualityExt(Node node)
 {
   Assert(node.getKind() == EQUAL && node[0].getType().isString());
+  TypeNode stype = node[0].getType();
 
   NodeManager* nm = NodeManager::currentNM();
   std::vector<Node> c[2];
@@ -461,8 +462,8 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
   if (changed)
   {
     // e.g. x++y = x++z ---> y = z, "AB" ++ x = "A" ++ y --> "B" ++ x = y
-    Node s1 = utils::mkConcat(STRING_CONCAT, c[0]);
-    Node s2 = utils::mkConcat(STRING_CONCAT, c[1]);
+    Node s1 = utils::mkConcat(c[0], stype);
+    Node s2 = utils::mkConcat(c[1], stype);
     new_ret = s1.eqNode(s2);
     node = returnRewrite(node, new_ret, "str-eq-unify");
   }
@@ -531,8 +532,8 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
         }
       }
 
-      Node lhs = utils::mkConcat(STRING_CONCAT, trimmed[i]);
-      Node ss = utils::mkConcat(STRING_CONCAT, trimmed[1 - i]);
+      Node lhs = utils::mkConcat(trimmed[i], stype);
+      Node ss = utils::mkConcat(trimmed[1 - i], stype);
       if (lhs != node[i] || ss != node[1 - i])
       {
         // e.g.
@@ -696,13 +697,13 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
     for (size_t i = 0, size0 = v0.size(); i <= size0; i++)
     {
       std::vector<Node> pfxv0(v0.begin(), v0.begin() + i);
-      Node pfx0 = utils::mkConcat(STRING_CONCAT, pfxv0);
+      Node pfx0 = utils::mkConcat(pfxv0, stype);
       for (size_t j = startRhs, size1 = v1.size(); j <= size1; j++)
       {
         if (!(i == 0 && j == 0) && !(i == v0.size() && j == v1.size()))
         {
           std::vector<Node> pfxv1(v1.begin(), v1.begin() + j);
-          Node pfx1 = utils::mkConcat(STRING_CONCAT, pfxv1);
+          Node pfx1 = utils::mkConcat(pfxv1, stype);
           Node lenPfx0 = nm->mkNode(STRING_LENGTH, pfx0);
           Node lenPfx1 = nm->mkNode(STRING_LENGTH, pfx1);
 
@@ -710,11 +711,10 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
           {
             std::vector<Node> sfxv0(v0.begin() + i, v0.end());
             std::vector<Node> sfxv1(v1.begin() + j, v1.end());
-            Node ret =
-                nm->mkNode(kind::AND,
-                           pfx0.eqNode(pfx1),
-                           utils::mkConcat(STRING_CONCAT, sfxv0)
-                               .eqNode(utils::mkConcat(STRING_CONCAT, sfxv1)));
+            Node ret = nm->mkNode(kind::AND,
+                                  pfx0.eqNode(pfx1),
+                                  utils::mkConcat(sfxv0, stype)
+                                      .eqNode(utils::mkConcat(sfxv1, stype)));
             return returnRewrite(node, ret, "split-eq");
           }
           else if (checkEntailArith(lenPfx1, lenPfx0, true))
@@ -731,11 +731,10 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
             {
               std::vector<Node> sfxv0(v0.begin() + i, v0.end());
               pfxv1.insert(pfxv1.end(), v1.begin() + j, v1.end());
-              Node ret = nm->mkNode(
-                  kind::AND,
-                  pfx0.eqNode(utils::mkConcat(STRING_CONCAT, rpfxv1)),
-                  utils::mkConcat(STRING_CONCAT, sfxv0)
-                      .eqNode(utils::mkConcat(STRING_CONCAT, pfxv1)));
+              Node ret = nm->mkNode(kind::AND,
+                                    pfx0.eqNode(utils::mkConcat(rpfxv1, stype)),
+                                    utils::mkConcat(sfxv0, stype)
+                                        .eqNode(utils::mkConcat(pfxv1, stype)));
               return returnRewrite(node, ret, "split-eq-strip-r");
             }
 
@@ -759,11 +758,10 @@ Node SequencesRewriter::rewriteStrEqualityExt(Node node)
             {
               pfxv0.insert(pfxv0.end(), v0.begin() + i, v0.end());
               std::vector<Node> sfxv1(v1.begin() + j, v1.end());
-              Node ret = nm->mkNode(
-                  kind::AND,
-                  utils::mkConcat(STRING_CONCAT, rpfxv0).eqNode(pfx1),
-                  utils::mkConcat(STRING_CONCAT, pfxv0)
-                      .eqNode(utils::mkConcat(STRING_CONCAT, sfxv1)));
+              Node ret = nm->mkNode(kind::AND,
+                                    utils::mkConcat(rpfxv0, stype).eqNode(pfx1),
+                                    utils::mkConcat(pfxv0, stype)
+                                        .eqNode(utils::mkConcat(sfxv1, stype)));
               return returnRewrite(node, ret, "split-eq-strip-l");
             }
 
@@ -891,7 +889,8 @@ Node SequencesRewriter::rewriteConcat(Node node)
   }
   std::sort(node_vec.begin() + lastIdx, node_vec.end());
 
-  retNode = utils::mkConcat(STRING_CONCAT, node_vec);
+  TypeNode tn = node.getType();
+  retNode = utils::mkConcat(node_vec, tn);
   Trace("strings-rewrite-debug")
       << "Strings::rewriteConcat end " << retNode << std::endl;
   return retNode;
@@ -907,6 +906,20 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node)
   std::vector<Node> vec;
   bool changed = false;
   Node emptyRe;
+
+  // get the string type that are members of this regular expression
+  TypeNode rtype = node.getType();
+  TypeNode stype;
+  if (rtype.isRegExp())
+  {
+    // standard regular expressions are for strings
+    stype = nm->stringType();
+  }
+  else
+  {
+    Unimplemented();
+  }
+
   for (const Node& c : node)
   {
     if (c.getKind() == REGEXP_CONCAT)
@@ -978,8 +991,7 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node)
     {
       Assert(!lastAllStar);
       // this groups consecutive strings a++b ---> ab
-      Node acc = nm->mkNode(STRING_TO_REGEXP,
-                            utils::mkConcat(STRING_CONCAT, preReStr));
+      Node acc = nm->mkNode(STRING_TO_REGEXP, utils::mkConcat(preReStr, stype));
       cvec.push_back(acc);
       preReStr.clear();
     }
@@ -1022,11 +1034,11 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node)
     }
   }
   Assert(!cvec.empty());
-  retNode = utils::mkConcat(REGEXP_CONCAT, cvec);
+  retNode = utils::mkConcat(cvec, rtype);
   if (retNode != node)
   {
-    // handles all cases where consecutive re constants are combined or dropped
-    // as described in the loop above.
+    // handles all cases where consecutive re constants are combined or
+    // dropped as described in the loop above.
     return returnRewrite(node, retNode, "re.concat");
   }
 
@@ -1043,7 +1055,7 @@ Node SequencesRewriter::rewriteConcatRegExp(TNode node)
   }
   if (changed)
   {
-    retNode = utils::mkConcat(REGEXP_CONCAT, cvec);
+    retNode = utils::mkConcat(cvec, rtype);
     return returnRewrite(node, retNode, "re.concat.opt");
   }
   return node;
@@ -1215,10 +1227,11 @@ Node SequencesRewriter::rewriteLoopRegExp(TNode node)
     {
       std::vector<Node> vec2;
       vec2.push_back(n);
+      TypeNode rtype = nm->regExpType();
       for (unsigned j = l; j < u; j++)
       {
         vec_nodes.push_back(r);
-        n = utils::mkConcat(REGEXP_CONCAT, vec_nodes);
+        n = utils::mkConcat(vec_nodes, rtype);
         vec2.push_back(n);
       }
       retNode = nm->mkNode(REGEXP_UNION, vec2);
@@ -1518,9 +1531,12 @@ Node SequencesRewriter::rewriteMembership(TNode node)
   Node x = node[0];
   Node r = node[1];
 
-  if (r.getKind() == kind::REGEXP_EMPTY)
+  TypeNode stype = x.getType();
+  TypeNode rtype = r.getType();
+
+  if(r.getKind() == kind::REGEXP_EMPTY) 
   {
-    retNode = NodeManager::currentNM()->mkConst(false);
+    retNode = NodeManager::currentNM()->mkConst( false );
   }
   else if (x.isConst() && isConstRegExp(r))
   {
@@ -1708,7 +1724,7 @@ Node SequencesRewriter::rewriteMembership(TNode node)
             else
             {
               retNode = nm->mkNode(STRING_IN_REGEXP,
-                                   utils::mkConcat(STRING_CONCAT, mchildren),
+                                   utils::mkConcat(mchildren, stype),
                                    r);
               success = true;
             }
@@ -1743,7 +1759,7 @@ Node SequencesRewriter::rewriteMembership(TNode node)
           // Given a membership (str.++ x1 ... xn) in (re.++ r1 ... rm),
           // above, we strip components to construct an equivalent membership:
           // (str.++ xi .. xj) in (re.++ rk ... rl).
-          Node xn = utils::mkConcat(STRING_CONCAT, mchildren);
+          Node xn = utils::mkConcat(mchildren, stype);
           Node emptyStr = nm->mkConst(String(""));
           if (children.empty())
           {
@@ -1756,7 +1772,7 @@ Node SequencesRewriter::rewriteMembership(TNode node)
           {
             // otherwise, construct the updated regular expression
             retNode = nm->mkNode(
-                STRING_IN_REGEXP, xn, utils::mkConcat(REGEXP_CONCAT, children));
+                STRING_IN_REGEXP, xn, utils::mkConcat(children, rtype));
           }
           Trace("regexp-ext-rewrite") << "Regexp : rewrite : " << node << " -> "
                                       << retNode << std::endl;
@@ -2115,6 +2131,7 @@ Node SequencesRewriter::rewriteSubstr(Node node)
 
   std::vector<Node> n1;
   utils::getConcat(node[0], n1);
+  TypeNode stype = node.getType();
 
   // definite inclusion
   if (node[1] == zero)
@@ -2125,12 +2142,10 @@ Node SequencesRewriter::rewriteSubstr(Node node)
     {
       if (curr != zero && !n1.empty())
       {
-        childrenr.push_back(nm->mkNode(kind::STRING_SUBSTR,
-                                       utils::mkConcat(STRING_CONCAT, n1),
-                                       node[1],
-                                       curr));
+        childrenr.push_back(nm->mkNode(
+            kind::STRING_SUBSTR, utils::mkConcat(n1, stype), node[1], curr));
       }
-      Node ret = utils::mkConcat(STRING_CONCAT, childrenr);
+      Node ret = utils::mkConcat(childrenr, stype);
       return returnRewrite(node, ret, "ss-len-include");
     }
   }
@@ -2211,16 +2226,14 @@ Node SequencesRewriter::rewriteSubstr(Node node)
       {
         if (r == 0)
         {
-          Node ret = nm->mkNode(kind::STRING_SUBSTR,
-                                utils::mkConcat(STRING_CONCAT, n1),
-                                curr,
-                                node[2]);
+          Node ret = nm->mkNode(
+              kind::STRING_SUBSTR, utils::mkConcat(n1, stype), curr, node[2]);
           return returnRewrite(node, ret, "ss-strip-start-pt");
         }
         else
         {
           Node ret = nm->mkNode(kind::STRING_SUBSTR,
-                                utils::mkConcat(STRING_CONCAT, n1),
+                                utils::mkConcat(n1, stype),
                                 node[1],
                                 node[2]);
           return returnRewrite(node, ret, "ss-strip-end-pt");
@@ -2395,6 +2408,7 @@ Node SequencesRewriter::rewriteContains(Node node)
     Node ret = NodeManager::currentNM()->mkConst(true);
     return returnRewrite(node, ret, "ctn-component");
   }
+  TypeNode stype = node[0].getType();
 
   // strip endpoints
   std::vector<Node> nb;
@@ -2402,7 +2416,7 @@ Node SequencesRewriter::rewriteContains(Node node)
   if (stripConstantEndpoints(nc1, nc2, nb, ne))
   {
     Node ret = NodeManager::currentNM()->mkNode(
-        kind::STRING_STRCTN, utils::mkConcat(STRING_CONCAT, nc1), node[1]);
+        kind::STRING_STRCTN, utils::mkConcat(nc1, stype), node[1]);
     return returnRewrite(node, ret, "ctn-strip-endpt");
   }
 
@@ -2517,14 +2531,12 @@ Node SequencesRewriter::rewriteContains(Node node)
             spl[1].insert(spl[1].end(), nc0.begin() + i + 1, nc0.end());
             Node ret = NodeManager::currentNM()->mkNode(
                 kind::OR,
-                NodeManager::currentNM()->mkNode(
-                    kind::STRING_STRCTN,
-                    utils::mkConcat(STRING_CONCAT, spl[0]),
-                    node[1]),
-                NodeManager::currentNM()->mkNode(
-                    kind::STRING_STRCTN,
-                    utils::mkConcat(STRING_CONCAT, spl[1]),
-                    node[1]));
+                NodeManager::currentNM()->mkNode(kind::STRING_STRCTN,
+                                                 utils::mkConcat(spl[0], stype),
+                                                 node[1]),
+                NodeManager::currentNM()->mkNode(kind::STRING_STRCTN,
+                                                 utils::mkConcat(spl[1], stype),
+                                                 node[1]));
             return returnRewrite(node, ret, "ctn-split");
           }
         }
@@ -2643,6 +2655,9 @@ Node SequencesRewriter::rewriteIndexof(Node node)
     return returnRewrite(node, negone, "idof-neg");
   }
 
+  // the string type
+  TypeNode stype = node[0].getType();
+
   // evaluation and simple cases
   std::vector<Node> children0;
   utils::getConcat(node[0], children0);
@@ -2751,7 +2766,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
         {
           // For example:
           // str.indexof(str.++(x,y,z),y,0) ---> str.indexof(str.++(x,y),y,0)
-          Node nn = utils::mkConcat(STRING_CONCAT, children0);
+          Node nn = utils::mkConcat(children0, stype);
           Node ret = nm->mkNode(kind::STRING_STRIDOF, nn, node[1], node[2]);
           return returnRewrite(node, ret, "idof-def-ctn");
         }
@@ -2761,14 +2776,13 @@ Node SequencesRewriter::rewriteIndexof(Node node)
         {
           // str.indexof(str.++("AB", x, "C"), "C", 0) --->
           // 2 + str.indexof(str.++(x, "C"), "C", 0)
-          Node ret =
-              nm->mkNode(kind::PLUS,
-                         nm->mkNode(kind::STRING_LENGTH,
-                                    utils::mkConcat(STRING_CONCAT, nb)),
-                         nm->mkNode(kind::STRING_STRIDOF,
-                                    utils::mkConcat(STRING_CONCAT, children0),
-                                    node[1],
-                                    node[2]));
+          Node ret = nm->mkNode(
+              kind::PLUS,
+              nm->mkNode(kind::STRING_LENGTH, utils::mkConcat(nb, stype)),
+              nm->mkNode(kind::STRING_STRIDOF,
+                         utils::mkConcat(children0, stype),
+                         node[1],
+                         node[2]));
           return returnRewrite(node, ret, "idof-strip-cnst-endpts");
         }
       }
@@ -2783,7 +2797,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
         // implies
         // str.indexof( str.++( x1, x2 ), y, z ) --->
         // str.len( x1 ) + str.indexof( x2, y, z-str.len(x1) )
-        Node nn = utils::mkConcat(STRING_CONCAT, children0);
+        Node nn = utils::mkConcat(children0, stype);
         Node ret =
             nm->mkNode(kind::PLUS,
                        nm->mkNode(kind::MINUS, node[2], new_len),
@@ -2809,7 +2823,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
       // For example:
       // str.indexof(str.++("ABCD", x), y, 3) --->
       // str.indexof(str.++("AAAD", x), y, 3)
-      Node nodeNr = utils::mkConcat(STRING_CONCAT, nr);
+      Node nodeNr = utils::mkConcat(nr, stype);
       Node normNr = lengthPreserveRewrite(nodeNr);
       if (normNr != nodeNr)
       {
@@ -2817,7 +2831,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
         utils::getConcat(normNr, normNrChildren);
         std::vector<Node> children(normNrChildren);
         children.insert(children.end(), children0.begin(), children0.end());
-        Node nn = utils::mkConcat(STRING_CONCAT, children);
+        Node nn = utils::mkConcat(children, stype);
         Node res = nm->mkNode(kind::STRING_STRIDOF, nn, node[1], node[2]);
         return returnRewrite(node, res, "idof-norm-prefix");
       }
@@ -2830,7 +2844,7 @@ Node SequencesRewriter::rewriteIndexof(Node node)
     std::vector<Node> ce;
     if (stripConstantEndpoints(children0, children1, cb, ce, -1))
     {
-      Node ret = utils::mkConcat(STRING_CONCAT, children0);
+      Node ret = utils::mkConcat(children0, stype);
       ret = nm->mkNode(STRING_STRIDOF, ret, node[1], node[2]);
       // For example:
       // str.indexof( str.++( x, "A" ), "B", 0 ) ---> str.indexof( x, "B", 0 )
@@ -2852,6 +2866,8 @@ Node SequencesRewriter::rewriteReplace(Node node)
     Node ret = nm->mkNode(STRING_CONCAT, node[2], node[0]);
     return returnRewrite(node, ret, "rpl-rpl-empty");
   }
+  // the string type
+  TypeNode stype = node.getType();
 
   std::vector<Node> children0;
   utils::getConcat(node[0], children0);
@@ -2883,7 +2899,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
         children.push_back(s3);
       }
       children.insert(children.end(), children0.begin() + 1, children0.end());
-      Node ret = utils::mkConcat(STRING_CONCAT, children);
+      Node ret = utils::mkConcat(children, stype);
       return returnRewrite(node, ret, "rpl-const-find");
     }
   }
@@ -2921,7 +2937,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
 
         if (allEmptyEqs)
         {
-          Node nn1 = utils::mkConcat(STRING_CONCAT, emptyNodes);
+          Node nn1 = utils::mkConcat(emptyNodes, stype);
           if (node[1] != nn1)
           {
             Node ret = nm->mkNode(STRING_STRREPL, node[0], nn1, node[2]);
@@ -2957,7 +2973,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
           std::vector<Node> cres;
           cres.push_back(node[2]);
           cres.insert(cres.end(), ce.begin(), ce.end());
-          Node ret = utils::mkConcat(STRING_CONCAT, cres);
+          Node ret = utils::mkConcat(cres, stype);
           return returnRewrite(node, ret, "rpl-cctn-rpl");
         }
         else if (!ce.empty())
@@ -2970,11 +2986,11 @@ Node SequencesRewriter::rewriteReplace(Node node)
           std::vector<Node> scc;
           scc.push_back(NodeManager::currentNM()->mkNode(
               kind::STRING_STRREPL,
-              utils::mkConcat(STRING_CONCAT, children0),
+              utils::mkConcat(children0, stype),
               node[1],
               node[2]));
           scc.insert(scc.end(), ce.begin(), ce.end());
-          Node ret = utils::mkConcat(STRING_CONCAT, scc);
+          Node ret = utils::mkConcat(scc, stype);
           return returnRewrite(node, ret, "rpl-cctn");
         }
       }
@@ -3022,7 +3038,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
       if (node[0] == empty && allEmptyEqs)
       {
         std::vector<Node> emptyNodesList(emptyNodes.begin(), emptyNodes.end());
-        Node nn1 = utils::mkConcat(STRING_CONCAT, emptyNodesList);
+        Node nn1 = utils::mkConcat(emptyNodesList, stype);
         if (nn1 != node[1] || nn2 != node[2])
         {
           Node res = nm->mkNode(kind::STRING_STRREPL, node[0], nn1, nn2);
@@ -3052,13 +3068,13 @@ Node SequencesRewriter::rewriteReplace(Node node)
       {
         std::vector<Node> cc;
         cc.insert(cc.end(), cb.begin(), cb.end());
-        cc.push_back(NodeManager::currentNM()->mkNode(
-            kind::STRING_STRREPL,
-            utils::mkConcat(STRING_CONCAT, children0),
-            node[1],
-            node[2]));
+        cc.push_back(
+            NodeManager::currentNM()->mkNode(kind::STRING_STRREPL,
+                                             utils::mkConcat(children0, stype),
+                                             node[1],
+                                             node[2]));
         cc.insert(cc.end(), ce.begin(), ce.end());
-        Node ret = utils::mkConcat(STRING_CONCAT, cc);
+        Node ret = utils::mkConcat(cc, stype);
         return returnRewrite(node, ret, "rpl-pull-endpt");
       }
     }
@@ -3080,8 +3096,8 @@ Node SequencesRewriter::rewriteReplace(Node node)
 
     children1.pop_back();
     // Length of the non-substr components in the second argument
-    Node partLen1 = nm->mkNode(kind::STRING_LENGTH,
-                               utils::mkConcat(STRING_CONCAT, children1));
+    Node partLen1 =
+        nm->mkNode(kind::STRING_LENGTH, utils::mkConcat(children1, stype));
     Node maxLen1 = nm->mkNode(kind::PLUS, partLen1, lastChild1[2]);
 
     Node zero = nm->mkConst(Rational(0));
@@ -3099,7 +3115,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
               kind::PLUS, len0, one, nm->mkNode(kind::UMINUS, partLen1))));
       Node res = nm->mkNode(kind::STRING_STRREPL,
                             node[0],
-                            utils::mkConcat(STRING_CONCAT, children1),
+                            utils::mkConcat(children1, stype),
                             node[2]);
       return returnRewrite(node, res, "repl-subst-idx");
     }
@@ -3285,7 +3301,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
       std::vector<Node> checkLhs;
       checkLhs.insert(
           checkLhs.end(), children0.begin(), children0.begin() + checkIndex);
-      Node lhs = utils::mkConcat(STRING_CONCAT, checkLhs);
+      Node lhs = utils::mkConcat(checkLhs, stype);
       Node rhs = children0[checkIndex];
       Node ctn = checkEntailContains(lhs, rhs);
       if (!ctn.isNull() && ctn.getConst<bool>())
@@ -3302,7 +3318,7 @@ Node SequencesRewriter::rewriteReplace(Node node)
     {
       std::vector<Node> remc(children0.begin() + lastCheckIndex,
                              children0.end());
-      Node rem = utils::mkConcat(STRING_CONCAT, remc);
+      Node rem = utils::mkConcat(remc, stype);
       Node ret =
           nm->mkNode(STRING_CONCAT,
                      nm->mkNode(STRING_STRREPL, lastLhs, node[1], node[2]),
@@ -3331,6 +3347,8 @@ Node SequencesRewriter::rewriteReplaceAll(Node node)
 {
   Assert(node.getKind() == STRING_STRREPLALL);
 
+  TypeNode stype = node.getType();
+
   if (node[0].isConst() && node[1].isConst())
   {
     std::vector<Node> children;
@@ -3362,7 +3380,7 @@ Node SequencesRewriter::rewriteReplaceAll(Node node)
       }
     } while (curr != std::string::npos && curr < sizeS);
     // constant evaluation
-    Node res = utils::mkConcat(STRING_CONCAT, children);
+    Node res = utils::mkConcat(children, stype);
     return returnRewrite(node, res, "replall-const");
   }
 
@@ -5411,6 +5429,8 @@ Node SequencesRewriter::inferEqsFromContains(Node x, Node y)
 {
   NodeManager* nm = NodeManager::currentNM();
   Node emp = nm->mkConst(String(""));
+  Assert(x.getType() == y.getType());
+  TypeNode stype = x.getType();
 
   Node xLen = nm->mkNode(STRING_LENGTH, x);
   std::vector<Node> yLens;
@@ -5481,7 +5501,7 @@ Node SequencesRewriter::inferEqsFromContains(Node x, Node y)
   // (= x (str.++ y1' ... ym'))
   if (!cs.empty())
   {
-    nb << nm->mkNode(EQUAL, x, utils::mkConcat(STRING_CONCAT, cs));
+    nb << nm->mkNode(EQUAL, x, utils::mkConcat(cs, stype));
   }
   // (= y1'' "") ... (= yk'' "")
   for (const Node& zeroLen : zeroLens)
index e6e0f8557601080bc0e6035d3781756d40f48de4..a26669fbf2fb3096dc4c15430722cf3110eaa78b 100644 (file)
@@ -519,7 +519,7 @@ bool TheoryStrings::collectModelInfoType(
         Assert(r.isConst() || processed.find(r) != processed.end());
         nc.push_back(r.isConst() ? r : processed[r]);
       }
-      Node cc = utils::mkNConcat(nc);
+      Node cc = utils::mkNConcat(nc, tn);
       Assert(cc.isConst());
       Trace("strings-model") << "*** Determined constant " << cc << " for " << nodes[i] << std::endl;
       processed[nodes[i]] = cc;
@@ -1115,7 +1115,7 @@ void TheoryStrings::checkRegisterTermsNormalForms()
     Node lt = ei ? ei->d_lengthTerm : Node::null();
     if (lt.isNull())
     {
-      Node c = utils::mkNConcat(nfi.d_nf);
+      Node c = utils::mkNConcat(nfi.d_nf, eqc.getType());
       registerTerm(c, 3);
     }
   }
index a325108e446a0970ea41906c6d48b863121685ce..5d27b8e2b06db84d911d1ee4213a575aa99e47cf 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "options/strings_options.h"
 #include "theory/rewriter.h"
+#include "theory/strings/word.h"
 
 using namespace CVC4::kind;
 
@@ -115,12 +116,20 @@ void getConcat(Node n, std::vector<Node>& c)
   }
 }
 
-Node mkConcat(Kind k, const std::vector<Node>& c)
+Node mkConcat(const std::vector<Node>& c, TypeNode tn)
 {
-  Assert(!c.empty() || k == STRING_CONCAT);
-  NodeManager* nm = NodeManager::currentNM();
-  return c.size() > 1 ? nm->mkNode(k, c)
-                      : (c.size() == 1 ? c[0] : nm->mkConst(String("")));
+  Assert(tn.isStringLike() || tn.isRegExp());
+  if (c.empty())
+  {
+    Assert(tn.isStringLike());
+    return Word::mkEmptyWord(tn);
+  }
+  else if (c.size() == 1)
+  {
+    return c[0];
+  }
+  Kind k = tn.isStringLike() ? STRING_CONCAT : REGEXP_CONCAT;
+  return NodeManager::currentNM()->mkNode(k, c);
 }
 
 Node mkNConcat(Node n1, Node n2)
@@ -135,9 +144,9 @@ Node mkNConcat(Node n1, Node n2, Node n3)
       NodeManager::currentNM()->mkNode(STRING_CONCAT, n1, n2, n3));
 }
 
-Node mkNConcat(const std::vector<Node>& c)
+Node mkNConcat(const std::vector<Node>& c, TypeNode tn)
 {
-  return Rewriter::rewrite(mkConcat(STRING_CONCAT, c));
+  return Rewriter::rewrite(mkConcat(c, tn));
 }
 
 Node mkNLength(Node t)
@@ -147,12 +156,11 @@ Node mkNLength(Node t)
 
 Node getConstantComponent(Node t)
 {
-  Kind tk = t.getKind();
-  if (tk == STRING_TO_REGEXP)
+  if (t.getKind() == STRING_TO_REGEXP)
   {
     return t[0].isConst() ? t[0] : Node::null();
   }
-  return tk == CONST_STRING ? t : Node::null();
+  return t.isConst() ? t : Node::null();
 }
 
 Node getConstantEndpoint(Node e, bool isSuf)
index 51fe8cfc7650d2e1b77515c82b65712676b3e276..5f18d393611b141545198f8ad2258f3cd286b7b7 100644 (file)
@@ -56,10 +56,10 @@ void flattenOp(Kind k, Node n, std::vector<Node>& conj);
 void getConcat(Node n, std::vector<Node>& c);
 
 /**
- * Make the concatentation from vector c
- * The kind k is either STRING_CONCAT or REGEXP_CONCAT.
+ * Make the concatentation from vector c of (string-like or regular
+ * expression) type tn.
  */
-Node mkConcat(Kind k, const std::vector<Node>& c);
+Node mkConcat(const std::vector<Node>& c, TypeNode tn);
 
 /**
  * Returns the rewritten form of the string concatenation of n1 and n2.
@@ -72,9 +72,10 @@ Node mkNConcat(Node n1, Node n2);
 Node mkNConcat(Node n1, Node n2, Node n3);
 
 /**
- * Returns the rewritten form of the string concatenation of nodes in c.
+ * Returns the rewritten form of the concatentation from vector c of
+ * (string-like) type tn.
  */
-Node mkNConcat(const std::vector<Node>& c);
+Node mkNConcat(const std::vector<Node>& c, TypeNode tn);
 
 /**
  * Returns the rewritten form of the length of string term t.