Fix combinations of cegqi and non-standard triggers (#4271)
[cvc5.git] / src / theory / subs_minimize.cpp
1 /********************* */
2 /*! \file subs_minimize.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of substitution minimization.
13 **/
14
15 #include "theory/subs_minimize.h"
16
17 #include "expr/node_algorithm.h"
18 #include "theory/bv/theory_bv_utils.h"
19 #include "theory/rewriter.h"
20
21 using namespace std;
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25 namespace theory {
26
27 SubstitutionMinimize::SubstitutionMinimize() {}
28
29 bool SubstitutionMinimize::find(Node t,
30 Node target,
31 const std::vector<Node>& vars,
32 const std::vector<Node>& subs,
33 std::vector<Node>& reqVars)
34 {
35 return findInternal(t, target, vars, subs, reqVars);
36 }
37
38 void getConjuncts(Node n, std::vector<Node>& conj)
39 {
40 if (n.getKind() == AND)
41 {
42 for (const Node& nc : n)
43 {
44 conj.push_back(nc);
45 }
46 }
47 else
48 {
49 conj.push_back(n);
50 }
51 }
52
53 bool SubstitutionMinimize::findWithImplied(Node t,
54 const std::vector<Node>& vars,
55 const std::vector<Node>& subs,
56 std::vector<Node>& reqVars,
57 std::vector<Node>& impliedVars)
58 {
59 NodeManager* nm = NodeManager::currentNM();
60 Node truen = nm->mkConst(true);
61 if (!findInternal(t, truen, vars, subs, reqVars))
62 {
63 return false;
64 }
65 if (reqVars.empty())
66 {
67 return true;
68 }
69
70 // map from conjuncts of t to whether they may be used to show an implied var
71 std::vector<Node> tconj;
72 getConjuncts(t, tconj);
73 // map from conjuncts to their free symbols
74 std::map<Node, std::unordered_set<Node, NodeHashFunction> > tcFv;
75
76 std::unordered_set<Node, NodeHashFunction> reqSet;
77 std::vector<Node> reqSubs;
78 std::map<Node, unsigned> reqVarToIndex;
79 for (const Node& v : reqVars)
80 {
81 reqVarToIndex[v] = reqSubs.size();
82 const std::vector<Node>::const_iterator& it =
83 std::find(vars.begin(), vars.end(), v);
84 Assert(it != vars.end());
85 ptrdiff_t pos = std::distance(vars.begin(), it);
86 reqSubs.push_back(subs[pos]);
87 }
88 std::vector<Node> finalReqVars;
89 for (const Node& v : vars)
90 {
91 if (reqVarToIndex.find(v) == reqVarToIndex.end())
92 {
93 // not a required variable, nothing to do
94 continue;
95 }
96 unsigned vindex = reqVarToIndex[v];
97 Node prev = reqSubs[vindex];
98 // make identity substitution
99 reqSubs[vindex] = v;
100 bool madeImplied = false;
101 // it is a required variable, can we make an implied variable?
102 for (const Node& tc : tconj)
103 {
104 // ensure we've computed its free symbols
105 std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator
106 itf = tcFv.find(tc);
107 if (itf == tcFv.end())
108 {
109 expr::getSymbols(tc, tcFv[tc]);
110 itf = tcFv.find(tc);
111 }
112 // only have a chance if contains v
113 if (itf->second.find(v) == itf->second.end())
114 {
115 continue;
116 }
117 // try the current substitution
118 Node tcs = tc.substitute(
119 reqVars.begin(), reqVars.end(), reqSubs.begin(), reqSubs.end());
120 Node tcsr = Rewriter::rewrite(tcs);
121 std::vector<Node> tcsrConj;
122 getConjuncts(tcsr, tcsrConj);
123 for (const Node& tcc : tcsrConj)
124 {
125 if (tcc.getKind() == EQUAL)
126 {
127 for (unsigned r = 0; r < 2; r++)
128 {
129 if (tcc[r] == v)
130 {
131 Node res = tcc[1 - r];
132 if (res.isConst())
133 {
134 Assert(res == prev);
135 madeImplied = true;
136 break;
137 }
138 }
139 }
140 }
141 if (madeImplied)
142 {
143 break;
144 }
145 }
146 if (madeImplied)
147 {
148 break;
149 }
150 }
151 if (!madeImplied)
152 {
153 // revert the substitution
154 reqSubs[vindex] = prev;
155 finalReqVars.push_back(v);
156 }
157 else
158 {
159 impliedVars.push_back(v);
160 }
161 }
162 reqVars.clear();
163 reqVars.insert(reqVars.end(), finalReqVars.begin(), finalReqVars.end());
164
165 return true;
166 }
167
168 bool SubstitutionMinimize::findInternal(Node n,
169 Node target,
170 const std::vector<Node>& vars,
171 const std::vector<Node>& subs,
172 std::vector<Node>& reqVars)
173 {
174 Trace("subs-min") << "Substitution minimize : " << std::endl;
175 Trace("subs-min") << " substitution : " << vars << " -> " << subs
176 << std::endl;
177 Trace("subs-min") << " node : " << n << std::endl;
178 Trace("subs-min") << " target : " << target << std::endl;
179
180 Trace("subs-min") << "--- Compute values for subterms..." << std::endl;
181 // the value of each subterm in n under the substitution
182 std::unordered_map<TNode, Node, TNodeHashFunction> value;
183 std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
184 std::vector<TNode> visit;
185 TNode cur;
186 visit.push_back(n);
187 do
188 {
189 cur = visit.back();
190 visit.pop_back();
191 it = value.find(cur);
192
193 if (it == value.end())
194 {
195 if (cur.isVar())
196 {
197 const std::vector<Node>::const_iterator& iit =
198 std::find(vars.begin(), vars.end(), cur);
199 if (iit == vars.end())
200 {
201 value[cur] = cur;
202 }
203 else
204 {
205 ptrdiff_t pos = std::distance(vars.begin(), iit);
206 value[cur] = subs[pos];
207 }
208 }
209 else
210 {
211 value[cur] = Node::null();
212 visit.push_back(cur);
213 if (cur.getKind() == APPLY_UF)
214 {
215 visit.push_back(cur.getOperator());
216 }
217 visit.insert(visit.end(), cur.begin(), cur.end());
218 }
219 }
220 else if (it->second.isNull())
221 {
222 Node ret = cur;
223 if (cur.getNumChildren() > 0)
224 {
225 std::vector<Node> children;
226 NodeBuilder<> nb(cur.getKind());
227 if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
228 {
229 if (cur.getKind() == APPLY_UF)
230 {
231 children.push_back(cur.getOperator());
232 }
233 else
234 {
235 nb << cur.getOperator();
236 }
237 }
238 children.insert(children.end(), cur.begin(), cur.end());
239 for (const Node& cn : children)
240 {
241 it = value.find(cn);
242 Assert(it != value.end());
243 Assert(!it->second.isNull());
244 nb << it->second;
245 }
246 ret = nb.constructNode();
247 ret = Rewriter::rewrite(ret);
248 }
249 value[cur] = ret;
250 }
251 } while (!visit.empty());
252 Assert(value.find(n) != value.end());
253 Assert(!value.find(n)->second.isNull());
254
255 Trace("subs-min") << "... got " << value[n] << std::endl;
256 if (value[n] != target)
257 {
258 Trace("subs-min") << "... not equal to target " << target << std::endl;
259 return false;
260 }
261
262 Trace("subs-min") << "--- Compute relevant variables..." << std::endl;
263 std::unordered_set<Node, NodeHashFunction> rlvFv;
264 // only variables that occur in assertions are relevant
265
266 visit.push_back(n);
267 std::unordered_set<TNode, TNodeHashFunction> visited;
268 std::unordered_set<TNode, TNodeHashFunction>::iterator itv;
269 do
270 {
271 cur = visit.back();
272 visit.pop_back();
273 itv = visited.find(cur);
274 if (itv == visited.end())
275 {
276 visited.insert(cur);
277 it = value.find(cur);
278 if (it->second == cur)
279 {
280 // if its value is the same as current, there is nothing to do
281 }
282 else if (cur.isVar())
283 {
284 // must include
285 rlvFv.insert(cur);
286 }
287 else if (cur.getKind() == ITE)
288 {
289 // only recurse on relevant branch
290 Node bval = value[cur[0]];
291 Assert(!bval.isNull() && bval.isConst());
292 unsigned cindex = bval.getConst<bool>() ? 1 : 2;
293 visit.push_back(cur[0]);
294 visit.push_back(cur[cindex]);
295 }
296 else if (cur.getNumChildren() > 0)
297 {
298 Kind ck = cur.getKind();
299 bool alreadyJustified = false;
300
301 // if the operator is an apply uf, check its value
302 if (cur.getKind() == APPLY_UF)
303 {
304 Node op = cur.getOperator();
305 it = value.find(op);
306 Assert(it != value.end());
307 TNode vop = it->second;
308 if (vop.getKind() == LAMBDA)
309 {
310 visit.push_back(op);
311 // do iterative partial evaluation on the body of the lambda
312 Node curr = vop[1];
313 for (unsigned i = 0, size = cur.getNumChildren(); i < size; i++)
314 {
315 it = value.find(cur[i]);
316 Assert(it != value.end());
317 Node scurr = curr.substitute(vop[0][i], it->second);
318 // if the valuation of the i^th argument changes the
319 // interpretation of the body of the lambda, then the i^th
320 // argument is relevant to the substitution. Hence, we add
321 // i to visit, and update curr below.
322 if (scurr != curr)
323 {
324 curr = Rewriter::rewrite(scurr);
325 visit.push_back(cur[i]);
326 }
327 }
328 alreadyJustified = true;
329 }
330 }
331 if (!alreadyJustified)
332 {
333 // a subset of the arguments of cur that fully justify the evaluation
334 std::vector<unsigned> justifyArgs;
335 if (cur.getNumChildren() > 1)
336 {
337 for (unsigned i = 0, size = cur.getNumChildren(); i < size; i++)
338 {
339 Node cn = cur[i];
340 it = value.find(cn);
341 Assert(it != value.end());
342 Assert(!it->second.isNull());
343 if (isSingularArg(it->second, ck, i))
344 {
345 // have we seen this argument already? if so, we are done
346 if (visited.find(cn) != visited.end())
347 {
348 alreadyJustified = true;
349 break;
350 }
351 justifyArgs.push_back(i);
352 }
353 }
354 }
355 // we need to recurse on at most one child
356 if (!alreadyJustified && !justifyArgs.empty())
357 {
358 unsigned sindex = justifyArgs[0];
359 // could choose a best index, for now, we just take the first
360 visit.push_back(cur[sindex]);
361 alreadyJustified = true;
362 }
363 }
364 if (!alreadyJustified)
365 {
366 // must recurse on all arguments, including operator
367 if (cur.getKind() == APPLY_UF)
368 {
369 visit.push_back(cur.getOperator());
370 }
371 for (const Node& cn : cur)
372 {
373 visit.push_back(cn);
374 }
375 }
376 }
377 }
378 } while (!visit.empty());
379
380 for (const Node& v : rlvFv)
381 {
382 Assert(std::find(vars.begin(), vars.end(), v) != vars.end());
383 reqVars.push_back(v);
384 }
385
386 Trace("subs-min") << "... requires " << reqVars.size() << "/" << vars.size()
387 << " : " << reqVars << std::endl;
388
389 return true;
390 }
391
392 bool SubstitutionMinimize::isSingularArg(Node n, Kind k, unsigned arg)
393 {
394 // Notice that this function is hardcoded. We could compute this function
395 // in a theory-independent way using partial evaluation. However, we
396 // prefer performance to generality here.
397
398 // TODO: a variant of this code is implemented in quantifiers::TermUtil.
399 // These implementations should be merged (see #1216).
400 if (!n.isConst())
401 {
402 return false;
403 }
404 if (k == AND)
405 {
406 return !n.getConst<bool>();
407 }
408 else if (k == OR)
409 {
410 return n.getConst<bool>();
411 }
412 else if (k == IMPLIES)
413 {
414 return arg == (n.getConst<bool>() ? 1 : 0);
415 }
416 if (k == MULT
417 || (arg == 0
418 && (k == DIVISION_TOTAL || k == INTS_DIVISION_TOTAL
419 || k == INTS_MODULUS_TOTAL))
420 || (arg == 2 && k == STRING_SUBSTR))
421 {
422 // zero
423 if (n.getConst<Rational>().sgn() == 0)
424 {
425 return true;
426 }
427 }
428 if (k == BITVECTOR_AND || k == BITVECTOR_MULT || k == BITVECTOR_UDIV_TOTAL
429 || k == BITVECTOR_UREM_TOTAL
430 || (arg == 0
431 && (k == BITVECTOR_SHL || k == BITVECTOR_LSHR
432 || k == BITVECTOR_ASHR)))
433 {
434 if (bv::utils::isZero(n))
435 {
436 return true;
437 }
438 }
439 if (k == BITVECTOR_OR)
440 {
441 // bit-vector ones
442 if (bv::utils::isOnes(n))
443 {
444 return true;
445 }
446 }
447
448 if ((arg == 1 && k == STRING_STRCTN) || (arg == 0 && k == STRING_SUBSTR))
449 {
450 // empty string
451 if (n.getConst<String>().size() == 0)
452 {
453 return true;
454 }
455 }
456 if ((arg != 0 && k == STRING_SUBSTR) || (arg == 2 && k == STRING_STRIDOF))
457 {
458 // negative integer
459 if (n.getConst<Rational>().sgn() < 0)
460 {
461 return true;
462 }
463 }
464 return false;
465 }
466
467 } // namespace theory
468 } // namespace CVC4