Fix corner case of wrongly applied selector as trigger (#5786)
[cvc5.git] / src / theory / subs_minimize.cpp
1 /********************* */
2 /*! \file subs_minimize.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Mathias Preiner
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of substitution minimization.
13 **/
14
15 #include "theory/subs_minimize.h"
16
17 #include "expr/node_algorithm.h"
18 #include "theory/bv/theory_bv_utils.h"
19 #include "theory/rewriter.h"
20 #include "theory/strings/word.h"
21
22 using namespace std;
23 using namespace CVC4::kind;
24
25 namespace CVC4 {
26 namespace theory {
27
28 SubstitutionMinimize::SubstitutionMinimize() {}
29
30 bool SubstitutionMinimize::find(Node t,
31 Node target,
32 const std::vector<Node>& vars,
33 const std::vector<Node>& subs,
34 std::vector<Node>& reqVars)
35 {
36 return findInternal(t, target, vars, subs, reqVars);
37 }
38
39 void getConjuncts(Node n, std::vector<Node>& conj)
40 {
41 if (n.getKind() == AND)
42 {
43 for (const Node& nc : n)
44 {
45 conj.push_back(nc);
46 }
47 }
48 else
49 {
50 conj.push_back(n);
51 }
52 }
53
54 bool SubstitutionMinimize::findWithImplied(Node t,
55 const std::vector<Node>& vars,
56 const std::vector<Node>& subs,
57 std::vector<Node>& reqVars,
58 std::vector<Node>& impliedVars)
59 {
60 NodeManager* nm = NodeManager::currentNM();
61 Node truen = nm->mkConst(true);
62 if (!findInternal(t, truen, vars, subs, reqVars))
63 {
64 return false;
65 }
66 if (reqVars.empty())
67 {
68 return true;
69 }
70
71 // map from conjuncts of t to whether they may be used to show an implied var
72 std::vector<Node> tconj;
73 getConjuncts(t, tconj);
74 // map from conjuncts to their free symbols
75 std::map<Node, std::unordered_set<Node, NodeHashFunction> > tcFv;
76
77 std::unordered_set<Node, NodeHashFunction> reqSet;
78 std::vector<Node> reqSubs;
79 std::map<Node, unsigned> reqVarToIndex;
80 for (const Node& v : reqVars)
81 {
82 reqVarToIndex[v] = reqSubs.size();
83 const std::vector<Node>::const_iterator& it =
84 std::find(vars.begin(), vars.end(), v);
85 Assert(it != vars.end());
86 ptrdiff_t pos = std::distance(vars.begin(), it);
87 reqSubs.push_back(subs[pos]);
88 }
89 std::vector<Node> finalReqVars;
90 for (const Node& v : vars)
91 {
92 if (reqVarToIndex.find(v) == reqVarToIndex.end())
93 {
94 // not a required variable, nothing to do
95 continue;
96 }
97 unsigned vindex = reqVarToIndex[v];
98 Node prev = reqSubs[vindex];
99 // make identity substitution
100 reqSubs[vindex] = v;
101 bool madeImplied = false;
102 // it is a required variable, can we make an implied variable?
103 for (const Node& tc : tconj)
104 {
105 // ensure we've computed its free symbols
106 std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator
107 itf = tcFv.find(tc);
108 if (itf == tcFv.end())
109 {
110 expr::getSymbols(tc, tcFv[tc]);
111 itf = tcFv.find(tc);
112 }
113 // only have a chance if contains v
114 if (itf->second.find(v) == itf->second.end())
115 {
116 continue;
117 }
118 // try the current substitution
119 Node tcs = tc.substitute(
120 reqVars.begin(), reqVars.end(), reqSubs.begin(), reqSubs.end());
121 Node tcsr = Rewriter::rewrite(tcs);
122 std::vector<Node> tcsrConj;
123 getConjuncts(tcsr, tcsrConj);
124 for (const Node& tcc : tcsrConj)
125 {
126 if (tcc.getKind() == EQUAL)
127 {
128 for (unsigned r = 0; r < 2; r++)
129 {
130 if (tcc[r] == v)
131 {
132 Node res = tcc[1 - r];
133 if (res.isConst())
134 {
135 Assert(res == prev);
136 madeImplied = true;
137 break;
138 }
139 }
140 }
141 }
142 if (madeImplied)
143 {
144 break;
145 }
146 }
147 if (madeImplied)
148 {
149 break;
150 }
151 }
152 if (!madeImplied)
153 {
154 // revert the substitution
155 reqSubs[vindex] = prev;
156 finalReqVars.push_back(v);
157 }
158 else
159 {
160 impliedVars.push_back(v);
161 }
162 }
163 reqVars.clear();
164 reqVars.insert(reqVars.end(), finalReqVars.begin(), finalReqVars.end());
165
166 return true;
167 }
168
169 bool SubstitutionMinimize::findInternal(Node n,
170 Node target,
171 const std::vector<Node>& vars,
172 const std::vector<Node>& subs,
173 std::vector<Node>& reqVars)
174 {
175 Trace("subs-min") << "Substitution minimize : " << std::endl;
176 Trace("subs-min") << " substitution : " << vars << " -> " << subs
177 << std::endl;
178 Trace("subs-min") << " node : " << n << std::endl;
179 Trace("subs-min") << " target : " << target << std::endl;
180
181 Trace("subs-min") << "--- Compute values for subterms..." << std::endl;
182 // the value of each subterm in n under the substitution
183 std::unordered_map<TNode, Node, TNodeHashFunction> value;
184 std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
185 std::vector<TNode> visit;
186 TNode cur;
187 visit.push_back(n);
188 do
189 {
190 cur = visit.back();
191 visit.pop_back();
192 it = value.find(cur);
193
194 if (it == value.end())
195 {
196 if (cur.isVar())
197 {
198 const std::vector<Node>::const_iterator& iit =
199 std::find(vars.begin(), vars.end(), cur);
200 if (iit == vars.end())
201 {
202 value[cur] = cur;
203 }
204 else
205 {
206 ptrdiff_t pos = std::distance(vars.begin(), iit);
207 value[cur] = subs[pos];
208 }
209 }
210 else
211 {
212 value[cur] = Node::null();
213 visit.push_back(cur);
214 if (cur.getKind() == APPLY_UF)
215 {
216 visit.push_back(cur.getOperator());
217 }
218 visit.insert(visit.end(), cur.begin(), cur.end());
219 }
220 }
221 else if (it->second.isNull())
222 {
223 Node ret = cur;
224 if (cur.getNumChildren() > 0)
225 {
226 std::vector<Node> children;
227 NodeBuilder<> nb(cur.getKind());
228 if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
229 {
230 if (cur.getKind() == APPLY_UF)
231 {
232 children.push_back(cur.getOperator());
233 }
234 else
235 {
236 nb << cur.getOperator();
237 }
238 }
239 children.insert(children.end(), cur.begin(), cur.end());
240 for (const Node& cn : children)
241 {
242 it = value.find(cn);
243 Assert(it != value.end());
244 Assert(!it->second.isNull());
245 nb << it->second;
246 }
247 ret = nb.constructNode();
248 ret = Rewriter::rewrite(ret);
249 }
250 value[cur] = ret;
251 }
252 } while (!visit.empty());
253 Assert(value.find(n) != value.end());
254 Assert(!value.find(n)->second.isNull());
255
256 Trace("subs-min") << "... got " << value[n] << std::endl;
257 if (value[n] != target)
258 {
259 Trace("subs-min") << "... not equal to target " << target << std::endl;
260 return false;
261 }
262
263 Trace("subs-min") << "--- Compute relevant variables..." << std::endl;
264 std::unordered_set<Node, NodeHashFunction> rlvFv;
265 // only variables that occur in assertions are relevant
266
267 visit.push_back(n);
268 std::unordered_set<TNode, TNodeHashFunction> visited;
269 std::unordered_set<TNode, TNodeHashFunction>::iterator itv;
270 do
271 {
272 cur = visit.back();
273 visit.pop_back();
274 itv = visited.find(cur);
275 if (itv == visited.end())
276 {
277 visited.insert(cur);
278 it = value.find(cur);
279 if (it->second == cur)
280 {
281 // if its value is the same as current, there is nothing to do
282 }
283 else if (cur.isVar())
284 {
285 // must include
286 rlvFv.insert(cur);
287 }
288 else if (cur.getKind() == ITE)
289 {
290 // only recurse on relevant branch
291 Node bval = value[cur[0]];
292 Assert(!bval.isNull() && bval.isConst());
293 unsigned cindex = bval.getConst<bool>() ? 1 : 2;
294 visit.push_back(cur[0]);
295 visit.push_back(cur[cindex]);
296 }
297 else if (cur.getNumChildren() > 0)
298 {
299 Kind ck = cur.getKind();
300 bool alreadyJustified = false;
301
302 // if the operator is an apply uf, check its value
303 if (cur.getKind() == APPLY_UF)
304 {
305 Node op = cur.getOperator();
306 it = value.find(op);
307 Assert(it != value.end());
308 TNode vop = it->second;
309 if (vop.getKind() == LAMBDA)
310 {
311 visit.push_back(op);
312 // do iterative partial evaluation on the body of the lambda
313 Node curr = vop[1];
314 for (unsigned i = 0, size = cur.getNumChildren(); i < size; i++)
315 {
316 it = value.find(cur[i]);
317 Assert(it != value.end());
318 Node scurr = curr.substitute(vop[0][i], it->second);
319 // if the valuation of the i^th argument changes the
320 // interpretation of the body of the lambda, then the i^th
321 // argument is relevant to the substitution. Hence, we add
322 // i to visit, and update curr below.
323 if (scurr != curr)
324 {
325 curr = Rewriter::rewrite(scurr);
326 visit.push_back(cur[i]);
327 }
328 }
329 alreadyJustified = true;
330 }
331 }
332 if (!alreadyJustified)
333 {
334 // a subset of the arguments of cur that fully justify the evaluation
335 std::vector<unsigned> justifyArgs;
336 if (cur.getNumChildren() > 1)
337 {
338 for (unsigned i = 0, size = cur.getNumChildren(); i < size; i++)
339 {
340 Node cn = cur[i];
341 it = value.find(cn);
342 Assert(it != value.end());
343 Assert(!it->second.isNull());
344 if (isSingularArg(it->second, ck, i))
345 {
346 // have we seen this argument already? if so, we are done
347 if (visited.find(cn) != visited.end())
348 {
349 alreadyJustified = true;
350 break;
351 }
352 justifyArgs.push_back(i);
353 }
354 }
355 }
356 // we need to recurse on at most one child
357 if (!alreadyJustified && !justifyArgs.empty())
358 {
359 unsigned sindex = justifyArgs[0];
360 // could choose a best index, for now, we just take the first
361 visit.push_back(cur[sindex]);
362 alreadyJustified = true;
363 }
364 }
365 if (!alreadyJustified)
366 {
367 // must recurse on all arguments, including operator
368 if (cur.getKind() == APPLY_UF)
369 {
370 visit.push_back(cur.getOperator());
371 }
372 for (const Node& cn : cur)
373 {
374 visit.push_back(cn);
375 }
376 }
377 }
378 }
379 } while (!visit.empty());
380
381 for (const Node& v : rlvFv)
382 {
383 Assert(std::find(vars.begin(), vars.end(), v) != vars.end());
384 reqVars.push_back(v);
385 }
386
387 Trace("subs-min") << "... requires " << reqVars.size() << "/" << vars.size()
388 << " : " << reqVars << std::endl;
389
390 return true;
391 }
392
393 bool SubstitutionMinimize::isSingularArg(Node n, Kind k, unsigned arg)
394 {
395 // Notice that this function is hardcoded. We could compute this function
396 // in a theory-independent way using partial evaluation. However, we
397 // prefer performance to generality here.
398
399 // TODO: a variant of this code is implemented in quantifiers::TermUtil.
400 // These implementations should be merged (see #1216).
401 if (!n.isConst())
402 {
403 return false;
404 }
405 if (k == AND)
406 {
407 return !n.getConst<bool>();
408 }
409 else if (k == OR)
410 {
411 return n.getConst<bool>();
412 }
413 else if (k == IMPLIES)
414 {
415 return arg == (n.getConst<bool>() ? 1 : 0);
416 }
417 if (k == MULT
418 || (arg == 0
419 && (k == DIVISION_TOTAL || k == INTS_DIVISION_TOTAL
420 || k == INTS_MODULUS_TOTAL))
421 || (arg == 2 && k == STRING_SUBSTR))
422 {
423 // zero
424 if (n.getConst<Rational>().sgn() == 0)
425 {
426 return true;
427 }
428 }
429 if (k == BITVECTOR_AND || k == BITVECTOR_MULT || k == BITVECTOR_UDIV_TOTAL
430 || k == BITVECTOR_UREM_TOTAL
431 || (arg == 0
432 && (k == BITVECTOR_SHL || k == BITVECTOR_LSHR
433 || k == BITVECTOR_ASHR)))
434 {
435 if (bv::utils::isZero(n))
436 {
437 return true;
438 }
439 }
440 if (k == BITVECTOR_OR)
441 {
442 // bit-vector ones
443 if (bv::utils::isOnes(n))
444 {
445 return true;
446 }
447 }
448
449 if ((arg == 1 && k == STRING_STRCTN) || (arg == 0 && k == STRING_SUBSTR))
450 {
451 // empty string
452 if (strings::Word::getLength(n) == 0)
453 {
454 return true;
455 }
456 }
457 if ((arg != 0 && k == STRING_SUBSTR) || (arg == 2 && k == STRING_STRIDOF))
458 {
459 // negative integer
460 if (n.getConst<Rational>().sgn() < 0)
461 {
462 return true;
463 }
464 }
465 return false;
466 }
467
468 } // namespace theory
469 } // namespace CVC4