Remove logic request (#6089)
[cvc5.git] / src / preprocessing / passes / synth_rew_rules.cpp
1 /********************* */
2 /*! \file synth_rew_rules.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Mathias Preiner
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief A technique for synthesizing candidate rewrites of the form t1 = t2,
13 ** where t1 and t2 are subterms of the input.
14 **/
15
16 #include "preprocessing/passes/synth_rew_rules.h"
17
18 #include <sstream>
19
20 #include "expr/sygus_datatype.h"
21 #include "expr/term_canonize.h"
22 #include "options/base_options.h"
23 #include "options/quantifiers_options.h"
24 #include "preprocessing/assertion_pipeline.h"
25 #include "printer/printer.h"
26 #include "theory/quantifiers/candidate_rewrite_database.h"
27 #include "theory/quantifiers/quantifiers_attributes.h"
28 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
29 #include "theory/quantifiers/sygus/sygus_utils.h"
30 #include "theory/quantifiers/term_util.h"
31
32 using namespace std;
33 using namespace CVC4::kind;
34
35 namespace CVC4 {
36 namespace preprocessing {
37 namespace passes {
38
39 SynthRewRulesPass::SynthRewRulesPass(PreprocessingPassContext* preprocContext)
40 : PreprocessingPass(preprocContext, "synth-rr"){};
41
42 PreprocessingPassResult SynthRewRulesPass::applyInternal(
43 AssertionPipeline* assertionsToPreprocess)
44 {
45 Trace("srs-input") << "Synthesize rewrite rules from assertions..."
46 << std::endl;
47 const std::vector<Node>& assertions = assertionsToPreprocess->ref();
48 if (assertions.empty())
49 {
50 return PreprocessingPassResult::NO_CONFLICT;
51 }
52
53 NodeManager* nm = NodeManager::currentNM();
54
55 // initialize the candidate rewrite
56 std::unordered_map<TNode, bool, TNodeHashFunction> visited;
57 std::unordered_map<TNode, bool, TNodeHashFunction>::iterator it;
58 std::vector<TNode> visit;
59 // Get all usable terms from the input. A term is usable if it does not
60 // contain a quantified subterm
61 std::vector<Node> terms;
62 // all variables (free constants) appearing in the input
63 std::vector<Node> vars;
64 // does the input contain a Boolean variable?
65 bool hasBoolVar = false;
66 // the types of subterms of our input
67 std::map<TypeNode, bool> typesFound;
68 // standard constants for each type (e.g. true, false for Bool)
69 std::map<TypeNode, std::vector<Node> > consts;
70
71 TNode cur;
72 Trace("srs-input") << "Collect terms in assertions..." << std::endl;
73 for (const Node& a : assertions)
74 {
75 Trace("srs-input-debug") << "Assertion : " << a << std::endl;
76 visit.push_back(a);
77 do
78 {
79 cur = visit.back();
80 visit.pop_back();
81 it = visited.find(cur);
82 if (it == visited.end())
83 {
84 Trace("srs-input-debug") << "...preprocess " << cur << std::endl;
85 visited[cur] = false;
86 bool isQuant = cur.isClosure();
87 // we recurse on this node if it is not a quantified formula
88 if (!isQuant)
89 {
90 visit.push_back(cur);
91 for (const Node& cc : cur)
92 {
93 visit.push_back(cc);
94 }
95 }
96 }
97 else if (!it->second)
98 {
99 Trace("srs-input-debug") << "...postprocess " << cur << std::endl;
100 // check if all of the children are valid
101 // this ensures we do not register terms that have e.g. quantified
102 // formulas as subterms
103 bool childrenValid = true;
104 for (const Node& cc : cur)
105 {
106 Assert(visited.find(cc) != visited.end());
107 if (!visited[cc])
108 {
109 childrenValid = false;
110 }
111 }
112 if (childrenValid)
113 {
114 Trace("srs-input-debug") << "...children are valid" << std::endl;
115 Trace("srs-input-debug") << "Add term " << cur << std::endl;
116 TypeNode tn = cur.getType();
117 if (cur.isVar())
118 {
119 vars.push_back(cur);
120 if (tn.isBoolean())
121 {
122 hasBoolVar = true;
123 }
124 }
125 // register type information
126 if (typesFound.find(tn) == typesFound.end())
127 {
128 typesFound[tn] = true;
129 // add the standard constants for this type
130 theory::quantifiers::CegGrammarConstructor::mkSygusConstantsForType(
131 tn, consts[tn]);
132 // We prepend them so that they come first in the grammar
133 // construction. The motivation is we'd prefer seeing e.g. "true"
134 // instead of (= x x) as a canonical term.
135 terms.insert(terms.begin(), consts[tn].begin(), consts[tn].end());
136 }
137 terms.push_back(cur);
138 }
139 visited[cur] = childrenValid;
140 }
141 } while (!visit.empty());
142 }
143 Trace("srs-input") << "...finished." << std::endl;
144
145 Trace("srs-input") << "Make synth variables for types..." << std::endl;
146 // We will generate a fixed number of variables per type. These are the
147 // variables that appear as free variables in the rewrites we generate.
148 unsigned nvars = options::sygusRewSynthInputNVars();
149 // must have at least one variable per type
150 nvars = nvars < 1 ? 1 : nvars;
151 std::map<TypeNode, std::vector<Node> > tvars;
152 std::vector<TypeNode> allVarTypes;
153 std::vector<Node> allVars;
154 unsigned varCounter = 0;
155 for (std::pair<const TypeNode, bool> tfp : typesFound)
156 {
157 TypeNode tn = tfp.first;
158 // If we are not interested in purely propositional rewrites, we only
159 // need to make one Boolean variable if the input has a Boolean variable.
160 // This ensures that no type in our grammar has zero constructors. If
161 // our input does not contain a Boolean variable, we need not allocate any
162 // Boolean variables here.
163 unsigned useNVars =
164 (options::sygusRewSynthInputUseBool() || !tn.isBoolean())
165 ? nvars
166 : (hasBoolVar ? 1 : 0);
167 for (unsigned i = 0; i < useNVars; i++)
168 {
169 // We must have a good name for these variables, these are
170 // the ones output in rewrite rules. We choose
171 // a,b,c,...,y,z,x1,x2,...
172 std::stringstream ssv;
173 if (varCounter < 26)
174 {
175 ssv << static_cast<char>(varCounter + 61);
176 }
177 else
178 {
179 ssv << "x" << (varCounter - 26);
180 }
181 varCounter++;
182 Node v = nm->mkBoundVar(ssv.str(), tn);
183 tvars[tn].push_back(v);
184 allVars.push_back(v);
185 allVarTypes.push_back(tn);
186 }
187 }
188 Trace("srs-input") << "...finished." << std::endl;
189
190 Trace("srs-input") << "Convert subterms to free variable form..."
191 << std::endl;
192 // Replace all free variables with bound variables. This ensures that
193 // we can perform term canonization on subterms.
194 std::vector<Node> vsubs;
195 for (const Node& v : vars)
196 {
197 TypeNode tnv = v.getType();
198 Node vs = nm->mkBoundVar(tnv);
199 vsubs.push_back(vs);
200 }
201 if (!vars.empty())
202 {
203 for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
204 {
205 terms[i] = terms[i].substitute(
206 vars.begin(), vars.end(), vsubs.begin(), vsubs.end());
207 }
208 }
209 Trace("srs-input") << "...finished." << std::endl;
210
211 Trace("srs-input") << "Process " << terms.size() << " subterms..."
212 << std::endl;
213 // We've collected all terms in the input. We construct a sygus grammar in
214 // following which generates terms that correspond to abstractions of the
215 // terms in the input.
216
217 // We map terms to a canonical (ordered variable) form. This ensures that
218 // we don't generate distinct grammar types for distinct alpha-equivalent
219 // terms, which would produce grammars of identical shape.
220 std::map<Node, Node> term_to_cterm;
221 std::map<Node, Node> cterm_to_term;
222 std::vector<Node> cterms;
223 // canonical terms for each type
224 std::map<TypeNode, std::vector<Node> > t_cterms;
225 expr::TermCanonize tcanon;
226 for (unsigned i = 0, nterms = terms.size(); i < nterms; i++)
227 {
228 Node n = terms[i];
229 Node cn = tcanon.getCanonicalTerm(n);
230 term_to_cterm[n] = cn;
231 Trace("srs-input-debug") << "Canon : " << n << " -> " << cn << std::endl;
232 std::map<Node, Node>::iterator itc = cterm_to_term.find(cn);
233 if (itc == cterm_to_term.end())
234 {
235 cterm_to_term[cn] = n;
236 cterms.push_back(cn);
237 t_cterms[cn.getType()].push_back(cn);
238 }
239 }
240 Trace("srs-input") << "...finished." << std::endl;
241 // the sygus variable list
242 Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars);
243 Trace("srs-input") << "Have " << cterms.size() << " canonical subterms."
244 << std::endl;
245
246 Trace("srs-input") << "Construct unresolved types..." << std::endl;
247 // each canonical subterm corresponds to a grammar type
248 std::set<TypeNode> unres;
249 std::vector<SygusDatatype> sdts;
250 // make unresolved types for each canonical term
251 std::map<Node, TypeNode> cterm_to_utype;
252 for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
253 {
254 Node ct = cterms[i];
255 std::stringstream ss;
256 ss << "T" << i;
257 std::string tname = ss.str();
258 TypeNode tnu = nm->mkSort(tname, NodeManager::SORT_FLAG_PLACEHOLDER);
259 cterm_to_utype[ct] = tnu;
260 unres.insert(tnu);
261 sdts.push_back(SygusDatatype(tname));
262 }
263 Trace("srs-input") << "...finished." << std::endl;
264
265 Trace("srs-input") << "Construct sygus datatypes..." << std::endl;
266 for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
267 {
268 Node ct = cterms[i];
269 Node t = cterm_to_term[ct];
270
271 // add the variables for the type
272 TypeNode ctt = ct.getType();
273 Assert(tvars.find(ctt) != tvars.end());
274 std::vector<TypeNode> argList;
275 // we add variable constructors if we are not Boolean, we are interested
276 // in purely propositional rewrites (via the option), or this term is
277 // a Boolean variable.
278 if (!ctt.isBoolean() || options::sygusRewSynthInputUseBool()
279 || ct.getKind() == BOUND_VARIABLE)
280 {
281 for (const Node& v : tvars[ctt])
282 {
283 std::stringstream ssc;
284 ssc << "C_" << i << "_" << v;
285 sdts[i].addConstructor(v, ssc.str(), argList);
286 }
287 }
288 // add the constructor for the operator if it is not a variable
289 if (ct.getKind() != BOUND_VARIABLE)
290 {
291 Assert(!ct.isVar());
292 Node op = ct.hasOperator() ? ct.getOperator() : ct;
293 // iterate over the original term
294 for (const Node& tc : t)
295 {
296 // map its arguments back to canonical
297 Assert(term_to_cterm.find(tc) != term_to_cterm.end());
298 Node ctc = term_to_cterm[tc];
299 Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end());
300 // get the type
301 argList.push_back(cterm_to_utype[ctc]);
302 }
303 // check if we should chain
304 bool do_chain = false;
305 if (argList.size() > 2)
306 {
307 Kind k = NodeManager::operatorToKind(op);
308 do_chain = theory::quantifiers::TermUtil::isAssoc(k)
309 && theory::quantifiers::TermUtil::isComm(k);
310 // eliminate duplicate child types
311 std::vector<TypeNode> argListTmp = argList;
312 argList.clear();
313 std::map<TypeNode, bool> hasArgType;
314 for (unsigned j = 0, size = argListTmp.size(); j < size; j++)
315 {
316 TypeNode tn = argListTmp[j];
317 if (hasArgType.find(tn) == hasArgType.end())
318 {
319 hasArgType[tn] = true;
320 argList.push_back(tn);
321 }
322 }
323 }
324 if (do_chain)
325 {
326 // we make one type per child
327 // the operator of each constructor is a no-op
328 Node tbv = nm->mkBoundVar(ctt);
329 Node lambdaOp =
330 nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
331 std::vector<TypeNode> argListc;
332 // the following construction admits any number of repeated factors,
333 // so for instance, t1+t2+t3, we generate the grammar:
334 // T_{t1+t2+t3} ->
335 // +( T_{t1+t2+t3}, T_{t1+t2+t3} ) | T_{t1} | T_{t2} | T_{t3}
336 // where we write T_t to denote "the type that abstracts term t".
337 // Notice this construction allows to abstract subsets of the factors
338 // of t1+t2+t3. This is particularly helpful for terms t1+...+tn for
339 // large n, where we would like to consider binary applications of +.
340 for (unsigned j = 0, size = argList.size(); j < size; j++)
341 {
342 argListc.clear();
343 argListc.push_back(argList[j]);
344 std::stringstream sscs;
345 sscs << "C_factor_" << i << "_" << j;
346 // ID function is not printed and does not count towards weight
347 sdts[i].addConstructor(lambdaOp,
348 sscs.str(),
349 argListc,
350 0);
351 }
352 // recursive apply
353 TypeNode recType = cterm_to_utype[ct];
354 argListc.clear();
355 argListc.push_back(recType);
356 argListc.push_back(recType);
357 std::stringstream ssc;
358 ssc << "C_" << i << "_rec_" << op;
359 sdts[i].addConstructor(op, ssc.str(), argListc);
360 }
361 else
362 {
363 std::stringstream ssc;
364 ssc << "C_" << i << "_" << op;
365 sdts[i].addConstructor(op, ssc.str(), argList);
366 }
367 }
368 Assert(sdts[i].getNumConstructors() > 0);
369 sdts[i].initializeDatatype(ctt, sygusVarList, false, false);
370 }
371 Trace("srs-input") << "...finished." << std::endl;
372
373 Trace("srs-input") << "Make mutual datatype types for subterms..."
374 << std::endl;
375 // extract the datatypes
376 std::vector<DType> datatypes;
377 for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
378 {
379 datatypes.push_back(sdts[i].getDatatype());
380 }
381 std::vector<TypeNode> types = nm->mkMutualDatatypeTypes(
382 datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER);
383 Trace("srs-input") << "...finished." << std::endl;
384 Assert(types.size() == unres.size());
385 std::map<Node, TypeNode> subtermTypes;
386 for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
387 {
388 subtermTypes[cterms[i]] = types[i];
389 }
390
391 Trace("srs-input") << "Construct the top-level types..." << std::endl;
392 // we now are ready to create the "top-level" types
393 std::map<TypeNode, TypeNode> tlGrammarTypes;
394 for (std::pair<const TypeNode, std::vector<Node> >& tcp : t_cterms)
395 {
396 TypeNode t = tcp.first;
397 std::stringstream ss;
398 ss << "T_" << t;
399 SygusDatatype sdttl(ss.str());
400 Node tbv = nm->mkBoundVar(t);
401 // the operator of each constructor is a no-op
402 Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
403 Trace("srs-input") << " We have " << tcp.second.size()
404 << " subterms of type " << t << std::endl;
405 for (unsigned i = 0, size = tcp.second.size(); i < size; i++)
406 {
407 Node n = tcp.second[i];
408 // add constructor that encodes abstractions of this subterm
409 std::vector<TypeNode> argList;
410 Assert(subtermTypes.find(n) != subtermTypes.end());
411 argList.push_back(subtermTypes[n]);
412 std::stringstream ssc;
413 ssc << "Ctl_" << i;
414 // the no-op should not be printed, hence we pass an empty callback
415 sdttl.addConstructor(lambdaOp,
416 ssc.str(),
417 argList,
418 0);
419 Trace("srs-input-debug")
420 << "Grammar for subterm " << n << " is: " << std::endl;
421 Trace("srs-input-debug") << subtermTypes[n].getDType() << std::endl;
422 }
423 // set that this is a sygus datatype
424 sdttl.initializeDatatype(t, sygusVarList, false, false);
425 DType dttl = sdttl.getDatatype();
426 TypeNode tlt =
427 nm->mkDatatypeType(dttl, NodeManager::DATATYPE_FLAG_PLACEHOLDER);
428 tlGrammarTypes[t] = tlt;
429 Trace("srs-input") << "Grammar is: " << std::endl;
430 Trace("srs-input") << tlt.getDType() << std::endl;
431 }
432 Trace("srs-input") << "...finished." << std::endl;
433
434 // sygus attribute to mark the conjecture as a sygus conjecture
435 Trace("srs-input") << "Make sygus conjecture..." << std::endl;
436 // we are "synthesizing" functions for each type of subterm
437 std::vector<Node> synthConj;
438 unsigned fCounter = 1;
439 theory::SygusSynthGrammarAttribute ssg;
440 for (std::pair<const TypeNode, TypeNode> ttp : tlGrammarTypes)
441 {
442 Node gvar = nm->mkBoundVar("sfproxy", ttp.second);
443 TypeNode ft = nm->mkFunctionType(allVarTypes, ttp.first);
444 // likewise, it is helpful if these have good names, we choose f1, f2, ...
445 std::stringstream ssf;
446 ssf << "f" << fCounter;
447 fCounter++;
448 Node sfun = nm->mkBoundVar(ssf.str(), ft);
449 // this marks that the grammar used for solutions for sfun is the type of
450 // gvar, which is the sygus datatype type constructed above.
451 sfun.setAttribute(ssg, gvar);
452
453 Node body = nm->mkConst(false);
454 body = theory::quantifiers::SygusUtils::mkSygusConjecture({sfun}, body);
455 synthConj.push_back(body);
456 }
457 Node trueNode = nm->mkConst(true);
458 Node res =
459 synthConj.empty()
460 ? trueNode
461 : (synthConj.size() == 1 ? synthConj[0] : nm->mkNode(AND, synthConj));
462
463 Trace("srs-input") << "got : " << res << std::endl;
464 Trace("srs-input") << "...finished." << std::endl;
465
466 assertionsToPreprocess->replace(0, res);
467 for (unsigned i = 1, size = assertionsToPreprocess->size(); i < size; ++i)
468 {
469 assertionsToPreprocess->replace(i, trueNode);
470 }
471
472 return PreprocessingPassResult::NO_CONFLICT;
473 }
474
475 } // namespace passes
476 } // namespace preprocessing
477 } // namespace CVC4