1 /********************* */
2 /*! \file sygus_inst.cpp
4 ** Top contributors (to current version):
5 ** Mathias Preiner, Aina Niemetz, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief SyGuS instantiator class.
15 #include "theory/quantifiers/sygus_inst.h"
18 #include <unordered_set>
20 #include "expr/node_algorithm.h"
21 #include "theory/bv/theory_bv_utils.h"
22 #include "theory/datatypes/sygus_datatype_utils.h"
23 #include "theory/quantifiers/sygus/sygus_enumerator.h"
24 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
25 #include "theory/quantifiers/sygus/synth_engine.h"
26 #include "theory/quantifiers_engine.h"
30 namespace quantifiers
{
35 * Collect maximal ground terms with type tn in node n.
37 * @param n: Node to traverse.
38 * @param tn: Collects only terms with type tn.
39 * @param terms: Collected terms.
40 * @param cache: Caches visited nodes.
41 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
43 void getMaxGroundTerms(TNode n
,
45 std::unordered_set
<Node
, NodeHashFunction
>& terms
,
46 std::unordered_set
<TNode
, TNodeHashFunction
>& cache
,
47 bool skip_quant
= false)
49 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MAX
50 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH
)
55 Trace("sygus-inst-term") << "Find maximal terms with type " << tn
56 << " in: " << n
<< std::endl
;
59 std::vector
<TNode
> visit
;
67 if (cache
.find(cur
) != cache
.end())
73 if (expr::hasBoundVar(cur
) || cur
.getType() != tn
)
75 if (!skip_quant
|| cur
.getKind() != kind::FORALL
)
77 visit
.insert(visit
.end(), cur
.begin(), cur
.end());
83 Trace("sygus-inst-term") << " found: " << cur
<< std::endl
;
85 } while (!visit
.empty());
89 * Collect minimal ground terms with type tn in node n.
91 * @param n: Node to traverse.
92 * @param tn: Collects only terms with type tn.
93 * @param terms: Collected terms.
94 * @param cache: Caches visited nodes and flags indicating whether a minimal
95 * term was already found in a subterm.
96 * @param skip_quant: Do not traverse quantified formulas (skip quantifiers).
98 void getMinGroundTerms(
101 std::unordered_set
<Node
, NodeHashFunction
>& terms
,
102 std::unordered_map
<TNode
, std::pair
<bool, bool>, TNodeHashFunction
>& cache
,
103 bool skip_quant
= false)
105 if (options::sygusInstTermSel() != options::SygusInstTermSelMode::MIN
106 && options::sygusInstTermSel() != options::SygusInstTermSelMode::BOTH
)
111 Trace("sygus-inst-term") << "Find minimal terms with type " << tn
112 << " in: " << n
<< std::endl
;
115 std::vector
<TNode
> visit
;
123 auto it
= cache
.find(cur
);
124 if (it
== cache
.end())
126 cache
.emplace(cur
, std::make_pair(false, false));
127 if (!skip_quant
|| cur
.getKind() != kind::FORALL
)
129 visit
.push_back(cur
);
130 visit
.insert(visit
.end(), cur
.begin(), cur
.end());
134 else if (!it
->second
.first
)
136 bool found_min_term
= false;
138 /* Check if we found a minimal term in one of the children. */
139 for (const Node
& c
: cur
)
141 found_min_term
|= cache
[c
].second
;
142 if (found_min_term
) break;
145 /* If we haven't found a minimal term yet, add this term if it has the
147 if (cur
.getType() == tn
&& !expr::hasBoundVar(cur
) && !found_min_term
)
150 found_min_term
= true;
151 Trace("sygus-inst-term") << " found: " << cur
<< std::endl
;
154 it
->second
.first
= true;
155 it
->second
.second
= found_min_term
;
157 } while (!visit
.empty());
161 * Add special values for a given type.
163 * @param tn: The type node.
164 * @param extra_cons: A map of TypeNode to constants, which are added in
165 * addition to the default grammar.
167 void addSpecialValues(
169 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>>& extra_cons
)
171 if (tn
.isBitVector())
173 uint32_t size
= tn
.getBitVectorSize();
174 extra_cons
[tn
].insert(bv::utils::mkOnes(size
));
175 extra_cons
[tn
].insert(bv::utils::mkMinSigned(size
));
176 extra_cons
[tn
].insert(bv::utils::mkMaxSigned(size
));
182 SygusInst::SygusInst(QuantifiersEngine
* qe
,
183 QuantifiersState
& qs
,
184 QuantifiersInferenceManager
& qim
)
185 : QuantifiersModule(qs
, qim
, qe
),
186 d_ce_lemma_added(qs
.getUserContext()),
187 d_global_terms(qs
.getUserContext()),
188 d_notified_assertions(qs
.getUserContext())
192 bool SygusInst::needsCheck(Theory::Effort e
)
194 return e
>= Theory::EFFORT_LAST_CALL
;
197 QuantifiersModule::QEffort
SygusInst::needsModel(Theory::Effort e
)
199 return QEFFORT_STANDARD
;
202 void SygusInst::reset_round(Theory::Effort e
)
204 d_active_quant
.clear();
205 d_inactive_quant
.clear();
207 FirstOrderModel
* model
= d_quantEngine
->getModel();
208 uint32_t nasserted
= model
->getNumAssertedQuantifiers();
210 for (uint32_t i
= 0; i
< nasserted
; ++i
)
212 Node q
= model
->getAssertedQuantifier(i
);
214 if (model
->isQuantifierActive(q
))
216 d_active_quant
.insert(q
);
217 Node lit
= getCeLiteral(q
);
220 if (d_quantEngine
->getValuation().hasSatValue(lit
, value
))
224 if (!d_quantEngine
->getValuation().isDecision(lit
))
226 model
->setQuantifierActive(q
, false);
227 d_active_quant
.erase(q
);
228 d_inactive_quant
.insert(q
);
229 Trace("sygus-inst") << "Set inactive: " << q
<< std::endl
;
237 void SygusInst::check(Theory::Effort e
, QEffort quant_e
)
239 Trace("sygus-inst") << "Check " << e
<< ", " << quant_e
<< std::endl
;
241 if (quant_e
!= QEFFORT_STANDARD
) return;
243 FirstOrderModel
* model
= d_quantEngine
->getModel();
244 Instantiate
* inst
= d_quantEngine
->getInstantiate();
245 TermDbSygus
* db
= d_quantEngine
->getTermDatabaseSygus();
246 SygusExplain
syexplain(db
);
247 NodeManager
* nm
= NodeManager::currentNM();
248 options::SygusInstMode mode
= options::sygusInstMode();
250 for (const Node
& q
: d_active_quant
)
252 const std::vector
<Node
>& inst_constants
= d_inst_constants
.at(q
);
253 const std::vector
<Node
>& dt_evals
= d_var_eval
.at(q
);
254 Assert(inst_constants
.size() == dt_evals
.size());
255 Assert(inst_constants
.size() == q
[0].getNumChildren());
257 std::vector
<Node
> terms
, eval_unfold_lemmas
;
258 for (size_t i
= 0, size
= q
[0].getNumChildren(); i
< size
; ++i
)
260 Node dt_var
= inst_constants
[i
];
261 Node dt_eval
= dt_evals
[i
];
262 Node value
= model
->getValue(dt_var
);
263 Node t
= datatypes::utils::sygusToBuiltin(value
);
266 std::vector
<Node
> exp
;
267 syexplain
.getExplanationForEquality(dt_var
, value
, exp
);
271 lem
= dt_eval
.eqNode(t
);
275 lem
= nm
->mkNode(kind::IMPLIES
,
276 exp
.size() == 1 ? exp
[0] : nm
->mkNode(kind::AND
, exp
),
279 eval_unfold_lemmas
.push_back(lem
);
282 if (mode
== options::SygusInstMode::PRIORITY_INST
)
284 if (!inst
->addInstantiation(q
, terms
))
286 sendEvalUnfoldLemmas(eval_unfold_lemmas
);
289 else if (mode
== options::SygusInstMode::PRIORITY_EVAL
)
291 if (!sendEvalUnfoldLemmas(eval_unfold_lemmas
))
293 inst
->addInstantiation(q
, terms
);
298 Assert(mode
== options::SygusInstMode::INTERLEAVE
);
299 inst
->addInstantiation(q
, terms
);
300 sendEvalUnfoldLemmas(eval_unfold_lemmas
);
305 bool SygusInst::sendEvalUnfoldLemmas(const std::vector
<Node
>& lemmas
)
307 bool added_lemma
= false;
308 for (const Node
& lem
: lemmas
)
310 Trace("sygus-inst") << "Evaluation unfolding: " << lem
<< std::endl
;
311 added_lemma
|= d_quantEngine
->addLemma(lem
);
316 bool SygusInst::checkCompleteFor(Node q
)
318 return d_inactive_quant
.find(q
) != d_inactive_quant
.end();
321 void SygusInst::registerQuantifier(Node q
)
323 Assert(d_ce_lemmas
.find(q
) == d_ce_lemmas
.end());
325 Trace("sygus-inst") << "Register " << q
<< std::endl
;
327 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>> extra_cons
;
328 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>> exclude_cons
;
329 std::map
<TypeNode
, std::unordered_set
<Node
, NodeHashFunction
>> include_cons
;
330 std::unordered_set
<Node
, NodeHashFunction
> term_irrelevant
;
332 /* Collect relevant local ground terms for each variable type. */
333 if (options::sygusInstScope() == options::SygusInstScope::IN
334 || options::sygusInstScope() == options::SygusInstScope::BOTH
)
336 std::unordered_map
<TypeNode
,
337 std::unordered_set
<Node
, NodeHashFunction
>,
338 TypeNodeHashFunction
>
340 for (const Node
& var
: q
[0])
342 TypeNode tn
= var
.getType();
344 /* Collect relevant ground terms for type tn. */
345 if (relevant_terms
.find(tn
) == relevant_terms
.end())
347 std::unordered_set
<Node
, NodeHashFunction
> terms
;
348 std::unordered_set
<TNode
, TNodeHashFunction
> cache_max
;
349 std::unordered_map
<TNode
, std::pair
<bool, bool>, TNodeHashFunction
>
352 getMinGroundTerms(q
, tn
, terms
, cache_min
);
353 getMaxGroundTerms(q
, tn
, terms
, cache_max
);
354 relevant_terms
.emplace(tn
, terms
);
357 /* Add relevant ground terms to grammar. */
358 auto& terms
= relevant_terms
[tn
];
359 for (const auto& t
: terms
)
361 TypeNode ttn
= t
.getType();
362 extra_cons
[ttn
].insert(t
);
363 Trace("sygus-inst") << "Adding (local) extra cons: " << t
<< std::endl
;
368 /* Collect relevant global ground terms for each variable type. */
369 if (options::sygusInstScope() == options::SygusInstScope::OUT
370 || options::sygusInstScope() == options::SygusInstScope::BOTH
)
372 for (const Node
& var
: q
[0])
374 TypeNode tn
= var
.getType();
376 /* Collect relevant ground terms for type tn. */
377 if (d_global_terms
.find(tn
) == d_global_terms
.end())
379 std::unordered_set
<Node
, NodeHashFunction
> terms
;
380 std::unordered_set
<TNode
, TNodeHashFunction
> cache_max
;
381 std::unordered_map
<TNode
, std::pair
<bool, bool>, TNodeHashFunction
>
384 for (const Node
& a
: d_notified_assertions
)
386 getMinGroundTerms(a
, tn
, terms
, cache_min
, true);
387 getMaxGroundTerms(a
, tn
, terms
, cache_max
, true);
389 d_global_terms
.insert(tn
, terms
);
392 /* Add relevant ground terms to grammar. */
393 auto it
= d_global_terms
.find(tn
);
394 if (it
!= d_global_terms
.end())
396 for (const auto& t
: (*it
).second
)
398 TypeNode ttn
= t
.getType();
399 extra_cons
[ttn
].insert(t
);
401 << "Adding (global) extra cons: " << t
<< std::endl
;
407 /* Construct grammar for each bound variable of 'q'. */
408 Trace("sygus-inst") << "Process variables of " << q
<< std::endl
;
409 std::vector
<TypeNode
> types
;
410 for (const Node
& var
: q
[0])
412 addSpecialValues(var
.getType(), extra_cons
);
413 TypeNode tn
= CegGrammarConstructor::mkSygusDefaultType(var
.getType(),
422 Trace("sygus-inst") << "Construct (default) datatype for " << var
427 registerCeLemma(q
, types
);
430 /* Construct grammars for all bound variables of quantifier 'q'. Currently,
431 * we use the default grammar of the variable's type.
433 void SygusInst::preRegisterQuantifier(Node q
)
435 Trace("sygus-inst") << "preRegister " << q
<< std::endl
;
439 void SygusInst::ppNotifyAssertions(const std::vector
<Node
>& assertions
)
441 for (const Node
& a
: assertions
)
443 d_notified_assertions
.insert(a
);
447 /*****************************************************************************/
448 /* private methods */
449 /*****************************************************************************/
451 Node
SygusInst::getCeLiteral(Node q
)
453 auto it
= d_ce_lits
.find(q
);
454 if (it
!= d_ce_lits
.end())
459 NodeManager
* nm
= NodeManager::currentNM();
460 Node sk
= nm
->mkSkolem("CeLiteral", nm
->booleanType());
461 Node lit
= d_quantEngine
->getValuation().ensureLiteral(sk
);
466 void SygusInst::registerCeLemma(Node q
, std::vector
<TypeNode
>& types
)
468 Assert(q
[0].getNumChildren() == types
.size());
469 Assert(d_ce_lemmas
.find(q
) == d_ce_lemmas
.end());
470 Assert(d_inst_constants
.find(q
) == d_inst_constants
.end());
471 Assert(d_var_eval
.find(q
) == d_var_eval
.end());
473 Trace("sygus-inst") << "Register CE Lemma for " << q
<< std::endl
;
475 /* Generate counterexample lemma for 'q'. */
476 NodeManager
* nm
= NodeManager::currentNM();
477 TermDbSygus
* db
= d_quantEngine
->getTermDatabaseSygus();
479 /* For each variable x_i of \forall x_i . P[x_i], create a fresh datatype
480 * instantiation constant ic_i with type types[i] and wrap each ic_i in
481 * DT_SYGUS_EVAL(ic_i), which will be used to instantiate x_i. */
482 std::vector
<Node
> evals
;
483 std::vector
<Node
> inst_constants
;
484 for (size_t i
= 0, size
= types
.size(); i
< size
; ++i
)
486 TypeNode tn
= types
[i
];
489 /* Create the instantiation constant and set attribute accordingly. */
490 Node ic
= nm
->mkInstConstant(tn
);
491 InstConstantAttribute ica
;
492 ic
.setAttribute(ica
, q
);
493 Trace("sygus-inst") << "Create " << ic
<< " for " << var
<< std::endl
;
495 db
->registerEnumerator(ic
, ic
, nullptr, ROLE_ENUM_MULTI_SOLUTION
);
497 std::vector
<Node
> args
= {ic
};
498 Node svl
= tn
.getDType().getSygusVarList();
501 args
.insert(args
.end(), svl
.begin(), svl
.end());
503 Node eval
= nm
->mkNode(kind::DT_SYGUS_EVAL
, args
);
505 inst_constants
.push_back(ic
);
506 evals
.push_back(eval
);
509 d_inst_constants
.emplace(q
, inst_constants
);
510 d_var_eval
.emplace(q
, evals
);
512 Node lit
= getCeLiteral(q
);
513 d_quantEngine
->addRequirePhase(lit
, true);
515 /* The decision strategy for quantified formula 'q' ensures that its
516 * counterexample literal is decided on first. It is user-context dependent.
518 Assert(d_dstrat
.find(q
) == d_dstrat
.end());
519 DecisionStrategy
* ds
=
520 new DecisionStrategySingleton("CeLiteral",
522 d_qstate
.getSatContext(),
523 d_quantEngine
->getValuation());
525 d_dstrat
[q
].reset(ds
);
526 d_quantEngine
->getDecisionManager()->registerStrategy(
527 DecisionManager::STRAT_QUANT_CEGQI_FEASIBLE
, ds
);
529 /* Add counterexample lemma (lit => ~P[x_i/eval_i]) */
531 q
[1].substitute(q
[0].begin(), q
[0].end(), evals
.begin(), evals
.end());
532 Node lem
= nm
->mkNode(kind::OR
, lit
.negate(), body
.negate());
533 lem
= Rewriter::rewrite(lem
);
535 d_ce_lemmas
.emplace(std::make_pair(q
, lem
));
536 Trace("sygus-inst") << "Register CE Lemma: " << lem
<< std::endl
;
539 void SygusInst::addCeLemma(Node q
)
541 Assert(d_ce_lemmas
.find(q
) != d_ce_lemmas
.end());
543 /* Already added in previous contexts. */
544 if (d_ce_lemma_added
.find(q
) != d_ce_lemma_added
.end()) return;
546 Node lem
= d_ce_lemmas
[q
];
547 d_quantEngine
->addLemma(lem
, false);
548 d_ce_lemma_added
.insert(q
);
549 Trace("sygus-inst") << "Add CE Lemma: " << lem
<< std::endl
;
552 } // namespace quantifiers
553 } // namespace theory