Use the proper evaluator for optimized SyGuS datatype rewriting (#7266)
[cvc5.git] / src / theory / datatypes / sygus_datatype_utils.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds, Mathias Preiner, Aina Niemetz
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Implementation of rewriter for the theory of (co)inductive datatypes.
14 */
15
16 #include "theory/datatypes/sygus_datatype_utils.h"
17
18 #include <sstream>
19
20 #include "expr/dtype.h"
21 #include "expr/dtype_cons.h"
22 #include "expr/node_algorithm.h"
23 #include "expr/sygus_datatype.h"
24 #include "smt/env.h"
25 #include "theory/evaluator.h"
26 #include "theory/rewriter.h"
27
28 using namespace cvc5;
29 using namespace cvc5::kind;
30
31 namespace cvc5 {
32 namespace theory {
33 namespace datatypes {
34 namespace utils {
35
36 Node applySygusArgs(const DType& dt,
37 Node op,
38 Node n,
39 const std::vector<Node>& args)
40 {
41 // optimization: if n is just a sygus bound variable, return immediately
42 // by replacing with the proper argument, or returning unchanged if it is
43 // a bound variable not corresponding to a formal argument.
44 if (n.getKind() == BOUND_VARIABLE)
45 {
46 if (n.hasAttribute(SygusVarNumAttribute()))
47 {
48 int vn = n.getAttribute(SygusVarNumAttribute());
49 Assert(dt.getSygusVarList()[vn] == n);
50 return args[vn];
51 }
52 // it is a different bound variable, it is unchanged
53 return n;
54 }
55 // n is an application of operator op.
56 // We must compute the free variables in op to determine if there are
57 // any substitutions we need to make to n.
58 TNode val;
59 if (!op.hasAttribute(SygusVarFreeAttribute()))
60 {
61 std::unordered_set<Node> fvs;
62 if (expr::getFreeVariables(op, fvs))
63 {
64 if (fvs.size() == 1)
65 {
66 for (const Node& v : fvs)
67 {
68 val = v;
69 }
70 }
71 else
72 {
73 val = op;
74 }
75 }
76 Trace("dt-sygus-fv") << "Free var in " << op << " : " << val << std::endl;
77 op.setAttribute(SygusVarFreeAttribute(), val);
78 }
79 else
80 {
81 val = op.getAttribute(SygusVarFreeAttribute());
82 }
83 if (val.isNull())
84 {
85 return n;
86 }
87 if (val.getKind() == BOUND_VARIABLE)
88 {
89 // single substitution case
90 int vn = val.getAttribute(SygusVarNumAttribute());
91 TNode sub = args[vn];
92 return n.substitute(val, sub);
93 }
94 // do the full substitution
95 std::vector<Node> vars;
96 Node bvl = dt.getSygusVarList();
97 for (unsigned i = 0, nvars = bvl.getNumChildren(); i < nvars; i++)
98 {
99 vars.push_back(bvl[i]);
100 }
101 return n.substitute(vars.begin(), vars.end(), args.begin(), args.end());
102 }
103
104 Kind getOperatorKindForSygusBuiltin(Node op)
105 {
106 Assert(op.getKind() != BUILTIN);
107 if (op.getKind() == LAMBDA)
108 {
109 return APPLY_UF;
110 }
111 return NodeManager::getKindForFunction(op);
112 }
113
114 struct SygusOpRewrittenAttributeId
115 {
116 };
117 typedef expr::Attribute<SygusOpRewrittenAttributeId, Node>
118 SygusOpRewrittenAttribute;
119
120 Kind getEliminateKind(Kind ok)
121 {
122 Kind nk = ok;
123 // We also must ensure that builtin operators which are eliminated
124 // during expand definitions are replaced by the proper operator.
125 if (ok == DIVISION)
126 {
127 nk = DIVISION_TOTAL;
128 }
129 else if (ok == INTS_DIVISION)
130 {
131 nk = INTS_DIVISION_TOTAL;
132 }
133 else if (ok == INTS_MODULUS)
134 {
135 nk = INTS_MODULUS_TOTAL;
136 }
137 return nk;
138 }
139
140 Node mkSygusTerm(const DType& dt,
141 unsigned i,
142 const std::vector<Node>& children,
143 bool doBetaReduction,
144 bool isExternal)
145 {
146 Trace("dt-sygus-util") << "Make sygus term " << dt.getName() << "[" << i
147 << "] with children: " << children << std::endl;
148 Assert(i < dt.getNumConstructors());
149 Assert(dt.isSygus());
150 Assert(!dt[i].getSygusOp().isNull());
151 Node op = dt[i].getSygusOp();
152 Node opn = op;
153 if (!isExternal)
154 {
155 // Get the normalized version of the sygus operator. We do this by
156 // expanding definitions, rewriting it, and eliminating partial operators.
157 if (!op.hasAttribute(SygusOpRewrittenAttribute()))
158 {
159 if (op.isConst())
160 {
161 // If it is a builtin operator, convert to total version if necessary.
162 // First, get the kind for the operator.
163 Kind ok = NodeManager::operatorToKind(op);
164 Trace("sygus-grammar-normalize-debug")
165 << "...builtin kind is " << ok << std::endl;
166 Kind nk = getEliminateKind(ok);
167 if (nk != ok)
168 {
169 Trace("sygus-grammar-normalize-debug")
170 << "...replace by builtin operator " << nk << std::endl;
171 opn = NodeManager::currentNM()->operatorOf(nk);
172 }
173 }
174 else
175 {
176 // Get the expanded definition form, if it has been marked. This ensures
177 // that user-defined functions have been eliminated from op.
178 opn = getExpandedDefinitionForm(op);
179 opn = Rewriter::rewrite(opn);
180 SygusOpRewrittenAttribute sora;
181 op.setAttribute(sora, opn);
182 }
183 }
184 else
185 {
186 opn = op.getAttribute(SygusOpRewrittenAttribute());
187 }
188 }
189 return mkSygusTerm(opn, children, doBetaReduction);
190 }
191
192 Node mkSygusTerm(Node op,
193 const std::vector<Node>& children,
194 bool doBetaReduction)
195 {
196 Trace("dt-sygus-util") << "Operator is " << op << std::endl;
197 if (children.empty())
198 {
199 // no children, return immediately
200 Trace("dt-sygus-util") << "...return direct op" << std::endl;
201 return op;
202 }
203 // if it is the any constant, we simply return the child
204 if (op.getAttribute(SygusAnyConstAttribute()))
205 {
206 Assert(children.size() == 1);
207 return children[0];
208 }
209 std::vector<Node> schildren;
210 // get the kind of the operator
211 Kind ok = op.getKind();
212 if (ok != BUILTIN)
213 {
214 if (ok == LAMBDA && doBetaReduction)
215 {
216 // Do immediate beta reduction. It suffices to use a normal substitution
217 // since neither op nor children have quantifiers, since they are
218 // generated by sygus grammars.
219 std::vector<Node> vars{op[0].begin(), op[0].end()};
220 Assert(vars.size() == children.size());
221 Node ret = op[1].substitute(
222 vars.begin(), vars.end(), children.begin(), children.end());
223 Trace("dt-sygus-util") << "...return (beta-reduce) " << ret << std::endl;
224 return ret;
225 }
226 else
227 {
228 schildren.push_back(op);
229 }
230 }
231 schildren.insert(schildren.end(), children.begin(), children.end());
232 Node ret;
233 if (ok == BUILTIN)
234 {
235 ret = NodeManager::currentNM()->mkNode(op, schildren);
236 Trace("dt-sygus-util") << "...return (builtin) " << ret << std::endl;
237 return ret;
238 }
239 // get the kind used for applying op
240 Kind otk = NodeManager::operatorToKind(op);
241 Trace("dt-sygus-util") << "operator kind is " << otk << std::endl;
242 if (otk != UNDEFINED_KIND)
243 {
244 // If it is an APPLY_UF operator, we should have at least an operator and
245 // a child.
246 Assert(otk != APPLY_UF || schildren.size() != 1);
247 ret = NodeManager::currentNM()->mkNode(otk, schildren);
248 Trace("dt-sygus-util") << "...return (op) " << ret << std::endl;
249 return ret;
250 }
251 Kind tok = getOperatorKindForSygusBuiltin(op);
252 if (schildren.size() == 1 && tok == UNDEFINED_KIND)
253 {
254 ret = schildren[0];
255 }
256 else
257 {
258 ret = NodeManager::currentNM()->mkNode(tok, schildren);
259 }
260 Trace("dt-sygus-util") << "...return " << ret << std::endl;
261 return ret;
262 }
263
264 struct SygusToBuiltinTermAttributeId
265 {
266 };
267 typedef expr::Attribute<SygusToBuiltinTermAttributeId, Node>
268 SygusToBuiltinTermAttribute;
269
270 // A variant of the above attribute for cases where we introduce a fresh
271 // variable. This is to support sygusToBuiltin on non-constant sygus terms,
272 // where sygus variables should be mapped to canonical builtin variables.
273 // It is important to cache this so that sygusToBuiltin is deterministic.
274 struct SygusToBuiltinVarAttributeId
275 {
276 };
277 typedef expr::Attribute<SygusToBuiltinVarAttributeId, Node>
278 SygusToBuiltinVarAttribute;
279
280 // A variant of the above attribute for cases where we introduce a fresh
281 // variable. This is to support sygusToBuiltin on non-constant sygus terms,
282 // where sygus variables should be mapped to canonical builtin variables.
283 // It is important to cache this so that sygusToBuiltin is deterministic.
284 struct BuiltinVarToSygusAttributeId
285 {
286 };
287 typedef expr::Attribute<BuiltinVarToSygusAttributeId, Node>
288 BuiltinVarToSygusAttribute;
289
290 Node sygusToBuiltin(Node n, bool isExternal)
291 {
292 std::unordered_map<TNode, Node> visited;
293 std::unordered_map<TNode, Node>::iterator it;
294 std::vector<TNode> visit;
295 TNode cur;
296 unsigned index;
297 visit.push_back(n);
298 do
299 {
300 cur = visit.back();
301 visit.pop_back();
302 it = visited.find(cur);
303 if (it == visited.end())
304 {
305 // Notice this condition succeeds in roughly 99% of the executions of this
306 // method (based on our coverage tests), hence the else if / else cases
307 // below do not significantly impact performance.
308 if (cur.getKind() == APPLY_CONSTRUCTOR)
309 {
310 if (!isExternal && cur.hasAttribute(SygusToBuiltinTermAttribute()))
311 {
312 visited[cur] = cur.getAttribute(SygusToBuiltinTermAttribute());
313 }
314 else
315 {
316 visited[cur] = Node::null();
317 visit.push_back(cur);
318 for (const Node& cn : cur)
319 {
320 visit.push_back(cn);
321 }
322 }
323 }
324 else if (cur.getType().isSygusDatatype())
325 {
326 Assert (cur.isVar());
327 if (cur.hasAttribute(SygusToBuiltinVarAttribute()))
328 {
329 // use the previously constructed variable for it
330 visited[cur] = cur.getAttribute(SygusToBuiltinVarAttribute());
331 }
332 else
333 {
334 std::stringstream ss;
335 ss << cur;
336 const DType& dt = cur.getType().getDType();
337 // make a fresh variable
338 NodeManager * nm = NodeManager::currentNM();
339 Node var = nm->mkBoundVar(ss.str(), dt.getSygusType());
340 SygusToBuiltinVarAttribute stbv;
341 cur.setAttribute(stbv, var);
342 visited[cur] = var;
343 // create backwards mapping
344 BuiltinVarToSygusAttribute bvtsa;
345 var.setAttribute(bvtsa, cur);
346 }
347 }
348 else
349 {
350 // non-datatypes are themselves
351 visited[cur] = cur;
352 }
353 }
354 else if (it->second.isNull())
355 {
356 Node ret = cur;
357 Assert(cur.getKind() == APPLY_CONSTRUCTOR);
358 const DType& dt = cur.getType().getDType();
359 // Non sygus-datatype terms are also themselves. Notice we treat the
360 // case of non-sygus datatypes this way since it avoids computing
361 // the type / datatype of the node in the pre-traversal above. The
362 // case of non-sygus datatypes is very rare, so the extra addition to
363 // visited is justified performance-wise.
364 if (dt.isSygus())
365 {
366 std::vector<Node> children;
367 for (const Node& cn : cur)
368 {
369 it = visited.find(cn);
370 Assert(it != visited.end());
371 Assert(!it->second.isNull());
372 children.push_back(it->second);
373 }
374 index = indexOf(cur.getOperator());
375 ret = mkSygusTerm(dt, index, children, true, isExternal);
376 }
377 visited[cur] = ret;
378 // cache
379 if (!isExternal)
380 {
381 SygusToBuiltinTermAttribute stbt;
382 cur.setAttribute(stbt, ret);
383 }
384 }
385 } while (!visit.empty());
386 Assert(visited.find(n) != visited.end());
387 Assert(!visited.find(n)->second.isNull());
388 return visited[n];
389 }
390
391 Node builtinVarToSygus(Node v)
392 {
393 BuiltinVarToSygusAttribute bvtsa;
394 if (v.hasAttribute(bvtsa))
395 {
396 return v.getAttribute(bvtsa);
397 }
398 return Node::null();
399 }
400
401 void getFreeSymbolsSygusType(TypeNode sdt, std::unordered_set<Node>& syms)
402 {
403 // datatype types we need to process
404 std::vector<TypeNode> typeToProcess;
405 // datatype types we have processed
406 std::map<TypeNode, TypeNode> typesProcessed;
407 typeToProcess.push_back(sdt);
408 while (!typeToProcess.empty())
409 {
410 std::vector<TypeNode> typeNextToProcess;
411 for (const TypeNode& curr : typeToProcess)
412 {
413 Assert(curr.isDatatype() && curr.getDType().isSygus());
414 const DType& dtc = curr.getDType();
415 for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
416 {
417 // collect the symbols from the operator
418 Node op = dtc[j].getSygusOp();
419 expr::getSymbols(op, syms);
420 // traverse the argument types
421 for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
422 {
423 TypeNode argt = dtc[j].getArgType(k);
424 if (!argt.isDatatype() || !argt.getDType().isSygus())
425 {
426 // not a sygus datatype
427 continue;
428 }
429 if (typesProcessed.find(argt) == typesProcessed.end())
430 {
431 typeNextToProcess.push_back(argt);
432 }
433 }
434 }
435 }
436 typeToProcess.clear();
437 typeToProcess.insert(typeToProcess.end(),
438 typeNextToProcess.begin(),
439 typeNextToProcess.end());
440 }
441 }
442
443 TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
444 const std::vector<Node>& syms,
445 const std::vector<Node>& vars)
446 {
447 NodeManager* nm = NodeManager::currentNM();
448 const DType& sdtd = sdt.getDType();
449 // compute the new formal argument list
450 std::vector<Node> formalVars;
451 Node prevVarList = sdtd.getSygusVarList();
452 if (!prevVarList.isNull())
453 {
454 for (const Node& v : prevVarList)
455 {
456 // if it is not being replaced
457 if (std::find(syms.begin(), syms.end(), v) != syms.end())
458 {
459 formalVars.push_back(v);
460 }
461 }
462 }
463 for (const Node& v : vars)
464 {
465 if (v.getKind() == BOUND_VARIABLE)
466 {
467 formalVars.push_back(v);
468 }
469 }
470 // make the sygus variable list for the formal argument list
471 Node abvl = nm->mkNode(BOUND_VAR_LIST, formalVars);
472 Trace("sygus-abduct-debug") << "...finish" << std::endl;
473
474 // must convert all constructors to version with variables in "vars"
475 std::vector<SygusDatatype> sdts;
476 std::set<TypeNode> unres;
477
478 Trace("dtsygus-gen-debug") << "Process sygus type:" << std::endl;
479 Trace("dtsygus-gen-debug") << sdtd.getName() << std::endl;
480
481 // datatype types we need to process
482 std::vector<TypeNode> dtToProcess;
483 // datatype types we have processed
484 std::map<TypeNode, TypeNode> dtProcessed;
485 dtToProcess.push_back(sdt);
486 std::stringstream ssutn0;
487 ssutn0 << sdtd.getName() << "_s";
488 TypeNode abdTNew =
489 nm->mkSort(ssutn0.str(), NodeManager::SORT_FLAG_PLACEHOLDER);
490 unres.insert(abdTNew);
491 dtProcessed[sdt] = abdTNew;
492
493 // We must convert all symbols in the sygus datatype type sdt to
494 // apply the substitution { syms -> vars }, where syms is the free
495 // variables of the input problem, and vars is the formal argument list
496 // of the function-to-synthesize.
497
498 // We are traversing over the subfield types of the datatype to convert
499 // them into the form described above.
500 while (!dtToProcess.empty())
501 {
502 std::vector<TypeNode> dtNextToProcess;
503 for (const TypeNode& curr : dtToProcess)
504 {
505 Assert(curr.isDatatype() && curr.getDType().isSygus());
506 const DType& dtc = curr.getDType();
507 std::stringstream ssdtn;
508 ssdtn << dtc.getName() << "_s";
509 sdts.push_back(SygusDatatype(ssdtn.str()));
510 Trace("dtsygus-gen-debug")
511 << "Process datatype " << sdts.back().getName() << "..." << std::endl;
512 for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
513 {
514 Node op = dtc[j].getSygusOp();
515 // apply the substitution to the argument
516 Node ops =
517 op.substitute(syms.begin(), syms.end(), vars.begin(), vars.end());
518 Trace("dtsygus-gen-debug") << " Process constructor " << op << " / "
519 << ops << "..." << std::endl;
520 std::vector<TypeNode> cargs;
521 for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
522 {
523 TypeNode argt = dtc[j].getArgType(k);
524 std::map<TypeNode, TypeNode>::iterator itdp = dtProcessed.find(argt);
525 TypeNode argtNew;
526 if (itdp == dtProcessed.end())
527 {
528 std::stringstream ssutn;
529 ssutn << argt.getDType().getName() << "_s";
530 argtNew =
531 nm->mkSort(ssutn.str(), NodeManager::SORT_FLAG_PLACEHOLDER);
532 Trace("dtsygus-gen-debug") << " ...unresolved type " << argtNew
533 << " for " << argt << std::endl;
534 unres.insert(argtNew);
535 dtProcessed[argt] = argtNew;
536 dtNextToProcess.push_back(argt);
537 }
538 else
539 {
540 argtNew = itdp->second;
541 }
542 Trace("dtsygus-gen-debug")
543 << " Arg #" << k << ": " << argtNew << std::endl;
544 cargs.push_back(argtNew);
545 }
546 std::stringstream ss;
547 ss << ops.getKind();
548 Trace("dtsygus-gen-debug") << "Add constructor : " << ops << std::endl;
549 sdts.back().addConstructor(ops, ss.str(), cargs);
550 }
551 Trace("dtsygus-gen-debug")
552 << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
553 TypeNode stn = dtc.getSygusType();
554 sdts.back().initializeDatatype(
555 stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
556 }
557 dtToProcess.clear();
558 dtToProcess.insert(
559 dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
560 }
561 Trace("dtsygus-gen-debug")
562 << "Make " << sdts.size() << " datatype types..." << std::endl;
563 // extract the datatypes
564 std::vector<DType> datatypes;
565 for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
566 {
567 datatypes.push_back(sdts[i].getDatatype());
568 }
569 // make the datatype types
570 std::vector<TypeNode> datatypeTypes = nm->mkMutualDatatypeTypes(
571 datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER);
572 TypeNode sdtS = datatypeTypes[0];
573 if (Trace.isOn("dtsygus-gen-debug"))
574 {
575 Trace("dtsygus-gen-debug") << "Made datatype types:" << std::endl;
576 for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
577 {
578 const DType& dtj = datatypeTypes[j].getDType();
579 Trace("dtsygus-gen-debug") << "#" << j << ": " << dtj << std::endl;
580 for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
581 {
582 for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
583 {
584 if (!dtj[k].getArgType(l).isDatatype())
585 {
586 Trace("dtsygus-gen-debug")
587 << "Argument " << l << " of " << dtj[k]
588 << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
589 AlwaysAssert(false);
590 }
591 }
592 }
593 }
594 }
595 return sdtS;
596 }
597
598 unsigned getSygusTermSize(Node n)
599 {
600 if (n.getKind() != APPLY_CONSTRUCTOR)
601 {
602 return 0;
603 }
604 unsigned sum = 0;
605 for (const Node& nc : n)
606 {
607 sum += getSygusTermSize(nc);
608 }
609 const DType& dt = datatypeOf(n.getOperator());
610 int cindex = indexOf(n.getOperator());
611 Assert(cindex >= 0 && static_cast<size_t>(cindex) < dt.getNumConstructors());
612 unsigned weight = dt[cindex].getWeight();
613 return weight + sum;
614 }
615
616 /**
617 * Map terms to the result of expand definitions calling smt::expandDefinitions
618 * on it.
619 */
620 struct SygusExpDefFormAttributeId
621 {
622 };
623 typedef expr::Attribute<SygusExpDefFormAttributeId, Node>
624 SygusExpDefFormAttribute;
625
626 void setExpandedDefinitionForm(Node op, Node eop)
627 {
628 op.setAttribute(SygusExpDefFormAttribute(), eop);
629 }
630
631 Node getExpandedDefinitionForm(Node op)
632 {
633 Node eop = op.getAttribute(SygusExpDefFormAttribute());
634 // if not set, assume original
635 return eop.isNull() ? op : eop;
636 }
637
638 } // namespace utils
639 } // namespace datatypes
640 } // namespace theory
641 } // namespace cvc5