Prepare theory of sets for dynamic allocation of equality engine (#4868)
[cvc5.git] / src / theory / sets / solver_state.cpp
1 /********************* */
2 /*! \file solver_state.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Mudathir Mohamed
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of sets state object
13 **/
14
15 #include "theory/sets/solver_state.h"
16
17 #include "expr/emptyset.h"
18 #include "options/sets_options.h"
19 #include "theory/sets/theory_sets_private.h"
20
21 using namespace std;
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25 namespace theory {
26 namespace sets {
27
28 SolverState::SolverState(TheorySetsPrivate& p,
29 context::Context* c,
30 context::UserContext* u)
31 : d_conflict(c), d_parent(p), d_ee(nullptr), d_proxy(u), d_proxy_to_term(u)
32 {
33 d_true = NodeManager::currentNM()->mkConst(true);
34 d_false = NodeManager::currentNM()->mkConst(false);
35 }
36
37 void SolverState::finishInit(eq::EqualityEngine* ee)
38 {
39 Assert(ee != nullptr);
40 d_ee = ee;
41 }
42
43 void SolverState::reset()
44 {
45 d_set_eqc.clear();
46 d_eqc_emptyset.clear();
47 d_eqc_univset.clear();
48 d_eqc_singleton.clear();
49 d_congruent.clear();
50 d_nvar_sets.clear();
51 d_var_set.clear();
52 d_compSets.clear();
53 d_pol_mems[0].clear();
54 d_pol_mems[1].clear();
55 d_members_index.clear();
56 d_singleton_index.clear();
57 d_bop_index.clear();
58 d_op_list.clear();
59 d_allCompSets.clear();
60 }
61
62 void SolverState::registerEqc(TypeNode tn, Node r)
63 {
64 if (tn.isSet())
65 {
66 d_set_eqc.push_back(r);
67 }
68 }
69
70 void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
71 {
72 Kind nk = n.getKind();
73 if (nk == MEMBER)
74 {
75 if (r.isConst())
76 {
77 Node s = d_ee->getRepresentative(n[1]);
78 Node x = d_ee->getRepresentative(n[0]);
79 int pindex = r == d_true ? 0 : (r == d_false ? 1 : -1);
80 if (pindex != -1)
81 {
82 if (d_pol_mems[pindex][s].find(x) == d_pol_mems[pindex][s].end())
83 {
84 d_pol_mems[pindex][s][x] = n;
85 Trace("sets-debug2") << "Membership[" << x << "][" << s << "] : " << n
86 << ", pindex = " << pindex << std::endl;
87 }
88 if (d_members_index[s].find(x) == d_members_index[s].end())
89 {
90 d_members_index[s][x] = n;
91 d_op_list[MEMBER].push_back(n);
92 }
93 }
94 else
95 {
96 Assert(false);
97 }
98 }
99 }
100 else if (nk == SINGLETON || nk == UNION || nk == INTERSECTION
101 || nk == SETMINUS || nk == EMPTYSET || nk == UNIVERSE_SET)
102 {
103 if (nk == SINGLETON)
104 {
105 // singleton lemma
106 getProxy(n);
107 Node re = d_ee->getRepresentative(n[0]);
108 if (d_singleton_index.find(re) == d_singleton_index.end())
109 {
110 d_singleton_index[re] = n;
111 d_eqc_singleton[r] = n;
112 d_op_list[SINGLETON].push_back(n);
113 }
114 else
115 {
116 d_congruent[n] = d_singleton_index[re];
117 }
118 }
119 else if (nk == EMPTYSET)
120 {
121 d_eqc_emptyset[tnn] = r;
122 }
123 else if (nk == UNIVERSE_SET)
124 {
125 Assert(options::setsExt());
126 d_eqc_univset[tnn] = r;
127 }
128 else
129 {
130 Node r1 = d_ee->getRepresentative(n[0]);
131 Node r2 = d_ee->getRepresentative(n[1]);
132 std::map<Node, Node>& binr1 = d_bop_index[nk][r1];
133 std::map<Node, Node>::iterator itb = binr1.find(r2);
134 if (itb == binr1.end())
135 {
136 binr1[r2] = n;
137 d_op_list[nk].push_back(n);
138 }
139 else
140 {
141 d_congruent[n] = itb->second;
142 }
143 }
144 d_nvar_sets[r].push_back(n);
145 Trace("sets-debug2") << "Non-var-set[" << r << "] : " << n << std::endl;
146 }
147 else if (nk == COMPREHENSION)
148 {
149 d_compSets[r].push_back(n);
150 d_allCompSets.push_back(n);
151 Trace("sets-debug2") << "Comp-set[" << r << "] : " << n << std::endl;
152 }
153 else if (n.isVar() && !d_skCache.isSkolem(n))
154 {
155 // it is important that we check it is a variable, but not an internally
156 // introduced skolem, due to the semantics of the universe set.
157 if (tnn.isSet())
158 {
159 if (d_var_set.find(r) == d_var_set.end())
160 {
161 d_var_set[r] = n;
162 Trace("sets-debug2") << "var-set[" << r << "] : " << n << std::endl;
163 }
164 }
165 }
166 else
167 {
168 Trace("sets-debug2") << "Unknown-set[" << r << "] : " << n << std::endl;
169 }
170 }
171
172 Node SolverState::getRepresentative(Node a) const
173 {
174 if (d_ee->hasTerm(a))
175 {
176 return d_ee->getRepresentative(a);
177 }
178 return a;
179 }
180
181 bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); }
182
183 bool SolverState::areEqual(Node a, Node b) const
184 {
185 if (a == b)
186 {
187 return true;
188 }
189 if (d_ee->hasTerm(a) && d_ee->hasTerm(b))
190 {
191 return d_ee->areEqual(a, b);
192 }
193 return false;
194 }
195
196 bool SolverState::areDisequal(Node a, Node b) const
197 {
198 if (a == b)
199 {
200 return false;
201 }
202 else if (d_ee->hasTerm(a) && d_ee->hasTerm(b))
203 {
204 return d_ee->areDisequal(a, b, false);
205 }
206 return a.isConst() && b.isConst();
207 }
208
209 eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; }
210
211 void SolverState::setConflict() { d_conflict = true; }
212 void SolverState::setConflict(Node conf)
213 {
214 d_parent.getOutputChannel()->conflict(conf);
215 d_conflict = true;
216 }
217
218 void SolverState::addEqualityToExp(Node a, Node b, std::vector<Node>& exp) const
219 {
220 if (a != b)
221 {
222 Assert(areEqual(a, b));
223 exp.push_back(a.eqNode(b));
224 }
225 }
226
227 Node SolverState::getEmptySetEqClass(TypeNode tn) const
228 {
229 std::map<TypeNode, Node>::const_iterator it = d_eqc_emptyset.find(tn);
230 if (it != d_eqc_emptyset.end())
231 {
232 return it->second;
233 }
234 return Node::null();
235 }
236
237 Node SolverState::getUnivSetEqClass(TypeNode tn) const
238 {
239 std::map<TypeNode, Node>::const_iterator it = d_univset.find(tn);
240 if (it != d_univset.end())
241 {
242 return it->second;
243 }
244 return Node::null();
245 }
246
247 Node SolverState::getSingletonEqClass(Node r) const
248 {
249 std::map<Node, Node>::const_iterator it = d_eqc_singleton.find(r);
250 if (it != d_eqc_singleton.end())
251 {
252 return it->second;
253 }
254 return Node::null();
255 }
256
257 Node SolverState::getBinaryOpTerm(Kind k, Node r1, Node r2) const
258 {
259 std::map<Kind, std::map<Node, std::map<Node, Node> > >::const_iterator itk =
260 d_bop_index.find(k);
261 if (itk == d_bop_index.end())
262 {
263 return Node::null();
264 }
265 std::map<Node, std::map<Node, Node> >::const_iterator it1 =
266 itk->second.find(r1);
267 if (it1 == itk->second.end())
268 {
269 return Node::null();
270 }
271 std::map<Node, Node>::const_iterator it2 = it1->second.find(r2);
272 if (it2 == it1->second.end())
273 {
274 return Node::null();
275 }
276 return it2->second;
277 }
278
279 bool SolverState::isEntailed(Node n, bool polarity) const
280 {
281 if (n.getKind() == NOT)
282 {
283 return isEntailed(n[0], !polarity);
284 }
285 else if (n.getKind() == EQUAL)
286 {
287 if (polarity)
288 {
289 return areEqual(n[0], n[1]);
290 }
291 return areDisequal(n[0], n[1]);
292 }
293 else if (n.getKind() == MEMBER)
294 {
295 if (areEqual(n, polarity ? d_true : d_false))
296 {
297 return true;
298 }
299 // check members cache
300 if (polarity && d_ee->hasTerm(n[1]))
301 {
302 Node r = d_ee->getRepresentative(n[1]);
303 if (d_parent.isMember(n[0], r))
304 {
305 return true;
306 }
307 }
308 }
309 else if (n.getKind() == AND || n.getKind() == OR)
310 {
311 bool conj = (n.getKind() == AND) == polarity;
312 for (const Node& nc : n)
313 {
314 bool isEnt = isEntailed(nc, polarity);
315 if (isEnt != conj)
316 {
317 return !conj;
318 }
319 }
320 return conj;
321 }
322 else if (n.isConst())
323 {
324 return (polarity && n == d_true) || (!polarity && n == d_false);
325 }
326 return false;
327 }
328
329 bool SolverState::isSetDisequalityEntailed(Node r1, Node r2) const
330 {
331 Assert(d_ee->hasTerm(r1) && d_ee->getRepresentative(r1) == r1);
332 Assert(d_ee->hasTerm(r2) && d_ee->getRepresentative(r2) == r2);
333 TypeNode tn = r1.getType();
334 Node re = getEmptySetEqClass(tn);
335 for (unsigned e = 0; e < 2; e++)
336 {
337 Node a = e == 0 ? r1 : r2;
338 Node b = e == 0 ? r2 : r1;
339 if (isSetDisequalityEntailedInternal(a, b, re))
340 {
341 return true;
342 }
343 }
344 return false;
345 }
346
347 bool SolverState::isSetDisequalityEntailedInternal(Node a,
348 Node b,
349 Node re) const
350 {
351 // if there are members in a
352 std::map<Node, std::map<Node, Node> >::const_iterator itpma =
353 d_pol_mems[0].find(a);
354 if (itpma == d_pol_mems[0].end())
355 {
356 // no positive members, continue
357 return false;
358 }
359 // if b is empty
360 if (b == re)
361 {
362 if (!itpma->second.empty())
363 {
364 Trace("sets-deq") << "Disequality is satisfied because members are in "
365 << a << " and " << b << " is empty" << std::endl;
366 return true;
367 }
368 else
369 {
370 // a should not be singleton
371 Assert(d_eqc_singleton.find(a) == d_eqc_singleton.end());
372 }
373 return false;
374 }
375 std::map<Node, Node>::const_iterator itsb = d_eqc_singleton.find(b);
376 std::map<Node, std::map<Node, Node> >::const_iterator itpmb =
377 d_pol_mems[1].find(b);
378 std::vector<Node> prev;
379 for (const std::pair<const Node, Node>& itm : itpma->second)
380 {
381 // if b is a singleton
382 if (itsb != d_eqc_singleton.end())
383 {
384 if (areDisequal(itm.first, itsb->second[0]))
385 {
386 Trace("sets-deq") << "Disequality is satisfied because of "
387 << itm.second << ", singleton eq " << itsb->second[0]
388 << std::endl;
389 return true;
390 }
391 // or disequal with another member
392 for (const Node& p : prev)
393 {
394 if (areDisequal(itm.first, p))
395 {
396 Trace("sets-deq")
397 << "Disequality is satisfied because of disequal members "
398 << itm.first << " " << p << ", singleton eq " << std::endl;
399 return true;
400 }
401 }
402 // if a has positive member that is negative member in b
403 }
404 else if (itpmb != d_pol_mems[1].end())
405 {
406 for (const std::pair<const Node, Node>& itnm : itpmb->second)
407 {
408 if (areEqual(itm.first, itnm.first))
409 {
410 Trace("sets-deq") << "Disequality is satisfied because of "
411 << itm.second << " " << itnm.second << std::endl;
412 return true;
413 }
414 }
415 }
416 prev.push_back(itm.first);
417 }
418 return false;
419 }
420
421 Node SolverState::getProxy(Node n)
422 {
423 Kind nk = n.getKind();
424 if (nk != EMPTYSET && nk != SINGLETON && nk != INTERSECTION && nk != SETMINUS
425 && nk != UNION && nk != UNIVERSE_SET)
426 {
427 return n;
428 }
429 NodeMap::const_iterator it = d_proxy.find(n);
430 if (it != d_proxy.end())
431 {
432 return (*it).second;
433 }
434 NodeManager* nm = NodeManager::currentNM();
435 Node k = d_skCache.mkTypedSkolemCached(
436 n.getType(), n, SkolemCache::SK_PURIFY, "sp");
437 d_proxy[n] = k;
438 d_proxy_to_term[k] = n;
439 Node eq = k.eqNode(n);
440 Trace("sets-lemma") << "Sets::Lemma : " << eq << " by proxy" << std::endl;
441 d_parent.getOutputChannel()->lemma(eq);
442 if (nk == SINGLETON)
443 {
444 Node slem = nm->mkNode(MEMBER, n[0], k);
445 Trace("sets-lemma") << "Sets::Lemma : " << slem << " by singleton"
446 << std::endl;
447 d_parent.getOutputChannel()->lemma(slem);
448 }
449 return k;
450 }
451
452 Node SolverState::getCongruent(Node n) const
453 {
454 Assert(d_ee->hasTerm(n));
455 std::map<Node, Node>::const_iterator it = d_congruent.find(n);
456 if (it == d_congruent.end())
457 {
458 return n;
459 }
460 return it->second;
461 }
462 bool SolverState::isCongruent(Node n) const
463 {
464 return d_congruent.find(n) != d_congruent.end();
465 }
466
467 Node SolverState::getEmptySet(TypeNode tn)
468 {
469 std::map<TypeNode, Node>::iterator it = d_emptyset.find(tn);
470 if (it != d_emptyset.end())
471 {
472 return it->second;
473 }
474 Node n = NodeManager::currentNM()->mkConst(EmptySet(tn));
475 d_emptyset[tn] = n;
476 return n;
477 }
478 Node SolverState::getUnivSet(TypeNode tn)
479 {
480 std::map<TypeNode, Node>::iterator it = d_univset.find(tn);
481 if (it != d_univset.end())
482 {
483 return it->second;
484 }
485 NodeManager* nm = NodeManager::currentNM();
486 Node n = nm->mkNullaryOperator(tn, UNIVERSE_SET);
487 for (it = d_univset.begin(); it != d_univset.end(); ++it)
488 {
489 Node n1;
490 Node n2;
491 if (tn.isSubtypeOf(it->first))
492 {
493 n1 = n;
494 n2 = it->second;
495 }
496 else if (it->first.isSubtypeOf(tn))
497 {
498 n1 = it->second;
499 n2 = n;
500 }
501 if (!n1.isNull())
502 {
503 Node ulem = nm->mkNode(SUBSET, n1, n2);
504 Trace("sets-lemma") << "Sets::Lemma : " << ulem << " by univ-type"
505 << std::endl;
506 d_parent.getOutputChannel()->lemma(ulem);
507 }
508 }
509 d_univset[tn] = n;
510 return n;
511 }
512
513 Node SolverState::getTypeConstraintSkolem(Node n, TypeNode tn)
514 {
515 std::map<TypeNode, Node>::iterator it = d_tc_skolem[n].find(tn);
516 if (it == d_tc_skolem[n].end())
517 {
518 Node k = NodeManager::currentNM()->mkSkolem("tc_k", tn);
519 d_tc_skolem[n][tn] = k;
520 return k;
521 }
522 return it->second;
523 }
524
525 const std::vector<Node>& SolverState::getNonVariableSets(Node r) const
526 {
527 std::map<Node, std::vector<Node> >::const_iterator it = d_nvar_sets.find(r);
528 if (it == d_nvar_sets.end())
529 {
530 return d_emptyVec;
531 }
532 return it->second;
533 }
534
535 Node SolverState::getVariableSet(Node r) const
536 {
537 std::map<Node, Node>::const_iterator it = d_var_set.find(r);
538 if (it != d_var_set.end())
539 {
540 return it->second;
541 }
542 return Node::null();
543 }
544
545 const std::vector<Node>& SolverState::getComprehensionSets(Node r) const
546 {
547 std::map<Node, std::vector<Node> >::const_iterator it = d_compSets.find(r);
548 if (it == d_compSets.end())
549 {
550 return d_emptyVec;
551 }
552 return it->second;
553 }
554
555 const std::map<Node, Node>& SolverState::getMembers(Node r) const
556 {
557 return getMembersInternal(r, 0);
558 }
559 const std::map<Node, Node>& SolverState::getNegativeMembers(Node r) const
560 {
561 return getMembersInternal(r, 1);
562 }
563 const std::map<Node, Node>& SolverState::getMembersInternal(Node r,
564 unsigned i) const
565 {
566 std::map<Node, std::map<Node, Node> >::const_iterator itp =
567 d_pol_mems[i].find(r);
568 if (itp == d_pol_mems[i].end())
569 {
570 return d_emptyMap;
571 }
572 return itp->second;
573 }
574
575 bool SolverState::hasMembers(Node r) const
576 {
577 std::map<Node, std::map<Node, Node> >::const_iterator it =
578 d_pol_mems[0].find(r);
579 if (it == d_pol_mems[0].end())
580 {
581 return false;
582 }
583 return !it->second.empty();
584 }
585 const std::map<Kind, std::map<Node, std::map<Node, Node> > >&
586 SolverState::getBinaryOpIndex() const
587 {
588 return d_bop_index;
589 }
590 const std::map<Kind, std::vector<Node> >& SolverState::getOperatorList() const
591 {
592 return d_op_list;
593 }
594
595 const std::vector<Node>& SolverState::getComprehensionSets() const
596 {
597 return d_allCompSets;
598 }
599
600 void SolverState::debugPrintSet(Node s, const char* c) const
601 {
602 if (s.getNumChildren() == 0)
603 {
604 NodeMap::const_iterator it = d_proxy_to_term.find(s);
605 if (it != d_proxy_to_term.end())
606 {
607 debugPrintSet((*it).second, c);
608 }
609 else
610 {
611 Trace(c) << s;
612 }
613 }
614 else
615 {
616 Trace(c) << "(" << s.getOperator();
617 for (const Node& sc : s)
618 {
619 Trace(c) << " ";
620 debugPrintSet(sc, c);
621 }
622 Trace(c) << ")";
623 }
624 }
625
626 const vector<Node> SolverState::getSetsEqClasses(const TypeNode& t) const
627 {
628 vector<Node> representatives;
629 for (const Node& eqc : getSetsEqClasses())
630 {
631 if (eqc.getType().getSetElementType() == t)
632 {
633 representatives.push_back(eqc);
634 }
635 }
636 return representatives;
637 }
638
639 } // namespace sets
640 } // namespace theory
641 } // namespace CVC4