1 /********************* */
2 /*! \file fun_def_fmf.cpp
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Haniel Barbosa, Mathias Preiner
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief Function definition processor for finite model finding
15 #include "preprocessing/passes/fun_def_fmf.h"
19 #include "options/smt_options.h"
20 #include "preprocessing/assertion_pipeline.h"
21 #include "preprocessing/preprocessing_pass_context.h"
22 #include "proof/proof_manager.h"
23 #include "theory/quantifiers/quantifiers_attributes.h"
24 #include "theory/quantifiers/term_database.h"
25 #include "theory/quantifiers/term_util.h"
26 #include "theory/rewriter.h"
29 using namespace CVC4::kind
;
30 using namespace CVC4::theory
;
31 using namespace CVC4::theory::quantifiers
;
34 namespace preprocessing
{
37 FunDefFmf::FunDefFmf(PreprocessingPassContext
* preprocContext
)
38 : PreprocessingPass(preprocContext
, "fun-def-fmf"),
39 d_fmfRecFunctionsDefined(nullptr)
41 d_fmfRecFunctionsDefined
=
42 new (true) NodeList(preprocContext
->getUserContext());
45 FunDefFmf::~FunDefFmf() { d_fmfRecFunctionsDefined
->deleteSelf(); }
47 PreprocessingPassResult
FunDefFmf::applyInternal(
48 AssertionPipeline
* assertionsToPreprocess
)
50 Assert(d_fmfRecFunctionsDefined
!= nullptr);
53 d_input_arg_inj
.clear();
55 // must carry over current definitions (in case of incremental)
56 for (context::CDList
<Node
>::const_iterator fit
=
57 d_fmfRecFunctionsDefined
->begin();
58 fit
!= d_fmfRecFunctionsDefined
->end();
62 Assert(d_fmfRecFunctionsAbs
.find(f
) != d_fmfRecFunctionsAbs
.end());
63 TypeNode ft
= d_fmfRecFunctionsAbs
[f
];
65 std::map
<Node
, std::vector
<Node
>>::iterator fcit
=
66 d_fmfRecFunctionsConcrete
.find(f
);
67 Assert(fcit
!= d_fmfRecFunctionsConcrete
.end());
68 for (const Node
& fcc
: fcit
->second
)
70 d_input_arg_inj
[f
].push_back(fcc
);
73 process(assertionsToPreprocess
);
74 // must store new definitions (in case of incremental)
75 for (const Node
& f
: d_funcs
)
77 d_fmfRecFunctionsAbs
[f
] = d_sorts
[f
];
78 d_fmfRecFunctionsConcrete
[f
].clear();
79 for (const Node
& fcc
: d_input_arg_inj
[f
])
81 d_fmfRecFunctionsConcrete
[f
].push_back(fcc
);
83 d_fmfRecFunctionsDefined
->push_back(f
);
85 return PreprocessingPassResult::NO_CONFLICT
;
88 void FunDefFmf::process(AssertionPipeline
* assertionsToPreprocess
)
90 const std::vector
<Node
>& assertions
= assertionsToPreprocess
->ref();
91 std::vector
<int> fd_assertions
;
92 std::map
<int, Node
> subs_head
;
93 // first pass : find defined functions, transform quantifiers
94 NodeManager
* nm
= NodeManager::currentNM();
95 for (size_t i
= 0, asize
= assertions
.size(); i
< asize
; i
++)
97 Node n
= QuantAttributes::getFunDefHead(assertions
[i
]);
100 Assert(n
.getKind() == APPLY_UF
);
101 Node f
= n
.getOperator();
103 // check if already defined, if so, throw error
104 if (d_sorts
.find(f
) != d_sorts
.end())
106 Unhandled() << "Cannot define function " << f
<< " more than once.";
109 Node bd
= QuantAttributes::getFunDefBody(assertions
[i
]);
110 Trace("fmf-fun-def-debug")
111 << "Process function " << n
<< ", body = " << bd
<< std::endl
;
114 d_funcs
.push_back(f
);
115 bd
= nm
->mkNode(EQUAL
, n
, bd
);
117 // create a sort S that represents the inputs of the function
118 std::stringstream ss
;
120 TypeNode iType
= nm
->mkSort(ss
.str());
121 AbsTypeFunDefAttribute atfda
;
122 iType
.setAttribute(atfda
, true);
125 // create functions f1...fn mapping from this sort to concrete elements
126 size_t nchildn
= n
.getNumChildren();
127 for (size_t j
= 0; j
< nchildn
; j
++)
129 TypeNode typ
= nm
->mkFunctionType(iType
, n
[j
].getType());
130 std::stringstream ssf
;
131 ssf
<< f
<< "_arg_" << j
;
132 d_input_arg_inj
[f
].push_back(
133 nm
->mkSkolem(ssf
.str(), typ
, "op created during fun def fmf"));
136 // construct new quantifier forall S. F[f1(S)/x1....fn(S)/xn]
137 std::vector
<Node
> children
;
138 Node bv
= nm
->mkBoundVar("?i", iType
);
139 Node bvl
= nm
->mkNode(BOUND_VAR_LIST
, bv
);
140 std::vector
<Node
> subs
;
141 std::vector
<Node
> vars
;
142 for (size_t j
= 0; j
< nchildn
; j
++)
144 vars
.push_back(n
[j
]);
145 subs
.push_back(nm
->mkNode(APPLY_UF
, d_input_arg_inj
[f
][j
], bv
));
147 bd
= bd
.substitute(vars
.begin(), vars
.end(), subs
.begin(), subs
.end());
149 n
.substitute(vars
.begin(), vars
.end(), subs
.begin(), subs
.end());
152 << "FMF fun def: FUNCTION : rewrite " << assertions
[i
] << std::endl
;
153 Trace("fmf-fun-def") << " to " << std::endl
;
154 Node new_q
= nm
->mkNode(FORALL
, bvl
, bd
);
155 new_q
= Rewriter::rewrite(new_q
);
156 assertionsToPreprocess
->replace(i
, new_q
);
157 Trace("fmf-fun-def") << " " << assertions
[i
] << std::endl
;
158 fd_assertions
.push_back(i
);
162 // can be, e.g. in corner cases forall x. f(x)=f(x), forall x.
167 // second pass : rewrite assertions
168 std::map
<int, std::map
<Node
, Node
>> visited
;
169 std::map
<int, std::map
<Node
, Node
>> visited_cons
;
170 for (size_t i
= 0, asize
= assertions
.size(); i
< asize
; i
++)
172 bool is_fd
= std::find(fd_assertions
.begin(), fd_assertions
.end(), i
)
173 != fd_assertions
.end();
174 std::vector
<Node
> constraints
;
175 Trace("fmf-fun-def-rewrite")
176 << "Rewriting " << assertions
[i
]
177 << ", is function definition = " << is_fd
<< std::endl
;
178 Node n
= simplifyFormula(assertions
[i
],
182 is_fd
? subs_head
[i
] : Node::null(),
186 Assert(constraints
.empty());
187 if (n
!= assertions
[i
])
189 n
= Rewriter::rewrite(n
);
190 Trace("fmf-fun-def-rewrite")
191 << "FMF fun def : rewrite " << assertions
[i
] << std::endl
;
192 Trace("fmf-fun-def-rewrite") << " to " << std::endl
;
193 Trace("fmf-fun-def-rewrite") << " " << n
<< std::endl
;
194 assertionsToPreprocess
->replace(i
, n
);
199 Node
FunDefFmf::simplifyFormula(
203 std::vector
<Node
>& constraints
,
206 std::map
<int, std::map
<Node
, Node
>>& visited
,
207 std::map
<int, std::map
<Node
, Node
>>& visited_cons
)
209 Assert(constraints
.empty());
210 int index
= (is_fun_def
? 1 : 0) + 2 * (hasPol
? (pol
? 1 : -1) : 0);
211 std::map
<Node
, Node
>::iterator itv
= visited
[index
].find(n
);
212 if (itv
!= visited
[index
].end())
214 // constraints.insert( visited_cons[index]
215 std::map
<Node
, Node
>::iterator itvc
= visited_cons
[index
].find(n
);
216 if (itvc
!= visited_cons
[index
].end())
218 constraints
.push_back(itvc
->second
);
222 NodeManager
* nm
= NodeManager::currentNM();
224 Trace("fmf-fun-def-debug2") << "Simplify " << n
<< " " << pol
<< " " << hasPol
225 << " " << is_fun_def
<< std::endl
;
226 if (n
.getKind() == FORALL
)
228 Node c
= simplifyFormula(
229 n
[1], pol
, hasPol
, constraints
, hd
, is_fun_def
, visited
, visited_cons
);
230 // append prenex to constraints
231 for (unsigned i
= 0; i
< constraints
.size(); i
++)
233 constraints
[i
] = nm
->mkNode(FORALL
, n
[0], constraints
[i
]);
234 constraints
[i
] = Rewriter::rewrite(constraints
[i
]);
238 ret
= nm
->mkNode(FORALL
, n
[0], c
);
248 bool isBool
= n
.getType().isBoolean();
249 if (isBool
&& n
.getKind() != APPLY_UF
)
251 std::vector
<Node
> children
;
252 bool childChanged
= false;
253 // are we at a branch position (not all children are necessarily
256 (n
.getKind() == ITE
|| n
.getKind() == OR
|| n
.getKind() == AND
);
257 std::vector
<Node
> branch_constraints
;
258 for (unsigned i
= 0; i
< n
.getNumChildren(); i
++)
261 // do not process LHS of definition
262 if (!is_fun_def
|| c
!= hd
)
266 QuantPhaseReq::getPolarity(n
, i
, hasPol
, pol
, newHasPol
, newPol
);
267 // get child constraints
268 std::vector
<Node
> cconstraints
;
269 c
= simplifyFormula(n
[i
],
279 // if at a branching position, the other constraints don't matter
280 // if this is satisfied
281 Node bcons
= nm
->mkAnd(cconstraints
);
282 branch_constraints
.push_back(bcons
);
283 Trace("fmf-fun-def-debug2") << "Branching constraint at arg " << i
284 << " is " << bcons
<< std::endl
;
287 constraints
.end(), cconstraints
.begin(), cconstraints
.end());
289 children
.push_back(c
);
290 childChanged
= c
!= n
[i
] || childChanged
;
294 nn
= nm
->mkNode(n
.getKind(), children
);
296 if (branch_pos
&& !constraints
.empty())
298 // if we are at a branching position in the formula, we can
299 // minimize recursive constraints on recursively defined predicates if
300 // we know one child forces the overall evaluation of this formula.
302 if (n
.getKind() == ITE
)
304 // always care about constraints on the head of the ITE, but only
305 // care about one of the children depending on how it evaluates
306 branch_cond
= nm
->mkNode(
308 branch_constraints
[0],
310 ITE
, n
[0], branch_constraints
[1], branch_constraints
[2]));
314 // in the default case, we care about all conditions
315 branch_cond
= nm
->mkAnd(constraints
);
316 for (size_t i
= 0, nchild
= n
.getNumChildren(); i
< nchild
; i
++)
318 // if this child holds with forcing polarity (true child of OR or
319 // false child of AND), then we only care about its associated
320 // recursive conditions
321 branch_cond
= nm
->mkNode(ITE
,
322 (n
.getKind() == OR
? n
[i
] : n
[i
].negate()),
323 branch_constraints
[i
],
327 Trace("fmf-fun-def-debug2")
328 << "Made branching condition " << branch_cond
<< std::endl
;
330 constraints
.push_back(branch_cond
);
336 std::map
<Node
, Node
> visitedT
;
337 getConstraints(n
, constraints
, visitedT
);
339 if (!constraints
.empty() && isBool
&& hasPol
)
341 // conjoin with current
342 Node cons
= nm
->mkAnd(constraints
);
345 ret
= nm
->mkNode(AND
, nn
, cons
);
349 ret
= nm
->mkNode(OR
, nn
, cons
.negate());
351 Trace("fmf-fun-def-debug2")
352 << "Add constraint to obtain " << ret
<< std::endl
;
360 if (!constraints
.empty())
363 // flatten to AND node for the purposes of caching
364 if (constraints
.size() > 1)
366 cons
= nm
->mkNode(AND
, constraints
);
367 cons
= Rewriter::rewrite(cons
);
369 constraints
.push_back(cons
);
373 cons
= constraints
[0];
375 visited_cons
[index
][n
] = cons
;
376 Assert(constraints
.size() == 1 && constraints
[0] == cons
);
378 visited
[index
][n
] = ret
;
382 void FunDefFmf::getConstraints(Node n
,
383 std::vector
<Node
>& constraints
,
384 std::map
<Node
, Node
>& visited
)
386 std::map
<Node
, Node
>::iterator itv
= visited
.find(n
);
387 if (itv
!= visited
.end())
390 if (!itv
->second
.isNull())
392 // add the cached constraint if it does not already occur
393 if (std::find(constraints
.begin(), constraints
.end(), itv
->second
)
394 == constraints
.end())
396 constraints
.push_back(itv
->second
);
401 visited
[n
] = Node::null();
402 std::vector
<Node
> currConstraints
;
403 NodeManager
* nm
= NodeManager::currentNM();
404 if (n
.getKind() == ITE
)
406 // collect constraints for the condition
407 getConstraints(n
[0], currConstraints
, visited
);
408 // collect constraints for each branch
410 for (unsigned i
= 0; i
< 2; i
++)
412 std::vector
<Node
> ccons
;
413 getConstraints(n
[i
+ 1], ccons
, visited
);
414 cs
[i
] = nm
->mkAnd(ccons
);
416 if (!cs
[0].isConst() || !cs
[1].isConst())
418 Node itec
= nm
->mkNode(ITE
, n
[0], cs
[0], cs
[1]);
419 currConstraints
.push_back(itec
);
420 Trace("fmf-fun-def-debug")
421 << "---> add constraint " << itec
<< " for " << n
<< std::endl
;
426 if (n
.getKind() == APPLY_UF
)
428 // check if f is defined, if so, we must enforce domain constraints for
429 // this f-application
430 Node f
= n
.getOperator();
431 std::map
<Node
, TypeNode
>::iterator it
= d_sorts
.find(f
);
432 if (it
!= d_sorts
.end())
434 // create existential
435 Node z
= nm
->mkBoundVar("?z", it
->second
);
436 Node bvl
= nm
->mkNode(BOUND_VAR_LIST
, z
);
437 std::vector
<Node
> children
;
438 for (unsigned j
= 0, size
= n
.getNumChildren(); j
< size
; j
++)
440 Node uz
= nm
->mkNode(APPLY_UF
, d_input_arg_inj
[f
][j
], z
);
441 children
.push_back(uz
.eqNode(n
[j
]));
443 Node bd
= nm
->mkAnd(children
);
445 Node ex
= nm
->mkNode(FORALL
, bvl
, bd
);
447 currConstraints
.push_back(ex
);
448 Trace("fmf-fun-def-debug")
449 << "---> add constraint " << ex
<< " for " << n
<< std::endl
;
452 for (const Node
& cn
: n
)
454 getConstraints(cn
, currConstraints
, visited
);
457 // set the visited cache
458 if (!currConstraints
.empty())
460 Node finalc
= nm
->mkAnd(currConstraints
);
462 // add to constraints
463 getConstraints(n
, constraints
, visited
);
467 } // namespace passes
468 } // namespace preprocessing