Split regular expression solver (#2891)
[cvc5.git] / src / theory / strings / regexp_solver.cpp
1 /********************* */
2 /*! \file regexp_solver.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of the regular expression solver for the theory of
13 ** strings.
14 **
15 **/
16
17 #include "theory/strings/regexp_solver.h"
18
19 #include <cmath>
20
21 #include "options/strings_options.h"
22 #include "theory/strings/theory_strings.h"
23 #include "theory/strings/theory_strings_rewriter.h"
24 #include "theory/theory_model.h"
25
26 using namespace std;
27 using namespace CVC4::context;
28 using namespace CVC4::kind;
29
30 namespace CVC4 {
31 namespace theory {
32 namespace strings {
33
34 RegExpSolver::RegExpSolver(TheoryStrings& p,
35 context::Context* c,
36 context::UserContext* u)
37 : d_parent(p),
38 d_regexp_memberships(c),
39 d_regexp_ucached(u),
40 d_regexp_ccached(c),
41 d_pos_memberships(c),
42 d_neg_memberships(c),
43 d_inter_cache(c),
44 d_inter_index(c),
45 d_processed_memberships(c)
46 {
47 d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String(""));
48 std::vector<Node> nvec;
49 d_emptyRegexp = NodeManager::currentNM()->mkNode(REGEXP_EMPTY, nvec);
50 d_true = NodeManager::currentNM()->mkConst(true);
51 d_false = NodeManager::currentNM()->mkConst(false);
52 }
53
54 unsigned RegExpSolver::getNumMemberships(Node n, bool isPos)
55 {
56 if (isPos)
57 {
58 NodeUIntMap::const_iterator it = d_pos_memberships.find(n);
59 if (it != d_pos_memberships.end())
60 {
61 return (*it).second;
62 }
63 }
64 else
65 {
66 NodeUIntMap::const_iterator it = d_neg_memberships.find(n);
67 if (it != d_neg_memberships.end())
68 {
69 return (*it).second;
70 }
71 }
72 return 0;
73 }
74
75 Node RegExpSolver::getMembership(Node n, bool isPos, unsigned i)
76 {
77 return isPos ? d_pos_memberships_data[n][i] : d_neg_memberships_data[n][i];
78 }
79
80 Node RegExpSolver::mkAnd(Node c1, Node c2)
81 {
82 return NodeManager::currentNM()->mkNode(AND, c1, c2);
83 }
84
85 void RegExpSolver::check()
86 {
87 bool addedLemma = false;
88 bool changed = false;
89 std::vector<Node> processed;
90 std::vector<Node> cprocessed;
91
92 Trace("regexp-debug") << "Checking Memberships ... " << std::endl;
93 for (NodeUIntMap::const_iterator itr_xr = d_pos_memberships.begin();
94 itr_xr != d_pos_memberships.end();
95 ++itr_xr)
96 {
97 bool spflag = false;
98 Node x = (*itr_xr).first;
99 Trace("regexp-debug") << "Checking Memberships for " << x << std::endl;
100 if (d_inter_index.find(x) == d_inter_index.end())
101 {
102 d_inter_index[x] = 0;
103 }
104 int cur_inter_idx = d_inter_index[x];
105 unsigned n_pmem = (*itr_xr).second;
106 Assert(getNumMemberships(x, true) == n_pmem);
107 if (cur_inter_idx != (int)n_pmem)
108 {
109 if (n_pmem == 1)
110 {
111 d_inter_cache[x] = getMembership(x, true, 0);
112 d_inter_index[x] = 1;
113 Trace("regexp-debug") << "... only one choice " << std::endl;
114 }
115 else if (n_pmem > 1)
116 {
117 Node r;
118 if (d_inter_cache.find(x) != d_inter_cache.end())
119 {
120 r = d_inter_cache[x];
121 }
122 if (r.isNull())
123 {
124 r = getMembership(x, true, 0);
125 cur_inter_idx = 1;
126 }
127
128 unsigned k_start = cur_inter_idx;
129 Trace("regexp-debug") << "... staring from : " << cur_inter_idx
130 << ", we have " << n_pmem << std::endl;
131 for (unsigned k = k_start; k < n_pmem; k++)
132 {
133 Node r2 = getMembership(x, true, k);
134 r = d_regexp_opr.intersect(r, r2, spflag);
135 if (spflag)
136 {
137 break;
138 }
139 else if (r == d_emptyRegexp)
140 {
141 std::vector<Node> vec_nodes;
142 for (unsigned kk = 0; kk <= k; kk++)
143 {
144 Node rr = getMembership(x, true, kk);
145 Node n =
146 NodeManager::currentNM()->mkNode(STRING_IN_REGEXP, x, rr);
147 vec_nodes.push_back(n);
148 }
149 Node conc;
150 d_parent.sendInference(vec_nodes, conc, "INTERSECT CONFLICT", true);
151 addedLemma = true;
152 break;
153 }
154 if (d_parent.inConflict())
155 {
156 break;
157 }
158 }
159 // updates
160 if (!d_parent.inConflict() && !spflag)
161 {
162 d_inter_cache[x] = r;
163 d_inter_index[x] = (int)n_pmem;
164 }
165 }
166 }
167 }
168
169 Trace("regexp-debug")
170 << "... No Intersect Conflict in Memberships, addedLemma: " << addedLemma
171 << std::endl;
172 if (!addedLemma)
173 {
174 NodeManager* nm = NodeManager::currentNM();
175 for (unsigned i = 0; i < d_regexp_memberships.size(); i++)
176 {
177 // check regular expression membership
178 Node assertion = d_regexp_memberships[i];
179 Trace("regexp-debug")
180 << "Check : " << assertion << " "
181 << (d_regexp_ucached.find(assertion) == d_regexp_ucached.end()) << " "
182 << (d_regexp_ccached.find(assertion) == d_regexp_ccached.end())
183 << std::endl;
184 if (d_regexp_ucached.find(assertion) == d_regexp_ucached.end()
185 && d_regexp_ccached.find(assertion) == d_regexp_ccached.end())
186 {
187 Trace("strings-regexp")
188 << "We have regular expression assertion : " << assertion
189 << std::endl;
190 Node atom = assertion.getKind() == NOT ? assertion[0] : assertion;
191 bool polarity = assertion.getKind() != NOT;
192 bool flag = true;
193 Node x = atom[0];
194 Node r = atom[1];
195 std::vector<Node> rnfexp;
196
197 if (!x.isConst())
198 {
199 x = d_parent.getNormalString(x, rnfexp);
200 changed = true;
201 }
202 if (!d_regexp_opr.checkConstRegExp(r))
203 {
204 r = getNormalSymRegExp(r, rnfexp);
205 changed = true;
206 }
207 Trace("strings-regexp-nf") << "Term " << atom << " is normalized to "
208 << x << " IN " << r << std::endl;
209 if (changed)
210 {
211 Node tmp = Rewriter::rewrite(nm->mkNode(STRING_IN_REGEXP, x, r));
212 if (!polarity)
213 {
214 tmp = tmp.negate();
215 }
216 if (tmp == d_true)
217 {
218 d_regexp_ccached.insert(assertion);
219 continue;
220 }
221 else if (tmp == d_false)
222 {
223 std::vector<Node> exp_n;
224 exp_n.push_back(assertion);
225 Node conc = Node::null();
226 d_parent.sendInference(rnfexp, exp_n, conc, "REGEXP NF Conflict");
227 addedLemma = true;
228 break;
229 }
230 }
231
232 if (polarity)
233 {
234 flag = checkPDerivative(x, r, atom, addedLemma, rnfexp);
235 }
236 else
237 {
238 if (!options::stringExp())
239 {
240 throw LogicException(
241 "Strings Incomplete (due to Negative Membership) by default, "
242 "try --strings-exp option.");
243 }
244 }
245 if (flag)
246 {
247 // check if the term is atomic
248 Node xr = d_parent.getRepresentative(x);
249 Trace("strings-regexp")
250 << "Unroll/simplify membership of atomic term " << xr
251 << std::endl;
252 // if so, do simple unrolling
253 std::vector<Node> nvec;
254 if (nvec.empty())
255 {
256 d_regexp_opr.simplify(atom, nvec, polarity);
257 }
258 std::vector<Node> exp_n;
259 exp_n.push_back(assertion);
260 Node conc = nvec.size() == 1 ? nvec[0] : nm->mkNode(AND, nvec);
261 conc = Rewriter::rewrite(conc);
262 d_parent.sendInference(rnfexp, exp_n, conc, "REGEXP_Unfold");
263 addedLemma = true;
264 if (changed)
265 {
266 cprocessed.push_back(assertion);
267 }
268 else
269 {
270 processed.push_back(assertion);
271 }
272 }
273 }
274 if (d_parent.inConflict())
275 {
276 break;
277 }
278 }
279 }
280 if (addedLemma)
281 {
282 if (!d_parent.inConflict())
283 {
284 for (unsigned i = 0; i < processed.size(); i++)
285 {
286 Trace("strings-regexp")
287 << "...add " << processed[i] << " to u-cache." << std::endl;
288 d_regexp_ucached.insert(processed[i]);
289 }
290 for (unsigned i = 0; i < cprocessed.size(); i++)
291 {
292 Trace("strings-regexp")
293 << "...add " << cprocessed[i] << " to c-cache." << std::endl;
294 d_regexp_ccached.insert(cprocessed[i]);
295 }
296 }
297 }
298 }
299
300 bool RegExpSolver::checkPDerivative(
301 Node x, Node r, Node atom, bool& addedLemma, std::vector<Node>& nf_exp)
302 {
303 if (d_parent.areEqual(x, d_emptyString))
304 {
305 Node exp;
306 switch (d_regexp_opr.delta(r, exp))
307 {
308 case 0:
309 {
310 std::vector<Node> exp_n;
311 exp_n.push_back(atom);
312 exp_n.push_back(x.eqNode(d_emptyString));
313 d_parent.sendInference(nf_exp, exp_n, exp, "RegExp Delta");
314 addedLemma = true;
315 d_regexp_ccached.insert(atom);
316 return false;
317 }
318 case 1:
319 {
320 d_regexp_ccached.insert(atom);
321 break;
322 }
323 case 2:
324 {
325 std::vector<Node> exp_n;
326 exp_n.push_back(atom);
327 exp_n.push_back(x.eqNode(d_emptyString));
328 Node conc;
329 d_parent.sendInference(nf_exp, exp_n, conc, "RegExp Delta CONFLICT");
330 addedLemma = true;
331 d_regexp_ccached.insert(atom);
332 return false;
333 }
334 default:
335 // Impossible
336 break;
337 }
338 }
339 else
340 {
341 if (deriveRegExp(x, r, atom, nf_exp))
342 {
343 addedLemma = true;
344 d_regexp_ccached.insert(atom);
345 return false;
346 }
347 }
348 return true;
349 }
350
351 CVC4::String RegExpSolver::getHeadConst(Node x)
352 {
353 if (x.isConst())
354 {
355 return x.getConst<String>();
356 }
357 else if (x.getKind() == STRING_CONCAT)
358 {
359 if (x[0].isConst())
360 {
361 return x[0].getConst<String>();
362 }
363 }
364 return d_emptyString.getConst<String>();
365 }
366
367 bool RegExpSolver::deriveRegExp(Node x,
368 Node r,
369 Node atom,
370 std::vector<Node>& ant)
371 {
372 Assert(x != d_emptyString);
373 Trace("regexp-derive") << "RegExpSolver::deriveRegExp: x=" << x
374 << ", r= " << r << std::endl;
375 CVC4::String s = getHeadConst(x);
376 if (!s.isEmptyString() && d_regexp_opr.checkConstRegExp(r))
377 {
378 Node conc = Node::null();
379 Node dc = r;
380 bool flag = true;
381 for (unsigned i = 0; i < s.size(); ++i)
382 {
383 CVC4::String c = s.substr(i, 1);
384 Node dc2;
385 int rt = d_regexp_opr.derivativeS(dc, c, dc2);
386 dc = dc2;
387 if (rt == 2)
388 {
389 // CONFLICT
390 flag = false;
391 break;
392 }
393 }
394 // send lemma
395 if (flag)
396 {
397 if (x.isConst())
398 {
399 Assert(false,
400 "Impossible: RegExpSolver::deriveRegExp: const string in const "
401 "regular expression.");
402 return false;
403 }
404 else
405 {
406 Assert(x.getKind() == STRING_CONCAT);
407 std::vector<Node> vec_nodes;
408 for (unsigned int i = 1; i < x.getNumChildren(); ++i)
409 {
410 vec_nodes.push_back(x[i]);
411 }
412 Node left = TheoryStringsRewriter::mkConcat(STRING_CONCAT, vec_nodes);
413 left = Rewriter::rewrite(left);
414 conc = NodeManager::currentNM()->mkNode(STRING_IN_REGEXP, left, dc);
415 }
416 }
417 std::vector<Node> exp_n;
418 exp_n.push_back(atom);
419 d_parent.sendInference(ant, exp_n, conc, "RegExp-Derive");
420 return true;
421 }
422 return false;
423 }
424
425 void RegExpSolver::addMembership(Node assertion)
426 {
427 bool polarity = assertion.getKind() != NOT;
428 TNode atom = polarity ? assertion : assertion[0];
429 Node x = atom[0];
430 Node r = atom[1];
431 if (polarity)
432 {
433 unsigned index = 0;
434 NodeUIntMap::const_iterator it = d_pos_memberships.find(x);
435 if (it != d_pos_memberships.end())
436 {
437 index = (*it).second;
438 for (unsigned k = 0; k < index; k++)
439 {
440 if (k < d_pos_memberships_data[x].size())
441 {
442 if (d_pos_memberships_data[x][k] == r)
443 {
444 return;
445 }
446 }
447 else
448 {
449 break;
450 }
451 }
452 }
453 d_pos_memberships[x] = index + 1;
454 if (index < d_pos_memberships_data[x].size())
455 {
456 d_pos_memberships_data[x][index] = r;
457 }
458 else
459 {
460 d_pos_memberships_data[x].push_back(r);
461 }
462 }
463 else if (!options::stringIgnNegMembership())
464 {
465 unsigned index = 0;
466 NodeUIntMap::const_iterator it = d_neg_memberships.find(x);
467 if (it != d_neg_memberships.end())
468 {
469 index = (*it).second;
470 for (unsigned k = 0; k < index; k++)
471 {
472 if (k < d_neg_memberships_data[x].size())
473 {
474 if (d_neg_memberships_data[x][k] == r)
475 {
476 return;
477 }
478 }
479 else
480 {
481 break;
482 }
483 }
484 }
485 d_neg_memberships[x] = index + 1;
486 if (index < d_neg_memberships_data[x].size())
487 {
488 d_neg_memberships_data[x][index] = r;
489 }
490 else
491 {
492 d_neg_memberships_data[x].push_back(r);
493 }
494 }
495 // old
496 if (polarity || !options::stringIgnNegMembership())
497 {
498 d_regexp_memberships.push_back(assertion);
499 }
500 }
501
502 Node RegExpSolver::getNormalSymRegExp(Node r, std::vector<Node>& nf_exp)
503 {
504 Node ret = r;
505 switch (r.getKind())
506 {
507 case REGEXP_EMPTY:
508 case REGEXP_SIGMA: break;
509 case STRING_TO_REGEXP:
510 {
511 if (!r[0].isConst())
512 {
513 Node tmp = d_parent.getNormalString(r[0], nf_exp);
514 if (tmp != r[0])
515 {
516 ret = NodeManager::currentNM()->mkNode(STRING_TO_REGEXP, tmp);
517 }
518 }
519 break;
520 }
521 case REGEXP_CONCAT:
522 case REGEXP_UNION:
523 case REGEXP_INTER:
524 case REGEXP_STAR:
525 {
526 std::vector<Node> vec_nodes;
527 for (const Node& cr : r)
528 {
529 vec_nodes.push_back(getNormalSymRegExp(cr, nf_exp));
530 }
531 ret = Rewriter::rewrite(
532 NodeManager::currentNM()->mkNode(r.getKind(), vec_nodes));
533 break;
534 }
535 default:
536 {
537 Trace("strings-error") << "Unsupported term: " << r
538 << " in normalization SymRegExp." << std::endl;
539 Assert(false);
540 }
541 }
542 return ret;
543 }
544
545 } // namespace strings
546 } // namespace theory
547 } // namespace CVC4