f27e14121b6e383636830cb6cd12050801c4476b
[cvc5.git] / src / theory / quantifiers / entailment_check.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Implementation of entailment check class.
14 */
15
16 #include "theory/quantifiers/entailment_check.h"
17
18 #include "theory/quantifiers/quantifiers_state.h"
19 #include "theory/quantifiers/term_database.h"
20
21 using namespace cvc5::kind;
22 using namespace cvc5::context;
23
24 namespace cvc5 {
25 namespace theory {
26 namespace quantifiers {
27
28 EntailmentCheck::EntailmentCheck(Env& env, QuantifiersState& qs, TermDb& tdb)
29 : EnvObj(env), d_qstate(qs), d_tdb(tdb)
30 {
31 d_true = NodeManager::currentNM()->mkConst(true);
32 d_false = NodeManager::currentNM()->mkConst(false);
33 }
34
35 EntailmentCheck::~EntailmentCheck() {}
36 Node EntailmentCheck::evaluateTerm2(TNode n,
37 std::map<TNode, Node>& visited,
38 std::vector<Node>& exp,
39 bool useEntailmentTests,
40 bool computeExp,
41 bool reqHasTerm)
42 {
43 std::map<TNode, Node>::iterator itv = visited.find(n);
44 if (itv != visited.end())
45 {
46 return itv->second;
47 }
48 size_t prevSize = exp.size();
49 Trace("term-db-eval") << "evaluate term : " << n << std::endl;
50 Node ret = n;
51 if (n.getKind() == FORALL || n.getKind() == BOUND_VARIABLE)
52 {
53 // do nothing
54 }
55 else if (d_qstate.hasTerm(n))
56 {
57 Trace("term-db-eval") << "...exists in ee, return rep" << std::endl;
58 ret = d_qstate.getRepresentative(n);
59 if (computeExp)
60 {
61 if (n != ret)
62 {
63 exp.push_back(n.eqNode(ret));
64 }
65 }
66 reqHasTerm = false;
67 }
68 else if (n.hasOperator())
69 {
70 std::vector<TNode> args;
71 bool ret_set = false;
72 Kind k = n.getKind();
73 std::vector<Node> tempExp;
74 for (unsigned i = 0, nchild = n.getNumChildren(); i < nchild; i++)
75 {
76 TNode c = evaluateTerm2(
77 n[i], visited, tempExp, useEntailmentTests, computeExp, reqHasTerm);
78 if (c.isNull())
79 {
80 ret = Node::null();
81 ret_set = true;
82 break;
83 }
84 else if (c == d_true || c == d_false)
85 {
86 // short-circuiting
87 if ((k == AND && c == d_false) || (k == OR && c == d_true))
88 {
89 ret = c;
90 ret_set = true;
91 reqHasTerm = false;
92 break;
93 }
94 else if (k == ITE && i == 0)
95 {
96 ret = evaluateTerm2(n[c == d_true ? 1 : 2],
97 visited,
98 tempExp,
99 useEntailmentTests,
100 computeExp,
101 reqHasTerm);
102 ret_set = true;
103 reqHasTerm = false;
104 break;
105 }
106 }
107 if (computeExp)
108 {
109 exp.insert(exp.end(), tempExp.begin(), tempExp.end());
110 }
111 Trace("term-db-eval") << " child " << i << " : " << c << std::endl;
112 args.push_back(c);
113 }
114 if (ret_set)
115 {
116 // if we short circuited
117 if (computeExp)
118 {
119 exp.clear();
120 exp.insert(exp.end(), tempExp.begin(), tempExp.end());
121 }
122 }
123 else
124 {
125 // get the (indexed) operator of n, if it exists
126 TNode f = d_tdb.getMatchOperator(n);
127 // if it is an indexed term, return the congruent term
128 if (!f.isNull())
129 {
130 // if f is congruent to a term indexed by this class
131 TNode nn = d_tdb.getCongruentTerm(f, args);
132 Trace("term-db-eval") << " got congruent term " << nn
133 << " from DB for " << n << std::endl;
134 if (!nn.isNull())
135 {
136 if (computeExp)
137 {
138 Assert(nn.getNumChildren() == n.getNumChildren());
139 for (size_t i = 0, nchild = nn.getNumChildren(); i < nchild; i++)
140 {
141 if (nn[i] != n[i])
142 {
143 exp.push_back(nn[i].eqNode(n[i]));
144 }
145 }
146 }
147 ret = d_qstate.getRepresentative(nn);
148 Trace("term-db-eval") << "return rep" << std::endl;
149 ret_set = true;
150 reqHasTerm = false;
151 Assert(!ret.isNull());
152 if (computeExp)
153 {
154 if (n != ret)
155 {
156 exp.push_back(nn.eqNode(ret));
157 }
158 }
159 }
160 }
161 if (!ret_set)
162 {
163 Trace("term-db-eval") << "return rewrite" << std::endl;
164 // a theory symbol or a new UF term
165 if (n.getMetaKind() == metakind::PARAMETERIZED)
166 {
167 args.insert(args.begin(), n.getOperator());
168 }
169 ret = NodeManager::currentNM()->mkNode(n.getKind(), args);
170 ret = rewrite(ret);
171 if (ret.getKind() == EQUAL)
172 {
173 if (d_qstate.areDisequal(ret[0], ret[1]))
174 {
175 ret = d_false;
176 }
177 }
178 if (useEntailmentTests)
179 {
180 if (ret.getKind() == EQUAL || ret.getKind() == GEQ)
181 {
182 Valuation& val = d_qstate.getValuation();
183 for (unsigned j = 0; j < 2; j++)
184 {
185 std::pair<bool, Node> et = val.entailmentCheck(
186 options::TheoryOfMode::THEORY_OF_TYPE_BASED,
187 j == 0 ? ret : ret.negate());
188 if (et.first)
189 {
190 ret = j == 0 ? d_true : d_false;
191 if (computeExp)
192 {
193 exp.push_back(et.second);
194 }
195 break;
196 }
197 }
198 }
199 }
200 }
201 }
202 }
203 // must have the term
204 if (reqHasTerm && !ret.isNull())
205 {
206 Kind k = ret.getKind();
207 if (k != OR && k != AND && k != EQUAL && k != ITE && k != NOT
208 && k != FORALL)
209 {
210 if (!d_qstate.hasTerm(ret))
211 {
212 ret = Node::null();
213 }
214 }
215 }
216 Trace("term-db-eval") << "evaluated term : " << n << ", got : " << ret
217 << ", reqHasTerm = " << reqHasTerm << std::endl;
218 // clear the explanation if failed
219 if (computeExp && ret.isNull())
220 {
221 exp.resize(prevSize);
222 }
223 visited[n] = ret;
224 return ret;
225 }
226
227 TNode EntailmentCheck::getEntailedTerm2(TNode n,
228 std::map<TNode, TNode>& subs,
229 bool subsRep,
230 bool hasSubs)
231 {
232 Trace("term-db-entail") << "get entailed term : " << n << std::endl;
233 if (d_qstate.hasTerm(n))
234 {
235 Trace("term-db-entail") << "...exists in ee, return rep " << std::endl;
236 return n;
237 }
238 else if (n.getKind() == BOUND_VARIABLE)
239 {
240 if (hasSubs)
241 {
242 std::map<TNode, TNode>::iterator it = subs.find(n);
243 if (it != subs.end())
244 {
245 Trace("term-db-entail")
246 << "...substitution is : " << it->second << std::endl;
247 if (subsRep)
248 {
249 Assert(d_qstate.hasTerm(it->second));
250 Assert(d_qstate.getRepresentative(it->second) == it->second);
251 return it->second;
252 }
253 return getEntailedTerm2(it->second, subs, subsRep, hasSubs);
254 }
255 }
256 }
257 else if (n.getKind() == ITE)
258 {
259 for (uint32_t i = 0; i < 2; i++)
260 {
261 if (isEntailed2(n[0], subs, subsRep, hasSubs, i == 0))
262 {
263 return getEntailedTerm2(n[i == 0 ? 1 : 2], subs, subsRep, hasSubs);
264 }
265 }
266 }
267 else
268 {
269 if (n.hasOperator())
270 {
271 TNode f = d_tdb.getMatchOperator(n);
272 if (!f.isNull())
273 {
274 std::vector<TNode> args;
275 for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
276 {
277 TNode c = getEntailedTerm2(n[i], subs, subsRep, hasSubs);
278 if (c.isNull())
279 {
280 return TNode::null();
281 }
282 c = d_qstate.getRepresentative(c);
283 Trace("term-db-entail") << " child " << i << " : " << c << std::endl;
284 args.push_back(c);
285 }
286 TNode nn = d_tdb.getCongruentTerm(f, args);
287 Trace("term-db-entail")
288 << " got congruent term " << nn << " for " << n << std::endl;
289 return nn;
290 }
291 }
292 }
293 return TNode::null();
294 }
295
296 Node EntailmentCheck::evaluateTerm(TNode n,
297 bool useEntailmentTests,
298 bool reqHasTerm)
299 {
300 std::map<TNode, Node> visited;
301 std::vector<Node> exp;
302 return evaluateTerm2(n, visited, exp, useEntailmentTests, false, reqHasTerm);
303 }
304
305 Node EntailmentCheck::evaluateTerm(TNode n,
306 std::vector<Node>& exp,
307 bool useEntailmentTests,
308 bool reqHasTerm)
309 {
310 std::map<TNode, Node> visited;
311 return evaluateTerm2(n, visited, exp, useEntailmentTests, true, reqHasTerm);
312 }
313
314 TNode EntailmentCheck::getEntailedTerm(TNode n,
315 std::map<TNode, TNode>& subs,
316 bool subsRep)
317 {
318 return getEntailedTerm2(n, subs, subsRep, true);
319 }
320
321 TNode EntailmentCheck::getEntailedTerm(TNode n)
322 {
323 std::map<TNode, TNode> subs;
324 return getEntailedTerm2(n, subs, false, false);
325 }
326
327 bool EntailmentCheck::isEntailed2(
328 TNode n, std::map<TNode, TNode>& subs, bool subsRep, bool hasSubs, bool pol)
329 {
330 Trace("term-db-entail") << "Check entailed : " << n << ", pol = " << pol
331 << std::endl;
332 Assert(n.getType().isBoolean());
333 if (n.getKind() == EQUAL && !n[0].getType().isBoolean())
334 {
335 TNode n1 = getEntailedTerm2(n[0], subs, subsRep, hasSubs);
336 if (!n1.isNull())
337 {
338 TNode n2 = getEntailedTerm2(n[1], subs, subsRep, hasSubs);
339 if (!n2.isNull())
340 {
341 if (n1 == n2)
342 {
343 return pol;
344 }
345 else
346 {
347 Assert(d_qstate.hasTerm(n1));
348 Assert(d_qstate.hasTerm(n2));
349 if (pol)
350 {
351 return d_qstate.areEqual(n1, n2);
352 }
353 else
354 {
355 return d_qstate.areDisequal(n1, n2);
356 }
357 }
358 }
359 }
360 }
361 else if (n.getKind() == NOT)
362 {
363 return isEntailed2(n[0], subs, subsRep, hasSubs, !pol);
364 }
365 else if (n.getKind() == OR || n.getKind() == AND)
366 {
367 bool simPol = (pol && n.getKind() == OR) || (!pol && n.getKind() == AND);
368 for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
369 {
370 if (isEntailed2(n[i], subs, subsRep, hasSubs, pol))
371 {
372 if (simPol)
373 {
374 return true;
375 }
376 }
377 else
378 {
379 if (!simPol)
380 {
381 return false;
382 }
383 }
384 }
385 return !simPol;
386 // Boolean equality here
387 }
388 else if (n.getKind() == EQUAL || n.getKind() == ITE)
389 {
390 for (size_t i = 0; i < 2; i++)
391 {
392 if (isEntailed2(n[0], subs, subsRep, hasSubs, i == 0))
393 {
394 size_t ch = (n.getKind() == EQUAL || i == 0) ? 1 : 2;
395 bool reqPol = (n.getKind() == ITE || i == 0) ? pol : !pol;
396 return isEntailed2(n[ch], subs, subsRep, hasSubs, reqPol);
397 }
398 }
399 }
400 else if (n.getKind() == APPLY_UF)
401 {
402 TNode n1 = getEntailedTerm2(n, subs, subsRep, hasSubs);
403 if (!n1.isNull())
404 {
405 Assert(d_qstate.hasTerm(n1));
406 if (n1 == d_true)
407 {
408 return pol;
409 }
410 else if (n1 == d_false)
411 {
412 return !pol;
413 }
414 else
415 {
416 return d_qstate.getRepresentative(n1) == (pol ? d_true : d_false);
417 }
418 }
419 }
420 else if (n.getKind() == FORALL && !pol)
421 {
422 return isEntailed2(n[1], subs, subsRep, hasSubs, pol);
423 }
424 return false;
425 }
426
427 bool EntailmentCheck::isEntailed(TNode n, bool pol)
428 {
429 std::map<TNode, TNode> subs;
430 return isEntailed2(n, subs, false, false, pol);
431 }
432
433 bool EntailmentCheck::isEntailed(TNode n,
434 std::map<TNode, TNode>& subs,
435 bool subsRep,
436 bool pol)
437 {
438 return isEntailed2(n, subs, subsRep, true, pol);
439 }
440
441 } // namespace quantifiers
442 } // namespace theory
443 } // namespace cvc5