Cleanup some includes (#5847)
[cvc5.git] / src / theory / quantifiers / sygus_inst.cpp
1 /********************* */
2 /*! \file sygus_inst.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Aina Niemetz, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief SyGuS instantiator class.
13 **/
14
15 #include "theory/quantifiers/sygus_inst.h"
16
17 #include <sstream>
18 #include <unordered_set>
19
20 #include "expr/node_algorithm.h"
21 #include "theory/bv/theory_bv_utils.h"
22 #include "theory/datatypes/sygus_datatype_utils.h"
23 #include "theory/quantifiers/sygus/sygus_enumerator.h"
24 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
25 #include "theory/quantifiers/sygus/synth_engine.h"
26 #include "theory/quantifiers_engine.h"
27
28 namespace CVC4 {
29 namespace theory {
30 namespace quantifiers {
31
32 namespace {
33
34 /**
35 * Collect maximal ground terms with type tn in node n.
36 *
37 * @param n: Node to traverse.
38 * @param tn: Collects only terms with type tn.
39 * @param terms: Collected terms.
40 * @param cache: Caches visited nodes.
41 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
42 */
43 void getMaxGroundTerms(TNode n,
44 TypeNode tn,
45 std::unordered_set<Node, NodeHashFunction>& terms,
46 std::unordered_set<TNode, TNodeHashFunction>& cache,
47 bool skip_quant = false)
48 {
49 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
50 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
51 {
52 return;
53 }
54
55 Trace("sygus-inst-term") << "Find maximal terms with type " << tn
56 << " in: " << n << std::endl;
57
58 Node cur;
59 std::vector<TNode> visit;
60
61 visit.push_back(n);
62 do
63 {
64 cur = visit.back();
65 visit.pop_back();
66
67 if (cache.find(cur) != cache.end())
68 {
69 continue;
70 }
71 cache.insert(cur);
72
73 if (expr::hasBoundVar(cur) || cur.getType() != tn)
74 {
75 if (!skip_quant || cur.getKind() != kind::FORALL)
76 {
77 visit.insert(visit.end(), cur.begin(), cur.end());
78 }
79 }
80 else
81 {
82 terms.insert(cur);
83 Trace("sygus-inst-term") << " found: " << cur << std::endl;
84 }
85 } while (!visit.empty());
86 }
87
88 /*
89 * Collect minimal ground terms with type tn in node n.
90 *
91 * @param n: Node to traverse.
92 * @param tn: Collects only terms with type tn.
93 * @param terms: Collected terms.
94 * @param cache: Caches visited nodes and flags indicating whether a minimal
95 * term was already found in a subterm.
96 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
97 */
98 void getMinGroundTerms(
99 TNode n,
100 TypeNode tn,
101 std::unordered_set<Node, NodeHashFunction>& terms,
102 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache,
103 bool skip_quant = false)
104 {
105 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
106 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
107 {
108 return;
109 }
110
111 Trace("sygus-inst-term") << "Find minimal terms with type " << tn
112 << " in: " << n << std::endl;
113
114 Node cur;
115 std::vector<TNode> visit;
116
117 visit.push_back(n);
118 do
119 {
120 cur = visit.back();
121 visit.pop_back();
122
123 auto it = cache.find(cur);
124 if (it == cache.end())
125 {
126 cache.emplace(cur, std::make_pair(false, false));
127 if (!skip_quant || cur.getKind() != kind::FORALL)
128 {
129 visit.push_back(cur);
130 visit.insert(visit.end(), cur.begin(), cur.end());
131 }
132 }
133 /* up-traversal */
134 else if (!it->second.first)
135 {
136 bool found_min_term = false;
137
138 /* Check if we found a minimal term in one of the children. */
139 for (const Node& c : cur)
140 {
141 found_min_term |= cache[c].second;
142 if (found_min_term) break;
143 }
144
145 /* If we haven't found a minimal term yet, add this term if it has the
146 * right type. */
147 if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term)
148 {
149 terms.insert(cur);
150 found_min_term = true;
151 Trace("sygus-inst-term") << " found: " << cur << std::endl;
152 }
153
154 it->second.first = true;
155 it->second.second = found_min_term;
156 }
157 } while (!visit.empty());
158 }
159
160 /*
161 * Add special values for a given type.
162 *
163 * @param tn: The type node.
164 * @param extra_cons: A map of TypeNode to constants, which are added in
165 * addition to the default grammar.
166 */
167 void addSpecialValues(
168 const TypeNode& tn,
169 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons)
170 {
171 if (tn.isBitVector())
172 {
173 uint32_t size = tn.getBitVectorSize();
174 extra_cons[tn].insert(bv::utils::mkOnes(size));
175 extra_cons[tn].insert(bv::utils::mkMinSigned(size));
176 extra_cons[tn].insert(bv::utils::mkMaxSigned(size));
177 }
178 }
179
180 } // namespace
181
182 SygusInst::SygusInst(QuantifiersEngine* qe,
183 QuantifiersState& qs,
184 QuantifiersInferenceManager& qim)
185 : QuantifiersModule(qs, qim, qe),
186 d_ce_lemma_added(qs.getUserContext()),
187 d_global_terms(qs.getUserContext()),
188 d_notified_assertions(qs.getUserContext())
189 {
190 }
191
192 bool SygusInst::needsCheck(Theory::Effort e)
193 {
194 return e >= Theory::EFFORT_LAST_CALL;
195 }
196
197 QuantifiersModule::QEffort SygusInst::needsModel(Theory::Effort e)
198 {
199 return QEFFORT_STANDARD;
200 }
201
202 void SygusInst::reset_round(Theory::Effort e)
203 {
204 d_active_quant.clear();
205 d_inactive_quant.clear();
206
207 FirstOrderModel* model = d_quantEngine->getModel();
208 uint32_t nasserted = model->getNumAssertedQuantifiers();
209
210 for (uint32_t i = 0; i < nasserted; ++i)
211 {
212 Node q = model->getAssertedQuantifier(i);
213
214 if (model->isQuantifierActive(q))
215 {
216 d_active_quant.insert(q);
217 Node lit = getCeLiteral(q);
218
219 bool value;
220 if (d_quantEngine->getValuation().hasSatValue(lit, value))
221 {
222 if (!value)
223 {
224 if (!d_quantEngine->getValuation().isDecision(lit))
225 {
226 model->setQuantifierActive(q, false);
227 d_active_quant.erase(q);
228 d_inactive_quant.insert(q);
229 Trace("sygus-inst") << "Set inactive: " << q << std::endl;
230 }
231 }
232 }
233 }
234 }
235 }
236
237 void SygusInst::check(Theory::Effort e, QEffort quant_e)
238 {
239 Trace("sygus-inst") << "Check " << e << ", " << quant_e << std::endl;
240
241 if (quant_e != QEFFORT_STANDARD) return;
242
243 FirstOrderModel* model = d_quantEngine->getModel();
244 Instantiate* inst = d_quantEngine->getInstantiate();
245 TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
246 SygusExplain syexplain(db);
247 NodeManager* nm = NodeManager::currentNM();
248 options::SygusInstMode mode = options::sygusInstMode();
249
250 for (const Node& q : d_active_quant)
251 {
252 const std::vector<Node>& inst_constants = d_inst_constants.at(q);
253 const std::vector<Node>& dt_evals = d_var_eval.at(q);
254 Assert(inst_constants.size() == dt_evals.size());
255 Assert(inst_constants.size() == q[0].getNumChildren());
256
257 std::vector<Node> terms, eval_unfold_lemmas;
258 for (size_t i = 0, size = q[0].getNumChildren(); i < size; ++i)
259 {
260 Node dt_var = inst_constants[i];
261 Node dt_eval = dt_evals[i];
262 Node value = model->getValue(dt_var);
263 Node t = datatypes::utils::sygusToBuiltin(value);
264 terms.push_back(t);
265
266 std::vector<Node> exp;
267 syexplain.getExplanationForEquality(dt_var, value, exp);
268 Node lem;
269 if (exp.empty())
270 {
271 lem = dt_eval.eqNode(t);
272 }
273 else
274 {
275 lem = nm->mkNode(kind::IMPLIES,
276 exp.size() == 1 ? exp[0] : nm->mkNode(kind::AND, exp),
277 dt_eval.eqNode(t));
278 }
279 eval_unfold_lemmas.push_back(lem);
280 }
281
282 if (mode == options::SygusInstMode::PRIORITY_INST)
283 {
284 if (!inst->addInstantiation(q, terms))
285 {
286 sendEvalUnfoldLemmas(eval_unfold_lemmas);
287 }
288 }
289 else if (mode == options::SygusInstMode::PRIORITY_EVAL)
290 {
291 if (!sendEvalUnfoldLemmas(eval_unfold_lemmas))
292 {
293 inst->addInstantiation(q, terms);
294 }
295 }
296 else
297 {
298 Assert(mode == options::SygusInstMode::INTERLEAVE);
299 inst->addInstantiation(q, terms);
300 sendEvalUnfoldLemmas(eval_unfold_lemmas);
301 }
302 }
303 }
304
305 bool SygusInst::sendEvalUnfoldLemmas(const std::vector<Node>& lemmas)
306 {
307 bool added_lemma = false;
308 for (const Node& lem : lemmas)
309 {
310 Trace("sygus-inst") << "Evaluation unfolding: " << lem << std::endl;
311 added_lemma |= d_quantEngine->addLemma(lem);
312 }
313 return added_lemma;
314 }
315
316 bool SygusInst::checkCompleteFor(Node q)
317 {
318 return d_inactive_quant.find(q) != d_inactive_quant.end();
319 }
320
321 void SygusInst::registerQuantifier(Node q)
322 {
323 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
324
325 Trace("sygus-inst") << "Register " << q << std::endl;
326
327 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
328 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exclude_cons;
329 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
330 std::unordered_set<Node, NodeHashFunction> term_irrelevant;
331
332 /* Collect relevant local ground terms for each variable type. */
333 if (options::sygusInstScope() == options::SygusInstScope::IN
334 || options::sygusInstScope() == options::SygusInstScope::BOTH)
335 {
336 std::unordered_map<TypeNode,
337 std::unordered_set<Node, NodeHashFunction>,
338 TypeNodeHashFunction>
339 relevant_terms;
340 for (const Node& var : q[0])
341 {
342 TypeNode tn = var.getType();
343
344 /* Collect relevant ground terms for type tn. */
345 if (relevant_terms.find(tn) == relevant_terms.end())
346 {
347 std::unordered_set<Node, NodeHashFunction> terms;
348 std::unordered_set<TNode, TNodeHashFunction> cache_max;
349 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
350 cache_min;
351
352 getMinGroundTerms(q, tn, terms, cache_min);
353 getMaxGroundTerms(q, tn, terms, cache_max);
354 relevant_terms.emplace(tn, terms);
355 }
356
357 /* Add relevant ground terms to grammar. */
358 auto& terms = relevant_terms[tn];
359 for (const auto& t : terms)
360 {
361 TypeNode ttn = t.getType();
362 extra_cons[ttn].insert(t);
363 Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl;
364 }
365 }
366 }
367
368 /* Collect relevant global ground terms for each variable type. */
369 if (options::sygusInstScope() == options::SygusInstScope::OUT
370 || options::sygusInstScope() == options::SygusInstScope::BOTH)
371 {
372 for (const Node& var : q[0])
373 {
374 TypeNode tn = var.getType();
375
376 /* Collect relevant ground terms for type tn. */
377 if (d_global_terms.find(tn) == d_global_terms.end())
378 {
379 std::unordered_set<Node, NodeHashFunction> terms;
380 std::unordered_set<TNode, TNodeHashFunction> cache_max;
381 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
382 cache_min;
383
384 for (const Node& a : d_notified_assertions)
385 {
386 getMinGroundTerms(a, tn, terms, cache_min, true);
387 getMaxGroundTerms(a, tn, terms, cache_max, true);
388 }
389 d_global_terms.insert(tn, terms);
390 }
391
392 /* Add relevant ground terms to grammar. */
393 auto it = d_global_terms.find(tn);
394 if (it != d_global_terms.end())
395 {
396 for (const auto& t : (*it).second)
397 {
398 TypeNode ttn = t.getType();
399 extra_cons[ttn].insert(t);
400 Trace("sygus-inst")
401 << "Adding (global) extra cons: " << t << std::endl;
402 }
403 }
404 }
405 }
406
407 /* Construct grammar for each bound variable of 'q'. */
408 Trace("sygus-inst") << "Process variables of " << q << std::endl;
409 std::vector<TypeNode> types;
410 for (const Node& var : q[0])
411 {
412 addSpecialValues(var.getType(), extra_cons);
413 TypeNode tn = CegGrammarConstructor::mkSygusDefaultType(var.getType(),
414 Node(),
415 var.toString(),
416 extra_cons,
417 exclude_cons,
418 include_cons,
419 term_irrelevant);
420 types.push_back(tn);
421
422 Trace("sygus-inst") << "Construct (default) datatype for " << var
423 << std::endl
424 << tn << std::endl;
425 }
426
427 registerCeLemma(q, types);
428 }
429
430 /* Construct grammars for all bound variables of quantifier 'q'. Currently,
431 * we use the default grammar of the variable's type.
432 */
433 void SygusInst::preRegisterQuantifier(Node q)
434 {
435 Trace("sygus-inst") << "preRegister " << q << std::endl;
436 addCeLemma(q);
437 }
438
439 void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions)
440 {
441 for (const Node& a : assertions)
442 {
443 d_notified_assertions.insert(a);
444 }
445 }
446
447 /*****************************************************************************/
448 /* private methods */
449 /*****************************************************************************/
450
451 Node SygusInst::getCeLiteral(Node q)
452 {
453 auto it = d_ce_lits.find(q);
454 if (it != d_ce_lits.end())
455 {
456 return it->second;
457 }
458
459 NodeManager* nm = NodeManager::currentNM();
460 Node sk = nm->mkSkolem("CeLiteral", nm->booleanType());
461 Node lit = d_quantEngine->getValuation().ensureLiteral(sk);
462 d_ce_lits[q] = lit;
463 return lit;
464 }
465
466 void SygusInst::registerCeLemma(Node q, std::vector<TypeNode>& types)
467 {
468 Assert(q[0].getNumChildren() == types.size());
469 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
470 Assert(d_inst_constants.find(q) == d_inst_constants.end());
471 Assert(d_var_eval.find(q) == d_var_eval.end());
472
473 Trace("sygus-inst") << "Register CE Lemma for " << q << std::endl;
474
475 /* Generate counterexample lemma for 'q'. */
476 NodeManager* nm = NodeManager::currentNM();
477 TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
478
479 /* For each variable x_i of \forall x_i . P[x_i], create a fresh datatype
480 * instantiation constant ic_i with type types[i] and wrap each ic_i in
481 * DT_SYGUS_EVAL(ic_i), which will be used to instantiate x_i. */
482 std::vector<Node> evals;
483 std::vector<Node> inst_constants;
484 for (size_t i = 0, size = types.size(); i < size; ++i)
485 {
486 TypeNode tn = types[i];
487 TNode var = q[0][i];
488
489 /* Create the instantiation constant and set attribute accordingly. */
490 Node ic = nm->mkInstConstant(tn);
491 InstConstantAttribute ica;
492 ic.setAttribute(ica, q);
493 Trace("sygus-inst") << "Create " << ic << " for " << var << std::endl;
494
495 db->registerEnumerator(ic, ic, nullptr, ROLE_ENUM_MULTI_SOLUTION);
496
497 std::vector<Node> args = {ic};
498 Node svl = tn.getDType().getSygusVarList();
499 if (!svl.isNull())
500 {
501 args.insert(args.end(), svl.begin(), svl.end());
502 }
503 Node eval = nm->mkNode(kind::DT_SYGUS_EVAL, args);
504
505 inst_constants.push_back(ic);
506 evals.push_back(eval);
507 }
508
509 d_inst_constants.emplace(q, inst_constants);
510 d_var_eval.emplace(q, evals);
511
512 Node lit = getCeLiteral(q);
513 d_quantEngine->addRequirePhase(lit, true);
514
515 /* The decision strategy for quantified formula 'q' ensures that its
516 * counterexample literal is decided on first. It is user-context dependent.
517 */
518 Assert(d_dstrat.find(q) == d_dstrat.end());
519 DecisionStrategy* ds =
520 new DecisionStrategySingleton("CeLiteral",
521 lit,
522 d_qstate.getSatContext(),
523 d_quantEngine->getValuation());
524
525 d_dstrat[q].reset(ds);
526 d_quantEngine->getDecisionManager()->registerStrategy(
527 DecisionManager::STRAT_QUANT_CEGQI_FEASIBLE, ds);
528
529 /* Add counterexample lemma (lit => ~P[x_i/eval_i]) */
530 Node body =
531 q[1].substitute(q[0].begin(), q[0].end(), evals.begin(), evals.end());
532 Node lem = nm->mkNode(kind::OR, lit.negate(), body.negate());
533 lem = Rewriter::rewrite(lem);
534
535 d_ce_lemmas.emplace(std::make_pair(q, lem));
536 Trace("sygus-inst") << "Register CE Lemma: " << lem << std::endl;
537 }
538
539 void SygusInst::addCeLemma(Node q)
540 {
541 Assert(d_ce_lemmas.find(q) != d_ce_lemmas.end());
542
543 /* Already added in previous contexts. */
544 if (d_ce_lemma_added.find(q) != d_ce_lemma_added.end()) return;
545
546 Node lem = d_ce_lemmas[q];
547 d_quantEngine->addLemma(lem, false);
548 d_ce_lemma_added.insert(q);
549 Trace("sygus-inst") << "Add CE Lemma: " << lem << std::endl;
550 }
551
552 } // namespace quantifiers
553 } // namespace theory
554 } // namespace CVC4