1 /********************* */
2 /*! \file sygus_inst.cpp
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Aina Niemetz, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief SyGuS instantiator class.
15 #include "theory/quantifiers/sygus_inst.h"
18 #include <unordered_set>
20 #include "expr/node_algorithm.h"
21 #include "options/quantifiers_options.h"
22 #include "theory/bv/theory_bv_utils.h"
23 #include "theory/datatypes/sygus_datatype_utils.h"
24 #include "theory/quantifiers/first_order_model.h"
25 #include "theory/quantifiers/sygus/sygus_enumerator.h"
26 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
27 #include "theory/quantifiers/sygus/synth_engine.h"
28 #include "theory/quantifiers/term_util.h"
29 #include "theory/quantifiers_engine.h"
30 #include "theory/rewriter.h"
34 namespace quantifiers
{
39 * Collect maximal ground terms with type tn in node n.
41 * @param n: Node to traverse.
42 * @param tn: Collects only terms with type tn.
43 * @param terms: Collected terms.
44 * @param cache: Caches visited nodes.
45 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
47 void getMaxGroundTerms(TNode n
,
49 std::unordered_set
<Node
, NodeHashFunction
>& terms
,
50 std::unordered_set
<TNode
, TNodeHashFunction
>& cache
,
51 bool skip_quant
= false)
53 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
54 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH
)
59 Trace("sygus-inst-term") << "Find maximal terms with type " << tn
60 << " in: " << n
<< std::endl
;
63 std::vector
<TNode
> visit
;
71 if (cache
.find(cur
) != cache
.end())
77 if (expr::hasBoundVar(cur
) || cur
.getType() != tn
)
79 if (!skip_quant
|| cur
.getKind() != kind::FORALL
)
81 visit
.insert(visit
.end(), cur
.begin(), cur
.end());
87 Trace("sygus-inst-term") << " found: " << cur
<< std::endl
;
89 } while (!visit
.empty());
93 * Collect minimal ground terms with type tn in node n.
95 * @param n: Node to traverse.
96 * @param tn: Collects only terms with type tn.
97 * @param terms: Collected terms.
98 * @param cache: Caches visited nodes and flags indicating whether a minimal
99 * term was already found in a subterm.
100 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
102 void getMinGroundTerms(
105 std::unordered_set
<Node
, NodeHashFunction
>& terms
,
106 std::unordered_map
<TNode
, std::pair
<bool, bool>, TNodeHashFunction
>& cache
,
107 bool skip_quant
= false)
109 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
110 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH
)
115 Trace("sygus-inst-term") << "Find minimal terms with type " << tn
116 << " in: " << n
<< std::endl
;
119 std::vector
<TNode
> visit
;
127 auto it
= cache
.find(cur
);
128 if (it
== cache
.end())
130 cache
.emplace(cur
, std::make_pair(false, false));
131 if (!skip_quant
|| cur
.getKind() != kind::FORALL
)
133 visit
.push_back(cur
);
134 visit
.insert(visit
.end(), cur
.begin(), cur
.end());
138 else if (!it
->second
.first
)
140 bool found_min_term
= false;
142 /* Check if we found a minimal term in one of the children. */
143 for (const Node
& c
: cur
)
145 found_min_term
|= cache
[c
].second
;
146 if (found_min_term
) break;
149 /* If we haven't found a minimal term yet, add this term if it has the
151 if (cur
.getType() == tn
&& !expr::hasBoundVar(cur
) && !found_min_term
)
154 found_min_term
= true;
155 Trace("sygus-inst-term") << " found: " << cur
<< std::endl
;
158 it
->second
.first
= true;
159 it
->second
.second
= found_min_term
;
161 } while (!visit
.empty());
165 * Add special values for a given type.
167 * @param tn: The type node.
168 * @param extra_cons: A map of TypeNode to constants, which are added in
169 * addition to the default grammar.
171 void addSpecialValues(
173 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>>& extra_cons
)
175 if (tn
.isBitVector())
177 uint32_t size
= tn
.getBitVectorSize();
178 extra_cons
[tn
].insert(bv::utils::mkOnes(size
));
179 extra_cons
[tn
].insert(bv::utils::mkMinSigned(size
));
180 extra_cons
[tn
].insert(bv::utils::mkMaxSigned(size
));
186 SygusInst::SygusInst(QuantifiersEngine
* qe
,
187 QuantifiersState
& qs
,
188 QuantifiersInferenceManager
& qim
,
189 QuantifiersRegistry
& qr
,
191 : QuantifiersModule(qs
, qim
, qr
, tr
, qe
),
192 d_ce_lemma_added(qs
.getUserContext()),
193 d_global_terms(qs
.getUserContext()),
194 d_notified_assertions(qs
.getUserContext())
198 bool SygusInst::needsCheck(Theory::Effort e
)
200 return e
>= Theory::EFFORT_LAST_CALL
;
203 QuantifiersModule::QEffort
SygusInst::needsModel(Theory::Effort e
)
205 return QEFFORT_STANDARD
;
208 void SygusInst::reset_round(Theory::Effort e
)
210 d_active_quant
.clear();
211 d_inactive_quant
.clear();
213 FirstOrderModel
* model
= d_treg
.getModel();
214 uint32_t nasserted
= model
->getNumAssertedQuantifiers();
216 for (uint32_t i
= 0; i
< nasserted
; ++i
)
218 Node q
= model
->getAssertedQuantifier(i
);
220 if (model
->isQuantifierActive(q
))
222 d_active_quant
.insert(q
);
223 Node lit
= getCeLiteral(q
);
226 if (d_qstate
.getValuation().hasSatValue(lit
, value
))
230 if (!d_qstate
.getValuation().isDecision(lit
))
232 model
->setQuantifierActive(q
, false);
233 d_active_quant
.erase(q
);
234 d_inactive_quant
.insert(q
);
235 Trace("sygus-inst") << "Set inactive: " << q
<< std::endl
;
243 void SygusInst::check(Theory::Effort e
, QEffort quant_e
)
245 Trace("sygus-inst") << "Check " << e
<< ", " << quant_e
<< std::endl
;
247 if (quant_e
!= QEFFORT_STANDARD
) return;
249 FirstOrderModel
* model
= d_treg
.getModel();
250 Instantiate
* inst
= d_qim
.getInstantiate();
251 TermDbSygus
* db
= d_treg
.getTermDatabaseSygus();
252 SygusExplain
syexplain(db
);
253 NodeManager
* nm
= NodeManager::currentNM();
254 options::SygusInstMode mode
= options::sygusInstMode();
256 for (const Node
& q
: d_active_quant
)
258 const std::vector
<Node
>& inst_constants
= d_inst_constants
.at(q
);
259 const std::vector
<Node
>& dt_evals
= d_var_eval
.at(q
);
260 Assert(inst_constants
.size() == dt_evals
.size());
261 Assert(inst_constants
.size() == q
[0].getNumChildren());
263 std::vector
<Node
> terms
, eval_unfold_lemmas
;
264 for (size_t i
= 0, size
= q
[0].getNumChildren(); i
< size
; ++i
)
266 Node dt_var
= inst_constants
[i
];
267 Node dt_eval
= dt_evals
[i
];
268 Node value
= model
->getValue(dt_var
);
269 Node t
= datatypes::utils::sygusToBuiltin(value
);
272 std::vector
<Node
> exp
;
273 syexplain
.getExplanationForEquality(dt_var
, value
, exp
);
277 lem
= dt_eval
.eqNode(t
);
281 lem
= nm
->mkNode(kind::IMPLIES
,
282 exp
.size() == 1 ? exp
[0] : nm
->mkNode(kind::AND
, exp
),
285 eval_unfold_lemmas
.push_back(lem
);
288 if (mode
== options::SygusInstMode::PRIORITY_INST
)
290 if (!inst
->addInstantiation(q
, terms
, InferenceId::QUANTIFIERS_INST_SYQI
))
292 sendEvalUnfoldLemmas(eval_unfold_lemmas
);
295 else if (mode
== options::SygusInstMode::PRIORITY_EVAL
)
297 if (!sendEvalUnfoldLemmas(eval_unfold_lemmas
))
299 inst
->addInstantiation(q
, terms
, InferenceId::QUANTIFIERS_INST_SYQI
);
304 Assert(mode
== options::SygusInstMode::INTERLEAVE
);
305 inst
->addInstantiation(q
, terms
, InferenceId::QUANTIFIERS_INST_SYQI
);
306 sendEvalUnfoldLemmas(eval_unfold_lemmas
);
311 bool SygusInst::sendEvalUnfoldLemmas(const std::vector
<Node
>& lemmas
)
313 bool added_lemma
= false;
314 for (const Node
& lem
: lemmas
)
316 Trace("sygus-inst") << "Evaluation unfolding: " << lem
<< std::endl
;
318 d_qim
.addPendingLemma(lem
, InferenceId::QUANTIFIERS_SYQI_EVAL_UNFOLD
);
323 bool SygusInst::checkCompleteFor(Node q
)
325 return d_inactive_quant
.find(q
) != d_inactive_quant
.end();
328 void SygusInst::registerQuantifier(Node q
)
330 Assert(d_ce_lemmas
.find(q
) == d_ce_lemmas
.end());
332 Trace("sygus-inst") << "Register " << q
<< std::endl
;
334 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>> extra_cons
;
335 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>> exclude_cons
;
336 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>> include_cons
;
337 std::unordered_set
<Node
, NodeHashFunction
> term_irrelevant
;
339 /* Collect relevant local ground terms for each variable type. */
340 if (options::sygusInstScope() == options::SygusInstScope::IN
341 || options::sygusInstScope() == options::SygusInstScope::BOTH
)
343 std::unordered_map
<TypeNode
,
344 std::unordered_set
<Node
, NodeHashFunction
>,
345 TypeNodeHashFunction
>
347 for (const Node
& var
: q
[0])
349 TypeNode tn
= var
.getType();
351 /* Collect relevant ground terms for type tn. */
352 if (relevant_terms
.find(tn
) == relevant_terms
.end())
354 std::unordered_set
<Node
, NodeHashFunction
> terms
;
355 std::unordered_set
<TNode
, TNodeHashFunction
> cache_max
;
356 std::unordered_map
<TNode
, std::pair
<bool, bool>, TNodeHashFunction
>
359 getMinGroundTerms(q
, tn
, terms
, cache_min
);
360 getMaxGroundTerms(q
, tn
, terms
, cache_max
);
361 relevant_terms
.emplace(tn
, terms
);
364 /* Add relevant ground terms to grammar. */
365 auto& terms
= relevant_terms
[tn
];
366 for (const auto& t
: terms
)
368 TypeNode ttn
= t
.getType();
369 extra_cons
[ttn
].insert(t
);
370 Trace("sygus-inst") << "Adding (local) extra cons: " << t
<< std::endl
;
375 /* Collect relevant global ground terms for each variable type. */
376 if (options::sygusInstScope() == options::SygusInstScope::OUT
377 || options::sygusInstScope() == options::SygusInstScope::BOTH
)
379 for (const Node
& var
: q
[0])
381 TypeNode tn
= var
.getType();
383 /* Collect relevant ground terms for type tn. */
384 if (d_global_terms
.find(tn
) == d_global_terms
.end())
386 std::unordered_set
<Node
, NodeHashFunction
> terms
;
387 std::unordered_set
<TNode
, TNodeHashFunction
> cache_max
;
388 std::unordered_map
<TNode
, std::pair
<bool, bool>, TNodeHashFunction
>
391 for (const Node
& a
: d_notified_assertions
)
393 getMinGroundTerms(a
, tn
, terms
, cache_min
, true);
394 getMaxGroundTerms(a
, tn
, terms
, cache_max
, true);
396 d_global_terms
.insert(tn
, terms
);
399 /* Add relevant ground terms to grammar. */
400 auto it
= d_global_terms
.find(tn
);
401 if (it
!= d_global_terms
.end())
403 for (const auto& t
: (*it
).second
)
405 TypeNode ttn
= t
.getType();
406 extra_cons
[ttn
].insert(t
);
408 << "Adding (global) extra cons: " << t
<< std::endl
;
414 /* Construct grammar for each bound variable of 'q'. */
415 Trace("sygus-inst") << "Process variables of " << q
<< std::endl
;
416 std::vector
<TypeNode
> types
;
417 for (const Node
& var
: q
[0])
419 addSpecialValues(var
.getType(), extra_cons
);
420 TypeNode tn
= CegGrammarConstructor::mkSygusDefaultType(var
.getType(),
429 Trace("sygus-inst") << "Construct (default) datatype for " << var
434 registerCeLemma(q
, types
);
437 /* Construct grammars for all bound variables of quantifier 'q'. Currently,
438 * we use the default grammar of the variable's type.
440 void SygusInst::preRegisterQuantifier(Node q
)
442 Trace("sygus-inst") << "preRegister " << q
<< std::endl
;
446 void SygusInst::ppNotifyAssertions(const std::vector
<Node
>& assertions
)
448 for (const Node
& a
: assertions
)
450 d_notified_assertions
.insert(a
);
454 /*****************************************************************************/
455 /* private methods */
456 /*****************************************************************************/
458 Node
SygusInst::getCeLiteral(Node q
)
460 auto it
= d_ce_lits
.find(q
);
461 if (it
!= d_ce_lits
.end())
466 NodeManager
* nm
= NodeManager::currentNM();
467 Node sk
= nm
->mkSkolem("CeLiteral", nm
->booleanType());
468 Node lit
= d_qstate
.getValuation().ensureLiteral(sk
);
473 void SygusInst::registerCeLemma(Node q
, std::vector
<TypeNode
>& types
)
475 Assert(q
[0].getNumChildren() == types
.size());
476 Assert(d_ce_lemmas
.find(q
) == d_ce_lemmas
.end());
477 Assert(d_inst_constants
.find(q
) == d_inst_constants
.end());
478 Assert(d_var_eval
.find(q
) == d_var_eval
.end());
480 Trace("sygus-inst") << "Register CE Lemma for " << q
<< std::endl
;
482 /* Generate counterexample lemma for 'q'. */
483 NodeManager
* nm
= NodeManager::currentNM();
484 TermDbSygus
* db
= d_treg
.getTermDatabaseSygus();
486 /* For each variable x_i of \forall x_i . P[x_i], create a fresh datatype
487 * instantiation constant ic_i with type types[i] and wrap each ic_i in
488 * DT_SYGUS_EVAL(ic_i), which will be used to instantiate x_i. */
489 std::vector
<Node
> evals
;
490 std::vector
<Node
> inst_constants
;
491 for (size_t i
= 0, size
= types
.size(); i
< size
; ++i
)
493 TypeNode tn
= types
[i
];
496 /* Create the instantiation constant and set attribute accordingly. */
497 Node ic
= nm
->mkInstConstant(tn
);
498 InstConstantAttribute ica
;
499 ic
.setAttribute(ica
, q
);
500 Trace("sygus-inst") << "Create " << ic
<< " for " << var
<< std::endl
;
502 db
->registerEnumerator(ic
, ic
, nullptr, ROLE_ENUM_MULTI_SOLUTION
);
504 std::vector
<Node
> args
= {ic
};
505 Node svl
= tn
.getDType().getSygusVarList();
508 args
.insert(args
.end(), svl
.begin(), svl
.end());
510 Node eval
= nm
->mkNode(kind::DT_SYGUS_EVAL
, args
);
512 inst_constants
.push_back(ic
);
513 evals
.push_back(eval
);
516 d_inst_constants
.emplace(q
, inst_constants
);
517 d_var_eval
.emplace(q
, evals
);
519 Node lit
= getCeLiteral(q
);
520 d_qim
.addPendingPhaseRequirement(lit
, true);
522 /* The decision strategy for quantified formula 'q' ensures that its
523 * counterexample literal is decided on first. It is user-context dependent.
525 Assert(d_dstrat
.find(q
) == d_dstrat
.end());
526 DecisionStrategy
* ds
= new DecisionStrategySingleton(
527 "CeLiteral", lit
, d_qstate
.getSatContext(), d_qstate
.getValuation());
529 d_dstrat
[q
].reset(ds
);
530 d_qim
.getDecisionManager()->registerStrategy(
531 DecisionManager::STRAT_QUANT_CEGQI_FEASIBLE
, ds
);
533 /* Add counterexample lemma (lit => ~P[x_i/eval_i]) */
535 q
[1].substitute(q
[0].begin(), q
[0].end(), evals
.begin(), evals
.end());
536 Node lem
= nm
->mkNode(kind::OR
, lit
.negate(), body
.negate());
537 lem
= Rewriter::rewrite(lem
);
539 d_ce_lemmas
.emplace(std::make_pair(q
, lem
));
540 Trace("sygus-inst") << "Register CE Lemma: " << lem
<< std::endl
;
543 void SygusInst::addCeLemma(Node q
)
545 Assert(d_ce_lemmas
.find(q
) != d_ce_lemmas
.end());
547 /* Already added in previous contexts. */
548 if (d_ce_lemma_added
.find(q
) != d_ce_lemma_added
.end()) return;
550 Node lem
= d_ce_lemmas
[q
];
551 d_qim
.addPendingLemma(lem
, InferenceId::QUANTIFIERS_SYQI_CEX
);
552 d_ce_lemma_added
.insert(q
);
553 Trace("sygus-inst") << "Add CE Lemma: " << lem
<< std::endl
;
556 } // namespace quantifiers
557 } // namespace theory