Move specific attributes out of term util (#3279)
[cvc5.git] / src / theory / quantifiers / sygus / sygus_abduct.cpp
1 /********************* */
2 /*! \file sygus_abduct.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of sygus abduction utility, which
13 ** transforms an arbitrary input into an abduction problem.
14 **/
15
16 #include "theory/quantifiers/sygus/sygus_abduct.h"
17
18 #include "expr/datatype.h"
19 #include "expr/node_algorithm.h"
20 #include "printer/sygus_print_callback.h"
21 #include "theory/quantifiers/quantifiers_attributes.h"
22 #include "theory/quantifiers/quantifiers_rewriter.h"
23 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
24 #include "theory/quantifiers/term_util.h"
25 #include "theory/rewriter.h"
26
27 using namespace std;
28 using namespace CVC4::kind;
29
30 namespace CVC4 {
31 namespace theory {
32 namespace quantifiers {
33
34 SygusAbduct::SygusAbduct() {}
35
36 Node SygusAbduct::mkAbductionConjecture(const std::string& name,
37 const std::vector<Node>& asserts,
38 const std::vector<Node>& axioms,
39 TypeNode abdGType)
40 {
41 NodeManager* nm = NodeManager::currentNM();
42 std::unordered_set<Node, NodeHashFunction> symset;
43 for (size_t i = 0, size = asserts.size(); i < size; i++)
44 {
45 expr::getSymbols(asserts[i], symset);
46 }
47 Trace("sygus-abduct-debug")
48 << "...finish, got " << symset.size() << " symbols." << std::endl;
49
50 Trace("sygus-abduct-debug") << "Setup symbols..." << std::endl;
51 std::vector<Node> syms;
52 std::vector<Node> vars;
53 std::vector<Node> varlist;
54 std::vector<TypeNode> varlistTypes;
55 for (const Node& s : symset)
56 {
57 TypeNode tn = s.getType();
58 if (tn.isFirstClass())
59 {
60 std::stringstream ss;
61 ss << s;
62 Node var = nm->mkBoundVar(tn);
63 syms.push_back(s);
64 vars.push_back(var);
65 Node vlv = nm->mkBoundVar(ss.str(), tn);
66 varlist.push_back(vlv);
67 varlistTypes.push_back(tn);
68 // set that this variable encodes the term s
69 SygusVarToTermAttribute sta;
70 vlv.setAttribute(sta, s);
71 }
72 }
73 // make the sygus variable list
74 Node abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
75 Trace("sygus-abduct-debug") << "...finish" << std::endl;
76
77 Trace("sygus-abduct-debug") << "Make abduction predicate..." << std::endl;
78 // make the abduction predicate to synthesize
79 TypeNode abdType = varlistTypes.empty() ? nm->booleanType()
80 : nm->mkPredicateType(varlistTypes);
81 Node abd = nm->mkBoundVar(name.c_str(), abdType);
82 Trace("sygus-abduct-debug") << "...finish" << std::endl;
83
84 // if provided, we will associate it with the function-to-synthesize
85 if (!abdGType.isNull())
86 {
87 Assert(abdGType.isDatatype() && abdGType.getDatatype().isSygus());
88 // must convert all constructors to version with bound variables in "vars"
89 std::vector<Datatype> datatypes;
90 std::set<Type> unres;
91
92 Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl;
93 Trace("sygus-abduct-debug") << abdGType.getDatatype() << std::endl;
94
95 // datatype types we need to process
96 std::vector<TypeNode> dtToProcess;
97 // datatype types we have processed
98 std::map<TypeNode, TypeNode> dtProcessed;
99 dtToProcess.push_back(abdGType);
100 std::stringstream ssutn0;
101 ssutn0 << abdGType.getDatatype().getName() << "_s";
102 TypeNode abdTNew =
103 nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
104 unres.insert(abdTNew.toType());
105 dtProcessed[abdGType] = abdTNew;
106
107 // We must convert all symbols in the sygus datatype type abdGType to
108 // apply the substitution { syms -> varlist }, where syms is the free
109 // variables of the input problem, and varlist is the formal argument list
110 // of the abduct-to-synthesize. For example, given user-provided sygus
111 // grammar:
112 // G -> a | +( b, G )
113 // we synthesize a abduct A with two arguments x_a and x_b corresponding to
114 // a and b, and reconstruct the grammar:
115 // G' -> x_a | +( x_b, G' )
116 // In this way, x_a and x_b are treated as bound variables and handled as
117 // arguments of the abduct-to-synthesize instead of as free variables with
118 // no relation to A. We additionally require that x_a, when printed, prints
119 // "a", which we do with a custom sygus callback below.
120
121 // We are traversing over the subfield types of the datatype to convert
122 // them into the form described above.
123 while (!dtToProcess.empty())
124 {
125 std::vector<TypeNode> dtNextToProcess;
126 for (const TypeNode& curr : dtToProcess)
127 {
128 Assert(curr.isDatatype() && curr.getDatatype().isSygus());
129 const Datatype& dtc = curr.getDatatype();
130 std::stringstream ssdtn;
131 ssdtn << dtc.getName() << "_s";
132 datatypes.push_back(Datatype(ssdtn.str()));
133 Trace("sygus-abduct-debug")
134 << "Process datatype " << datatypes.back().getName() << "..."
135 << std::endl;
136 for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
137 {
138 Node op = Node::fromExpr(dtc[j].getSygusOp());
139 // apply the substitution to the argument
140 Node ops = op.substitute(
141 syms.begin(), syms.end(), varlist.begin(), varlist.end());
142 Trace("sygus-abduct-debug") << " Process constructor " << op << " / "
143 << ops << "..." << std::endl;
144 std::vector<Type> cargs;
145 for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
146 {
147 TypeNode argt = TypeNode::fromType(dtc[j].getArgType(k));
148 std::map<TypeNode, TypeNode>::iterator itdp =
149 dtProcessed.find(argt);
150 TypeNode argtNew;
151 if (itdp == dtProcessed.end())
152 {
153 std::stringstream ssutn;
154 ssutn << argt.getDatatype().getName() << "_s";
155 argtNew =
156 nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
157 Trace("sygus-abduct-debug")
158 << " ...unresolved type " << argtNew << " for " << argt
159 << std::endl;
160 unres.insert(argtNew.toType());
161 dtProcessed[argt] = argtNew;
162 dtNextToProcess.push_back(argt);
163 }
164 else
165 {
166 argtNew = itdp->second;
167 }
168 Trace("sygus-abduct-debug")
169 << " Arg #" << k << ": " << argtNew << std::endl;
170 cargs.push_back(argtNew.toType());
171 }
172 // callback prints as the expression
173 std::shared_ptr<SygusPrintCallback> spc;
174 std::vector<Expr> args;
175 if (op.getKind() == LAMBDA)
176 {
177 Node opBody = op[1];
178 for (const Node& v : op[0])
179 {
180 args.push_back(v.toExpr());
181 }
182 spc = std::make_shared<printer::SygusExprPrintCallback>(
183 opBody.toExpr(), args);
184 }
185 else if (cargs.empty())
186 {
187 spc = std::make_shared<printer::SygusExprPrintCallback>(op.toExpr(),
188 args);
189 }
190 std::stringstream ss;
191 ss << ops.getKind();
192 Trace("sygus-abduct-debug")
193 << "Add constructor : " << ops << std::endl;
194 datatypes.back().addSygusConstructor(
195 ops.toExpr(), ss.str(), cargs, spc);
196 }
197 Trace("sygus-abduct-debug")
198 << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
199 datatypes.back().setSygus(dtc.getSygusType(),
200 abvl.toExpr(),
201 dtc.getSygusAllowConst(),
202 dtc.getSygusAllowAll());
203 }
204 dtToProcess.clear();
205 dtToProcess.insert(
206 dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
207 }
208 Trace("sygus-abduct-debug")
209 << "Make " << datatypes.size() << " datatype types..." << std::endl;
210 // make the datatype types
211 std::vector<DatatypeType> datatypeTypes =
212 nm->toExprManager()->mkMutualDatatypeTypes(
213 datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
214 TypeNode abdGTypeS = TypeNode::fromType(datatypeTypes[0]);
215 if (Trace.isOn("sygus-abduct-debug"))
216 {
217 Trace("sygus-abduct-debug") << "Made datatype types:" << std::endl;
218 for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
219 {
220 const Datatype& dtj = datatypeTypes[j].getDatatype();
221 Trace("sygus-abduct-debug") << "#" << j << ": " << dtj << std::endl;
222 for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
223 {
224 for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
225 {
226 if (!dtj[k].getArgType(l).isDatatype())
227 {
228 Trace("sygus-abduct-debug")
229 << "Argument " << l << " of " << dtj[k]
230 << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
231 AlwaysAssert(false);
232 }
233 }
234 }
235 }
236 }
237
238 Trace("sygus-abduct-debug")
239 << "Make sygus grammar attribute..." << std::endl;
240 Node sym = nm->mkBoundVar("sfproxy_abduct", abdGTypeS);
241 // Set the sygus grammar attribute to indicate that abdGTypeS encodes the
242 // grammar for abd.
243 theory::SygusSynthGrammarAttribute ssg;
244 abd.setAttribute(ssg, sym);
245 Trace("sygus-abduct-debug") << "Finished setting up grammar." << std::endl;
246 }
247
248 Trace("sygus-abduct-debug") << "Make abduction predicate app..." << std::endl;
249 std::vector<Node> achildren;
250 achildren.push_back(abd);
251 achildren.insert(achildren.end(), vars.begin(), vars.end());
252 Node abdApp = vars.empty() ? abd : nm->mkNode(APPLY_UF, achildren);
253 Trace("sygus-abduct-debug") << "...finish" << std::endl;
254
255 Trace("sygus-abduct-debug") << "Set attributes..." << std::endl;
256 // set the sygus bound variable list
257 abd.setAttribute(theory::SygusSynthFunVarListAttribute(), abvl);
258 Trace("sygus-abduct-debug") << "...finish" << std::endl;
259
260 Trace("sygus-abduct-debug") << "Make conjecture body..." << std::endl;
261 Node input = asserts.size() == 1 ? asserts[0] : nm->mkNode(AND, asserts);
262 input = input.substitute(syms.begin(), syms.end(), vars.begin(), vars.end());
263 // A(x) => ~input( x )
264 input = nm->mkNode(OR, abdApp.negate(), input.negate());
265 Trace("sygus-abduct-debug") << "...finish" << std::endl;
266
267 Trace("sygus-abduct-debug") << "Make conjecture..." << std::endl;
268 Node res = input.negate();
269 if (!vars.empty())
270 {
271 Node bvl = nm->mkNode(BOUND_VAR_LIST, vars);
272 // exists x. ~( A( x ) => ~input( x ) )
273 res = nm->mkNode(EXISTS, bvl, res);
274 }
275 // sygus attribute
276 Node sygusVar = nm->mkSkolem("sygus", nm->booleanType());
277 theory::SygusAttribute ca;
278 sygusVar.setAttribute(ca, true);
279 Node instAttr = nm->mkNode(INST_ATTRIBUTE, sygusVar);
280 std::vector<Node> iplc;
281 iplc.push_back(instAttr);
282 Node aconj = axioms.size() == 0
283 ? nm->mkConst(true)
284 : (axioms.size() == 1 ? axioms[0] : nm->mkNode(AND, axioms));
285 aconj = aconj.substitute(syms.begin(), syms.end(), vars.begin(), vars.end());
286 Trace("sygus-abduct") << "---> Assumptions: " << aconj << std::endl;
287 Node sc = nm->mkNode(AND, aconj, abdApp);
288 Node vbvl = nm->mkNode(BOUND_VAR_LIST, vars);
289 sc = nm->mkNode(EXISTS, vbvl, sc);
290 Node sygusScVar = nm->mkSkolem("sygus_sc", nm->booleanType());
291 sygusScVar.setAttribute(theory::SygusSideConditionAttribute(), sc);
292 instAttr = nm->mkNode(INST_ATTRIBUTE, sygusScVar);
293 // build in the side condition
294 // exists x. A( x ) ^ input_axioms( x )
295 // as an additional annotation on the sygus conjecture. In other words,
296 // the abducts A we procedure must be consistent with our axioms.
297 iplc.push_back(instAttr);
298 Node instAttrList = nm->mkNode(INST_PATTERN_LIST, iplc);
299
300 Node fbvl = nm->mkNode(BOUND_VAR_LIST, abd);
301
302 // forall A. exists x. ~( A( x ) => ~input( x ) )
303 res = nm->mkNode(FORALL, fbvl, res, instAttrList);
304 Trace("sygus-abduct-debug") << "...finish" << std::endl;
305
306 res = theory::Rewriter::rewrite(res);
307
308 Trace("sygus-abduct") << "Generate: " << res << std::endl;
309
310 return res;
311 }
312
313 } // namespace quantifiers
314 } // namespace theory
315 } // namespace CVC4