SyGuS instantiation modes (#5228)
[cvc5.git] / src / theory / quantifiers / sygus_inst.cpp
1 /********************* */
2 /*! \file sygus_inst.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief SyGuS instantiator class.
13 **/
14
15 #include "theory/quantifiers/sygus_inst.h"
16
17 #include <sstream>
18 #include <unordered_set>
19
20 #include "expr/node_algorithm.h"
21 #include "theory/bv/theory_bv_utils.h"
22 #include "theory/datatypes/sygus_datatype_utils.h"
23 #include "theory/quantifiers/sygus/sygus_enumerator.h"
24 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
25 #include "theory/quantifiers/sygus/synth_engine.h"
26 #include "theory/quantifiers_engine.h"
27 #include "theory/theory_engine.h"
28
29 namespace CVC4 {
30 namespace theory {
31 namespace quantifiers {
32
33 namespace {
34
35 /**
36 * Collect maximal ground terms with type tn in node n.
37 *
38 * @param n: Node to traverse.
39 * @param tn: Collects only terms with type tn.
40 * @param terms: Collected terms.
41 * @param cache: Caches visited nodes.
42 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
43 */
44 void getMaxGroundTerms(TNode n,
45 TypeNode tn,
46 std::unordered_set<Node, NodeHashFunction>& terms,
47 std::unordered_set<TNode, TNodeHashFunction>& cache,
48 bool skip_quant = false)
49 {
50 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
51 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
52 {
53 return;
54 }
55
56 Trace("sygus-inst-term") << "Find maximal terms with type " << tn
57 << " in: " << n << std::endl;
58
59 Node cur;
60 std::vector<TNode> visit;
61
62 visit.push_back(n);
63 do
64 {
65 cur = visit.back();
66 visit.pop_back();
67
68 if (cache.find(cur) != cache.end())
69 {
70 continue;
71 }
72 cache.insert(cur);
73
74 if (expr::hasBoundVar(cur) || cur.getType() != tn)
75 {
76 if (!skip_quant || cur.getKind() != kind::FORALL)
77 {
78 visit.insert(visit.end(), cur.begin(), cur.end());
79 }
80 }
81 else
82 {
83 terms.insert(cur);
84 Trace("sygus-inst-term") << " found: " << cur << std::endl;
85 }
86 } while (!visit.empty());
87 }
88
89 /*
90 * Collect minimal ground terms with type tn in node n.
91 *
92 * @param n: Node to traverse.
93 * @param tn: Collects only terms with type tn.
94 * @param terms: Collected terms.
95 * @param cache: Caches visited nodes and flags indicating whether a minimal
96 * term was already found in a subterm.
97 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
98 */
99 void getMinGroundTerms(
100 TNode n,
101 TypeNode tn,
102 std::unordered_set<Node, NodeHashFunction>& terms,
103 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache,
104 bool skip_quant = false)
105 {
106 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
107 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
108 {
109 return;
110 }
111
112 Trace("sygus-inst-term") << "Find minimal terms with type " << tn
113 << " in: " << n << std::endl;
114
115 Node cur;
116 std::vector<TNode> visit;
117
118 visit.push_back(n);
119 do
120 {
121 cur = visit.back();
122 visit.pop_back();
123
124 auto it = cache.find(cur);
125 if (it == cache.end())
126 {
127 cache.emplace(cur, std::make_pair(false, false));
128 if (!skip_quant || cur.getKind() != kind::FORALL)
129 {
130 visit.push_back(cur);
131 visit.insert(visit.end(), cur.begin(), cur.end());
132 }
133 }
134 /* up-traversal */
135 else if (!it->second.first)
136 {
137 bool found_min_term = false;
138
139 /* Check if we found a minimal term in one of the children. */
140 for (const Node& c : cur)
141 {
142 found_min_term |= cache[c].second;
143 if (found_min_term) break;
144 }
145
146 /* If we haven't found a minimal term yet, add this term if it has the
147 * right type. */
148 if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term)
149 {
150 terms.insert(cur);
151 found_min_term = true;
152 Trace("sygus-inst-term") << " found: " << cur << std::endl;
153 }
154
155 it->second.first = true;
156 it->second.second = found_min_term;
157 }
158 } while (!visit.empty());
159 }
160
161 /*
162 * Add special values for a given type.
163 *
164 * @param tn: The type node.
165 * @param extra_cons: A map of TypeNode to constants, which are added in
166 * addition to the default grammar.
167 */
168 void addSpecialValues(
169 const TypeNode& tn,
170 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons)
171 {
172 if (tn.isBitVector())
173 {
174 uint32_t size = tn.getBitVectorSize();
175 extra_cons[tn].insert(bv::utils::mkOnes(size));
176 extra_cons[tn].insert(bv::utils::mkMinSigned(size));
177 extra_cons[tn].insert(bv::utils::mkMaxSigned(size));
178 }
179 }
180
181 } // namespace
182
183 SygusInst::SygusInst(QuantifiersEngine* qe)
184 : QuantifiersModule(qe),
185 d_ce_lemma_added(qe->getUserContext()),
186 d_global_terms(qe->getUserContext()),
187 d_notified_assertions(qe->getUserContext())
188 {
189 }
190
191 bool SygusInst::needsCheck(Theory::Effort e)
192 {
193 return e >= Theory::EFFORT_LAST_CALL;
194 }
195
196 QuantifiersModule::QEffort SygusInst::needsModel(Theory::Effort e)
197 {
198 return QEFFORT_STANDARD;
199 }
200
201 void SygusInst::reset_round(Theory::Effort e)
202 {
203 d_active_quant.clear();
204 d_inactive_quant.clear();
205
206 FirstOrderModel* model = d_quantEngine->getModel();
207 uint32_t nasserted = model->getNumAssertedQuantifiers();
208
209 for (uint32_t i = 0; i < nasserted; ++i)
210 {
211 Node q = model->getAssertedQuantifier(i);
212
213 if (model->isQuantifierActive(q))
214 {
215 d_active_quant.insert(q);
216 Node lit = getCeLiteral(q);
217
218 bool value;
219 if (d_quantEngine->getValuation().hasSatValue(lit, value))
220 {
221 if (!value)
222 {
223 if (!d_quantEngine->getValuation().isDecision(lit))
224 {
225 model->setQuantifierActive(q, false);
226 d_active_quant.erase(q);
227 d_inactive_quant.insert(q);
228 Trace("sygus-inst") << "Set inactive: " << q << std::endl;
229 }
230 }
231 }
232 }
233 }
234 }
235
236 void SygusInst::check(Theory::Effort e, QEffort quant_e)
237 {
238 Trace("sygus-inst") << "Check " << e << ", " << quant_e << std::endl;
239
240 if (quant_e != QEFFORT_STANDARD) return;
241
242 FirstOrderModel* model = d_quantEngine->getModel();
243 Instantiate* inst = d_quantEngine->getInstantiate();
244 TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
245 SygusExplain syexplain(db);
246 NodeManager* nm = NodeManager::currentNM();
247 options::SygusInstMode mode = options::sygusInstMode();
248
249 for (const Node& q : d_active_quant)
250 {
251 const std::vector<Node>& inst_constants = d_inst_constants.at(q);
252 const std::vector<Node>& dt_evals = d_var_eval.at(q);
253 Assert(inst_constants.size() == dt_evals.size());
254 Assert(inst_constants.size() == q[0].getNumChildren());
255
256 std::vector<Node> terms, eval_unfold_lemmas;
257 for (size_t i = 0, size = q[0].getNumChildren(); i < size; ++i)
258 {
259 Node dt_var = inst_constants[i];
260 Node dt_eval = dt_evals[i];
261 Node value = model->getValue(dt_var);
262 Node t = datatypes::utils::sygusToBuiltin(value);
263 terms.push_back(t);
264
265 std::vector<Node> exp;
266 syexplain.getExplanationForEquality(dt_var, value, exp);
267 Node lem;
268 if (exp.empty())
269 {
270 lem = dt_eval.eqNode(t);
271 }
272 else
273 {
274 lem = nm->mkNode(kind::IMPLIES,
275 exp.size() == 1 ? exp[0] : nm->mkNode(kind::AND, exp),
276 dt_eval.eqNode(t));
277 }
278 eval_unfold_lemmas.push_back(lem);
279 }
280
281 if (mode == options::SygusInstMode::PRIORITY_INST)
282 {
283 if (!inst->addInstantiation(q, terms))
284 {
285 sendEvalUnfoldLemmas(eval_unfold_lemmas);
286 }
287 }
288 else if (mode == options::SygusInstMode::PRIORITY_EVAL)
289 {
290 if (!sendEvalUnfoldLemmas(eval_unfold_lemmas))
291 {
292 inst->addInstantiation(q, terms);
293 }
294 }
295 else
296 {
297 Assert(mode == options::SygusInstMode::INTERLEAVE);
298 inst->addInstantiation(q, terms);
299 sendEvalUnfoldLemmas(eval_unfold_lemmas);
300 }
301 }
302 }
303
304 bool SygusInst::sendEvalUnfoldLemmas(const std::vector<Node>& lemmas)
305 {
306 bool added_lemma = false;
307 for (const Node& lem : lemmas)
308 {
309 Trace("sygus-inst") << "Evaluation unfolding: " << lem << std::endl;
310 added_lemma |= d_quantEngine->addLemma(lem);
311 }
312 return added_lemma;
313 }
314
315 bool SygusInst::checkCompleteFor(Node q)
316 {
317 return d_inactive_quant.find(q) != d_inactive_quant.end();
318 }
319
320 void SygusInst::registerQuantifier(Node q)
321 {
322 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
323
324 Trace("sygus-inst") << "Register " << q << std::endl;
325
326 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
327 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exclude_cons;
328 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
329 std::unordered_set<Node, NodeHashFunction> term_irrelevant;
330
331 /* Collect relevant local ground terms for each variable type. */
332 if (options::sygusInstScope() == options::SygusInstScope::IN
333 || options::sygusInstScope() == options::SygusInstScope::BOTH)
334 {
335 std::unordered_map<TypeNode,
336 std::unordered_set<Node, NodeHashFunction>,
337 TypeNodeHashFunction>
338 relevant_terms;
339 for (const Node& var : q[0])
340 {
341 TypeNode tn = var.getType();
342
343 /* Collect relevant ground terms for type tn. */
344 if (relevant_terms.find(tn) == relevant_terms.end())
345 {
346 std::unordered_set<Node, NodeHashFunction> terms;
347 std::unordered_set<TNode, TNodeHashFunction> cache_max;
348 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
349 cache_min;
350
351 getMinGroundTerms(q, tn, terms, cache_min);
352 getMaxGroundTerms(q, tn, terms, cache_max);
353 relevant_terms.emplace(tn, terms);
354 }
355
356 /* Add relevant ground terms to grammar. */
357 auto& terms = relevant_terms[tn];
358 for (const auto& t : terms)
359 {
360 TypeNode ttn = t.getType();
361 extra_cons[ttn].insert(t);
362 Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl;
363 }
364 }
365 }
366
367 /* Collect relevant global ground terms for each variable type. */
368 if (options::sygusInstScope() == options::SygusInstScope::OUT
369 || options::sygusInstScope() == options::SygusInstScope::BOTH)
370 {
371 for (const Node& var : q[0])
372 {
373 TypeNode tn = var.getType();
374
375 /* Collect relevant ground terms for type tn. */
376 if (d_global_terms.find(tn) == d_global_terms.end())
377 {
378 std::unordered_set<Node, NodeHashFunction> terms;
379 std::unordered_set<TNode, TNodeHashFunction> cache_max;
380 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
381 cache_min;
382
383 for (const Node& a : d_notified_assertions)
384 {
385 getMinGroundTerms(a, tn, terms, cache_min, true);
386 getMaxGroundTerms(a, tn, terms, cache_max, true);
387 }
388 d_global_terms.insert(tn, terms);
389 }
390
391 /* Add relevant ground terms to grammar. */
392 auto it = d_global_terms.find(tn);
393 if (it != d_global_terms.end())
394 {
395 for (const auto& t : (*it).second)
396 {
397 TypeNode ttn = t.getType();
398 extra_cons[ttn].insert(t);
399 Trace("sygus-inst")
400 << "Adding (global) extra cons: " << t << std::endl;
401 }
402 }
403 }
404 }
405
406 /* Construct grammar for each bound variable of 'q'. */
407 Trace("sygus-inst") << "Process variables of " << q << std::endl;
408 std::vector<TypeNode> types;
409 for (const Node& var : q[0])
410 {
411 addSpecialValues(var.getType(), extra_cons);
412 TypeNode tn = CegGrammarConstructor::mkSygusDefaultType(var.getType(),
413 Node(),
414 var.toString(),
415 extra_cons,
416 exclude_cons,
417 include_cons,
418 term_irrelevant);
419 types.push_back(tn);
420
421 Trace("sygus-inst") << "Construct (default) datatype for " << var
422 << std::endl
423 << tn << std::endl;
424 }
425
426 registerCeLemma(q, types);
427 }
428
429 /* Construct grammars for all bound variables of quantifier 'q'. Currently,
430 * we use the default grammar of the variable's type.
431 */
432 void SygusInst::preRegisterQuantifier(Node q)
433 {
434 Trace("sygus-inst") << "preRegister " << q << std::endl;
435 addCeLemma(q);
436 }
437
438 void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions)
439 {
440 for (const Node& a : assertions)
441 {
442 d_notified_assertions.insert(a);
443 }
444 }
445
446 /*****************************************************************************/
447 /* private methods */
448 /*****************************************************************************/
449
450 Node SygusInst::getCeLiteral(Node q)
451 {
452 auto it = d_ce_lits.find(q);
453 if (it != d_ce_lits.end())
454 {
455 return it->second;
456 }
457
458 NodeManager* nm = NodeManager::currentNM();
459 Node sk = nm->mkSkolem("CeLiteral", nm->booleanType());
460 Node lit = d_quantEngine->getValuation().ensureLiteral(sk);
461 d_ce_lits[q] = lit;
462 return lit;
463 }
464
465 void SygusInst::registerCeLemma(Node q, std::vector<TypeNode>& types)
466 {
467 Assert(q[0].getNumChildren() == types.size());
468 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
469 Assert(d_inst_constants.find(q) == d_inst_constants.end());
470 Assert(d_var_eval.find(q) == d_var_eval.end());
471
472 Trace("sygus-inst") << "Register CE Lemma for " << q << std::endl;
473
474 /* Generate counterexample lemma for 'q'. */
475 NodeManager* nm = NodeManager::currentNM();
476 TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
477
478 /* For each variable x_i of \forall x_i . P[x_i], create a fresh datatype
479 * instantiation constant ic_i with type types[i] and wrap each ic_i in
480 * DT_SYGUS_EVAL(ic_i), which will be used to instantiate x_i. */
481 std::vector<Node> evals;
482 std::vector<Node> inst_constants;
483 for (size_t i = 0, size = types.size(); i < size; ++i)
484 {
485 TypeNode tn = types[i];
486 TNode var = q[0][i];
487
488 /* Create the instantiation constant and set attribute accordingly. */
489 Node ic = nm->mkInstConstant(tn);
490 InstConstantAttribute ica;
491 ic.setAttribute(ica, q);
492 Trace("sygus-inst") << "Create " << ic << " for " << var << std::endl;
493
494 db->registerEnumerator(ic, ic, nullptr, ROLE_ENUM_MULTI_SOLUTION);
495
496 std::vector<Node> args = {ic};
497 Node svl = tn.getDType().getSygusVarList();
498 if (!svl.isNull())
499 {
500 args.insert(args.end(), svl.begin(), svl.end());
501 }
502 Node eval = nm->mkNode(kind::DT_SYGUS_EVAL, args);
503
504 inst_constants.push_back(ic);
505 evals.push_back(eval);
506 }
507
508 d_inst_constants.emplace(q, inst_constants);
509 d_var_eval.emplace(q, evals);
510
511 Node lit = getCeLiteral(q);
512 d_quantEngine->addRequirePhase(lit, true);
513
514 /* The decision strategy for quantified formula 'q' ensures that its
515 * counterexample literal is decided on first. It is user-context dependent.
516 */
517 Assert(d_dstrat.find(q) == d_dstrat.end());
518 DecisionStrategy* ds =
519 new DecisionStrategySingleton("CeLiteral",
520 lit,
521 d_quantEngine->getSatContext(),
522 d_quantEngine->getValuation());
523
524 d_dstrat[q].reset(ds);
525 d_quantEngine->getDecisionManager()->registerStrategy(
526 DecisionManager::STRAT_QUANT_CEGQI_FEASIBLE, ds);
527
528 /* Add counterexample lemma (lit => ~P[x_i/eval_i]) */
529 Node body =
530 q[1].substitute(q[0].begin(), q[0].end(), evals.begin(), evals.end());
531 Node lem = nm->mkNode(kind::OR, lit.negate(), body.negate());
532 lem = Rewriter::rewrite(lem);
533
534 d_ce_lemmas.emplace(std::make_pair(q, lem));
535 Trace("sygus-inst") << "Register CE Lemma: " << lem << std::endl;
536 }
537
538 void SygusInst::addCeLemma(Node q)
539 {
540 Assert(d_ce_lemmas.find(q) != d_ce_lemmas.end());
541
542 /* Already added in previous contexts. */
543 if (d_ce_lemma_added.find(q) != d_ce_lemma_added.end()) return;
544
545 Node lem = d_ce_lemmas[q];
546 d_quantEngine->addLemma(lem, false);
547 d_ce_lemma_added.insert(q);
548 Trace("sygus-inst") << "Add CE Lemma: " << lem << std::endl;
549 }
550
551 } // namespace quantifiers
552 } // namespace theory
553 } // namespace CVC4