Enum for all remaining string inferences (#4220)
[cvc5.git] / src / theory / strings / base_solver.cpp
1 /********************* */
2 /*! \file base_solver.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Base solver for the theory of strings. This class implements term
13 ** indexing and constant inference for the theory of strings.
14 **/
15
16 #include "theory/strings/base_solver.h"
17
18 #include "options/strings_options.h"
19 #include "theory/strings/theory_strings_utils.h"
20
21 using namespace std;
22 using namespace CVC4::context;
23 using namespace CVC4::kind;
24
25 namespace CVC4 {
26 namespace theory {
27 namespace strings {
28
29 BaseSolver::BaseSolver(context::Context* c,
30 context::UserContext* u,
31 SolverState& s,
32 InferenceManager& im)
33 : d_state(s), d_im(im), d_congruent(c)
34 {
35 d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String(""));
36 d_false = NodeManager::currentNM()->mkConst(false);
37 d_cardSize = utils::getAlphabetCardinality();
38 }
39
40 BaseSolver::~BaseSolver() {}
41
42 void BaseSolver::checkInit()
43 {
44 // build term index
45 d_eqcToConst.clear();
46 d_eqcToConstBase.clear();
47 d_eqcToConstExp.clear();
48 d_termIndex.clear();
49 d_stringsEqc.clear();
50
51 std::map<Kind, uint32_t> ncongruent;
52 std::map<Kind, uint32_t> congruent;
53 eq::EqualityEngine* ee = d_state.getEqualityEngine();
54 Assert(d_state.getRepresentative(d_emptyString) == d_emptyString);
55 eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
56 while (!eqcs_i.isFinished())
57 {
58 Node eqc = (*eqcs_i);
59 TypeNode tn = eqc.getType();
60 if (!tn.isRegExp())
61 {
62 if (tn.isString())
63 {
64 d_stringsEqc.push_back(eqc);
65 }
66 Node var;
67 eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
68 while (!eqc_i.isFinished())
69 {
70 Node n = *eqc_i;
71 if (n.isConst())
72 {
73 d_eqcToConst[eqc] = n;
74 d_eqcToConstBase[eqc] = n;
75 d_eqcToConstExp[eqc] = Node::null();
76 }
77 else if (tn.isInteger())
78 {
79 // do nothing
80 }
81 else if (n.getNumChildren() > 0)
82 {
83 Kind k = n.getKind();
84 if (k != EQUAL)
85 {
86 if (d_congruent.find(n) == d_congruent.end())
87 {
88 std::vector<Node> c;
89 Node nc = d_termIndex[k].add(n, 0, d_state, d_emptyString, c);
90 if (nc != n)
91 {
92 // check if we have inferred a new equality by removal of empty
93 // components
94 if (n.getKind() == STRING_CONCAT && !d_state.areEqual(nc, n))
95 {
96 std::vector<Node> exp;
97 size_t count[2] = {0, 0};
98 while (count[0] < nc.getNumChildren()
99 || count[1] < n.getNumChildren())
100 {
101 // explain empty prefixes
102 for (unsigned t = 0; t < 2; t++)
103 {
104 Node nn = t == 0 ? nc : n;
105 while (
106 count[t] < nn.getNumChildren()
107 && (nn[count[t]] == d_emptyString
108 || d_state.areEqual(nn[count[t]], d_emptyString)))
109 {
110 if (nn[count[t]] != d_emptyString)
111 {
112 exp.push_back(nn[count[t]].eqNode(d_emptyString));
113 }
114 count[t]++;
115 }
116 }
117 // explain equal components
118 if (count[0] < nc.getNumChildren())
119 {
120 Assert(count[1] < n.getNumChildren());
121 if (nc[count[0]] != n[count[1]])
122 {
123 exp.push_back(nc[count[0]].eqNode(n[count[1]]));
124 }
125 count[0]++;
126 count[1]++;
127 }
128 }
129 // infer the equality
130 d_im.sendInference(exp, n.eqNode(nc), Inference::I_NORM);
131 }
132 else
133 {
134 // mark as congruent : only process if neither has been
135 // reduced
136 d_im.markCongruent(nc, n);
137 }
138 // this node is congruent to another one, we can ignore it
139 Trace("strings-process-debug")
140 << " congruent term : " << n << " (via " << nc << ")"
141 << std::endl;
142 d_congruent.insert(n);
143 congruent[k]++;
144 }
145 else if (k == STRING_CONCAT && c.size() == 1)
146 {
147 Trace("strings-process-debug")
148 << " congruent term by singular : " << n << " " << c[0]
149 << std::endl;
150 // singular case
151 if (!d_state.areEqual(c[0], n))
152 {
153 Node ns;
154 std::vector<Node> exp;
155 // explain empty components
156 bool foundNEmpty = false;
157 for (const Node& nnc : n)
158 {
159 if (d_state.areEqual(nnc, d_emptyString))
160 {
161 if (nnc != d_emptyString)
162 {
163 exp.push_back(nnc.eqNode(d_emptyString));
164 }
165 }
166 else
167 {
168 Assert(!foundNEmpty);
169 ns = nnc;
170 foundNEmpty = true;
171 }
172 }
173 AlwaysAssert(foundNEmpty);
174 // infer the equality
175 d_im.sendInference(exp, n.eqNode(ns), Inference::I_NORM_S);
176 }
177 d_congruent.insert(n);
178 congruent[k]++;
179 }
180 else
181 {
182 ncongruent[k]++;
183 }
184 }
185 else
186 {
187 congruent[k]++;
188 }
189 }
190 }
191 else
192 {
193 if (d_congruent.find(n) == d_congruent.end())
194 {
195 // We mark all but the oldest variable in the equivalence class as
196 // congruent.
197 if (var.isNull())
198 {
199 var = n;
200 }
201 else if (var > n)
202 {
203 Trace("strings-process-debug")
204 << " congruent variable : " << var << std::endl;
205 d_congruent.insert(var);
206 var = n;
207 }
208 else
209 {
210 Trace("strings-process-debug")
211 << " congruent variable : " << n << std::endl;
212 d_congruent.insert(n);
213 }
214 }
215 }
216 ++eqc_i;
217 }
218 }
219 ++eqcs_i;
220 }
221 if (Trace.isOn("strings-process"))
222 {
223 for (std::map<Kind, TermIndex>::iterator it = d_termIndex.begin();
224 it != d_termIndex.end();
225 ++it)
226 {
227 Trace("strings-process")
228 << " Terms[" << it->first << "] = " << ncongruent[it->first] << "/"
229 << (congruent[it->first] + ncongruent[it->first]) << std::endl;
230 }
231 }
232 }
233
234 void BaseSolver::checkConstantEquivalenceClasses()
235 {
236 // do fixed point
237 size_t prevSize = 0;
238 std::vector<Node> vecc;
239 do
240 {
241 vecc.clear();
242 Trace("strings-process-debug")
243 << "Check constant equivalence classes..." << std::endl;
244 prevSize = d_eqcToConst.size();
245 checkConstantEquivalenceClasses(&d_termIndex[STRING_CONCAT], vecc);
246 } while (!d_im.hasProcessed() && d_eqcToConst.size() > prevSize);
247 }
248
249 void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
250 std::vector<Node>& vecc)
251 {
252 Node n = ti->d_data;
253 if (!n.isNull())
254 {
255 // construct the constant
256 Node c = utils::mkNConcat(vecc, n.getType());
257 if (!d_state.areEqual(n, c))
258 {
259 if (Trace.isOn("strings-debug"))
260 {
261 Trace("strings-debug")
262 << "Constant eqc : " << c << " for " << n << std::endl;
263 Trace("strings-debug") << " ";
264 for (const Node& v : vecc)
265 {
266 Trace("strings-debug") << v << " ";
267 }
268 Trace("strings-debug") << std::endl;
269 }
270 size_t count = 0;
271 size_t countc = 0;
272 std::vector<Node> exp;
273 while (count < n.getNumChildren())
274 {
275 while (count < n.getNumChildren()
276 && d_state.areEqual(n[count], d_emptyString))
277 {
278 d_im.addToExplanation(n[count], d_emptyString, exp);
279 count++;
280 }
281 if (count < n.getNumChildren())
282 {
283 Trace("strings-debug")
284 << "...explain " << n[count] << " " << vecc[countc] << std::endl;
285 if (!d_state.areEqual(n[count], vecc[countc]))
286 {
287 Node nrr = d_state.getRepresentative(n[count]);
288 Assert(!d_eqcToConstExp[nrr].isNull());
289 d_im.addToExplanation(n[count], d_eqcToConstBase[nrr], exp);
290 exp.push_back(d_eqcToConstExp[nrr]);
291 }
292 else
293 {
294 d_im.addToExplanation(n[count], vecc[countc], exp);
295 }
296 countc++;
297 count++;
298 }
299 }
300 // exp contains an explanation of n==c
301 Assert(countc == vecc.size());
302 if (d_state.hasTerm(c))
303 {
304 d_im.sendInference(exp, n.eqNode(c), Inference::I_CONST_MERGE);
305 return;
306 }
307 else if (!d_im.hasProcessed())
308 {
309 Node nr = d_state.getRepresentative(n);
310 std::map<Node, Node>::iterator it = d_eqcToConst.find(nr);
311 if (it == d_eqcToConst.end())
312 {
313 Trace("strings-debug")
314 << "Set eqc const " << n << " to " << c << std::endl;
315 d_eqcToConst[nr] = c;
316 d_eqcToConstBase[nr] = n;
317 d_eqcToConstExp[nr] = utils::mkAnd(exp);
318 }
319 else if (c != it->second)
320 {
321 // conflict
322 Trace("strings-debug")
323 << "Conflict, other constant was " << it->second
324 << ", this constant was " << c << std::endl;
325 if (d_eqcToConstExp[nr].isNull())
326 {
327 // n==c ^ n == c' => false
328 d_im.addToExplanation(n, it->second, exp);
329 }
330 else
331 {
332 // n==c ^ n == d_eqcToConstBase[nr] == c' => false
333 exp.push_back(d_eqcToConstExp[nr]);
334 d_im.addToExplanation(n, d_eqcToConstBase[nr], exp);
335 }
336 d_im.sendInference(exp, d_false, Inference::I_CONST_CONFLICT);
337 return;
338 }
339 else
340 {
341 Trace("strings-debug") << "Duplicate constant." << std::endl;
342 }
343 }
344 }
345 }
346 for (std::pair<const TNode, TermIndex>& p : ti->d_children)
347 {
348 std::map<Node, Node>::iterator itc = d_eqcToConst.find(p.first);
349 if (itc != d_eqcToConst.end())
350 {
351 vecc.push_back(itc->second);
352 checkConstantEquivalenceClasses(&p.second, vecc);
353 vecc.pop_back();
354 if (d_im.hasProcessed())
355 {
356 break;
357 }
358 }
359 }
360 }
361
362 void BaseSolver::checkCardinality()
363 {
364 // This will create a partition of eqc, where each collection has length that
365 // are pairwise propagated to be equal. We do not require disequalities
366 // between the lengths of each collection, since we split on disequalities
367 // between lengths of string terms that are disequal (DEQ-LENGTH-SP).
368 std::vector<std::vector<Node> > cols;
369 std::vector<Node> lts;
370 d_state.separateByLength(d_stringsEqc, cols, lts);
371 NodeManager* nm = NodeManager::currentNM();
372 Trace("strings-card") << "Check cardinality...." << std::endl;
373 // for each collection
374 for (unsigned i = 0, csize = cols.size(); i < csize; ++i)
375 {
376 Node lr = lts[i];
377 Trace("strings-card") << "Number of strings with length equal to " << lr
378 << " is " << cols[i].size() << std::endl;
379 if (cols[i].size() <= 1)
380 {
381 // no restriction on sets in the partition of size 1
382 continue;
383 }
384 // size > c^k
385 unsigned card_need = 1;
386 double curr = static_cast<double>(cols[i].size());
387 while (curr > d_cardSize)
388 {
389 curr = curr / static_cast<double>(d_cardSize);
390 card_need++;
391 }
392 Trace("strings-card")
393 << "Need length " << card_need
394 << " for this number of strings (where alphabet size is " << d_cardSize
395 << ")." << std::endl;
396 // check if we need to split
397 bool needsSplit = true;
398 if (lr.isConst())
399 {
400 // if constant, compare
401 Node cmp = nm->mkNode(GEQ, lr, nm->mkConst(Rational(card_need)));
402 cmp = Rewriter::rewrite(cmp);
403 needsSplit = !cmp.getConst<bool>();
404 }
405 else
406 {
407 // find the minimimum constant that we are unknown to be disequal from, or
408 // otherwise stop if we increment such that cardinality does not apply
409 unsigned r = 0;
410 bool success = true;
411 while (r < card_need && success)
412 {
413 Node rr = nm->mkConst(Rational(r));
414 if (d_state.areDisequal(rr, lr))
415 {
416 r++;
417 }
418 else
419 {
420 success = false;
421 }
422 }
423 if (r > 0)
424 {
425 Trace("strings-card")
426 << "Symbolic length " << lr << " must be at least " << r
427 << " due to constant disequalities." << std::endl;
428 }
429 needsSplit = r < card_need;
430 }
431
432 if (!needsSplit)
433 {
434 // don't need to split
435 continue;
436 }
437 // first, try to split to merge equivalence classes
438 for (std::vector<Node>::iterator itr1 = cols[i].begin();
439 itr1 != cols[i].end();
440 ++itr1)
441 {
442 for (std::vector<Node>::iterator itr2 = itr1 + 1; itr2 != cols[i].end();
443 ++itr2)
444 {
445 if (!d_state.areDisequal(*itr1, *itr2))
446 {
447 // add split lemma
448 if (d_im.sendSplit(*itr1, *itr2, Inference::CARD_SP))
449 {
450 return;
451 }
452 }
453 }
454 }
455 // otherwise, we need a length constraint
456 uint32_t int_k = static_cast<uint32_t>(card_need);
457 EqcInfo* ei = d_state.getOrMakeEqcInfo(lr, true);
458 Trace("strings-card") << "Previous cardinality used for " << lr << " is "
459 << ((int)ei->d_cardinalityLemK.get() - 1)
460 << std::endl;
461 if (int_k + 1 > ei->d_cardinalityLemK.get())
462 {
463 Node k_node = nm->mkConst(Rational(int_k));
464 // add cardinality lemma
465 Node dist = nm->mkNode(DISTINCT, cols[i]);
466 std::vector<Node> vec_node;
467 vec_node.push_back(dist);
468 for (std::vector<Node>::iterator itr1 = cols[i].begin();
469 itr1 != cols[i].end();
470 ++itr1)
471 {
472 Node len = nm->mkNode(STRING_LENGTH, *itr1);
473 if (len != lr)
474 {
475 Node len_eq_lr = len.eqNode(lr);
476 vec_node.push_back(len_eq_lr);
477 }
478 }
479 Node len = nm->mkNode(STRING_LENGTH, cols[i][0]);
480 Node cons = nm->mkNode(GEQ, len, k_node);
481 cons = Rewriter::rewrite(cons);
482 ei->d_cardinalityLemK.set(int_k + 1);
483 if (!cons.isConst() || !cons.getConst<bool>())
484 {
485 std::vector<Node> emptyVec;
486 d_im.sendInference(
487 emptyVec, vec_node, cons, Inference::CARDINALITY, true);
488 return;
489 }
490 }
491 }
492 Trace("strings-card") << "...end check cardinality" << std::endl;
493 }
494
495 bool BaseSolver::isCongruent(Node n)
496 {
497 return d_congruent.find(n) != d_congruent.end();
498 }
499
500 Node BaseSolver::getConstantEqc(Node eqc)
501 {
502 std::map<Node, Node>::iterator it = d_eqcToConst.find(eqc);
503 if (it != d_eqcToConst.end())
504 {
505 return it->second;
506 }
507 return Node::null();
508 }
509
510 Node BaseSolver::explainConstantEqc(Node n, Node eqc, std::vector<Node>& exp)
511 {
512 std::map<Node, Node>::iterator it = d_eqcToConst.find(eqc);
513 if (it != d_eqcToConst.end())
514 {
515 if (!d_eqcToConstExp[eqc].isNull())
516 {
517 exp.push_back(d_eqcToConstExp[eqc]);
518 }
519 if (!d_eqcToConstBase[eqc].isNull())
520 {
521 d_im.addToExplanation(n, d_eqcToConstBase[eqc], exp);
522 }
523 return it->second;
524 }
525 return Node::null();
526 }
527
528 const std::vector<Node>& BaseSolver::getStringEqc() const
529 {
530 return d_stringsEqc;
531 }
532
533 Node BaseSolver::TermIndex::add(TNode n,
534 unsigned index,
535 const SolverState& s,
536 Node er,
537 std::vector<Node>& c)
538 {
539 if (index == n.getNumChildren())
540 {
541 if (d_data.isNull())
542 {
543 d_data = n;
544 }
545 return d_data;
546 }
547 Assert(index < n.getNumChildren());
548 TNode nir = s.getRepresentative(n[index]);
549 // if it is empty, and doing CONCAT, ignore
550 if (nir == er && n.getKind() == STRING_CONCAT)
551 {
552 return add(n, index + 1, s, er, c);
553 }
554 c.push_back(nir);
555 return d_children[nir].add(n, index + 1, s, er, c);
556 }
557
558 } // namespace strings
559 } // namespace theory
560 } // namespace CVC4