Move enum value generator to own file (#6941)
[cvc5.git] / src / theory / quantifiers / sygus / enum_stream_substitution.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Haniel Barbosa, Andrew Reynolds, Gereon Kremer
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Class for streaming concrete values (through substitutions) from
14 * enumerated abstract ones.
15 */
16
17 #include "theory/quantifiers/sygus/enum_stream_substitution.h"
18
19 #include "expr/dtype_cons.h"
20 #include "options/base_options.h"
21 #include "options/datatypes_options.h"
22 #include "options/quantifiers_options.h"
23 #include "printer/printer.h"
24 #include "theory/quantifiers/sygus/term_database_sygus.h"
25
26 #include <numeric> // for std::iota
27 #include <sstream>
28
29 using namespace cvc5::kind;
30
31 namespace cvc5 {
32 namespace theory {
33 namespace quantifiers {
34
35 EnumStreamPermutation::EnumStreamPermutation(quantifiers::TermDbSygus* tds)
36 : d_tds(tds), d_first(true), d_curr_ind(0)
37 {
38 }
39
40 void EnumStreamPermutation::reset(Node value)
41 {
42 // clean state
43 d_var_classes.clear();
44 d_var_tn_cons.clear();
45 d_first = true;
46 d_perm_state_class.clear();
47 d_perm_values.clear();
48 d_value = value;
49 // get variables in value's type
50 TypeNode tn = value.getType();
51 Node var_list = tn.getDType().getSygusVarList();
52 NodeManager* nm = NodeManager::currentNM();
53 // get subtypes in value's type
54 SygusTypeInfo& ti = d_tds->getTypeInfo(tn);
55 std::vector<TypeNode> sf_types;
56 ti.getSubfieldTypes(sf_types);
57 // associate variables with constructors in all subfield types
58 std::map<Node, Node> cons_var;
59 for (const Node& v : var_list)
60 {
61 // collect constructors for variable in all subtypes
62 for (const TypeNode& stn : sf_types)
63 {
64 const DType& dt = stn.getDType();
65 for (unsigned i = 0, size = dt.getNumConstructors(); i < size; ++i)
66 {
67 if (dt[i].getNumArgs() == 0 && dt[i].getSygusOp() == v)
68 {
69 Node cons = nm->mkNode(APPLY_CONSTRUCTOR, dt[i].getConstructor());
70 d_var_tn_cons[v][stn] = cons;
71 cons_var[cons] = v;
72 }
73 }
74 }
75 }
76 // collect variables occurring in value
77 std::vector<Node> vars;
78 std::unordered_set<Node> visited;
79 collectVars(value, vars, visited);
80 // partition permutation variables
81 d_curr_ind = 0;
82 Trace("synth-stream-concrete") << " ..permutting vars :";
83 std::unordered_set<Node> seen_vars;
84 for (const Node& v_cons : vars)
85 {
86 Assert(cons_var.find(v_cons) != cons_var.end());
87 Node var = cons_var[v_cons];
88 if (seen_vars.insert(var).second)
89 {
90 // do not add repeated vars
91 d_var_classes[ti.getSubclassForVar(var)].push_back(var);
92 }
93 }
94 for (const std::pair<const unsigned, std::vector<Node>>& p : d_var_classes)
95 {
96 d_perm_state_class.push_back(PermutationState(p.second));
97 if (Trace.isOn("synth-stream-concrete"))
98 {
99 Trace("synth-stream-concrete") << " " << p.first << " -> [";
100 for (const Node& var : p.second)
101 {
102 std::stringstream ss;
103 TermDbSygus::toStreamSygus(ss, var);
104 Trace("synth-stream-concrete") << " " << ss.str();
105 }
106 Trace("synth-stream-concrete") << " ]";
107 }
108 }
109 Trace("synth-stream-concrete") << "\n";
110 }
111
112 Node EnumStreamPermutation::getNext()
113 {
114 if (Trace.isOn("synth-stream-concrete"))
115 {
116 std::stringstream ss;
117 TermDbSygus::toStreamSygus(ss, d_value);
118 Trace("synth-stream-concrete")
119 << " ....streaming next permutation for value : " << ss.str()
120 << " with " << d_perm_state_class.size() << " permutation classes\n";
121 }
122 // initial value
123 if (d_first)
124 {
125 d_first = false;
126 Node bultin_value = d_tds->sygusToBuiltin(d_value, d_value.getType());
127 d_perm_values.insert(
128 d_tds->getExtRewriter()->extendedRewrite(bultin_value));
129 return d_value;
130 }
131 unsigned n_classes = d_perm_state_class.size();
132 Node perm_value, bultin_perm_value;
133 do
134 {
135 bool new_perm = false;
136 while (!new_perm && d_curr_ind < n_classes)
137 {
138 if (d_perm_state_class[d_curr_ind].getNextPermutation())
139 {
140 new_perm = true;
141 Trace("synth-stream-concrete-debug2")
142 << " ....class " << d_curr_ind << " has new perm\n";
143 d_curr_ind = 0;
144 }
145 else
146 {
147 Trace("synth-stream-concrete-debug2")
148 << " ....class " << d_curr_ind << " reset\n";
149 d_perm_state_class[d_curr_ind].reset();
150 d_curr_ind++;
151 }
152 }
153 // no new permutation
154 if (!new_perm)
155 {
156 Trace("synth-stream-concrete") << " ....no new perm, return null\n";
157 return Node::null();
158 }
159 // building substitution
160 std::vector<Node> domain_sub, range_sub;
161 for (unsigned i = 0, size = d_perm_state_class.size(); i < size; ++i)
162 {
163 Trace("synth-stream-concrete") << " ..perm for class " << i << " is";
164 std::vector<Node> raw_sub;
165 d_perm_state_class[i].getLastPerm(raw_sub);
166 // retrieve variables for substitution domain
167 const std::vector<Node>& domain_sub_class =
168 d_perm_state_class[i].getVars();
169 Assert(domain_sub_class.size() == raw_sub.size());
170 // build proper substitution based on variables types and constructors
171 for (unsigned j = 0, size_j = raw_sub.size(); j < size_j; ++j)
172 {
173 for (std::pair<const TypeNode, Node>& p :
174 d_var_tn_cons[domain_sub_class[j]])
175 {
176 // get constructor of type p.first from variable being permuted
177 domain_sub.push_back(p.second);
178 // get constructor of type p.first from variable to be permuted for
179 range_sub.push_back(d_var_tn_cons[raw_sub[j]][p.first]);
180 Trace("synth-stream-concrete-debug2")
181 << "\n ....{ adding " << domain_sub.back() << " ["
182 << domain_sub.back().getType() << "] -> " << range_sub.back()
183 << " [" << range_sub.back().getType() << "] }";
184 }
185 }
186 Trace("synth-stream-concrete") << "\n";
187 }
188 perm_value = d_value.substitute(domain_sub.begin(),
189 domain_sub.end(),
190 range_sub.begin(),
191 range_sub.end());
192 bultin_perm_value = d_tds->sygusToBuiltin(perm_value, perm_value.getType());
193 Trace("synth-stream-concrete-debug")
194 << " ......perm builtin is " << bultin_perm_value;
195 if (options::sygusSymBreakDynamic())
196 {
197 bultin_perm_value =
198 d_tds->getExtRewriter()->extendedRewrite(bultin_perm_value);
199 Trace("synth-stream-concrete-debug")
200 << " and rewrites to " << bultin_perm_value;
201 }
202 Trace("synth-stream-concrete-debug") << "\n";
203 // if permuted value is equivalent modulo rewriting to a previous one, look
204 // for another
205 } while (!d_perm_values.insert(bultin_perm_value).second);
206 if (Trace.isOn("synth-stream-concrete"))
207 {
208 std::stringstream ss;
209 TermDbSygus::toStreamSygus(ss, perm_value);
210 Trace("synth-stream-concrete")
211 << " ....return new perm " << ss.str() << "\n";
212 }
213 return perm_value;
214 }
215
216 const std::vector<Node>& EnumStreamPermutation::getVarsClass(unsigned id) const
217 {
218 std::map<unsigned, std::vector<Node>>::const_iterator it =
219 d_var_classes.find(id);
220 Assert(it != d_var_classes.end());
221 return it->second;
222 }
223
224 unsigned EnumStreamPermutation::getVarClassSize(unsigned id) const
225 {
226 std::map<unsigned, std::vector<Node>>::const_iterator it =
227 d_var_classes.find(id);
228 if (it == d_var_classes.end())
229 {
230 return 0;
231 }
232 return it->second.size();
233 }
234
235 void EnumStreamPermutation::collectVars(Node n,
236 std::vector<Node>& vars,
237 std::unordered_set<Node>& visited)
238 {
239 if (visited.find(n) != visited.end())
240 {
241 return;
242 }
243 visited.insert(n);
244 if (n.getNumChildren() > 0)
245 {
246 for (const Node& ni : n)
247 {
248 collectVars(ni, vars, visited);
249 }
250 return;
251 }
252 if (d_tds->sygusToBuiltin(n, n.getType()).getKind() == kind::BOUND_VARIABLE)
253 {
254 if (std::find(vars.begin(), vars.end(), n) == vars.end())
255 {
256 vars.push_back(n);
257 }
258 return;
259 }
260 }
261
262 EnumStreamPermutation::PermutationState::PermutationState(
263 const std::vector<Node>& vars)
264 {
265 d_vars = vars;
266 d_curr_ind = 0;
267 d_seq.resize(vars.size());
268 std::fill(d_seq.begin(), d_seq.end(), 0);
269 // initialize variable indices
270 d_last_perm.resize(vars.size());
271 std::iota(d_last_perm.begin(), d_last_perm.end(), 0);
272 }
273
274 void EnumStreamPermutation::PermutationState::reset()
275 {
276 d_curr_ind = 0;
277 std::fill(d_seq.begin(), d_seq.end(), 0);
278 std::iota(d_last_perm.begin(), d_last_perm.end(), 0);
279 }
280
281 const std::vector<Node>& EnumStreamPermutation::PermutationState::getVars()
282 const
283 {
284 return d_vars;
285 }
286
287 void EnumStreamPermutation::PermutationState::getLastPerm(
288 std::vector<Node>& vars)
289 {
290 for (unsigned i = 0, size = d_last_perm.size(); i < size; ++i)
291 {
292 if (Trace.isOn("synth-stream-concrete"))
293 {
294 std::stringstream ss;
295 TermDbSygus::toStreamSygus(ss, d_vars[d_last_perm[i]]);
296 Trace("synth-stream-concrete") << " " << ss.str();
297 }
298 vars.push_back(d_vars[d_last_perm[i]]);
299 }
300 }
301
302 bool EnumStreamPermutation::PermutationState::getNextPermutation()
303 {
304 // exhausted permutations
305 if (d_curr_ind == d_vars.size())
306 {
307 Trace("synth-stream-concrete-debug2") << "exhausted perms, ";
308 return false;
309 }
310 if (d_seq[d_curr_ind] >= d_curr_ind)
311 {
312 d_seq[d_curr_ind] = 0;
313 d_curr_ind++;
314 return getNextPermutation();
315 }
316 if (d_curr_ind % 2 == 0)
317 {
318 // swap with first element
319 std::swap(d_last_perm[0], d_last_perm[d_curr_ind]);
320 }
321 else
322 {
323 // swap with element in index in sequence of current index
324 std::swap(d_last_perm[d_seq[d_curr_ind]], d_last_perm[d_curr_ind]);
325 }
326 d_seq[d_curr_ind] += 1;
327 d_curr_ind = 0;
328 return true;
329 }
330
331 EnumStreamSubstitution::EnumStreamSubstitution(quantifiers::TermDbSygus* tds)
332 : d_tds(tds), d_stream_permutations(tds), d_curr_ind(0)
333 {
334 }
335
336 void EnumStreamSubstitution::initialize(TypeNode tn)
337 {
338 d_tn = tn;
339 // get variables in value's type
340 Node var_list = tn.getDType().getSygusVarList();
341 // get subtypes in value's type
342 NodeManager* nm = NodeManager::currentNM();
343 SygusTypeInfo& ti = d_tds->getTypeInfo(tn);
344 std::vector<TypeNode> sf_types;
345 ti.getSubfieldTypes(sf_types);
346 // associate variables with constructors in all subfield types
347 for (const Node& v : var_list)
348 {
349 // collect constructors for variable in all subtypes
350 for (const TypeNode& stn : sf_types)
351 {
352 const DType& dt = stn.getDType();
353 for (unsigned i = 0, size = dt.getNumConstructors(); i < size; ++i)
354 {
355 if (dt[i].getNumArgs() == 0 && dt[i].getSygusOp() == v)
356 {
357 d_var_tn_cons[v][stn] =
358 nm->mkNode(APPLY_CONSTRUCTOR, dt[i].getConstructor());
359 }
360 }
361 }
362 }
363 // split initial variables into classes
364 for (const Node& v : var_list)
365 {
366 Assert(ti.getSubclassForVar(v) > 0);
367 d_var_classes[ti.getSubclassForVar(v)].push_back(v);
368 }
369 }
370
371 void EnumStreamSubstitution::resetValue(Node value)
372 {
373 if (Trace.isOn("synth-stream-concrete"))
374 {
375 std::stringstream ss;
376 TermDbSygus::toStreamSygus(ss, value);
377 Trace("synth-stream-concrete")
378 << " * Streaming concrete: registering value " << ss.str() << "\n";
379 }
380 d_last = Node::null();
381 d_value = value;
382 // reset permutation util
383 d_stream_permutations.reset(value);
384 // reset combination utils
385 d_curr_ind = 0;
386 d_comb_state_class.clear();
387 Trace("synth-stream-concrete") << " ..combining vars :";
388 for (const std::pair<const unsigned, std::vector<Node>>& p : d_var_classes)
389 {
390 // ignore classes without variables being permuted
391 unsigned perm_var_class_sz = d_stream_permutations.getVarClassSize(p.first);
392 if (perm_var_class_sz == 0)
393 {
394 continue;
395 }
396 d_comb_state_class.push_back(CombinationState(
397 p.second.size(), perm_var_class_sz, p.first, p.second));
398 if (Trace.isOn("synth-stream-concrete"))
399 {
400 Trace("synth-stream-concrete")
401 << " " << p.first << " -> " << perm_var_class_sz << " from [ ";
402 for (const Node& var : p.second)
403 {
404 std::stringstream ss;
405 TermDbSygus::toStreamSygus(ss, var);
406 Trace("synth-stream-concrete") << " " << ss.str();
407 }
408 Trace("synth-stream-concrete") << " ]";
409 }
410 }
411 Trace("synth-stream-concrete") << "\n";
412 }
413
414 Node EnumStreamSubstitution::getNext()
415 {
416 if (Trace.isOn("synth-stream-concrete"))
417 {
418 std::stringstream ss;
419 TermDbSygus::toStreamSygus(ss, d_value);
420 Trace("synth-stream-concrete")
421 << " ..streaming next combination of " << ss.str() << "\n";
422 }
423 unsigned n_classes = d_comb_state_class.size();
424 // intial case
425 if (d_last.isNull())
426 {
427 d_last = d_stream_permutations.getNext();
428 }
429 else
430 {
431 bool new_comb = false;
432 while (!new_comb && d_curr_ind < n_classes)
433 {
434 if (d_comb_state_class[d_curr_ind].getNextCombination())
435 {
436 new_comb = true;
437 Trace("synth-stream-concrete-debug2")
438 << " ....class " << d_curr_ind << " has new comb\n";
439 d_curr_ind = 0;
440 }
441 else
442 {
443 Trace("synth-stream-concrete-debug2")
444 << " ....class " << d_curr_ind << " reset\n";
445 d_comb_state_class[d_curr_ind].reset();
446 d_curr_ind++;
447 }
448 }
449 // no new combination
450 if (!new_comb)
451 {
452 Trace("synth-stream-concrete")
453 << " ..no new comb, get next permutation\n ....total combs until "
454 "here : "
455 << d_comb_values.size() << "\n";
456 d_last = d_stream_permutations.getNext();
457 // exhausted permutations
458 if (d_last.isNull())
459 {
460 Trace("synth-stream-concrete") << " ..no new comb, return null\n";
461 return Node::null();
462 }
463 // reset combination classes for next permutation
464 d_curr_ind = 0;
465 for (unsigned i = 0, size = d_comb_state_class.size(); i < size; ++i)
466 {
467 d_comb_state_class[i].reset();
468 }
469 }
470 }
471 if (Trace.isOn("synth-stream-concrete-debug"))
472 {
473 std::stringstream ss;
474 TermDbSygus::toStreamSygus(ss, d_last);
475 Trace("synth-stream-concrete-debug")
476 << " ..using base perm " << ss.str() << "\n";
477 }
478 // building substitution
479 std::vector<Node> domain_sub, range_sub;
480 for (unsigned i = 0, size = d_comb_state_class.size(); i < size; ++i)
481 {
482 Trace("synth-stream-concrete")
483 << " ..comb for class " << d_comb_state_class[i].getSubclassId()
484 << " is";
485 std::vector<Node> raw_sub;
486 d_comb_state_class[i].getLastComb(raw_sub);
487 // retrieve variables for substitution domain
488 const std::vector<Node>& domain_sub_class =
489 d_stream_permutations.getVarsClass(
490 d_comb_state_class[i].getSubclassId());
491 Assert(domain_sub_class.size() == raw_sub.size());
492 // build proper substitution based on variables types and constructors
493 for (unsigned j = 0, size_j = raw_sub.size(); j < size_j; ++j)
494 {
495 for (std::pair<const TypeNode, Node>& p :
496 d_var_tn_cons[domain_sub_class[j]])
497 {
498 // get constructor of type p.first from variable being permuted
499 domain_sub.push_back(p.second);
500 // get constructor of type p.first from variable to be permuted for
501 range_sub.push_back(d_var_tn_cons[raw_sub[j]][p.first]);
502 Trace("synth-stream-concrete-debug2")
503 << "\n ....{ adding " << domain_sub.back() << " ["
504 << domain_sub.back().getType() << "] -> " << range_sub.back()
505 << " [" << range_sub.back().getType() << "] }";
506 }
507 }
508 Trace("synth-stream-concrete") << "\n";
509 }
510 Node comb_value = d_last.substitute(
511 domain_sub.begin(), domain_sub.end(), range_sub.begin(), range_sub.end());
512 // the new combination value should be fresh, modulo rewriting, by
513 // construction (unless it's equiv to a constant, e.g. true / false)
514 Node builtin_comb_value =
515 d_tds->sygusToBuiltin(comb_value, comb_value.getType());
516 if (options::sygusSymBreakDynamic())
517 {
518 builtin_comb_value =
519 d_tds->getExtRewriter()->extendedRewrite(builtin_comb_value);
520 }
521 if (Trace.isOn("synth-stream-concrete"))
522 {
523 std::stringstream ss;
524 TermDbSygus::toStreamSygus(ss, comb_value);
525 Trace("synth-stream-concrete")
526 << " ....register new comb value " << ss.str()
527 << " with rewritten form " << builtin_comb_value
528 << (builtin_comb_value.isConst() ? " (isConst)" : "") << "\n";
529 }
530 if (!builtin_comb_value.isConst()
531 && !d_comb_values.insert(builtin_comb_value).second)
532 {
533 if (Trace.isOn("synth-stream-concrete"))
534 {
535 std::stringstream ss, ss1;
536 TermDbSygus::toStreamSygus(ss, comb_value);
537 Trace("synth-stream-concrete")
538 << " ..term " << ss.str() << " is REDUNDANT with " << builtin_comb_value
539 << "\n ..excluding all other concretizations (had "
540 << d_comb_values.size() << " already)\n\n";
541 }
542 return Node::null();
543 }
544 if (Trace.isOn("synth-stream-concrete"))
545 {
546 std::stringstream ss;
547 TermDbSygus::toStreamSygus(ss, comb_value);
548 Trace("synth-stream-concrete")
549 << " ..return new comb " << ss.str() << "\n\n";
550 }
551 return comb_value;
552 }
553
554 EnumStreamSubstitution::CombinationState::CombinationState(
555 unsigned n, unsigned k, unsigned subclass_id, const std::vector<Node>& vars)
556 : d_n(n), d_k(k)
557 {
558 Assert(!vars.empty());
559 Assert(k <= n);
560 d_last_comb.resize(k);
561 std::iota(d_last_comb.begin(), d_last_comb.end(), 0);
562 d_vars = vars;
563 d_subclass_id = subclass_id;
564 }
565
566 const unsigned EnumStreamSubstitution::CombinationState::getSubclassId() const
567 {
568 return d_subclass_id;
569 }
570
571 void EnumStreamSubstitution::CombinationState::reset()
572 {
573 std::iota(d_last_comb.begin(), d_last_comb.end(), 0);
574 }
575
576 void EnumStreamSubstitution::CombinationState::getLastComb(
577 std::vector<Node>& vars)
578 {
579 for (unsigned i = 0, size = d_last_comb.size(); i < size; ++i)
580 {
581 if (Trace.isOn("synth-stream-concrete"))
582 {
583 std::stringstream ss;
584 TermDbSygus::toStreamSygus(ss, d_vars[d_last_comb[i]]);
585 Trace("synth-stream-concrete") << " " << ss.str();
586 }
587 vars.push_back(d_vars[d_last_comb[i]]);
588 }
589 }
590
591 bool EnumStreamSubstitution::CombinationState::getNextCombination()
592 {
593 // find what to increment
594 bool new_comb = false;
595 for (int i = d_k - 1; i >= 0; --i)
596 {
597 if (d_last_comb[i] < d_n - d_k + i)
598 {
599 unsigned j = d_last_comb[i] + 1;
600 while (static_cast<unsigned>(i) <= d_k - 1)
601 {
602 d_last_comb[i++] = j++;
603 }
604 new_comb = true;
605 break;
606 }
607 }
608 return new_comb;
609 }
610
611 void EnumStreamConcrete::initialize(Node e) { d_ess.initialize(e.getType()); }
612 void EnumStreamConcrete::addValue(Node v)
613 {
614 d_ess.resetValue(v);
615 d_currTerm = d_ess.getNext();
616 }
617 bool EnumStreamConcrete::increment()
618 {
619 d_currTerm = d_ess.getNext();
620 return !d_currTerm.isNull();
621 }
622 Node EnumStreamConcrete::getCurrent() { return d_currTerm; }
623 } // namespace quantifiers
624 } // namespace theory
625 } // namespace cvc5