Update sets inference manager to inherit from InferenceManagerBuffered (#5007)
[cvc5.git] / src / theory / sets / solver_state.cpp
1 /********************* */
2 /*! \file solver_state.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Mudathir Mohamed
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of sets state object
13 **/
14
15 #include "theory/sets/solver_state.h"
16
17 #include "expr/emptyset.h"
18 #include "options/sets_options.h"
19 #include "theory/sets/theory_sets_private.h"
20
21 using namespace std;
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25 namespace theory {
26 namespace sets {
27
28 SolverState::SolverState(context::Context* c,
29 context::UserContext* u,
30 Valuation val)
31 : TheoryState(c, u, val), d_parent(nullptr), d_proxy(u), d_proxy_to_term(u)
32 {
33 d_true = NodeManager::currentNM()->mkConst(true);
34 d_false = NodeManager::currentNM()->mkConst(false);
35 }
36
37 void SolverState::setParent(TheorySetsPrivate* p) { d_parent = p; }
38
39 void SolverState::reset()
40 {
41 d_set_eqc.clear();
42 d_eqc_emptyset.clear();
43 d_eqc_univset.clear();
44 d_eqc_singleton.clear();
45 d_congruent.clear();
46 d_nvar_sets.clear();
47 d_var_set.clear();
48 d_compSets.clear();
49 d_pol_mems[0].clear();
50 d_pol_mems[1].clear();
51 d_members_index.clear();
52 d_singleton_index.clear();
53 d_bop_index.clear();
54 d_op_list.clear();
55 d_allCompSets.clear();
56 }
57
58 void SolverState::registerEqc(TypeNode tn, Node r)
59 {
60 if (tn.isSet())
61 {
62 d_set_eqc.push_back(r);
63 }
64 }
65
66 void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
67 {
68 Kind nk = n.getKind();
69 if (nk == MEMBER)
70 {
71 if (r.isConst())
72 {
73 Node s = d_ee->getRepresentative(n[1]);
74 Node x = d_ee->getRepresentative(n[0]);
75 int pindex = r == d_true ? 0 : (r == d_false ? 1 : -1);
76 if (pindex != -1)
77 {
78 if (d_pol_mems[pindex][s].find(x) == d_pol_mems[pindex][s].end())
79 {
80 d_pol_mems[pindex][s][x] = n;
81 Trace("sets-debug2") << "Membership[" << x << "][" << s << "] : " << n
82 << ", pindex = " << pindex << std::endl;
83 }
84 if (d_members_index[s].find(x) == d_members_index[s].end())
85 {
86 d_members_index[s][x] = n;
87 d_op_list[MEMBER].push_back(n);
88 }
89 }
90 else
91 {
92 Assert(false);
93 }
94 }
95 }
96 else if (nk == SINGLETON || nk == UNION || nk == INTERSECTION
97 || nk == SETMINUS || nk == EMPTYSET || nk == UNIVERSE_SET)
98 {
99 if (nk == SINGLETON)
100 {
101 // singleton lemma
102 getProxy(n);
103 Node re = d_ee->getRepresentative(n[0]);
104 if (d_singleton_index.find(re) == d_singleton_index.end())
105 {
106 d_singleton_index[re] = n;
107 d_eqc_singleton[r] = n;
108 d_op_list[SINGLETON].push_back(n);
109 }
110 else
111 {
112 d_congruent[n] = d_singleton_index[re];
113 }
114 }
115 else if (nk == EMPTYSET)
116 {
117 d_eqc_emptyset[tnn] = r;
118 }
119 else if (nk == UNIVERSE_SET)
120 {
121 Assert(options::setsExt());
122 d_eqc_univset[tnn] = r;
123 }
124 else
125 {
126 Node r1 = d_ee->getRepresentative(n[0]);
127 Node r2 = d_ee->getRepresentative(n[1]);
128 std::map<Node, Node>& binr1 = d_bop_index[nk][r1];
129 std::map<Node, Node>::iterator itb = binr1.find(r2);
130 if (itb == binr1.end())
131 {
132 binr1[r2] = n;
133 d_op_list[nk].push_back(n);
134 }
135 else
136 {
137 d_congruent[n] = itb->second;
138 }
139 }
140 d_nvar_sets[r].push_back(n);
141 Trace("sets-debug2") << "Non-var-set[" << r << "] : " << n << std::endl;
142 }
143 else if (nk == COMPREHENSION)
144 {
145 d_compSets[r].push_back(n);
146 d_allCompSets.push_back(n);
147 Trace("sets-debug2") << "Comp-set[" << r << "] : " << n << std::endl;
148 }
149 else if (n.isVar() && !d_skCache.isSkolem(n))
150 {
151 // it is important that we check it is a variable, but not an internally
152 // introduced skolem, due to the semantics of the universe set.
153 if (tnn.isSet())
154 {
155 if (d_var_set.find(r) == d_var_set.end())
156 {
157 d_var_set[r] = n;
158 Trace("sets-debug2") << "var-set[" << r << "] : " << n << std::endl;
159 }
160 }
161 }
162 else
163 {
164 Trace("sets-debug2") << "Unknown-set[" << r << "] : " << n << std::endl;
165 }
166 }
167
168 void SolverState::addEqualityToExp(Node a, Node b, std::vector<Node>& exp) const
169 {
170 if (a != b)
171 {
172 Assert(areEqual(a, b));
173 exp.push_back(a.eqNode(b));
174 }
175 }
176
177 Node SolverState::getEmptySetEqClass(TypeNode tn) const
178 {
179 std::map<TypeNode, Node>::const_iterator it = d_eqc_emptyset.find(tn);
180 if (it != d_eqc_emptyset.end())
181 {
182 return it->second;
183 }
184 return Node::null();
185 }
186
187 Node SolverState::getUnivSetEqClass(TypeNode tn) const
188 {
189 std::map<TypeNode, Node>::const_iterator it = d_univset.find(tn);
190 if (it != d_univset.end())
191 {
192 return it->second;
193 }
194 return Node::null();
195 }
196
197 Node SolverState::getSingletonEqClass(Node r) const
198 {
199 std::map<Node, Node>::const_iterator it = d_eqc_singleton.find(r);
200 if (it != d_eqc_singleton.end())
201 {
202 return it->second;
203 }
204 return Node::null();
205 }
206
207 Node SolverState::getBinaryOpTerm(Kind k, Node r1, Node r2) const
208 {
209 std::map<Kind, std::map<Node, std::map<Node, Node> > >::const_iterator itk =
210 d_bop_index.find(k);
211 if (itk == d_bop_index.end())
212 {
213 return Node::null();
214 }
215 std::map<Node, std::map<Node, Node> >::const_iterator it1 =
216 itk->second.find(r1);
217 if (it1 == itk->second.end())
218 {
219 return Node::null();
220 }
221 std::map<Node, Node>::const_iterator it2 = it1->second.find(r2);
222 if (it2 == it1->second.end())
223 {
224 return Node::null();
225 }
226 return it2->second;
227 }
228
229 bool SolverState::isEntailed(Node n, bool polarity) const
230 {
231 if (n.getKind() == NOT)
232 {
233 return isEntailed(n[0], !polarity);
234 }
235 else if (n.getKind() == EQUAL)
236 {
237 if (polarity)
238 {
239 return areEqual(n[0], n[1]);
240 }
241 return areDisequal(n[0], n[1]);
242 }
243 else if (n.getKind() == MEMBER)
244 {
245 if (areEqual(n, polarity ? d_true : d_false))
246 {
247 return true;
248 }
249 // check members cache
250 if (polarity && d_ee->hasTerm(n[1]))
251 {
252 Node r = d_ee->getRepresentative(n[1]);
253 if (d_parent->isMember(n[0], r))
254 {
255 return true;
256 }
257 }
258 }
259 else if (n.getKind() == AND || n.getKind() == OR)
260 {
261 bool conj = (n.getKind() == AND) == polarity;
262 for (const Node& nc : n)
263 {
264 bool isEnt = isEntailed(nc, polarity);
265 if (isEnt != conj)
266 {
267 return !conj;
268 }
269 }
270 return conj;
271 }
272 else if (n.isConst())
273 {
274 return (polarity && n == d_true) || (!polarity && n == d_false);
275 }
276 return false;
277 }
278
279 bool SolverState::isSetDisequalityEntailed(Node r1, Node r2) const
280 {
281 Assert(d_ee->hasTerm(r1) && d_ee->getRepresentative(r1) == r1);
282 Assert(d_ee->hasTerm(r2) && d_ee->getRepresentative(r2) == r2);
283 TypeNode tn = r1.getType();
284 Node re = getEmptySetEqClass(tn);
285 for (unsigned e = 0; e < 2; e++)
286 {
287 Node a = e == 0 ? r1 : r2;
288 Node b = e == 0 ? r2 : r1;
289 if (isSetDisequalityEntailedInternal(a, b, re))
290 {
291 return true;
292 }
293 }
294 return false;
295 }
296
297 bool SolverState::isSetDisequalityEntailedInternal(Node a,
298 Node b,
299 Node re) const
300 {
301 // if there are members in a
302 std::map<Node, std::map<Node, Node> >::const_iterator itpma =
303 d_pol_mems[0].find(a);
304 if (itpma == d_pol_mems[0].end())
305 {
306 // no positive members, continue
307 return false;
308 }
309 // if b is empty
310 if (b == re)
311 {
312 if (!itpma->second.empty())
313 {
314 Trace("sets-deq") << "Disequality is satisfied because members are in "
315 << a << " and " << b << " is empty" << std::endl;
316 return true;
317 }
318 else
319 {
320 // a should not be singleton
321 Assert(d_eqc_singleton.find(a) == d_eqc_singleton.end());
322 }
323 return false;
324 }
325 std::map<Node, Node>::const_iterator itsb = d_eqc_singleton.find(b);
326 std::map<Node, std::map<Node, Node> >::const_iterator itpmb =
327 d_pol_mems[1].find(b);
328 std::vector<Node> prev;
329 for (const std::pair<const Node, Node>& itm : itpma->second)
330 {
331 // if b is a singleton
332 if (itsb != d_eqc_singleton.end())
333 {
334 if (areDisequal(itm.first, itsb->second[0]))
335 {
336 Trace("sets-deq") << "Disequality is satisfied because of "
337 << itm.second << ", singleton eq " << itsb->second[0]
338 << std::endl;
339 return true;
340 }
341 // or disequal with another member
342 for (const Node& p : prev)
343 {
344 if (areDisequal(itm.first, p))
345 {
346 Trace("sets-deq")
347 << "Disequality is satisfied because of disequal members "
348 << itm.first << " " << p << ", singleton eq " << std::endl;
349 return true;
350 }
351 }
352 // if a has positive member that is negative member in b
353 }
354 else if (itpmb != d_pol_mems[1].end())
355 {
356 for (const std::pair<const Node, Node>& itnm : itpmb->second)
357 {
358 if (areEqual(itm.first, itnm.first))
359 {
360 Trace("sets-deq") << "Disequality is satisfied because of "
361 << itm.second << " " << itnm.second << std::endl;
362 return true;
363 }
364 }
365 }
366 prev.push_back(itm.first);
367 }
368 return false;
369 }
370
371 Node SolverState::getProxy(Node n)
372 {
373 Kind nk = n.getKind();
374 if (nk != EMPTYSET && nk != SINGLETON && nk != INTERSECTION && nk != SETMINUS
375 && nk != UNION && nk != UNIVERSE_SET)
376 {
377 return n;
378 }
379 NodeMap::const_iterator it = d_proxy.find(n);
380 if (it != d_proxy.end())
381 {
382 return (*it).second;
383 }
384 NodeManager* nm = NodeManager::currentNM();
385 Node k = d_skCache.mkTypedSkolemCached(
386 n.getType(), n, SkolemCache::SK_PURIFY, "sp");
387 d_proxy[n] = k;
388 d_proxy_to_term[k] = n;
389 Node eq = k.eqNode(n);
390 Trace("sets-lemma") << "Sets::Lemma : " << eq << " by proxy" << std::endl;
391 d_parent->getOutputChannel()->lemma(eq);
392 if (nk == SINGLETON)
393 {
394 Node slem = nm->mkNode(MEMBER, n[0], k);
395 Trace("sets-lemma") << "Sets::Lemma : " << slem << " by singleton"
396 << std::endl;
397 d_parent->getOutputChannel()->lemma(slem);
398 }
399 return k;
400 }
401
402 Node SolverState::getCongruent(Node n) const
403 {
404 Assert(d_ee->hasTerm(n));
405 std::map<Node, Node>::const_iterator it = d_congruent.find(n);
406 if (it == d_congruent.end())
407 {
408 return n;
409 }
410 return it->second;
411 }
412 bool SolverState::isCongruent(Node n) const
413 {
414 return d_congruent.find(n) != d_congruent.end();
415 }
416
417 Node SolverState::getEmptySet(TypeNode tn)
418 {
419 std::map<TypeNode, Node>::iterator it = d_emptyset.find(tn);
420 if (it != d_emptyset.end())
421 {
422 return it->second;
423 }
424 Node n = NodeManager::currentNM()->mkConst(EmptySet(tn));
425 d_emptyset[tn] = n;
426 return n;
427 }
428 Node SolverState::getUnivSet(TypeNode tn)
429 {
430 std::map<TypeNode, Node>::iterator it = d_univset.find(tn);
431 if (it != d_univset.end())
432 {
433 return it->second;
434 }
435 NodeManager* nm = NodeManager::currentNM();
436 Node n = nm->mkNullaryOperator(tn, UNIVERSE_SET);
437 for (it = d_univset.begin(); it != d_univset.end(); ++it)
438 {
439 Node n1;
440 Node n2;
441 if (tn.isSubtypeOf(it->first))
442 {
443 n1 = n;
444 n2 = it->second;
445 }
446 else if (it->first.isSubtypeOf(tn))
447 {
448 n1 = it->second;
449 n2 = n;
450 }
451 if (!n1.isNull())
452 {
453 Node ulem = nm->mkNode(SUBSET, n1, n2);
454 Trace("sets-lemma") << "Sets::Lemma : " << ulem << " by univ-type"
455 << std::endl;
456 d_parent->getOutputChannel()->lemma(ulem);
457 }
458 }
459 d_univset[tn] = n;
460 return n;
461 }
462
463 Node SolverState::getTypeConstraintSkolem(Node n, TypeNode tn)
464 {
465 std::map<TypeNode, Node>::iterator it = d_tc_skolem[n].find(tn);
466 if (it == d_tc_skolem[n].end())
467 {
468 Node k = NodeManager::currentNM()->mkSkolem("tc_k", tn);
469 d_tc_skolem[n][tn] = k;
470 return k;
471 }
472 return it->second;
473 }
474
475 const std::vector<Node>& SolverState::getNonVariableSets(Node r) const
476 {
477 std::map<Node, std::vector<Node> >::const_iterator it = d_nvar_sets.find(r);
478 if (it == d_nvar_sets.end())
479 {
480 return d_emptyVec;
481 }
482 return it->second;
483 }
484
485 Node SolverState::getVariableSet(Node r) const
486 {
487 std::map<Node, Node>::const_iterator it = d_var_set.find(r);
488 if (it != d_var_set.end())
489 {
490 return it->second;
491 }
492 return Node::null();
493 }
494
495 const std::vector<Node>& SolverState::getComprehensionSets(Node r) const
496 {
497 std::map<Node, std::vector<Node> >::const_iterator it = d_compSets.find(r);
498 if (it == d_compSets.end())
499 {
500 return d_emptyVec;
501 }
502 return it->second;
503 }
504
505 const std::map<Node, Node>& SolverState::getMembers(Node r) const
506 {
507 return getMembersInternal(r, 0);
508 }
509 const std::map<Node, Node>& SolverState::getNegativeMembers(Node r) const
510 {
511 return getMembersInternal(r, 1);
512 }
513 const std::map<Node, Node>& SolverState::getMembersInternal(Node r,
514 unsigned i) const
515 {
516 std::map<Node, std::map<Node, Node> >::const_iterator itp =
517 d_pol_mems[i].find(r);
518 if (itp == d_pol_mems[i].end())
519 {
520 return d_emptyMap;
521 }
522 return itp->second;
523 }
524
525 bool SolverState::hasMembers(Node r) const
526 {
527 std::map<Node, std::map<Node, Node> >::const_iterator it =
528 d_pol_mems[0].find(r);
529 if (it == d_pol_mems[0].end())
530 {
531 return false;
532 }
533 return !it->second.empty();
534 }
535 const std::map<Kind, std::map<Node, std::map<Node, Node> > >&
536 SolverState::getBinaryOpIndex() const
537 {
538 return d_bop_index;
539 }
540 const std::map<Kind, std::vector<Node> >& SolverState::getOperatorList() const
541 {
542 return d_op_list;
543 }
544
545 const std::vector<Node>& SolverState::getComprehensionSets() const
546 {
547 return d_allCompSets;
548 }
549
550 void SolverState::debugPrintSet(Node s, const char* c) const
551 {
552 if (s.getNumChildren() == 0)
553 {
554 NodeMap::const_iterator it = d_proxy_to_term.find(s);
555 if (it != d_proxy_to_term.end())
556 {
557 debugPrintSet((*it).second, c);
558 }
559 else
560 {
561 Trace(c) << s;
562 }
563 }
564 else
565 {
566 Trace(c) << "(" << s.getOperator();
567 for (const Node& sc : s)
568 {
569 Trace(c) << " ";
570 debugPrintSet(sc, c);
571 }
572 Trace(c) << ")";
573 }
574 }
575
576 const vector<Node> SolverState::getSetsEqClasses(const TypeNode& t) const
577 {
578 vector<Node> representatives;
579 for (const Node& eqc : getSetsEqClasses())
580 {
581 if (eqc.getType().getSetElementType() == t)
582 {
583 representatives.push_back(eqc);
584 }
585 }
586 return representatives;
587 }
588
589 } // namespace sets
590 } // namespace theory
591 } // namespace CVC4