Remove logic request (#6089)
[cvc5.git] / src / preprocessing / passes / fun_def_fmf.cpp
1 /********************* */
2 /*! \file fun_def_fmf.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Haniel Barbosa, Mathias Preiner
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Function definition processor for finite model finding
13 **/
14
15 #include "preprocessing/passes/fun_def_fmf.h"
16
17 #include <sstream>
18
19 #include "options/smt_options.h"
20 #include "preprocessing/assertion_pipeline.h"
21 #include "preprocessing/preprocessing_pass_context.h"
22 #include "proof/proof_manager.h"
23 #include "theory/quantifiers/quantifiers_attributes.h"
24 #include "theory/quantifiers/term_database.h"
25 #include "theory/quantifiers/term_util.h"
26 #include "theory/rewriter.h"
27
28 using namespace std;
29 using namespace CVC4::kind;
30 using namespace CVC4::theory;
31 using namespace CVC4::theory::quantifiers;
32
33 namespace CVC4 {
34 namespace preprocessing {
35 namespace passes {
36
37 FunDefFmf::FunDefFmf(PreprocessingPassContext* preprocContext)
38 : PreprocessingPass(preprocContext, "fun-def-fmf"),
39 d_fmfRecFunctionsDefined(nullptr)
40 {
41 d_fmfRecFunctionsDefined =
42 new (true) NodeList(preprocContext->getUserContext());
43 }
44
45 FunDefFmf::~FunDefFmf() { d_fmfRecFunctionsDefined->deleteSelf(); }
46
47 PreprocessingPassResult FunDefFmf::applyInternal(
48 AssertionPipeline* assertionsToPreprocess)
49 {
50 Assert(d_fmfRecFunctionsDefined != nullptr);
51 // reset
52 d_sorts.clear();
53 d_input_arg_inj.clear();
54 d_funcs.clear();
55 // must carry over current definitions (in case of incremental)
56 for (context::CDList<Node>::const_iterator fit =
57 d_fmfRecFunctionsDefined->begin();
58 fit != d_fmfRecFunctionsDefined->end();
59 ++fit)
60 {
61 Node f = (*fit);
62 Assert(d_fmfRecFunctionsAbs.find(f) != d_fmfRecFunctionsAbs.end());
63 TypeNode ft = d_fmfRecFunctionsAbs[f];
64 d_sorts[f] = ft;
65 std::map<Node, std::vector<Node>>::iterator fcit =
66 d_fmfRecFunctionsConcrete.find(f);
67 Assert(fcit != d_fmfRecFunctionsConcrete.end());
68 for (const Node& fcc : fcit->second)
69 {
70 d_input_arg_inj[f].push_back(fcc);
71 }
72 }
73 process(assertionsToPreprocess);
74 // must store new definitions (in case of incremental)
75 for (const Node& f : d_funcs)
76 {
77 d_fmfRecFunctionsAbs[f] = d_sorts[f];
78 d_fmfRecFunctionsConcrete[f].clear();
79 for (const Node& fcc : d_input_arg_inj[f])
80 {
81 d_fmfRecFunctionsConcrete[f].push_back(fcc);
82 }
83 d_fmfRecFunctionsDefined->push_back(f);
84 }
85 return PreprocessingPassResult::NO_CONFLICT;
86 }
87
88 void FunDefFmf::process(AssertionPipeline* assertionsToPreprocess)
89 {
90 const std::vector<Node>& assertions = assertionsToPreprocess->ref();
91 std::vector<int> fd_assertions;
92 std::map<int, Node> subs_head;
93 // first pass : find defined functions, transform quantifiers
94 NodeManager* nm = NodeManager::currentNM();
95 for (size_t i = 0, asize = assertions.size(); i < asize; i++)
96 {
97 Node n = QuantAttributes::getFunDefHead(assertions[i]);
98 if (!n.isNull())
99 {
100 Assert(n.getKind() == APPLY_UF);
101 Node f = n.getOperator();
102
103 // check if already defined, if so, throw error
104 if (d_sorts.find(f) != d_sorts.end())
105 {
106 Unhandled() << "Cannot define function " << f << " more than once.";
107 }
108
109 Node bd = QuantAttributes::getFunDefBody(assertions[i]);
110 Trace("fmf-fun-def-debug")
111 << "Process function " << n << ", body = " << bd << std::endl;
112 if (!bd.isNull())
113 {
114 d_funcs.push_back(f);
115 bd = nm->mkNode(EQUAL, n, bd);
116
117 // create a sort S that represents the inputs of the function
118 std::stringstream ss;
119 ss << "I_" << f;
120 TypeNode iType = nm->mkSort(ss.str());
121 AbsTypeFunDefAttribute atfda;
122 iType.setAttribute(atfda, true);
123 d_sorts[f] = iType;
124
125 // create functions f1...fn mapping from this sort to concrete elements
126 size_t nchildn = n.getNumChildren();
127 for (size_t j = 0; j < nchildn; j++)
128 {
129 TypeNode typ = nm->mkFunctionType(iType, n[j].getType());
130 std::stringstream ssf;
131 ssf << f << "_arg_" << j;
132 d_input_arg_inj[f].push_back(
133 nm->mkSkolem(ssf.str(), typ, "op created during fun def fmf"));
134 }
135
136 // construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
137 std::vector<Node> children;
138 Node bv = nm->mkBoundVar("?i", iType);
139 Node bvl = nm->mkNode(BOUND_VAR_LIST, bv);
140 std::vector<Node> subs;
141 std::vector<Node> vars;
142 for (size_t j = 0; j < nchildn; j++)
143 {
144 vars.push_back(n[j]);
145 subs.push_back(nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], bv));
146 }
147 bd = bd.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
148 subs_head[i] =
149 n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
150
151 Trace("fmf-fun-def")
152 << "FMF fun def: FUNCTION : rewrite " << assertions[i] << std::endl;
153 Trace("fmf-fun-def") << " to " << std::endl;
154 Node new_q = nm->mkNode(FORALL, bvl, bd);
155 new_q = Rewriter::rewrite(new_q);
156 assertionsToPreprocess->replace(i, new_q);
157 Trace("fmf-fun-def") << " " << assertions[i] << std::endl;
158 fd_assertions.push_back(i);
159 }
160 else
161 {
162 // can be, e.g. in corner cases forall x. f(x)=f(x), forall x.
163 // f(x)=f(x)+1
164 }
165 }
166 }
167 // second pass : rewrite assertions
168 std::map<int, std::map<Node, Node>> visited;
169 std::map<int, std::map<Node, Node>> visited_cons;
170 for (size_t i = 0, asize = assertions.size(); i < asize; i++)
171 {
172 bool is_fd = std::find(fd_assertions.begin(), fd_assertions.end(), i)
173 != fd_assertions.end();
174 std::vector<Node> constraints;
175 Trace("fmf-fun-def-rewrite")
176 << "Rewriting " << assertions[i]
177 << ", is function definition = " << is_fd << std::endl;
178 Node n = simplifyFormula(assertions[i],
179 true,
180 true,
181 constraints,
182 is_fd ? subs_head[i] : Node::null(),
183 is_fd,
184 visited,
185 visited_cons);
186 Assert(constraints.empty());
187 if (n != assertions[i])
188 {
189 n = Rewriter::rewrite(n);
190 Trace("fmf-fun-def-rewrite")
191 << "FMF fun def : rewrite " << assertions[i] << std::endl;
192 Trace("fmf-fun-def-rewrite") << " to " << std::endl;
193 Trace("fmf-fun-def-rewrite") << " " << n << std::endl;
194 assertionsToPreprocess->replace(i, n);
195 }
196 }
197 }
198
199 Node FunDefFmf::simplifyFormula(
200 Node n,
201 bool pol,
202 bool hasPol,
203 std::vector<Node>& constraints,
204 Node hd,
205 bool is_fun_def,
206 std::map<int, std::map<Node, Node>>& visited,
207 std::map<int, std::map<Node, Node>>& visited_cons)
208 {
209 Assert(constraints.empty());
210 int index = (is_fun_def ? 1 : 0) + 2 * (hasPol ? (pol ? 1 : -1) : 0);
211 std::map<Node, Node>::iterator itv = visited[index].find(n);
212 if (itv != visited[index].end())
213 {
214 // constraints.insert( visited_cons[index]
215 std::map<Node, Node>::iterator itvc = visited_cons[index].find(n);
216 if (itvc != visited_cons[index].end())
217 {
218 constraints.push_back(itvc->second);
219 }
220 return itv->second;
221 }
222 NodeManager* nm = NodeManager::currentNM();
223 Node ret;
224 Trace("fmf-fun-def-debug2") << "Simplify " << n << " " << pol << " " << hasPol
225 << " " << is_fun_def << std::endl;
226 if (n.getKind() == FORALL)
227 {
228 Node c = simplifyFormula(
229 n[1], pol, hasPol, constraints, hd, is_fun_def, visited, visited_cons);
230 // append prenex to constraints
231 for (unsigned i = 0; i < constraints.size(); i++)
232 {
233 constraints[i] = nm->mkNode(FORALL, n[0], constraints[i]);
234 constraints[i] = Rewriter::rewrite(constraints[i]);
235 }
236 if (c != n[1])
237 {
238 ret = nm->mkNode(FORALL, n[0], c);
239 }
240 else
241 {
242 ret = n;
243 }
244 }
245 else
246 {
247 Node nn = n;
248 bool isBool = n.getType().isBoolean();
249 if (isBool && n.getKind() != APPLY_UF)
250 {
251 std::vector<Node> children;
252 bool childChanged = false;
253 // are we at a branch position (not all children are necessarily
254 // relevant)?
255 bool branch_pos =
256 (n.getKind() == ITE || n.getKind() == OR || n.getKind() == AND);
257 std::vector<Node> branch_constraints;
258 for (unsigned i = 0; i < n.getNumChildren(); i++)
259 {
260 Node c = n[i];
261 // do not process LHS of definition
262 if (!is_fun_def || c != hd)
263 {
264 bool newHasPol;
265 bool newPol;
266 QuantPhaseReq::getPolarity(n, i, hasPol, pol, newHasPol, newPol);
267 // get child constraints
268 std::vector<Node> cconstraints;
269 c = simplifyFormula(n[i],
270 newPol,
271 newHasPol,
272 cconstraints,
273 hd,
274 false,
275 visited,
276 visited_cons);
277 if (branch_pos)
278 {
279 // if at a branching position, the other constraints don't matter
280 // if this is satisfied
281 Node bcons = nm->mkAnd(cconstraints);
282 branch_constraints.push_back(bcons);
283 Trace("fmf-fun-def-debug2") << "Branching constraint at arg " << i
284 << " is " << bcons << std::endl;
285 }
286 constraints.insert(
287 constraints.end(), cconstraints.begin(), cconstraints.end());
288 }
289 children.push_back(c);
290 childChanged = c != n[i] || childChanged;
291 }
292 if (childChanged)
293 {
294 nn = nm->mkNode(n.getKind(), children);
295 }
296 if (branch_pos && !constraints.empty())
297 {
298 // if we are at a branching position in the formula, we can
299 // minimize recursive constraints on recursively defined predicates if
300 // we know one child forces the overall evaluation of this formula.
301 Node branch_cond;
302 if (n.getKind() == ITE)
303 {
304 // always care about constraints on the head of the ITE, but only
305 // care about one of the children depending on how it evaluates
306 branch_cond = nm->mkNode(
307 AND,
308 branch_constraints[0],
309 nm->mkNode(
310 ITE, n[0], branch_constraints[1], branch_constraints[2]));
311 }
312 else
313 {
314 // in the default case, we care about all conditions
315 branch_cond = nm->mkAnd(constraints);
316 for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
317 {
318 // if this child holds with forcing polarity (true child of OR or
319 // false child of AND), then we only care about its associated
320 // recursive conditions
321 branch_cond = nm->mkNode(ITE,
322 (n.getKind() == OR ? n[i] : n[i].negate()),
323 branch_constraints[i],
324 branch_cond);
325 }
326 }
327 Trace("fmf-fun-def-debug2")
328 << "Made branching condition " << branch_cond << std::endl;
329 constraints.clear();
330 constraints.push_back(branch_cond);
331 }
332 }
333 else
334 {
335 // simplify term
336 std::map<Node, Node> visitedT;
337 getConstraints(n, constraints, visitedT);
338 }
339 if (!constraints.empty() && isBool && hasPol)
340 {
341 // conjoin with current
342 Node cons = nm->mkAnd(constraints);
343 if (pol)
344 {
345 ret = nm->mkNode(AND, nn, cons);
346 }
347 else
348 {
349 ret = nm->mkNode(OR, nn, cons.negate());
350 }
351 Trace("fmf-fun-def-debug2")
352 << "Add constraint to obtain " << ret << std::endl;
353 constraints.clear();
354 }
355 else
356 {
357 ret = nn;
358 }
359 }
360 if (!constraints.empty())
361 {
362 Node cons;
363 // flatten to AND node for the purposes of caching
364 if (constraints.size() > 1)
365 {
366 cons = nm->mkNode(AND, constraints);
367 cons = Rewriter::rewrite(cons);
368 constraints.clear();
369 constraints.push_back(cons);
370 }
371 else
372 {
373 cons = constraints[0];
374 }
375 visited_cons[index][n] = cons;
376 Assert(constraints.size() == 1 && constraints[0] == cons);
377 }
378 visited[index][n] = ret;
379 return ret;
380 }
381
382 void FunDefFmf::getConstraints(Node n,
383 std::vector<Node>& constraints,
384 std::map<Node, Node>& visited)
385 {
386 std::map<Node, Node>::iterator itv = visited.find(n);
387 if (itv != visited.end())
388 {
389 // already visited
390 if (!itv->second.isNull())
391 {
392 // add the cached constraint if it does not already occur
393 if (std::find(constraints.begin(), constraints.end(), itv->second)
394 == constraints.end())
395 {
396 constraints.push_back(itv->second);
397 }
398 }
399 return;
400 }
401 visited[n] = Node::null();
402 std::vector<Node> currConstraints;
403 NodeManager* nm = NodeManager::currentNM();
404 if (n.getKind() == ITE)
405 {
406 // collect constraints for the condition
407 getConstraints(n[0], currConstraints, visited);
408 // collect constraints for each branch
409 Node cs[2];
410 for (unsigned i = 0; i < 2; i++)
411 {
412 std::vector<Node> ccons;
413 getConstraints(n[i + 1], ccons, visited);
414 cs[i] = nm->mkAnd(ccons);
415 }
416 if (!cs[0].isConst() || !cs[1].isConst())
417 {
418 Node itec = nm->mkNode(ITE, n[0], cs[0], cs[1]);
419 currConstraints.push_back(itec);
420 Trace("fmf-fun-def-debug")
421 << "---> add constraint " << itec << " for " << n << std::endl;
422 }
423 }
424 else
425 {
426 if (n.getKind() == APPLY_UF)
427 {
428 // check if f is defined, if so, we must enforce domain constraints for
429 // this f-application
430 Node f = n.getOperator();
431 std::map<Node, TypeNode>::iterator it = d_sorts.find(f);
432 if (it != d_sorts.end())
433 {
434 // create existential
435 Node z = nm->mkBoundVar("?z", it->second);
436 Node bvl = nm->mkNode(BOUND_VAR_LIST, z);
437 std::vector<Node> children;
438 for (unsigned j = 0, size = n.getNumChildren(); j < size; j++)
439 {
440 Node uz = nm->mkNode(APPLY_UF, d_input_arg_inj[f][j], z);
441 children.push_back(uz.eqNode(n[j]));
442 }
443 Node bd = nm->mkAnd(children);
444 bd = bd.negate();
445 Node ex = nm->mkNode(FORALL, bvl, bd);
446 ex = ex.negate();
447 currConstraints.push_back(ex);
448 Trace("fmf-fun-def-debug")
449 << "---> add constraint " << ex << " for " << n << std::endl;
450 }
451 }
452 for (const Node& cn : n)
453 {
454 getConstraints(cn, currConstraints, visited);
455 }
456 }
457 // set the visited cache
458 if (!currConstraints.empty())
459 {
460 Node finalc = nm->mkAnd(currConstraints);
461 visited[n] = finalc;
462 // add to constraints
463 getConstraints(n, constraints, visited);
464 }
465 }
466
467 } // namespace passes
468 } // namespace preprocessing
469 } // namespace CVC4