Move cardinality inference scheme to base solver in strings (#3792)
[cvc5.git] / src / theory / strings / base_solver.cpp
1 /********************* */
2 /*! \file base_solver.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Base solver for the theory of strings. This class implements term
13 ** indexing and constant inference for the theory of strings.
14 **/
15
16 #include "theory/strings/base_solver.h"
17
18 #include "options/strings_options.h"
19 #include "theory/strings/theory_strings_rewriter.h"
20 #include "theory/strings/theory_strings_utils.h"
21
22 using namespace std;
23 using namespace CVC4::context;
24 using namespace CVC4::kind;
25
26 namespace CVC4 {
27 namespace theory {
28 namespace strings {
29
30 BaseSolver::BaseSolver(context::Context* c,
31 context::UserContext* u,
32 SolverState& s,
33 InferenceManager& im)
34 : d_state(s), d_im(im), d_congruent(c)
35 {
36 d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String(""));
37 d_false = NodeManager::currentNM()->mkConst(false);
38 d_cardSize = utils::getAlphabetCardinality();
39 }
40
41 BaseSolver::~BaseSolver() {}
42
43 void BaseSolver::checkInit()
44 {
45 // build term index
46 d_eqcToConst.clear();
47 d_eqcToConstBase.clear();
48 d_eqcToConstExp.clear();
49 d_termIndex.clear();
50 d_stringsEqc.clear();
51
52 std::map<Kind, uint32_t> ncongruent;
53 std::map<Kind, uint32_t> congruent;
54 eq::EqualityEngine* ee = d_state.getEqualityEngine();
55 Assert(d_state.getRepresentative(d_emptyString) == d_emptyString);
56 eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
57 while (!eqcs_i.isFinished())
58 {
59 Node eqc = (*eqcs_i);
60 TypeNode tn = eqc.getType();
61 if (!tn.isRegExp())
62 {
63 if (tn.isString())
64 {
65 d_stringsEqc.push_back(eqc);
66 }
67 Node var;
68 eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, ee);
69 while (!eqc_i.isFinished())
70 {
71 Node n = *eqc_i;
72 if (n.isConst())
73 {
74 d_eqcToConst[eqc] = n;
75 d_eqcToConstBase[eqc] = n;
76 d_eqcToConstExp[eqc] = Node::null();
77 }
78 else if (tn.isInteger())
79 {
80 // do nothing
81 }
82 else if (n.getNumChildren() > 0)
83 {
84 Kind k = n.getKind();
85 if (k != EQUAL)
86 {
87 if (d_congruent.find(n) == d_congruent.end())
88 {
89 std::vector<Node> c;
90 Node nc = d_termIndex[k].add(n, 0, d_state, d_emptyString, c);
91 if (nc != n)
92 {
93 // check if we have inferred a new equality by removal of empty
94 // components
95 if (n.getKind() == STRING_CONCAT && !d_state.areEqual(nc, n))
96 {
97 std::vector<Node> exp;
98 size_t count[2] = {0, 0};
99 while (count[0] < nc.getNumChildren()
100 || count[1] < n.getNumChildren())
101 {
102 // explain empty prefixes
103 for (unsigned t = 0; t < 2; t++)
104 {
105 Node nn = t == 0 ? nc : n;
106 while (
107 count[t] < nn.getNumChildren()
108 && (nn[count[t]] == d_emptyString
109 || d_state.areEqual(nn[count[t]], d_emptyString)))
110 {
111 if (nn[count[t]] != d_emptyString)
112 {
113 exp.push_back(nn[count[t]].eqNode(d_emptyString));
114 }
115 count[t]++;
116 }
117 }
118 // explain equal components
119 if (count[0] < nc.getNumChildren())
120 {
121 Assert(count[1] < n.getNumChildren());
122 if (nc[count[0]] != n[count[1]])
123 {
124 exp.push_back(nc[count[0]].eqNode(n[count[1]]));
125 }
126 count[0]++;
127 count[1]++;
128 }
129 }
130 // infer the equality
131 d_im.sendInference(exp, n.eqNode(nc), "I_Norm");
132 }
133 else
134 {
135 // mark as congruent : only process if neither has been
136 // reduced
137 d_im.markCongruent(nc, n);
138 }
139 // this node is congruent to another one, we can ignore it
140 Trace("strings-process-debug")
141 << " congruent term : " << n << " (via " << nc << ")"
142 << std::endl;
143 d_congruent.insert(n);
144 congruent[k]++;
145 }
146 else if (k == STRING_CONCAT && c.size() == 1)
147 {
148 Trace("strings-process-debug")
149 << " congruent term by singular : " << n << " " << c[0]
150 << std::endl;
151 // singular case
152 if (!d_state.areEqual(c[0], n))
153 {
154 Node ns;
155 std::vector<Node> exp;
156 // explain empty components
157 bool foundNEmpty = false;
158 for (const Node& nc : n)
159 {
160 if (d_state.areEqual(nc, d_emptyString))
161 {
162 if (nc != d_emptyString)
163 {
164 exp.push_back(nc.eqNode(d_emptyString));
165 }
166 }
167 else
168 {
169 Assert(!foundNEmpty);
170 ns = nc;
171 foundNEmpty = true;
172 }
173 }
174 AlwaysAssert(foundNEmpty);
175 // infer the equality
176 d_im.sendInference(exp, n.eqNode(ns), "I_Norm_S");
177 }
178 d_congruent.insert(n);
179 congruent[k]++;
180 }
181 else
182 {
183 ncongruent[k]++;
184 }
185 }
186 else
187 {
188 congruent[k]++;
189 }
190 }
191 }
192 else
193 {
194 if (d_congruent.find(n) == d_congruent.end())
195 {
196 // We mark all but the oldest variable in the equivalence class as
197 // congruent.
198 if (var.isNull())
199 {
200 var = n;
201 }
202 else if (var > n)
203 {
204 Trace("strings-process-debug")
205 << " congruent variable : " << var << std::endl;
206 d_congruent.insert(var);
207 var = n;
208 }
209 else
210 {
211 Trace("strings-process-debug")
212 << " congruent variable : " << n << std::endl;
213 d_congruent.insert(n);
214 }
215 }
216 }
217 ++eqc_i;
218 }
219 }
220 ++eqcs_i;
221 }
222 if (Trace.isOn("strings-process"))
223 {
224 for (std::map<Kind, TermIndex>::iterator it = d_termIndex.begin();
225 it != d_termIndex.end();
226 ++it)
227 {
228 Trace("strings-process")
229 << " Terms[" << it->first << "] = " << ncongruent[it->first] << "/"
230 << (congruent[it->first] + ncongruent[it->first]) << std::endl;
231 }
232 }
233 }
234
235 void BaseSolver::checkConstantEquivalenceClasses()
236 {
237 // do fixed point
238 size_t prevSize = 0;
239 std::vector<Node> vecc;
240 do
241 {
242 vecc.clear();
243 Trace("strings-process-debug")
244 << "Check constant equivalence classes..." << std::endl;
245 prevSize = d_eqcToConst.size();
246 checkConstantEquivalenceClasses(&d_termIndex[STRING_CONCAT], vecc);
247 } while (!d_im.hasProcessed() && d_eqcToConst.size() > prevSize);
248 }
249
250 void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
251 std::vector<Node>& vecc)
252 {
253 Node n = ti->d_data;
254 if (!n.isNull())
255 {
256 // construct the constant
257 Node c = utils::mkNConcat(vecc);
258 if (!d_state.areEqual(n, c))
259 {
260 if (Trace.isOn("strings-debug"))
261 {
262 Trace("strings-debug")
263 << "Constant eqc : " << c << " for " << n << std::endl;
264 Trace("strings-debug") << " ";
265 for (const Node& v : vecc)
266 {
267 Trace("strings-debug") << v << " ";
268 }
269 Trace("strings-debug") << std::endl;
270 }
271 size_t count = 0;
272 size_t countc = 0;
273 std::vector<Node> exp;
274 while (count < n.getNumChildren())
275 {
276 while (count < n.getNumChildren()
277 && d_state.areEqual(n[count], d_emptyString))
278 {
279 d_im.addToExplanation(n[count], d_emptyString, exp);
280 count++;
281 }
282 if (count < n.getNumChildren())
283 {
284 Trace("strings-debug")
285 << "...explain " << n[count] << " " << vecc[countc] << std::endl;
286 if (!d_state.areEqual(n[count], vecc[countc]))
287 {
288 Node nrr = d_state.getRepresentative(n[count]);
289 Assert(!d_eqcToConstExp[nrr].isNull());
290 d_im.addToExplanation(n[count], d_eqcToConstBase[nrr], exp);
291 exp.push_back(d_eqcToConstExp[nrr]);
292 }
293 else
294 {
295 d_im.addToExplanation(n[count], vecc[countc], exp);
296 }
297 countc++;
298 count++;
299 }
300 }
301 // exp contains an explanation of n==c
302 Assert(countc == vecc.size());
303 if (d_state.hasTerm(c))
304 {
305 d_im.sendInference(exp, n.eqNode(c), "I_CONST_MERGE");
306 return;
307 }
308 else if (!d_im.hasProcessed())
309 {
310 Node nr = d_state.getRepresentative(n);
311 std::map<Node, Node>::iterator it = d_eqcToConst.find(nr);
312 if (it == d_eqcToConst.end())
313 {
314 Trace("strings-debug")
315 << "Set eqc const " << n << " to " << c << std::endl;
316 d_eqcToConst[nr] = c;
317 d_eqcToConstBase[nr] = n;
318 d_eqcToConstExp[nr] = utils::mkAnd(exp);
319 }
320 else if (c != it->second)
321 {
322 // conflict
323 Trace("strings-debug")
324 << "Conflict, other constant was " << it->second
325 << ", this constant was " << c << std::endl;
326 if (d_eqcToConstExp[nr].isNull())
327 {
328 // n==c ^ n == c' => false
329 d_im.addToExplanation(n, it->second, exp);
330 }
331 else
332 {
333 // n==c ^ n == d_eqcToConstBase[nr] == c' => false
334 exp.push_back(d_eqcToConstExp[nr]);
335 d_im.addToExplanation(n, d_eqcToConstBase[nr], exp);
336 }
337 d_im.sendInference(exp, d_false, "I_CONST_CONFLICT");
338 return;
339 }
340 else
341 {
342 Trace("strings-debug") << "Duplicate constant." << std::endl;
343 }
344 }
345 }
346 }
347 for (std::pair<const TNode, TermIndex>& p : ti->d_children)
348 {
349 std::map<Node, Node>::iterator itc = d_eqcToConst.find(p.first);
350 if (itc != d_eqcToConst.end())
351 {
352 vecc.push_back(itc->second);
353 checkConstantEquivalenceClasses(&p.second, vecc);
354 vecc.pop_back();
355 if (d_im.hasProcessed())
356 {
357 break;
358 }
359 }
360 }
361 }
362
363 void BaseSolver::checkCardinality()
364 {
365 // This will create a partition of eqc, where each collection has length that
366 // are pairwise propagated to be equal. We do not require disequalities
367 // between the lengths of each collection, since we split on disequalities
368 // between lengths of string terms that are disequal (DEQ-LENGTH-SP).
369 std::vector<std::vector<Node> > cols;
370 std::vector<Node> lts;
371 d_state.separateByLength(d_stringsEqc, cols, lts);
372 NodeManager* nm = NodeManager::currentNM();
373 Trace("strings-card") << "Check cardinality...." << std::endl;
374 // for each collection
375 for (unsigned i = 0, csize = cols.size(); i < csize; ++i)
376 {
377 Node lr = lts[i];
378 Trace("strings-card") << "Number of strings with length equal to " << lr
379 << " is " << cols[i].size() << std::endl;
380 if (cols[i].size() <= 1)
381 {
382 // no restriction on sets in the partition of size 1
383 continue;
384 }
385 // size > c^k
386 unsigned card_need = 1;
387 double curr = static_cast<double>(cols[i].size());
388 while (curr > d_cardSize)
389 {
390 curr = curr / static_cast<double>(d_cardSize);
391 card_need++;
392 }
393 Trace("strings-card")
394 << "Need length " << card_need
395 << " for this number of strings (where alphabet size is " << d_cardSize
396 << ")." << std::endl;
397 // check if we need to split
398 bool needsSplit = true;
399 if (lr.isConst())
400 {
401 // if constant, compare
402 Node cmp = nm->mkNode(GEQ, lr, nm->mkConst(Rational(card_need)));
403 cmp = Rewriter::rewrite(cmp);
404 needsSplit = !cmp.getConst<bool>();
405 }
406 else
407 {
408 // find the minimimum constant that we are unknown to be disequal from, or
409 // otherwise stop if we increment such that cardinality does not apply
410 unsigned r = 0;
411 bool success = true;
412 while (r < card_need && success)
413 {
414 Node rr = nm->mkConst(Rational(r));
415 if (d_state.areDisequal(rr, lr))
416 {
417 r++;
418 }
419 else
420 {
421 success = false;
422 }
423 }
424 if (r > 0)
425 {
426 Trace("strings-card")
427 << "Symbolic length " << lr << " must be at least " << r
428 << " due to constant disequalities." << std::endl;
429 }
430 needsSplit = r < card_need;
431 }
432
433 if (!needsSplit)
434 {
435 // don't need to split
436 continue;
437 }
438 // first, try to split to merge equivalence classes
439 for (std::vector<Node>::iterator itr1 = cols[i].begin();
440 itr1 != cols[i].end();
441 ++itr1)
442 {
443 for (std::vector<Node>::iterator itr2 = itr1 + 1; itr2 != cols[i].end();
444 ++itr2)
445 {
446 if (!d_state.areDisequal(*itr1, *itr2))
447 {
448 // add split lemma
449 if (d_im.sendSplit(*itr1, *itr2, "CARD-SP"))
450 {
451 return;
452 }
453 }
454 }
455 }
456 // otherwise, we need a length constraint
457 uint32_t int_k = static_cast<uint32_t>(card_need);
458 EqcInfo* ei = d_state.getOrMakeEqcInfo(lr, true);
459 Trace("strings-card") << "Previous cardinality used for " << lr << " is "
460 << ((int)ei->d_cardinalityLemK.get() - 1)
461 << std::endl;
462 if (int_k + 1 > ei->d_cardinalityLemK.get())
463 {
464 Node k_node = nm->mkConst(Rational(int_k));
465 // add cardinality lemma
466 Node dist = nm->mkNode(DISTINCT, cols[i]);
467 std::vector<Node> vec_node;
468 vec_node.push_back(dist);
469 for (std::vector<Node>::iterator itr1 = cols[i].begin();
470 itr1 != cols[i].end();
471 ++itr1)
472 {
473 Node len = nm->mkNode(STRING_LENGTH, *itr1);
474 if (len != lr)
475 {
476 Node len_eq_lr = len.eqNode(lr);
477 vec_node.push_back(len_eq_lr);
478 }
479 }
480 Node len = nm->mkNode(STRING_LENGTH, cols[i][0]);
481 Node cons = nm->mkNode(GEQ, len, k_node);
482 cons = Rewriter::rewrite(cons);
483 ei->d_cardinalityLemK.set(int_k + 1);
484 if (!cons.isConst() || !cons.getConst<bool>())
485 {
486 std::vector<Node> emptyVec;
487 d_im.sendInference(emptyVec, vec_node, cons, "CARDINALITY", true);
488 return;
489 }
490 }
491 }
492 Trace("strings-card") << "...end check cardinality" << std::endl;
493 }
494
495 bool BaseSolver::isCongruent(Node n)
496 {
497 return d_congruent.find(n) != d_congruent.end();
498 }
499
500 Node BaseSolver::getConstantEqc(Node eqc)
501 {
502 std::map<Node, Node>::iterator it = d_eqcToConst.find(eqc);
503 if (it != d_eqcToConst.end())
504 {
505 return it->second;
506 }
507 return Node::null();
508 }
509
510 Node BaseSolver::explainConstantEqc(Node n, Node eqc, std::vector<Node>& exp)
511 {
512 std::map<Node, Node>::iterator it = d_eqcToConst.find(eqc);
513 if (it != d_eqcToConst.end())
514 {
515 if (!d_eqcToConstExp[eqc].isNull())
516 {
517 exp.push_back(d_eqcToConstExp[eqc]);
518 }
519 if (!d_eqcToConstBase[eqc].isNull())
520 {
521 d_im.addToExplanation(n, d_eqcToConstBase[eqc], exp);
522 }
523 return it->second;
524 }
525 return Node::null();
526 }
527
528 const std::vector<Node>& BaseSolver::getStringEqc() const
529 {
530 return d_stringsEqc;
531 }
532
533 Node BaseSolver::TermIndex::add(TNode n,
534 unsigned index,
535 const SolverState& s,
536 Node er,
537 std::vector<Node>& c)
538 {
539 if (index == n.getNumChildren())
540 {
541 if (d_data.isNull())
542 {
543 d_data = n;
544 }
545 return d_data;
546 }
547 Assert(index < n.getNumChildren());
548 TNode nir = s.getRepresentative(n[index]);
549 // if it is empty, and doing CONCAT, ignore
550 if (nir == er && n.getKind() == STRING_CONCAT)
551 {
552 return add(n, index + 1, s, er, c);
553 }
554 c.push_back(nir);
555 return d_children[nir].add(n, index + 1, s, er, c);
556 }
557
558 } // namespace strings
559 } // namespace theory
560 } // namespace CVC4