Add missing inference ids (#6242)
[cvc5.git] / src / theory / quantifiers / sygus_inst.cpp
1 /********************* */
2 /*! \file sygus_inst.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Aina Niemetz, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief SyGuS instantiator class.
13 **/
14
15 #include "theory/quantifiers/sygus_inst.h"
16
17 #include <sstream>
18 #include <unordered_set>
19
20 #include "expr/node_algorithm.h"
21 #include "options/quantifiers_options.h"
22 #include "theory/bv/theory_bv_utils.h"
23 #include "theory/datatypes/sygus_datatype_utils.h"
24 #include "theory/quantifiers/first_order_model.h"
25 #include "theory/quantifiers/sygus/sygus_enumerator.h"
26 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
27 #include "theory/quantifiers/sygus/synth_engine.h"
28 #include "theory/quantifiers/term_util.h"
29 #include "theory/quantifiers_engine.h"
30 #include "theory/rewriter.h"
31
32 namespace CVC4 {
33 namespace theory {
34 namespace quantifiers {
35
36 namespace {
37
38 /**
39 * Collect maximal ground terms with type tn in node n.
40 *
41 * @param n: Node to traverse.
42 * @param tn: Collects only terms with type tn.
43 * @param terms: Collected terms.
44 * @param cache: Caches visited nodes.
45 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
46 */
47 void getMaxGroundTerms(TNode n,
48 TypeNode tn,
49 std::unordered_set<Node, NodeHashFunction>& terms,
50 std::unordered_set<TNode, TNodeHashFunction>& cache,
51 bool skip_quant = false)
52 {
53 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
54 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
55 {
56 return;
57 }
58
59 Trace("sygus-inst-term") << "Find maximal terms with type " << tn
60 << " in: " << n << std::endl;
61
62 Node cur;
63 std::vector<TNode> visit;
64
65 visit.push_back(n);
66 do
67 {
68 cur = visit.back();
69 visit.pop_back();
70
71 if (cache.find(cur) != cache.end())
72 {
73 continue;
74 }
75 cache.insert(cur);
76
77 if (expr::hasBoundVar(cur) || cur.getType() != tn)
78 {
79 if (!skip_quant || cur.getKind() != kind::FORALL)
80 {
81 visit.insert(visit.end(), cur.begin(), cur.end());
82 }
83 }
84 else
85 {
86 terms.insert(cur);
87 Trace("sygus-inst-term") << " found: " << cur << std::endl;
88 }
89 } while (!visit.empty());
90 }
91
92 /*
93 * Collect minimal ground terms with type tn in node n.
94 *
95 * @param n: Node to traverse.
96 * @param tn: Collects only terms with type tn.
97 * @param terms: Collected terms.
98 * @param cache: Caches visited nodes and flags indicating whether a minimal
99 * term was already found in a subterm.
100 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
101 */
102 void getMinGroundTerms(
103 TNode n,
104 TypeNode tn,
105 std::unordered_set<Node, NodeHashFunction>& terms,
106 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache,
107 bool skip_quant = false)
108 {
109 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
110 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
111 {
112 return;
113 }
114
115 Trace("sygus-inst-term") << "Find minimal terms with type " << tn
116 << " in: " << n << std::endl;
117
118 Node cur;
119 std::vector<TNode> visit;
120
121 visit.push_back(n);
122 do
123 {
124 cur = visit.back();
125 visit.pop_back();
126
127 auto it = cache.find(cur);
128 if (it == cache.end())
129 {
130 cache.emplace(cur, std::make_pair(false, false));
131 if (!skip_quant || cur.getKind() != kind::FORALL)
132 {
133 visit.push_back(cur);
134 visit.insert(visit.end(), cur.begin(), cur.end());
135 }
136 }
137 /* up-traversal */
138 else if (!it->second.first)
139 {
140 bool found_min_term = false;
141
142 /* Check if we found a minimal term in one of the children. */
143 for (const Node& c : cur)
144 {
145 found_min_term |= cache[c].second;
146 if (found_min_term) break;
147 }
148
149 /* If we haven't found a minimal term yet, add this term if it has the
150 * right type. */
151 if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term)
152 {
153 terms.insert(cur);
154 found_min_term = true;
155 Trace("sygus-inst-term") << " found: " << cur << std::endl;
156 }
157
158 it->second.first = true;
159 it->second.second = found_min_term;
160 }
161 } while (!visit.empty());
162 }
163
164 /*
165 * Add special values for a given type.
166 *
167 * @param tn: The type node.
168 * @param extra_cons: A map of TypeNode to constants, which are added in
169 * addition to the default grammar.
170 */
171 void addSpecialValues(
172 const TypeNode& tn,
173 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons)
174 {
175 if (tn.isBitVector())
176 {
177 uint32_t size = tn.getBitVectorSize();
178 extra_cons[tn].insert(bv::utils::mkOnes(size));
179 extra_cons[tn].insert(bv::utils::mkMinSigned(size));
180 extra_cons[tn].insert(bv::utils::mkMaxSigned(size));
181 }
182 }
183
184 } // namespace
185
186 SygusInst::SygusInst(QuantifiersEngine* qe,
187 QuantifiersState& qs,
188 QuantifiersInferenceManager& qim,
189 QuantifiersRegistry& qr,
190 TermRegistry& tr)
191 : QuantifiersModule(qs, qim, qr, tr, qe),
192 d_ce_lemma_added(qs.getUserContext()),
193 d_global_terms(qs.getUserContext()),
194 d_notified_assertions(qs.getUserContext())
195 {
196 }
197
198 bool SygusInst::needsCheck(Theory::Effort e)
199 {
200 return e >= Theory::EFFORT_LAST_CALL;
201 }
202
203 QuantifiersModule::QEffort SygusInst::needsModel(Theory::Effort e)
204 {
205 return QEFFORT_STANDARD;
206 }
207
208 void SygusInst::reset_round(Theory::Effort e)
209 {
210 d_active_quant.clear();
211 d_inactive_quant.clear();
212
213 FirstOrderModel* model = d_treg.getModel();
214 uint32_t nasserted = model->getNumAssertedQuantifiers();
215
216 for (uint32_t i = 0; i < nasserted; ++i)
217 {
218 Node q = model->getAssertedQuantifier(i);
219
220 if (model->isQuantifierActive(q))
221 {
222 d_active_quant.insert(q);
223 Node lit = getCeLiteral(q);
224
225 bool value;
226 if (d_qstate.getValuation().hasSatValue(lit, value))
227 {
228 if (!value)
229 {
230 if (!d_qstate.getValuation().isDecision(lit))
231 {
232 model->setQuantifierActive(q, false);
233 d_active_quant.erase(q);
234 d_inactive_quant.insert(q);
235 Trace("sygus-inst") << "Set inactive: " << q << std::endl;
236 }
237 }
238 }
239 }
240 }
241 }
242
243 void SygusInst::check(Theory::Effort e, QEffort quant_e)
244 {
245 Trace("sygus-inst") << "Check " << e << ", " << quant_e << std::endl;
246
247 if (quant_e != QEFFORT_STANDARD) return;
248
249 FirstOrderModel* model = d_treg.getModel();
250 Instantiate* inst = d_qim.getInstantiate();
251 TermDbSygus* db = d_treg.getTermDatabaseSygus();
252 SygusExplain syexplain(db);
253 NodeManager* nm = NodeManager::currentNM();
254 options::SygusInstMode mode = options::sygusInstMode();
255
256 for (const Node& q : d_active_quant)
257 {
258 const std::vector<Node>& inst_constants = d_inst_constants.at(q);
259 const std::vector<Node>& dt_evals = d_var_eval.at(q);
260 Assert(inst_constants.size() == dt_evals.size());
261 Assert(inst_constants.size() == q[0].getNumChildren());
262
263 std::vector<Node> terms, eval_unfold_lemmas;
264 for (size_t i = 0, size = q[0].getNumChildren(); i < size; ++i)
265 {
266 Node dt_var = inst_constants[i];
267 Node dt_eval = dt_evals[i];
268 Node value = model->getValue(dt_var);
269 Node t = datatypes::utils::sygusToBuiltin(value);
270 terms.push_back(t);
271
272 std::vector<Node> exp;
273 syexplain.getExplanationForEquality(dt_var, value, exp);
274 Node lem;
275 if (exp.empty())
276 {
277 lem = dt_eval.eqNode(t);
278 }
279 else
280 {
281 lem = nm->mkNode(kind::IMPLIES,
282 exp.size() == 1 ? exp[0] : nm->mkNode(kind::AND, exp),
283 dt_eval.eqNode(t));
284 }
285 eval_unfold_lemmas.push_back(lem);
286 }
287
288 if (mode == options::SygusInstMode::PRIORITY_INST)
289 {
290 if (!inst->addInstantiation(q, terms, InferenceId::QUANTIFIERS_INST_SYQI))
291 {
292 sendEvalUnfoldLemmas(eval_unfold_lemmas);
293 }
294 }
295 else if (mode == options::SygusInstMode::PRIORITY_EVAL)
296 {
297 if (!sendEvalUnfoldLemmas(eval_unfold_lemmas))
298 {
299 inst->addInstantiation(q, terms, InferenceId::QUANTIFIERS_INST_SYQI);
300 }
301 }
302 else
303 {
304 Assert(mode == options::SygusInstMode::INTERLEAVE);
305 inst->addInstantiation(q, terms, InferenceId::QUANTIFIERS_INST_SYQI);
306 sendEvalUnfoldLemmas(eval_unfold_lemmas);
307 }
308 }
309 }
310
311 bool SygusInst::sendEvalUnfoldLemmas(const std::vector<Node>& lemmas)
312 {
313 bool added_lemma = false;
314 for (const Node& lem : lemmas)
315 {
316 Trace("sygus-inst") << "Evaluation unfolding: " << lem << std::endl;
317 added_lemma |=
318 d_qim.addPendingLemma(lem, InferenceId::QUANTIFIERS_SYQI_EVAL_UNFOLD);
319 }
320 return added_lemma;
321 }
322
323 bool SygusInst::checkCompleteFor(Node q)
324 {
325 return d_inactive_quant.find(q) != d_inactive_quant.end();
326 }
327
328 void SygusInst::registerQuantifier(Node q)
329 {
330 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
331
332 Trace("sygus-inst") << "Register " << q << std::endl;
333
334 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
335 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exclude_cons;
336 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
337 std::unordered_set<Node, NodeHashFunction> term_irrelevant;
338
339 /* Collect relevant local ground terms for each variable type. */
340 if (options::sygusInstScope() == options::SygusInstScope::IN
341 || options::sygusInstScope() == options::SygusInstScope::BOTH)
342 {
343 std::unordered_map<TypeNode,
344 std::unordered_set<Node, NodeHashFunction>,
345 TypeNodeHashFunction>
346 relevant_terms;
347 for (const Node& var : q[0])
348 {
349 TypeNode tn = var.getType();
350
351 /* Collect relevant ground terms for type tn. */
352 if (relevant_terms.find(tn) == relevant_terms.end())
353 {
354 std::unordered_set<Node, NodeHashFunction> terms;
355 std::unordered_set<TNode, TNodeHashFunction> cache_max;
356 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
357 cache_min;
358
359 getMinGroundTerms(q, tn, terms, cache_min);
360 getMaxGroundTerms(q, tn, terms, cache_max);
361 relevant_terms.emplace(tn, terms);
362 }
363
364 /* Add relevant ground terms to grammar. */
365 auto& terms = relevant_terms[tn];
366 for (const auto& t : terms)
367 {
368 TypeNode ttn = t.getType();
369 extra_cons[ttn].insert(t);
370 Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl;
371 }
372 }
373 }
374
375 /* Collect relevant global ground terms for each variable type. */
376 if (options::sygusInstScope() == options::SygusInstScope::OUT
377 || options::sygusInstScope() == options::SygusInstScope::BOTH)
378 {
379 for (const Node& var : q[0])
380 {
381 TypeNode tn = var.getType();
382
383 /* Collect relevant ground terms for type tn. */
384 if (d_global_terms.find(tn) == d_global_terms.end())
385 {
386 std::unordered_set<Node, NodeHashFunction> terms;
387 std::unordered_set<TNode, TNodeHashFunction> cache_max;
388 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
389 cache_min;
390
391 for (const Node& a : d_notified_assertions)
392 {
393 getMinGroundTerms(a, tn, terms, cache_min, true);
394 getMaxGroundTerms(a, tn, terms, cache_max, true);
395 }
396 d_global_terms.insert(tn, terms);
397 }
398
399 /* Add relevant ground terms to grammar. */
400 auto it = d_global_terms.find(tn);
401 if (it != d_global_terms.end())
402 {
403 for (const auto& t : (*it).second)
404 {
405 TypeNode ttn = t.getType();
406 extra_cons[ttn].insert(t);
407 Trace("sygus-inst")
408 << "Adding (global) extra cons: " << t << std::endl;
409 }
410 }
411 }
412 }
413
414 /* Construct grammar for each bound variable of 'q'. */
415 Trace("sygus-inst") << "Process variables of " << q << std::endl;
416 std::vector<TypeNode> types;
417 for (const Node& var : q[0])
418 {
419 addSpecialValues(var.getType(), extra_cons);
420 TypeNode tn = CegGrammarConstructor::mkSygusDefaultType(var.getType(),
421 Node(),
422 var.toString(),
423 extra_cons,
424 exclude_cons,
425 include_cons,
426 term_irrelevant);
427 types.push_back(tn);
428
429 Trace("sygus-inst") << "Construct (default) datatype for " << var
430 << std::endl
431 << tn << std::endl;
432 }
433
434 registerCeLemma(q, types);
435 }
436
437 /* Construct grammars for all bound variables of quantifier 'q'. Currently,
438 * we use the default grammar of the variable's type.
439 */
440 void SygusInst::preRegisterQuantifier(Node q)
441 {
442 Trace("sygus-inst") << "preRegister " << q << std::endl;
443 addCeLemma(q);
444 }
445
446 void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions)
447 {
448 for (const Node& a : assertions)
449 {
450 d_notified_assertions.insert(a);
451 }
452 }
453
454 /*****************************************************************************/
455 /* private methods */
456 /*****************************************************************************/
457
458 Node SygusInst::getCeLiteral(Node q)
459 {
460 auto it = d_ce_lits.find(q);
461 if (it != d_ce_lits.end())
462 {
463 return it->second;
464 }
465
466 NodeManager* nm = NodeManager::currentNM();
467 Node sk = nm->mkSkolem("CeLiteral", nm->booleanType());
468 Node lit = d_qstate.getValuation().ensureLiteral(sk);
469 d_ce_lits[q] = lit;
470 return lit;
471 }
472
473 void SygusInst::registerCeLemma(Node q, std::vector<TypeNode>& types)
474 {
475 Assert(q[0].getNumChildren() == types.size());
476 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
477 Assert(d_inst_constants.find(q) == d_inst_constants.end());
478 Assert(d_var_eval.find(q) == d_var_eval.end());
479
480 Trace("sygus-inst") << "Register CE Lemma for " << q << std::endl;
481
482 /* Generate counterexample lemma for 'q'. */
483 NodeManager* nm = NodeManager::currentNM();
484 TermDbSygus* db = d_treg.getTermDatabaseSygus();
485
486 /* For each variable x_i of \forall x_i . P[x_i], create a fresh datatype
487 * instantiation constant ic_i with type types[i] and wrap each ic_i in
488 * DT_SYGUS_EVAL(ic_i), which will be used to instantiate x_i. */
489 std::vector<Node> evals;
490 std::vector<Node> inst_constants;
491 for (size_t i = 0, size = types.size(); i < size; ++i)
492 {
493 TypeNode tn = types[i];
494 TNode var = q[0][i];
495
496 /* Create the instantiation constant and set attribute accordingly. */
497 Node ic = nm->mkInstConstant(tn);
498 InstConstantAttribute ica;
499 ic.setAttribute(ica, q);
500 Trace("sygus-inst") << "Create " << ic << " for " << var << std::endl;
501
502 db->registerEnumerator(ic, ic, nullptr, ROLE_ENUM_MULTI_SOLUTION);
503
504 std::vector<Node> args = {ic};
505 Node svl = tn.getDType().getSygusVarList();
506 if (!svl.isNull())
507 {
508 args.insert(args.end(), svl.begin(), svl.end());
509 }
510 Node eval = nm->mkNode(kind::DT_SYGUS_EVAL, args);
511
512 inst_constants.push_back(ic);
513 evals.push_back(eval);
514 }
515
516 d_inst_constants.emplace(q, inst_constants);
517 d_var_eval.emplace(q, evals);
518
519 Node lit = getCeLiteral(q);
520 d_qim.addPendingPhaseRequirement(lit, true);
521
522 /* The decision strategy for quantified formula 'q' ensures that its
523 * counterexample literal is decided on first. It is user-context dependent.
524 */
525 Assert(d_dstrat.find(q) == d_dstrat.end());
526 DecisionStrategy* ds = new DecisionStrategySingleton(
527 "CeLiteral", lit, d_qstate.getSatContext(), d_qstate.getValuation());
528
529 d_dstrat[q].reset(ds);
530 d_qim.getDecisionManager()->registerStrategy(
531 DecisionManager::STRAT_QUANT_CEGQI_FEASIBLE, ds);
532
533 /* Add counterexample lemma (lit => ~P[x_i/eval_i]) */
534 Node body =
535 q[1].substitute(q[0].begin(), q[0].end(), evals.begin(), evals.end());
536 Node lem = nm->mkNode(kind::OR, lit.negate(), body.negate());
537 lem = Rewriter::rewrite(lem);
538
539 d_ce_lemmas.emplace(std::make_pair(q, lem));
540 Trace("sygus-inst") << "Register CE Lemma: " << lem << std::endl;
541 }
542
543 void SygusInst::addCeLemma(Node q)
544 {
545 Assert(d_ce_lemmas.find(q) != d_ce_lemmas.end());
546
547 /* Already added in previous contexts. */
548 if (d_ce_lemma_added.find(q) != d_ce_lemma_added.end()) return;
549
550 Node lem = d_ce_lemmas[q];
551 d_qim.addPendingLemma(lem, InferenceId::QUANTIFIERS_SYQI_CEX);
552 d_ce_lemma_added.insert(q);
553 Trace("sygus-inst") << "Add CE Lemma: " << lem << std::endl;
554 }
555
556 } // namespace quantifiers
557 } // namespace theory
558 } // namespace CVC4