Make sets and strings solver states inherit from TheoryState (#4918)
[cvc5.git] / src / theory / sets / solver_state.cpp
1 /********************* */
2 /*! \file solver_state.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Mudathir Mohamed
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of sets state object
13 **/
14
15 #include "theory/sets/solver_state.h"
16
17 #include "expr/emptyset.h"
18 #include "options/sets_options.h"
19 #include "theory/sets/theory_sets_private.h"
20
21 using namespace std;
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25 namespace theory {
26 namespace sets {
27
28 SolverState::SolverState(TheorySetsPrivate& p,
29 context::Context* c,
30 context::UserContext* u,
31 Valuation val)
32 : TheoryState(c, u, val), d_parent(p), d_proxy(u), d_proxy_to_term(u)
33 {
34 d_true = NodeManager::currentNM()->mkConst(true);
35 d_false = NodeManager::currentNM()->mkConst(false);
36 }
37
38 void SolverState::reset()
39 {
40 d_set_eqc.clear();
41 d_eqc_emptyset.clear();
42 d_eqc_univset.clear();
43 d_eqc_singleton.clear();
44 d_congruent.clear();
45 d_nvar_sets.clear();
46 d_var_set.clear();
47 d_compSets.clear();
48 d_pol_mems[0].clear();
49 d_pol_mems[1].clear();
50 d_members_index.clear();
51 d_singleton_index.clear();
52 d_bop_index.clear();
53 d_op_list.clear();
54 d_allCompSets.clear();
55 }
56
57 void SolverState::registerEqc(TypeNode tn, Node r)
58 {
59 if (tn.isSet())
60 {
61 d_set_eqc.push_back(r);
62 }
63 }
64
65 void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
66 {
67 Kind nk = n.getKind();
68 if (nk == MEMBER)
69 {
70 if (r.isConst())
71 {
72 Node s = d_ee->getRepresentative(n[1]);
73 Node x = d_ee->getRepresentative(n[0]);
74 int pindex = r == d_true ? 0 : (r == d_false ? 1 : -1);
75 if (pindex != -1)
76 {
77 if (d_pol_mems[pindex][s].find(x) == d_pol_mems[pindex][s].end())
78 {
79 d_pol_mems[pindex][s][x] = n;
80 Trace("sets-debug2") << "Membership[" << x << "][" << s << "] : " << n
81 << ", pindex = " << pindex << std::endl;
82 }
83 if (d_members_index[s].find(x) == d_members_index[s].end())
84 {
85 d_members_index[s][x] = n;
86 d_op_list[MEMBER].push_back(n);
87 }
88 }
89 else
90 {
91 Assert(false);
92 }
93 }
94 }
95 else if (nk == SINGLETON || nk == UNION || nk == INTERSECTION
96 || nk == SETMINUS || nk == EMPTYSET || nk == UNIVERSE_SET)
97 {
98 if (nk == SINGLETON)
99 {
100 // singleton lemma
101 getProxy(n);
102 Node re = d_ee->getRepresentative(n[0]);
103 if (d_singleton_index.find(re) == d_singleton_index.end())
104 {
105 d_singleton_index[re] = n;
106 d_eqc_singleton[r] = n;
107 d_op_list[SINGLETON].push_back(n);
108 }
109 else
110 {
111 d_congruent[n] = d_singleton_index[re];
112 }
113 }
114 else if (nk == EMPTYSET)
115 {
116 d_eqc_emptyset[tnn] = r;
117 }
118 else if (nk == UNIVERSE_SET)
119 {
120 Assert(options::setsExt());
121 d_eqc_univset[tnn] = r;
122 }
123 else
124 {
125 Node r1 = d_ee->getRepresentative(n[0]);
126 Node r2 = d_ee->getRepresentative(n[1]);
127 std::map<Node, Node>& binr1 = d_bop_index[nk][r1];
128 std::map<Node, Node>::iterator itb = binr1.find(r2);
129 if (itb == binr1.end())
130 {
131 binr1[r2] = n;
132 d_op_list[nk].push_back(n);
133 }
134 else
135 {
136 d_congruent[n] = itb->second;
137 }
138 }
139 d_nvar_sets[r].push_back(n);
140 Trace("sets-debug2") << "Non-var-set[" << r << "] : " << n << std::endl;
141 }
142 else if (nk == COMPREHENSION)
143 {
144 d_compSets[r].push_back(n);
145 d_allCompSets.push_back(n);
146 Trace("sets-debug2") << "Comp-set[" << r << "] : " << n << std::endl;
147 }
148 else if (n.isVar() && !d_skCache.isSkolem(n))
149 {
150 // it is important that we check it is a variable, but not an internally
151 // introduced skolem, due to the semantics of the universe set.
152 if (tnn.isSet())
153 {
154 if (d_var_set.find(r) == d_var_set.end())
155 {
156 d_var_set[r] = n;
157 Trace("sets-debug2") << "var-set[" << r << "] : " << n << std::endl;
158 }
159 }
160 }
161 else
162 {
163 Trace("sets-debug2") << "Unknown-set[" << r << "] : " << n << std::endl;
164 }
165 }
166
167 void SolverState::addEqualityToExp(Node a, Node b, std::vector<Node>& exp) const
168 {
169 if (a != b)
170 {
171 Assert(areEqual(a, b));
172 exp.push_back(a.eqNode(b));
173 }
174 }
175
176 Node SolverState::getEmptySetEqClass(TypeNode tn) const
177 {
178 std::map<TypeNode, Node>::const_iterator it = d_eqc_emptyset.find(tn);
179 if (it != d_eqc_emptyset.end())
180 {
181 return it->second;
182 }
183 return Node::null();
184 }
185
186 Node SolverState::getUnivSetEqClass(TypeNode tn) const
187 {
188 std::map<TypeNode, Node>::const_iterator it = d_univset.find(tn);
189 if (it != d_univset.end())
190 {
191 return it->second;
192 }
193 return Node::null();
194 }
195
196 Node SolverState::getSingletonEqClass(Node r) const
197 {
198 std::map<Node, Node>::const_iterator it = d_eqc_singleton.find(r);
199 if (it != d_eqc_singleton.end())
200 {
201 return it->second;
202 }
203 return Node::null();
204 }
205
206 Node SolverState::getBinaryOpTerm(Kind k, Node r1, Node r2) const
207 {
208 std::map<Kind, std::map<Node, std::map<Node, Node> > >::const_iterator itk =
209 d_bop_index.find(k);
210 if (itk == d_bop_index.end())
211 {
212 return Node::null();
213 }
214 std::map<Node, std::map<Node, Node> >::const_iterator it1 =
215 itk->second.find(r1);
216 if (it1 == itk->second.end())
217 {
218 return Node::null();
219 }
220 std::map<Node, Node>::const_iterator it2 = it1->second.find(r2);
221 if (it2 == it1->second.end())
222 {
223 return Node::null();
224 }
225 return it2->second;
226 }
227
228 bool SolverState::isEntailed(Node n, bool polarity) const
229 {
230 if (n.getKind() == NOT)
231 {
232 return isEntailed(n[0], !polarity);
233 }
234 else if (n.getKind() == EQUAL)
235 {
236 if (polarity)
237 {
238 return areEqual(n[0], n[1]);
239 }
240 return areDisequal(n[0], n[1]);
241 }
242 else if (n.getKind() == MEMBER)
243 {
244 if (areEqual(n, polarity ? d_true : d_false))
245 {
246 return true;
247 }
248 // check members cache
249 if (polarity && d_ee->hasTerm(n[1]))
250 {
251 Node r = d_ee->getRepresentative(n[1]);
252 if (d_parent.isMember(n[0], r))
253 {
254 return true;
255 }
256 }
257 }
258 else if (n.getKind() == AND || n.getKind() == OR)
259 {
260 bool conj = (n.getKind() == AND) == polarity;
261 for (const Node& nc : n)
262 {
263 bool isEnt = isEntailed(nc, polarity);
264 if (isEnt != conj)
265 {
266 return !conj;
267 }
268 }
269 return conj;
270 }
271 else if (n.isConst())
272 {
273 return (polarity && n == d_true) || (!polarity && n == d_false);
274 }
275 return false;
276 }
277
278 bool SolverState::isSetDisequalityEntailed(Node r1, Node r2) const
279 {
280 Assert(d_ee->hasTerm(r1) && d_ee->getRepresentative(r1) == r1);
281 Assert(d_ee->hasTerm(r2) && d_ee->getRepresentative(r2) == r2);
282 TypeNode tn = r1.getType();
283 Node re = getEmptySetEqClass(tn);
284 for (unsigned e = 0; e < 2; e++)
285 {
286 Node a = e == 0 ? r1 : r2;
287 Node b = e == 0 ? r2 : r1;
288 if (isSetDisequalityEntailedInternal(a, b, re))
289 {
290 return true;
291 }
292 }
293 return false;
294 }
295
296 bool SolverState::isSetDisequalityEntailedInternal(Node a,
297 Node b,
298 Node re) const
299 {
300 // if there are members in a
301 std::map<Node, std::map<Node, Node> >::const_iterator itpma =
302 d_pol_mems[0].find(a);
303 if (itpma == d_pol_mems[0].end())
304 {
305 // no positive members, continue
306 return false;
307 }
308 // if b is empty
309 if (b == re)
310 {
311 if (!itpma->second.empty())
312 {
313 Trace("sets-deq") << "Disequality is satisfied because members are in "
314 << a << " and " << b << " is empty" << std::endl;
315 return true;
316 }
317 else
318 {
319 // a should not be singleton
320 Assert(d_eqc_singleton.find(a) == d_eqc_singleton.end());
321 }
322 return false;
323 }
324 std::map<Node, Node>::const_iterator itsb = d_eqc_singleton.find(b);
325 std::map<Node, std::map<Node, Node> >::const_iterator itpmb =
326 d_pol_mems[1].find(b);
327 std::vector<Node> prev;
328 for (const std::pair<const Node, Node>& itm : itpma->second)
329 {
330 // if b is a singleton
331 if (itsb != d_eqc_singleton.end())
332 {
333 if (areDisequal(itm.first, itsb->second[0]))
334 {
335 Trace("sets-deq") << "Disequality is satisfied because of "
336 << itm.second << ", singleton eq " << itsb->second[0]
337 << std::endl;
338 return true;
339 }
340 // or disequal with another member
341 for (const Node& p : prev)
342 {
343 if (areDisequal(itm.first, p))
344 {
345 Trace("sets-deq")
346 << "Disequality is satisfied because of disequal members "
347 << itm.first << " " << p << ", singleton eq " << std::endl;
348 return true;
349 }
350 }
351 // if a has positive member that is negative member in b
352 }
353 else if (itpmb != d_pol_mems[1].end())
354 {
355 for (const std::pair<const Node, Node>& itnm : itpmb->second)
356 {
357 if (areEqual(itm.first, itnm.first))
358 {
359 Trace("sets-deq") << "Disequality is satisfied because of "
360 << itm.second << " " << itnm.second << std::endl;
361 return true;
362 }
363 }
364 }
365 prev.push_back(itm.first);
366 }
367 return false;
368 }
369
370 Node SolverState::getProxy(Node n)
371 {
372 Kind nk = n.getKind();
373 if (nk != EMPTYSET && nk != SINGLETON && nk != INTERSECTION && nk != SETMINUS
374 && nk != UNION && nk != UNIVERSE_SET)
375 {
376 return n;
377 }
378 NodeMap::const_iterator it = d_proxy.find(n);
379 if (it != d_proxy.end())
380 {
381 return (*it).second;
382 }
383 NodeManager* nm = NodeManager::currentNM();
384 Node k = d_skCache.mkTypedSkolemCached(
385 n.getType(), n, SkolemCache::SK_PURIFY, "sp");
386 d_proxy[n] = k;
387 d_proxy_to_term[k] = n;
388 Node eq = k.eqNode(n);
389 Trace("sets-lemma") << "Sets::Lemma : " << eq << " by proxy" << std::endl;
390 d_parent.getOutputChannel()->lemma(eq);
391 if (nk == SINGLETON)
392 {
393 Node slem = nm->mkNode(MEMBER, n[0], k);
394 Trace("sets-lemma") << "Sets::Lemma : " << slem << " by singleton"
395 << std::endl;
396 d_parent.getOutputChannel()->lemma(slem);
397 }
398 return k;
399 }
400
401 Node SolverState::getCongruent(Node n) const
402 {
403 Assert(d_ee->hasTerm(n));
404 std::map<Node, Node>::const_iterator it = d_congruent.find(n);
405 if (it == d_congruent.end())
406 {
407 return n;
408 }
409 return it->second;
410 }
411 bool SolverState::isCongruent(Node n) const
412 {
413 return d_congruent.find(n) != d_congruent.end();
414 }
415
416 Node SolverState::getEmptySet(TypeNode tn)
417 {
418 std::map<TypeNode, Node>::iterator it = d_emptyset.find(tn);
419 if (it != d_emptyset.end())
420 {
421 return it->second;
422 }
423 Node n = NodeManager::currentNM()->mkConst(EmptySet(tn));
424 d_emptyset[tn] = n;
425 return n;
426 }
427 Node SolverState::getUnivSet(TypeNode tn)
428 {
429 std::map<TypeNode, Node>::iterator it = d_univset.find(tn);
430 if (it != d_univset.end())
431 {
432 return it->second;
433 }
434 NodeManager* nm = NodeManager::currentNM();
435 Node n = nm->mkNullaryOperator(tn, UNIVERSE_SET);
436 for (it = d_univset.begin(); it != d_univset.end(); ++it)
437 {
438 Node n1;
439 Node n2;
440 if (tn.isSubtypeOf(it->first))
441 {
442 n1 = n;
443 n2 = it->second;
444 }
445 else if (it->first.isSubtypeOf(tn))
446 {
447 n1 = it->second;
448 n2 = n;
449 }
450 if (!n1.isNull())
451 {
452 Node ulem = nm->mkNode(SUBSET, n1, n2);
453 Trace("sets-lemma") << "Sets::Lemma : " << ulem << " by univ-type"
454 << std::endl;
455 d_parent.getOutputChannel()->lemma(ulem);
456 }
457 }
458 d_univset[tn] = n;
459 return n;
460 }
461
462 Node SolverState::getTypeConstraintSkolem(Node n, TypeNode tn)
463 {
464 std::map<TypeNode, Node>::iterator it = d_tc_skolem[n].find(tn);
465 if (it == d_tc_skolem[n].end())
466 {
467 Node k = NodeManager::currentNM()->mkSkolem("tc_k", tn);
468 d_tc_skolem[n][tn] = k;
469 return k;
470 }
471 return it->second;
472 }
473
474 const std::vector<Node>& SolverState::getNonVariableSets(Node r) const
475 {
476 std::map<Node, std::vector<Node> >::const_iterator it = d_nvar_sets.find(r);
477 if (it == d_nvar_sets.end())
478 {
479 return d_emptyVec;
480 }
481 return it->second;
482 }
483
484 Node SolverState::getVariableSet(Node r) const
485 {
486 std::map<Node, Node>::const_iterator it = d_var_set.find(r);
487 if (it != d_var_set.end())
488 {
489 return it->second;
490 }
491 return Node::null();
492 }
493
494 const std::vector<Node>& SolverState::getComprehensionSets(Node r) const
495 {
496 std::map<Node, std::vector<Node> >::const_iterator it = d_compSets.find(r);
497 if (it == d_compSets.end())
498 {
499 return d_emptyVec;
500 }
501 return it->second;
502 }
503
504 const std::map<Node, Node>& SolverState::getMembers(Node r) const
505 {
506 return getMembersInternal(r, 0);
507 }
508 const std::map<Node, Node>& SolverState::getNegativeMembers(Node r) const
509 {
510 return getMembersInternal(r, 1);
511 }
512 const std::map<Node, Node>& SolverState::getMembersInternal(Node r,
513 unsigned i) const
514 {
515 std::map<Node, std::map<Node, Node> >::const_iterator itp =
516 d_pol_mems[i].find(r);
517 if (itp == d_pol_mems[i].end())
518 {
519 return d_emptyMap;
520 }
521 return itp->second;
522 }
523
524 bool SolverState::hasMembers(Node r) const
525 {
526 std::map<Node, std::map<Node, Node> >::const_iterator it =
527 d_pol_mems[0].find(r);
528 if (it == d_pol_mems[0].end())
529 {
530 return false;
531 }
532 return !it->second.empty();
533 }
534 const std::map<Kind, std::map<Node, std::map<Node, Node> > >&
535 SolverState::getBinaryOpIndex() const
536 {
537 return d_bop_index;
538 }
539 const std::map<Kind, std::vector<Node> >& SolverState::getOperatorList() const
540 {
541 return d_op_list;
542 }
543
544 const std::vector<Node>& SolverState::getComprehensionSets() const
545 {
546 return d_allCompSets;
547 }
548
549 void SolverState::debugPrintSet(Node s, const char* c) const
550 {
551 if (s.getNumChildren() == 0)
552 {
553 NodeMap::const_iterator it = d_proxy_to_term.find(s);
554 if (it != d_proxy_to_term.end())
555 {
556 debugPrintSet((*it).second, c);
557 }
558 else
559 {
560 Trace(c) << s;
561 }
562 }
563 else
564 {
565 Trace(c) << "(" << s.getOperator();
566 for (const Node& sc : s)
567 {
568 Trace(c) << " ";
569 debugPrintSet(sc, c);
570 }
571 Trace(c) << ")";
572 }
573 }
574
575 const vector<Node> SolverState::getSetsEqClasses(const TypeNode& t) const
576 {
577 vector<Node> representatives;
578 for (const Node& eqc : getSetsEqClasses())
579 {
580 if (eqc.getType().getSetElementType() == t)
581 {
582 representatives.push_back(eqc);
583 }
584 }
585 return representatives;
586 }
587
588 } // namespace sets
589 } // namespace theory
590 } // namespace CVC4