Updates not related to creation for eliminating Expr-level datatype (#4838)
[cvc5.git] / src / preprocessing / passes / synth_rew_rules.cpp
1 /********************* */
2 /*! \file synth_rew_rules.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Mathias Preiner
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief A technique for synthesizing candidate rewrites of the form t1 = t2,
13 ** where t1 and t2 are subterms of the input.
14 **/
15
16 #include "preprocessing/passes/synth_rew_rules.h"
17
18 #include "expr/sygus_datatype.h"
19 #include "expr/term_canonize.h"
20 #include "options/base_options.h"
21 #include "options/quantifiers_options.h"
22 #include "printer/printer.h"
23 #include "theory/quantifiers/candidate_rewrite_database.h"
24 #include "theory/quantifiers/quantifiers_attributes.h"
25 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
26 #include "theory/quantifiers/term_util.h"
27
28 using namespace std;
29 using namespace CVC4::kind;
30
31 namespace CVC4 {
32 namespace preprocessing {
33 namespace passes {
34
35 SynthRewRulesPass::SynthRewRulesPass(PreprocessingPassContext* preprocContext)
36 : PreprocessingPass(preprocContext, "synth-rr"){};
37
38 PreprocessingPassResult SynthRewRulesPass::applyInternal(
39 AssertionPipeline* assertionsToPreprocess)
40 {
41 Trace("srs-input") << "Synthesize rewrite rules from assertions..."
42 << std::endl;
43 std::vector<Node>& assertions = assertionsToPreprocess->ref();
44 if (assertions.empty())
45 {
46 return PreprocessingPassResult::NO_CONFLICT;
47 }
48
49 NodeManager* nm = NodeManager::currentNM();
50
51 // initialize the candidate rewrite
52 std::unordered_map<TNode, bool, TNodeHashFunction> visited;
53 std::unordered_map<TNode, bool, TNodeHashFunction>::iterator it;
54 std::vector<TNode> visit;
55 // Get all usable terms from the input. A term is usable if it does not
56 // contain a quantified subterm
57 std::vector<Node> terms;
58 // all variables (free constants) appearing in the input
59 std::vector<Node> vars;
60 // does the input contain a Boolean variable?
61 bool hasBoolVar = false;
62 // the types of subterms of our input
63 std::map<TypeNode, bool> typesFound;
64 // standard constants for each type (e.g. true, false for Bool)
65 std::map<TypeNode, std::vector<Node> > consts;
66
67 TNode cur;
68 Trace("srs-input") << "Collect terms in assertions..." << std::endl;
69 for (const Node& a : assertions)
70 {
71 Trace("srs-input-debug") << "Assertion : " << a << std::endl;
72 visit.push_back(a);
73 do
74 {
75 cur = visit.back();
76 visit.pop_back();
77 it = visited.find(cur);
78 if (it == visited.end())
79 {
80 Trace("srs-input-debug") << "...preprocess " << cur << std::endl;
81 visited[cur] = false;
82 bool isQuant = cur.isClosure();
83 // we recurse on this node if it is not a quantified formula
84 if (!isQuant)
85 {
86 visit.push_back(cur);
87 for (const Node& cc : cur)
88 {
89 visit.push_back(cc);
90 }
91 }
92 }
93 else if (!it->second)
94 {
95 Trace("srs-input-debug") << "...postprocess " << cur << std::endl;
96 // check if all of the children are valid
97 // this ensures we do not register terms that have e.g. quantified
98 // formulas as subterms
99 bool childrenValid = true;
100 for (const Node& cc : cur)
101 {
102 Assert(visited.find(cc) != visited.end());
103 if (!visited[cc])
104 {
105 childrenValid = false;
106 }
107 }
108 if (childrenValid)
109 {
110 Trace("srs-input-debug") << "...children are valid" << std::endl;
111 Trace("srs-input-debug") << "Add term " << cur << std::endl;
112 TypeNode tn = cur.getType();
113 if (cur.isVar())
114 {
115 vars.push_back(cur);
116 if (tn.isBoolean())
117 {
118 hasBoolVar = true;
119 }
120 }
121 // register type information
122 if (typesFound.find(tn) == typesFound.end())
123 {
124 typesFound[tn] = true;
125 // add the standard constants for this type
126 theory::quantifiers::CegGrammarConstructor::mkSygusConstantsForType(
127 tn, consts[tn]);
128 // We prepend them so that they come first in the grammar
129 // construction. The motivation is we'd prefer seeing e.g. "true"
130 // instead of (= x x) as a canonical term.
131 terms.insert(terms.begin(), consts[tn].begin(), consts[tn].end());
132 }
133 terms.push_back(cur);
134 }
135 visited[cur] = childrenValid;
136 }
137 } while (!visit.empty());
138 }
139 Trace("srs-input") << "...finished." << std::endl;
140
141 Trace("srs-input") << "Make synth variables for types..." << std::endl;
142 // We will generate a fixed number of variables per type. These are the
143 // variables that appear as free variables in the rewrites we generate.
144 unsigned nvars = options::sygusRewSynthInputNVars();
145 // must have at least one variable per type
146 nvars = nvars < 1 ? 1 : nvars;
147 std::map<TypeNode, std::vector<Node> > tvars;
148 std::vector<TypeNode> allVarTypes;
149 std::vector<Node> allVars;
150 unsigned varCounter = 0;
151 for (std::pair<const TypeNode, bool> tfp : typesFound)
152 {
153 TypeNode tn = tfp.first;
154 // If we are not interested in purely propositional rewrites, we only
155 // need to make one Boolean variable if the input has a Boolean variable.
156 // This ensures that no type in our grammar has zero constructors. If
157 // our input does not contain a Boolean variable, we need not allocate any
158 // Boolean variables here.
159 unsigned useNVars =
160 (options::sygusRewSynthInputUseBool() || !tn.isBoolean())
161 ? nvars
162 : (hasBoolVar ? 1 : 0);
163 for (unsigned i = 0; i < useNVars; i++)
164 {
165 // We must have a good name for these variables, these are
166 // the ones output in rewrite rules. We choose
167 // a,b,c,...,y,z,x1,x2,...
168 std::stringstream ssv;
169 if (varCounter < 26)
170 {
171 ssv << static_cast<char>(varCounter + 61);
172 }
173 else
174 {
175 ssv << "x" << (varCounter - 26);
176 }
177 varCounter++;
178 Node v = nm->mkBoundVar(ssv.str(), tn);
179 tvars[tn].push_back(v);
180 allVars.push_back(v);
181 allVarTypes.push_back(tn);
182 }
183 }
184 Trace("srs-input") << "...finished." << std::endl;
185
186 Trace("srs-input") << "Convert subterms to free variable form..."
187 << std::endl;
188 // Replace all free variables with bound variables. This ensures that
189 // we can perform term canonization on subterms.
190 std::vector<Node> vsubs;
191 for (const Node& v : vars)
192 {
193 TypeNode tnv = v.getType();
194 Node vs = nm->mkBoundVar(tnv);
195 vsubs.push_back(vs);
196 }
197 if (!vars.empty())
198 {
199 for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
200 {
201 terms[i] = terms[i].substitute(
202 vars.begin(), vars.end(), vsubs.begin(), vsubs.end());
203 }
204 }
205 Trace("srs-input") << "...finished." << std::endl;
206
207 Trace("srs-input") << "Process " << terms.size() << " subterms..."
208 << std::endl;
209 // We've collected all terms in the input. We construct a sygus grammar in
210 // following which generates terms that correspond to abstractions of the
211 // terms in the input.
212
213 // We map terms to a canonical (ordered variable) form. This ensures that
214 // we don't generate distinct grammar types for distinct alpha-equivalent
215 // terms, which would produce grammars of identical shape.
216 std::map<Node, Node> term_to_cterm;
217 std::map<Node, Node> cterm_to_term;
218 std::vector<Node> cterms;
219 // canonical terms for each type
220 std::map<TypeNode, std::vector<Node> > t_cterms;
221 expr::TermCanonize tcanon;
222 for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
223 {
224 Node n = terms[i];
225 Node cn = tcanon.getCanonicalTerm(n);
226 term_to_cterm[n] = cn;
227 Trace("srs-input-debug") << "Canon : " << n << " -> " << cn << std::endl;
228 std::map<Node, Node>::iterator itc = cterm_to_term.find(cn);
229 if (itc == cterm_to_term.end())
230 {
231 cterm_to_term[cn] = n;
232 cterms.push_back(cn);
233 t_cterms[cn.getType()].push_back(cn);
234 }
235 }
236 Trace("srs-input") << "...finished." << std::endl;
237 // the sygus variable list
238 Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars);
239 Trace("srs-input") << "Have " << cterms.size() << " canonical subterms."
240 << std::endl;
241
242 Trace("srs-input") << "Construct unresolved types..." << std::endl;
243 // each canonical subterm corresponds to a grammar type
244 std::set<Type> unres;
245 std::vector<SygusDatatype> sdts;
246 // make unresolved types for each canonical term
247 std::map<Node, TypeNode> cterm_to_utype;
248 for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
249 {
250 Node ct = cterms[i];
251 std::stringstream ss;
252 ss << "T" << i;
253 std::string tname = ss.str();
254 TypeNode tnu = nm->mkSort(tname, ExprManager::SORT_FLAG_PLACEHOLDER);
255 cterm_to_utype[ct] = tnu;
256 unres.insert(tnu.toType());
257 sdts.push_back(SygusDatatype(tname));
258 }
259 Trace("srs-input") << "...finished." << std::endl;
260
261 Trace("srs-input") << "Construct sygus datatypes..." << std::endl;
262 for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
263 {
264 Node ct = cterms[i];
265 Node t = cterm_to_term[ct];
266
267 // add the variables for the type
268 TypeNode ctt = ct.getType();
269 Assert(tvars.find(ctt) != tvars.end());
270 std::vector<TypeNode> argList;
271 // we add variable constructors if we are not Boolean, we are interested
272 // in purely propositional rewrites (via the option), or this term is
273 // a Boolean variable.
274 if (!ctt.isBoolean() || options::sygusRewSynthInputUseBool()
275 || ct.getKind() == BOUND_VARIABLE)
276 {
277 for (const Node& v : tvars[ctt])
278 {
279 std::stringstream ssc;
280 ssc << "C_" << i << "_" << v;
281 sdts[i].addConstructor(v, ssc.str(), argList);
282 }
283 }
284 // add the constructor for the operator if it is not a variable
285 if (ct.getKind() != BOUND_VARIABLE)
286 {
287 Assert(!ct.isVar());
288 Node op = ct.hasOperator() ? ct.getOperator() : ct;
289 // iterate over the original term
290 for (const Node& tc : t)
291 {
292 // map its arguments back to canonical
293 Assert(term_to_cterm.find(tc) != term_to_cterm.end());
294 Node ctc = term_to_cterm[tc];
295 Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end());
296 // get the type
297 argList.push_back(cterm_to_utype[ctc]);
298 }
299 // check if we should chain
300 bool do_chain = false;
301 if (argList.size() > 2)
302 {
303 Kind k = NodeManager::operatorToKind(op);
304 do_chain = theory::quantifiers::TermUtil::isAssoc(k)
305 && theory::quantifiers::TermUtil::isComm(k);
306 // eliminate duplicate child types
307 std::vector<TypeNode> argListTmp = argList;
308 argList.clear();
309 std::map<TypeNode, bool> hasArgType;
310 for (unsigned j = 0, size = argListTmp.size(); j < size; j++)
311 {
312 TypeNode tn = argListTmp[j];
313 if (hasArgType.find(tn) == hasArgType.end())
314 {
315 hasArgType[tn] = true;
316 argList.push_back(tn);
317 }
318 }
319 }
320 if (do_chain)
321 {
322 // we make one type per child
323 // the operator of each constructor is a no-op
324 Node tbv = nm->mkBoundVar(ctt);
325 Node lambdaOp =
326 nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
327 std::vector<TypeNode> argListc;
328 // the following construction admits any number of repeated factors,
329 // so for instance, t1+t2+t3, we generate the grammar:
330 // T_{t1+t2+t3} ->
331 // +( T_{t1+t2+t3}, T_{t1+t2+t3} ) | T_{t1} | T_{t2} | T_{t3}
332 // where we write T_t to denote "the type that abstracts term t".
333 // Notice this construction allows to abstract subsets of the factors
334 // of t1+t2+t3. This is particularly helpful for terms t1+...+tn for
335 // large n, where we would like to consider binary applications of +.
336 for (unsigned j = 0, size = argList.size(); j < size; j++)
337 {
338 argListc.clear();
339 argListc.push_back(argList[j]);
340 std::stringstream sscs;
341 sscs << "C_factor_" << i << "_" << j;
342 // ID function is not printed and does not count towards weight
343 sdts[i].addConstructor(lambdaOp,
344 sscs.str(),
345 argListc,
346 0);
347 }
348 // recursive apply
349 TypeNode recType = cterm_to_utype[ct];
350 argListc.clear();
351 argListc.push_back(recType);
352 argListc.push_back(recType);
353 std::stringstream ssc;
354 ssc << "C_" << i << "_rec_" << op;
355 sdts[i].addConstructor(op, ssc.str(), argListc);
356 }
357 else
358 {
359 std::stringstream ssc;
360 ssc << "C_" << i << "_" << op;
361 sdts[i].addConstructor(op, ssc.str(), argList);
362 }
363 }
364 Assert(sdts[i].getNumConstructors() > 0);
365 sdts[i].initializeDatatype(ctt, sygusVarList, false, false);
366 }
367 Trace("srs-input") << "...finished." << std::endl;
368
369 Trace("srs-input") << "Make mutual datatype types for subterms..."
370 << std::endl;
371 // extract the datatypes
372 std::vector<Datatype> datatypes;
373 for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
374 {
375 datatypes.push_back(sdts[i].getDatatype());
376 }
377 std::vector<DatatypeType> types = nm->toExprManager()->mkMutualDatatypeTypes(
378 datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
379 Trace("srs-input") << "...finished." << std::endl;
380 Assert(types.size() == unres.size());
381 std::map<Node, TypeNode> subtermTypes;
382 for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
383 {
384 subtermTypes[cterms[i]] = TypeNode::fromType(types[i]);
385 }
386
387 Trace("srs-input") << "Construct the top-level types..." << std::endl;
388 // we now are ready to create the "top-level" types
389 std::map<TypeNode, TypeNode> tlGrammarTypes;
390 for (std::pair<const TypeNode, std::vector<Node> >& tcp : t_cterms)
391 {
392 TypeNode t = tcp.first;
393 std::stringstream ss;
394 ss << "T_" << t;
395 SygusDatatype sdttl(ss.str());
396 Node tbv = nm->mkBoundVar(t);
397 // the operator of each constructor is a no-op
398 Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
399 Trace("srs-input") << " We have " << tcp.second.size()
400 << " subterms of type " << t << std::endl;
401 for (unsigned i = 0, size = tcp.second.size(); i < size; i++)
402 {
403 Node n = tcp.second[i];
404 // add constructor that encodes abstractions of this subterm
405 std::vector<TypeNode> argList;
406 Assert(subtermTypes.find(n) != subtermTypes.end());
407 argList.push_back(subtermTypes[n]);
408 std::stringstream ssc;
409 ssc << "Ctl_" << i;
410 // the no-op should not be printed, hence we pass an empty callback
411 sdttl.addConstructor(lambdaOp,
412 ssc.str(),
413 argList,
414 0);
415 Trace("srs-input-debug")
416 << "Grammar for subterm " << n << " is: " << std::endl;
417 Trace("srs-input-debug") << subtermTypes[n].getDType() << std::endl;
418 }
419 // set that this is a sygus datatype
420 sdttl.initializeDatatype(t, sygusVarList, false, false);
421 Datatype dttl = sdttl.getDatatype();
422 DatatypeType tlt = nm->toExprManager()->mkDatatypeType(
423 dttl, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
424 tlGrammarTypes[t] = TypeNode::fromType(tlt);
425 Trace("srs-input") << "Grammar is: " << std::endl;
426 Trace("srs-input") << tlt.getDatatype() << std::endl;
427 }
428 Trace("srs-input") << "...finished." << std::endl;
429
430 // sygus attribute to mark the conjecture as a sygus conjecture
431 Trace("srs-input") << "Make sygus conjecture..." << std::endl;
432 Node iaVar = nm->mkSkolem("sygus", nm->booleanType());
433 // the attribute to mark the conjecture as being a sygus conjecture
434 theory::SygusAttribute ca;
435 iaVar.setAttribute(ca, true);
436 Node instAttr = nm->mkNode(INST_ATTRIBUTE, iaVar);
437 Node instAttrList = nm->mkNode(INST_PATTERN_LIST, instAttr);
438 // we are "synthesizing" functions for each type of subterm
439 std::vector<Node> synthConj;
440 unsigned fCounter = 1;
441 theory::SygusSynthGrammarAttribute ssg;
442 for (std::pair<const TypeNode, TypeNode> ttp : tlGrammarTypes)
443 {
444 Node gvar = nm->mkBoundVar("sfproxy", ttp.second);
445 TypeNode ft = nm->mkFunctionType(allVarTypes, ttp.first);
446 // likewise, it is helpful if these have good names, we choose f1, f2, ...
447 std::stringstream ssf;
448 ssf << "f" << fCounter;
449 fCounter++;
450 Node sfun = nm->mkBoundVar(ssf.str(), ft);
451 // this marks that the grammar used for solutions for sfun is the type of
452 // gvar, which is the sygus datatype type constructed above.
453 sfun.setAttribute(ssg, gvar);
454 Node fvarBvl = nm->mkNode(BOUND_VAR_LIST, sfun);
455
456 Node body = nm->mkConst(false);
457 body = nm->mkNode(FORALL, fvarBvl, body, instAttrList);
458 synthConj.push_back(body);
459 }
460 Node trueNode = nm->mkConst(true);
461 Node res =
462 synthConj.empty()
463 ? trueNode
464 : (synthConj.size() == 1 ? synthConj[0] : nm->mkNode(AND, synthConj));
465
466 Trace("srs-input") << "got : " << res << std::endl;
467 Trace("srs-input") << "...finished." << std::endl;
468
469 assertionsToPreprocess->replace(0, res);
470 for (unsigned i = 1, size = assertionsToPreprocess->size(); i < size; ++i)
471 {
472 assertionsToPreprocess->replace(i, trueNode);
473 }
474
475 return PreprocessingPassResult::NO_CONFLICT;
476 }
477
478 } // namespace passes
479 } // namespace preprocessing
480 } // namespace CVC4