sygus-inst: Add more special BV values. (#5191)
[cvc5.git] / src / theory / quantifiers / sygus_inst.cpp
1 /********************* */
2 /*! \file sygus_inst.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief SyGuS instantiator class.
13 **/
14
15 #include "theory/quantifiers/sygus_inst.h"
16
17 #include <sstream>
18 #include <unordered_set>
19
20 #include "expr/node_algorithm.h"
21 #include "theory/bv/theory_bv_utils.h"
22 #include "theory/datatypes/sygus_datatype_utils.h"
23 #include "theory/quantifiers/sygus/sygus_enumerator.h"
24 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
25 #include "theory/quantifiers/sygus/synth_engine.h"
26 #include "theory/quantifiers_engine.h"
27 #include "theory/theory_engine.h"
28
29 namespace CVC4 {
30 namespace theory {
31 namespace quantifiers {
32
33 namespace {
34
35 /**
36 * Collect maximal ground terms with type tn in node n.
37 *
38 * @param n: Node to traverse.
39 * @param tn: Collects only terms with type tn.
40 * @param terms: Collected terms.
41 * @param cache: Caches visited nodes.
42 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
43 */
44 void getMaxGroundTerms(TNode n,
45 TypeNode tn,
46 std::unordered_set<Node, NodeHashFunction>& terms,
47 std::unordered_set<TNode, TNodeHashFunction>& cache,
48 bool skip_quant = false)
49 {
50 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
51 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
52 {
53 return;
54 }
55
56 Trace("sygus-inst-term") << "Find maximal terms with type " << tn
57 << " in: " << n << std::endl;
58
59 Node cur;
60 std::vector<TNode> visit;
61
62 visit.push_back(n);
63 do
64 {
65 cur = visit.back();
66 visit.pop_back();
67
68 if (cache.find(cur) != cache.end())
69 {
70 continue;
71 }
72 cache.insert(cur);
73
74 if (expr::hasBoundVar(cur) || cur.getType() != tn)
75 {
76 if (!skip_quant || cur.getKind() != kind::FORALL)
77 {
78 visit.insert(visit.end(), cur.begin(), cur.end());
79 }
80 }
81 else
82 {
83 terms.insert(cur);
84 Trace("sygus-inst-term") << " found: " << cur << std::endl;
85 }
86 } while (!visit.empty());
87 }
88
89 /*
90 * Collect minimal ground terms with type tn in node n.
91 *
92 * @param n: Node to traverse.
93 * @param tn: Collects only terms with type tn.
94 * @param terms: Collected terms.
95 * @param cache: Caches visited nodes and flags indicating whether a minimal
96 * term was already found in a subterm.
97 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
98 */
99 void getMinGroundTerms(
100 TNode n,
101 TypeNode tn,
102 std::unordered_set<Node, NodeHashFunction>& terms,
103 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>& cache,
104 bool skip_quant = false)
105 {
106 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
107 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH)
108 {
109 return;
110 }
111
112 Trace("sygus-inst-term") << "Find minimal terms with type " << tn
113 << " in: " << n << std::endl;
114
115 Node cur;
116 std::vector<TNode> visit;
117
118 visit.push_back(n);
119 do
120 {
121 cur = visit.back();
122 visit.pop_back();
123
124 auto it = cache.find(cur);
125 if (it == cache.end())
126 {
127 cache.emplace(cur, std::make_pair(false, false));
128 if (!skip_quant || cur.getKind() != kind::FORALL)
129 {
130 visit.push_back(cur);
131 visit.insert(visit.end(), cur.begin(), cur.end());
132 }
133 }
134 /* up-traversal */
135 else if (!it->second.first)
136 {
137 bool found_min_term = false;
138
139 /* Check if we found a minimal term in one of the children. */
140 for (const Node& c : cur)
141 {
142 found_min_term |= cache[c].second;
143 if (found_min_term) break;
144 }
145
146 /* If we haven't found a minimal term yet, add this term if it has the
147 * right type. */
148 if (cur.getType() == tn && !expr::hasBoundVar(cur) && !found_min_term)
149 {
150 terms.insert(cur);
151 found_min_term = true;
152 Trace("sygus-inst-term") << " found: " << cur << std::endl;
153 }
154
155 it->second.first = true;
156 it->second.second = found_min_term;
157 }
158 } while (!visit.empty());
159 }
160
161 /*
162 * Add special values for a given type.
163 *
164 * @param tn: The type node.
165 * @param extra_cons: A map of TypeNode to constants, which are added in
166 * addition to the default grammar.
167 */
168 void addSpecialValues(
169 const TypeNode& tn,
170 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons)
171 {
172 if (tn.isBitVector())
173 {
174 uint32_t size = tn.getBitVectorSize();
175 extra_cons[tn].insert(bv::utils::mkOnes(size));
176 extra_cons[tn].insert(bv::utils::mkMinSigned(size));
177 extra_cons[tn].insert(bv::utils::mkMaxSigned(size));
178 }
179 }
180
181 } // namespace
182
183 SygusInst::SygusInst(QuantifiersEngine* qe)
184 : QuantifiersModule(qe),
185 d_lemma_cache(qe->getUserContext()),
186 d_ce_lemma_added(qe->getUserContext()),
187 d_global_terms(qe->getUserContext()),
188 d_notified_assertions(qe->getUserContext())
189 {
190 }
191
192 bool SygusInst::needsCheck(Theory::Effort e)
193 {
194 return e >= Theory::EFFORT_LAST_CALL;
195 }
196
197 QuantifiersModule::QEffort SygusInst::needsModel(Theory::Effort e)
198 {
199 return QEFFORT_STANDARD;
200 }
201
202 void SygusInst::reset_round(Theory::Effort e)
203 {
204 d_active_quant.clear();
205 d_inactive_quant.clear();
206
207 FirstOrderModel* model = d_quantEngine->getModel();
208 uint32_t nasserted = model->getNumAssertedQuantifiers();
209
210 for (uint32_t i = 0; i < nasserted; ++i)
211 {
212 Node q = model->getAssertedQuantifier(i);
213
214 if (model->isQuantifierActive(q))
215 {
216 d_active_quant.insert(q);
217 Node lit = getCeLiteral(q);
218
219 bool value;
220 if (d_quantEngine->getValuation().hasSatValue(lit, value))
221 {
222 if (!value)
223 {
224 if (!d_quantEngine->getValuation().isDecision(lit))
225 {
226 model->setQuantifierActive(q, false);
227 d_active_quant.erase(q);
228 d_inactive_quant.insert(q);
229 Trace("sygus-inst") << "Set inactive: " << q << std::endl;
230 }
231 }
232 }
233 }
234 }
235 }
236
237 void SygusInst::check(Theory::Effort e, QEffort quant_e)
238 {
239 Trace("sygus-inst") << "Check " << e << ", " << quant_e << std::endl;
240
241 if (quant_e != QEFFORT_STANDARD) return;
242
243 FirstOrderModel* model = d_quantEngine->getModel();
244 Instantiate* inst = d_quantEngine->getInstantiate();
245 TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
246 SygusExplain syexplain(db);
247 NodeManager* nm = NodeManager::currentNM();
248
249 for (const Node& q : d_active_quant)
250 {
251 std::vector<Node> terms;
252 for (const Node& var : q[0])
253 {
254 Node dt_var = d_inst_constants[var];
255 Node dt_eval = d_var_eval[var];
256 Node value = model->getValue(dt_var);
257 Node t = datatypes::utils::sygusToBuiltin(value);
258 terms.push_back(t);
259
260 std::vector<Node> exp;
261 syexplain.getExplanationForEquality(dt_var, value, exp);
262 Node lem;
263 if (exp.empty())
264 {
265 lem = dt_eval.eqNode(t);
266 }
267 else
268 {
269 lem = nm->mkNode(kind::IMPLIES,
270 exp.size() == 1 ? exp[0] : nm->mkNode(kind::AND, exp),
271 dt_eval.eqNode(t));
272 }
273
274 if (d_lemma_cache.find(lem) == d_lemma_cache.end())
275 {
276 Trace("sygus-inst") << "Evaluation unfolding: " << lem << std::endl;
277 d_quantEngine->addLemma(lem, false);
278 d_lemma_cache.insert(lem);
279 }
280 }
281
282 if (inst->addInstantiation(q, terms))
283 {
284 Trace("sygus-inst") << "Instantiate " << q << std::endl;
285 }
286 }
287 }
288
289 bool SygusInst::checkCompleteFor(Node q)
290 {
291 return d_inactive_quant.find(q) != d_inactive_quant.end();
292 }
293
294 void SygusInst::registerQuantifier(Node q)
295 {
296 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
297
298 Trace("sygus-inst") << "Register " << q << std::endl;
299
300 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
301 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exclude_cons;
302 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
303 std::unordered_set<Node, NodeHashFunction> term_irrelevant;
304
305 /* Collect relevant local ground terms for each variable type. */
306 if (options::sygusInstScope() == options::SygusInstScope::IN
307 || options::sygusInstScope() == options::SygusInstScope::BOTH)
308 {
309 std::unordered_map<TypeNode,
310 std::unordered_set<Node, NodeHashFunction>,
311 TypeNodeHashFunction>
312 relevant_terms;
313 for (const Node& var : q[0])
314 {
315 TypeNode tn = var.getType();
316
317 /* Collect relevant ground terms for type tn. */
318 if (relevant_terms.find(tn) == relevant_terms.end())
319 {
320 std::unordered_set<Node, NodeHashFunction> terms;
321 std::unordered_set<TNode, TNodeHashFunction> cache_max;
322 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
323 cache_min;
324
325 getMinGroundTerms(q, tn, terms, cache_min);
326 getMaxGroundTerms(q, tn, terms, cache_max);
327 relevant_terms.emplace(tn, terms);
328 }
329
330 /* Add relevant ground terms to grammar. */
331 auto& terms = relevant_terms[tn];
332 for (const auto& t : terms)
333 {
334 TypeNode ttn = t.getType();
335 extra_cons[ttn].insert(t);
336 Trace("sygus-inst") << "Adding (local) extra cons: " << t << std::endl;
337 }
338 }
339 }
340
341 /* Collect relevant global ground terms for each variable type. */
342 if (options::sygusInstScope() == options::SygusInstScope::OUT
343 || options::sygusInstScope() == options::SygusInstScope::BOTH)
344 {
345 for (const Node& var : q[0])
346 {
347 TypeNode tn = var.getType();
348
349 /* Collect relevant ground terms for type tn. */
350 if (d_global_terms.find(tn) == d_global_terms.end())
351 {
352 std::unordered_set<Node, NodeHashFunction> terms;
353 std::unordered_set<TNode, TNodeHashFunction> cache_max;
354 std::unordered_map<TNode, std::pair<bool, bool>, TNodeHashFunction>
355 cache_min;
356
357 for (const Node& a : d_notified_assertions)
358 {
359 getMinGroundTerms(a, tn, terms, cache_min, true);
360 getMaxGroundTerms(a, tn, terms, cache_max, true);
361 }
362 d_global_terms.insert(tn, terms);
363 }
364
365 /* Add relevant ground terms to grammar. */
366 auto it = d_global_terms.find(tn);
367 if (it != d_global_terms.end())
368 {
369 for (const auto& t : (*it).second)
370 {
371 TypeNode ttn = t.getType();
372 extra_cons[ttn].insert(t);
373 Trace("sygus-inst")
374 << "Adding (global) extra cons: " << t << std::endl;
375 }
376 }
377 }
378 }
379
380 /* Construct grammar for each bound variable of 'q'. */
381 Trace("sygus-inst") << "Process variables of " << q << std::endl;
382 std::vector<TypeNode> types;
383 for (const Node& var : q[0])
384 {
385 addSpecialValues(var.getType(), extra_cons);
386 TypeNode tn = CegGrammarConstructor::mkSygusDefaultType(var.getType(),
387 Node(),
388 var.toString(),
389 extra_cons,
390 exclude_cons,
391 include_cons,
392 term_irrelevant);
393 types.push_back(tn);
394
395 Trace("sygus-inst") << "Construct (default) datatype for " << var
396 << std::endl
397 << tn << std::endl;
398 }
399
400 registerCeLemma(q, types);
401 }
402
403 /* Construct grammars for all bound variables of quantifier 'q'. Currently,
404 * we use the default grammar of the variable's type.
405 */
406 void SygusInst::preRegisterQuantifier(Node q)
407 {
408 Trace("sygus-inst") << "preRegister " << q << std::endl;
409 addCeLemma(q);
410 }
411
412 void SygusInst::ppNotifyAssertions(const std::vector<Node>& assertions)
413 {
414 for (const Node& a : assertions)
415 {
416 d_notified_assertions.insert(a);
417 }
418 }
419
420 /*****************************************************************************/
421 /* private methods */
422 /*****************************************************************************/
423
424 Node SygusInst::getCeLiteral(Node q)
425 {
426 auto it = d_ce_lits.find(q);
427 if (it != d_ce_lits.end())
428 {
429 return it->second;
430 }
431
432 NodeManager* nm = NodeManager::currentNM();
433 Node sk = nm->mkSkolem("CeLiteral", nm->booleanType());
434 Node lit = d_quantEngine->getValuation().ensureLiteral(sk);
435 d_ce_lits[q] = lit;
436 return lit;
437 }
438
439 void SygusInst::registerCeLemma(Node q, std::vector<TypeNode>& types)
440 {
441 Assert(q[0].getNumChildren() == types.size());
442 Assert(d_ce_lemmas.find(q) == d_ce_lemmas.end());
443
444 /* Generate counterexample lemma for 'q'. */
445 NodeManager* nm = NodeManager::currentNM();
446 TermDbSygus* db = d_quantEngine->getTermDatabaseSygus();
447
448 /* For each variable x_i of \forall x_i . P[x_i], create a fresh datatype
449 * instantiation constant ic_i with type types[i] and wrap each ic_i in
450 * DT_SYGUS_EVAL(ic_i), which will be used to instantiate x_i. */
451 std::vector<Node> vars;
452 std::vector<Node> evals;
453 for (size_t i = 0, size = types.size(); i < size; ++i)
454 {
455 TypeNode tn = types[i];
456 TNode var = q[0][i];
457
458 /* Create the instantiation constant and set attribute accordingly. */
459 Node ic = nm->mkInstConstant(tn);
460 InstConstantAttribute ica;
461 ic.setAttribute(ica, q);
462
463 db->registerEnumerator(ic, ic, nullptr, ROLE_ENUM_MULTI_SOLUTION);
464
465 std::vector<Node> args = {ic};
466 Node svl = tn.getDType().getSygusVarList();
467 if (!svl.isNull())
468 {
469 args.insert(args.end(), svl.begin(), svl.end());
470 }
471 Node eval = nm->mkNode(kind::DT_SYGUS_EVAL, args);
472
473 d_inst_constants.emplace(std::make_pair(var, ic));
474 d_var_eval.emplace(std::make_pair(var, eval));
475
476 vars.push_back(var);
477 evals.push_back(eval);
478 }
479
480 Node lit = getCeLiteral(q);
481 d_quantEngine->addRequirePhase(lit, true);
482
483 /* The decision strategy for quantified formula 'q' ensures that its
484 * counterexample literal is decided on first. It is user-context dependent.
485 */
486 Assert(d_dstrat.find(q) == d_dstrat.end());
487 DecisionStrategy* ds =
488 new DecisionStrategySingleton("CeLiteral",
489 lit,
490 d_quantEngine->getSatContext(),
491 d_quantEngine->getValuation());
492
493 d_dstrat[q].reset(ds);
494 d_quantEngine->getDecisionManager()->registerStrategy(
495 DecisionManager::STRAT_QUANT_CEGQI_FEASIBLE, ds);
496
497 /* Add counterexample lemma (lit => ~P[x_i/eval_i]) */
498 Node body =
499 q[1].substitute(vars.begin(), vars.end(), evals.begin(), evals.end());
500 Node lem = nm->mkNode(kind::OR, lit.negate(), body.negate());
501 lem = Rewriter::rewrite(lem);
502
503 d_ce_lemmas.emplace(std::make_pair(q, lem));
504 Trace("sygus-inst") << "Register CE Lemma: " << lem << std::endl;
505 }
506
507 void SygusInst::addCeLemma(Node q)
508 {
509 Assert(d_ce_lemmas.find(q) != d_ce_lemmas.end());
510
511 /* Already added in previous contexts. */
512 if (d_ce_lemma_added.find(q) != d_ce_lemma_added.end()) return;
513
514 Node lem = d_ce_lemmas[q];
515 d_quantEngine->addLemma(lem, false);
516 d_ce_lemma_added.insert(q);
517 Trace("sygus-inst") << "Add CE Lemma: " << lem << std::endl;
518 }
519
520 } // namespace quantifiers
521 } // namespace theory
522 } // namespace CVC4