Use standard sygus interface for abduction and rewrite rule synthesis (#3471)
[cvc5.git] / src / theory / quantifiers / sygus / sygus_abduct.cpp
1 /********************* */
2 /*! \file sygus_abduct.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of sygus abduction utility, which
13 ** transforms an arbitrary input into an abduction problem.
14 **/
15
16 #include "theory/quantifiers/sygus/sygus_abduct.h"
17
18 #include "expr/datatype.h"
19 #include "expr/node_algorithm.h"
20 #include "expr/sygus_datatype.h"
21 #include "printer/sygus_print_callback.h"
22 #include "theory/quantifiers/quantifiers_attributes.h"
23 #include "theory/quantifiers/quantifiers_rewriter.h"
24 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
25 #include "theory/quantifiers/term_util.h"
26 #include "theory/rewriter.h"
27
28 using namespace std;
29 using namespace CVC4::kind;
30
31 namespace CVC4 {
32 namespace theory {
33 namespace quantifiers {
34
35 SygusAbduct::SygusAbduct() {}
36
37 Node SygusAbduct::mkAbductionConjecture(const std::string& name,
38 const std::vector<Node>& asserts,
39 const std::vector<Node>& axioms,
40 TypeNode abdGType)
41 {
42 NodeManager* nm = NodeManager::currentNM();
43 std::unordered_set<Node, NodeHashFunction> symset;
44 for (size_t i = 0, size = asserts.size(); i < size; i++)
45 {
46 expr::getSymbols(asserts[i], symset);
47 }
48 Trace("sygus-abduct-debug")
49 << "...finish, got " << symset.size() << " symbols." << std::endl;
50
51 Trace("sygus-abduct-debug") << "Setup symbols..." << std::endl;
52 std::vector<Node> syms;
53 std::vector<Node> vars;
54 std::vector<Node> varlist;
55 std::vector<TypeNode> varlistTypes;
56 for (const Node& s : symset)
57 {
58 TypeNode tn = s.getType();
59 // Notice that we allow for non-first class (e.g. function) variables here.
60 // This is applicable to the case where we are doing get-abduct in a logic
61 // with UF.
62 std::stringstream ss;
63 ss << s;
64 Node var = nm->mkBoundVar(tn);
65 syms.push_back(s);
66 vars.push_back(var);
67 Node vlv = nm->mkBoundVar(ss.str(), tn);
68 varlist.push_back(vlv);
69 varlistTypes.push_back(tn);
70 // set that this variable encodes the term s
71 SygusVarToTermAttribute sta;
72 vlv.setAttribute(sta, s);
73 }
74 // make the sygus variable list
75 Node abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
76 Trace("sygus-abduct-debug") << "...finish" << std::endl;
77
78 Trace("sygus-abduct-debug") << "Make abduction predicate..." << std::endl;
79 // make the abduction predicate to synthesize
80 TypeNode abdType = varlistTypes.empty() ? nm->booleanType()
81 : nm->mkPredicateType(varlistTypes);
82 Node abd = nm->mkBoundVar(name.c_str(), abdType);
83 Trace("sygus-abduct-debug") << "...finish" << std::endl;
84
85 // if provided, we will associate it with the function-to-synthesize
86 if (!abdGType.isNull())
87 {
88 Assert(abdGType.isDatatype() && abdGType.getDatatype().isSygus());
89 // must convert all constructors to version with bound variables in "vars"
90 std::vector<SygusDatatype> sdts;
91 std::set<Type> unres;
92
93 Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl;
94 Trace("sygus-abduct-debug") << abdGType.getDatatype() << std::endl;
95
96 // datatype types we need to process
97 std::vector<TypeNode> dtToProcess;
98 // datatype types we have processed
99 std::map<TypeNode, TypeNode> dtProcessed;
100 dtToProcess.push_back(abdGType);
101 std::stringstream ssutn0;
102 ssutn0 << abdGType.getDatatype().getName() << "_s";
103 TypeNode abdTNew =
104 nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
105 unres.insert(abdTNew.toType());
106 dtProcessed[abdGType] = abdTNew;
107
108 // We must convert all symbols in the sygus datatype type abdGType to
109 // apply the substitution { syms -> varlist }, where syms is the free
110 // variables of the input problem, and varlist is the formal argument list
111 // of the abduct-to-synthesize. For example, given user-provided sygus
112 // grammar:
113 // G -> a | +( b, G )
114 // we synthesize a abduct A with two arguments x_a and x_b corresponding to
115 // a and b, and reconstruct the grammar:
116 // G' -> x_a | +( x_b, G' )
117 // In this way, x_a and x_b are treated as bound variables and handled as
118 // arguments of the abduct-to-synthesize instead of as free variables with
119 // no relation to A. We additionally require that x_a, when printed, prints
120 // "a", which we do with a custom sygus callback below.
121
122 // We are traversing over the subfield types of the datatype to convert
123 // them into the form described above.
124 while (!dtToProcess.empty())
125 {
126 std::vector<TypeNode> dtNextToProcess;
127 for (const TypeNode& curr : dtToProcess)
128 {
129 Assert(curr.isDatatype() && curr.getDatatype().isSygus());
130 const Datatype& dtc = curr.getDatatype();
131 std::stringstream ssdtn;
132 ssdtn << dtc.getName() << "_s";
133 sdts.push_back(SygusDatatype(ssdtn.str()));
134 Trace("sygus-abduct-debug")
135 << "Process datatype " << sdts.back().getName() << "..."
136 << std::endl;
137 for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
138 {
139 Node op = Node::fromExpr(dtc[j].getSygusOp());
140 // apply the substitution to the argument
141 Node ops = op.substitute(
142 syms.begin(), syms.end(), varlist.begin(), varlist.end());
143 Trace("sygus-abduct-debug") << " Process constructor " << op << " / "
144 << ops << "..." << std::endl;
145 std::vector<TypeNode> cargs;
146 for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
147 {
148 TypeNode argt = TypeNode::fromType(dtc[j].getArgType(k));
149 std::map<TypeNode, TypeNode>::iterator itdp =
150 dtProcessed.find(argt);
151 TypeNode argtNew;
152 if (itdp == dtProcessed.end())
153 {
154 std::stringstream ssutn;
155 ssutn << argt.getDatatype().getName() << "_s";
156 argtNew =
157 nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
158 Trace("sygus-abduct-debug")
159 << " ...unresolved type " << argtNew << " for " << argt
160 << std::endl;
161 unres.insert(argtNew.toType());
162 dtProcessed[argt] = argtNew;
163 dtNextToProcess.push_back(argt);
164 }
165 else
166 {
167 argtNew = itdp->second;
168 }
169 Trace("sygus-abduct-debug")
170 << " Arg #" << k << ": " << argtNew << std::endl;
171 cargs.push_back(argtNew);
172 }
173 // callback prints as the expression
174 std::shared_ptr<SygusPrintCallback> spc;
175 std::vector<Expr> args;
176 if (op.getKind() == LAMBDA)
177 {
178 Node opBody = op[1];
179 for (const Node& v : op[0])
180 {
181 args.push_back(v.toExpr());
182 }
183 spc = std::make_shared<printer::SygusExprPrintCallback>(
184 opBody.toExpr(), args);
185 }
186 else if (cargs.empty())
187 {
188 spc = std::make_shared<printer::SygusExprPrintCallback>(op.toExpr(),
189 args);
190 }
191 std::stringstream ss;
192 ss << ops.getKind();
193 Trace("sygus-abduct-debug")
194 << "Add constructor : " << ops << std::endl;
195 sdts.back().addConstructor(ops, ss.str(), cargs, spc);
196 }
197 Trace("sygus-abduct-debug")
198 << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
199 TypeNode stn = TypeNode::fromType(dtc.getSygusType());
200 sdts.back().initializeDatatype(
201 stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
202 }
203 dtToProcess.clear();
204 dtToProcess.insert(
205 dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
206 }
207 Trace("sygus-abduct-debug")
208 << "Make " << sdts.size() << " datatype types..." << std::endl;
209 // extract the datatypes
210 std::vector<Datatype> datatypes;
211 for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
212 {
213 datatypes.push_back(sdts[i].getDatatype());
214 }
215 // make the datatype types
216 std::vector<DatatypeType> datatypeTypes =
217 nm->toExprManager()->mkMutualDatatypeTypes(
218 datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
219 TypeNode abdGTypeS = TypeNode::fromType(datatypeTypes[0]);
220 if (Trace.isOn("sygus-abduct-debug"))
221 {
222 Trace("sygus-abduct-debug") << "Made datatype types:" << std::endl;
223 for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
224 {
225 const Datatype& dtj = datatypeTypes[j].getDatatype();
226 Trace("sygus-abduct-debug") << "#" << j << ": " << dtj << std::endl;
227 for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
228 {
229 for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
230 {
231 if (!dtj[k].getArgType(l).isDatatype())
232 {
233 Trace("sygus-abduct-debug")
234 << "Argument " << l << " of " << dtj[k]
235 << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
236 AlwaysAssert(false);
237 }
238 }
239 }
240 }
241 }
242
243 Trace("sygus-abduct-debug")
244 << "Make sygus grammar attribute..." << std::endl;
245 Node sym = nm->mkBoundVar("sfproxy_abduct", abdGTypeS);
246 // Set the sygus grammar attribute to indicate that abdGTypeS encodes the
247 // grammar for abd.
248 theory::SygusSynthGrammarAttribute ssg;
249 abd.setAttribute(ssg, sym);
250 Trace("sygus-abduct-debug") << "Finished setting up grammar." << std::endl;
251 }
252
253 Trace("sygus-abduct-debug") << "Make abduction predicate app..." << std::endl;
254 std::vector<Node> achildren;
255 achildren.push_back(abd);
256 achildren.insert(achildren.end(), vars.begin(), vars.end());
257 Node abdApp = vars.empty() ? abd : nm->mkNode(APPLY_UF, achildren);
258 Trace("sygus-abduct-debug") << "...finish" << std::endl;
259
260 Trace("sygus-abduct-debug") << "Set attributes..." << std::endl;
261 // set the sygus bound variable list
262 abd.setAttribute(theory::SygusSynthFunVarListAttribute(), abvl);
263 Trace("sygus-abduct-debug") << "...finish" << std::endl;
264
265 Trace("sygus-abduct-debug") << "Make conjecture body..." << std::endl;
266 Node input = asserts.size() == 1 ? asserts[0] : nm->mkNode(AND, asserts);
267 input = input.substitute(syms.begin(), syms.end(), vars.begin(), vars.end());
268 // A(x) => ~input( x )
269 input = nm->mkNode(OR, abdApp.negate(), input.negate());
270 Trace("sygus-abduct-debug") << "...finish" << std::endl;
271
272 Trace("sygus-abduct-debug") << "Make conjecture..." << std::endl;
273 Node res = input.negate();
274 if (!vars.empty())
275 {
276 Node bvl = nm->mkNode(BOUND_VAR_LIST, vars);
277 // exists x. ~( A( x ) => ~input( x ) )
278 res = nm->mkNode(EXISTS, bvl, res);
279 }
280 // sygus attribute
281 Node sygusVar = nm->mkSkolem("sygus", nm->booleanType());
282 theory::SygusAttribute ca;
283 sygusVar.setAttribute(ca, true);
284 Node instAttr = nm->mkNode(INST_ATTRIBUTE, sygusVar);
285 std::vector<Node> iplc;
286 iplc.push_back(instAttr);
287 Node aconj = axioms.size() == 0
288 ? nm->mkConst(true)
289 : (axioms.size() == 1 ? axioms[0] : nm->mkNode(AND, axioms));
290 aconj = aconj.substitute(syms.begin(), syms.end(), vars.begin(), vars.end());
291 Trace("sygus-abduct") << "---> Assumptions: " << aconj << std::endl;
292 Node sc = nm->mkNode(AND, aconj, abdApp);
293 Node vbvl = nm->mkNode(BOUND_VAR_LIST, vars);
294 sc = nm->mkNode(EXISTS, vbvl, sc);
295 Node sygusScVar = nm->mkSkolem("sygus_sc", nm->booleanType());
296 sygusScVar.setAttribute(theory::SygusSideConditionAttribute(), sc);
297 instAttr = nm->mkNode(INST_ATTRIBUTE, sygusScVar);
298 // build in the side condition
299 // exists x. A( x ) ^ input_axioms( x )
300 // as an additional annotation on the sygus conjecture. In other words,
301 // the abducts A we procedure must be consistent with our axioms.
302 iplc.push_back(instAttr);
303 Node instAttrList = nm->mkNode(INST_PATTERN_LIST, iplc);
304
305 Node fbvl = nm->mkNode(BOUND_VAR_LIST, abd);
306
307 // forall A. exists x. ~( A( x ) => ~input( x ) )
308 res = nm->mkNode(FORALL, fbvl, res, instAttrList);
309 Trace("sygus-abduct-debug") << "...finish" << std::endl;
310
311 res = theory::Rewriter::rewrite(res);
312
313 Trace("sygus-abduct") << "Generate: " << res << std::endl;
314
315 return res;
316 }
317
318 } // namespace quantifiers
319 } // namespace theory
320 } // namespace CVC4