Add some missing cases in evaluator (#3133)
[cvc5.git] / src / theory / evaluator.cpp
1 /********************* */
2 /*! \file evaluator.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andres Noetzli, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief The Evaluator class
13 **
14 ** The Evaluator class.
15 **/
16
17 #include "theory/evaluator.h"
18
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/theory.h"
21 #include "util/integer.h"
22
23 namespace CVC4 {
24 namespace theory {
25
26 EvalResult::EvalResult(const EvalResult& other)
27 {
28 d_tag = other.d_tag;
29 switch (d_tag)
30 {
31 case BOOL: d_bool = other.d_bool; break;
32 case BITVECTOR:
33 new (&d_bv) BitVector;
34 d_bv = other.d_bv;
35 break;
36 case RATIONAL:
37 new (&d_rat) Rational;
38 d_rat = other.d_rat;
39 break;
40 case STRING:
41 new (&d_str) String;
42 d_str = other.d_str;
43 break;
44 case INVALID: break;
45 }
46 }
47
48 EvalResult& EvalResult::operator=(const EvalResult& other)
49 {
50 if (this != &other)
51 {
52 d_tag = other.d_tag;
53 switch (d_tag)
54 {
55 case BOOL: d_bool = other.d_bool; break;
56 case BITVECTOR:
57 new (&d_bv) BitVector;
58 d_bv = other.d_bv;
59 break;
60 case RATIONAL:
61 new (&d_rat) Rational;
62 d_rat = other.d_rat;
63 break;
64 case STRING:
65 new (&d_str) String;
66 d_str = other.d_str;
67 break;
68 case INVALID: break;
69 }
70 }
71 return *this;
72 }
73
74 EvalResult::~EvalResult()
75 {
76 switch (d_tag)
77 {
78 case BITVECTOR:
79 {
80 d_bv.~BitVector();
81 break;
82 }
83 case RATIONAL:
84 {
85 d_rat.~Rational();
86 break;
87 }
88 case STRING:
89 {
90 d_str.~String();
91 break;
92
93 default: break;
94 }
95 }
96 }
97
98 Node EvalResult::toNode() const
99 {
100 NodeManager* nm = NodeManager::currentNM();
101 switch (d_tag)
102 {
103 case EvalResult::BOOL: return nm->mkConst(d_bool);
104 case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
105 case EvalResult::RATIONAL: return nm->mkConst(d_rat);
106 case EvalResult::STRING: return nm->mkConst(d_str);
107 default:
108 {
109 Trace("evaluator") << "Missing conversion from " << d_tag << " to node"
110 << std::endl;
111 return Node();
112 }
113 }
114 }
115
116 Node Evaluator::eval(TNode n,
117 const std::vector<Node>& args,
118 const std::vector<Node>& vals)
119 {
120 Trace("evaluator") << "Evaluating " << n << " under substitution " << args
121 << " " << vals << std::endl;
122 return evalInternal(n, args, vals).toNode();
123 }
124
125 EvalResult Evaluator::evalInternal(TNode n,
126 const std::vector<Node>& args,
127 const std::vector<Node>& vals)
128 {
129 std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
130 std::vector<TNode> queue;
131 queue.emplace_back(n);
132
133 while (queue.size() != 0)
134 {
135 TNode currNode = queue.back();
136
137 if (results.find(currNode) != results.end())
138 {
139 queue.pop_back();
140 continue;
141 }
142
143 bool doEval = true;
144 for (const auto& currNodeChild : currNode)
145 {
146 if (results.find(currNodeChild) == results.end())
147 {
148 queue.emplace_back(currNodeChild);
149 doEval = false;
150 }
151 }
152
153 if (doEval)
154 {
155 queue.pop_back();
156
157 Node currNodeVal = currNode;
158 if (currNode.isVar())
159 {
160 const auto& it = std::find(args.begin(), args.end(), currNode);
161
162 if (it == args.end())
163 {
164 return EvalResult();
165 }
166
167 ptrdiff_t pos = std::distance(args.begin(), it);
168 currNodeVal = vals[pos];
169 }
170 else if (currNode.getKind() == kind::APPLY_UF
171 && currNode.getOperator().getKind() == kind::LAMBDA)
172 {
173 // Create a copy of the current substitutions
174 std::vector<Node> lambdaArgs(args);
175 std::vector<Node> lambdaVals(vals);
176
177 // Add the values for the arguments of the lambda as substitutions at
178 // the beginning of the vector to shadow variables from outer scopes
179 // with the same name
180 Node op = currNode.getOperator();
181 for (const auto& lambdaArg : op[0])
182 {
183 lambdaArgs.insert(lambdaArgs.begin(), lambdaArg);
184 }
185
186 for (const auto& lambdaVal : currNode)
187 {
188 lambdaVals.insert(lambdaVals.begin(), results[lambdaVal].toNode());
189 }
190
191 // Lambdas are evaluated in a recursive fashion because each evaluation
192 // requires different substitutions
193 results[currNode] = evalInternal(op[1], lambdaArgs, lambdaVals);
194 if (results[currNode].d_tag == EvalResult::INVALID)
195 {
196 // evaluation was invalid, we fail
197 return results[currNode];
198 }
199 continue;
200 }
201
202 switch (currNodeVal.getKind())
203 {
204 case kind::CONST_BOOLEAN:
205 results[currNode] = EvalResult(currNodeVal.getConst<bool>());
206 break;
207
208 case kind::NOT:
209 {
210 results[currNode] = EvalResult(!(results[currNode[0]].d_bool));
211 break;
212 }
213
214 case kind::AND:
215 {
216 bool res = results[currNode[0]].d_bool;
217 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
218 {
219 res = res && results[currNode[i]].d_bool;
220 }
221 results[currNode] = EvalResult(res);
222 break;
223 }
224
225 case kind::OR:
226 {
227 bool res = results[currNode[0]].d_bool;
228 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
229 {
230 res = res || results[currNode[i]].d_bool;
231 }
232 results[currNode] = EvalResult(res);
233 break;
234 }
235
236 case kind::CONST_RATIONAL:
237 {
238 const Rational& r = currNodeVal.getConst<Rational>();
239 results[currNode] = EvalResult(r);
240 break;
241 }
242
243 case kind::PLUS:
244 {
245 Rational res = results[currNode[0]].d_rat;
246 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
247 {
248 res = res + results[currNode[i]].d_rat;
249 }
250 results[currNode] = EvalResult(res);
251 break;
252 }
253
254 case kind::MINUS:
255 {
256 const Rational& x = results[currNode[0]].d_rat;
257 const Rational& y = results[currNode[1]].d_rat;
258 results[currNode] = EvalResult(x - y);
259 break;
260 }
261
262 case kind::UMINUS:
263 {
264 const Rational& x = results[currNode[0]].d_rat;
265 results[currNode] = EvalResult(-x);
266 break;
267 }
268 case kind::MULT:
269 {
270 Rational res = results[currNode[0]].d_rat;
271 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
272 {
273 res = res * results[currNode[i]].d_rat;
274 }
275 results[currNode] = EvalResult(res);
276 break;
277 }
278
279 case kind::GEQ:
280 {
281 const Rational& x = results[currNode[0]].d_rat;
282 const Rational& y = results[currNode[1]].d_rat;
283 results[currNode] = EvalResult(x >= y);
284 break;
285 }
286 case kind::LEQ:
287 {
288 const Rational& x = results[currNode[0]].d_rat;
289 const Rational& y = results[currNode[1]].d_rat;
290 results[currNode] = EvalResult(x <= y);
291 break;
292 }
293 case kind::GT:
294 {
295 const Rational& x = results[currNode[0]].d_rat;
296 const Rational& y = results[currNode[1]].d_rat;
297 results[currNode] = EvalResult(x > y);
298 break;
299 }
300 case kind::LT:
301 {
302 const Rational& x = results[currNode[0]].d_rat;
303 const Rational& y = results[currNode[1]].d_rat;
304 results[currNode] = EvalResult(x < y);
305 break;
306 }
307
308 case kind::CONST_STRING:
309 results[currNode] = EvalResult(currNodeVal.getConst<String>());
310 break;
311
312 case kind::STRING_CONCAT:
313 {
314 String res = results[currNode[0]].d_str;
315 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
316 {
317 res = res.concat(results[currNode[i]].d_str);
318 }
319 results[currNode] = EvalResult(res);
320 break;
321 }
322
323 case kind::STRING_LENGTH:
324 {
325 const String& s = results[currNode[0]].d_str;
326 results[currNode] = EvalResult(Rational(s.size()));
327 break;
328 }
329
330 case kind::STRING_SUBSTR:
331 {
332 const String& s = results[currNode[0]].d_str;
333 Integer s_len(s.size());
334 Integer i = results[currNode[1]].d_rat.getNumerator();
335 Integer j = results[currNode[2]].d_rat.getNumerator();
336
337 if (i.strictlyNegative() || j.strictlyNegative() || i >= s_len)
338 {
339 results[currNode] = EvalResult(String(""));
340 }
341 else if (i + j > s_len)
342 {
343 results[currNode] =
344 EvalResult(s.suffix((s_len - i).toUnsignedInt()));
345 }
346 else
347 {
348 results[currNode] =
349 EvalResult(s.substr(i.toUnsignedInt(), j.toUnsignedInt()));
350 }
351 break;
352 }
353
354 case kind::STRING_CHARAT:
355 {
356 const String& s = results[currNode[0]].d_str;
357 Integer s_len(s.size());
358 Integer i = results[currNode[1]].d_rat.getNumerator();
359 if (i.strictlyNegative() || i >= s_len)
360 {
361 results[currNode] = EvalResult(String(""));
362 }
363 else
364 {
365 results[currNode] = EvalResult(s.substr(i.toUnsignedInt(), 1));
366 }
367 break;
368 }
369
370 case kind::STRING_STRCTN:
371 {
372 const String& s = results[currNode[0]].d_str;
373 const String& t = results[currNode[1]].d_str;
374 results[currNode] = EvalResult(s.find(t) != std::string::npos);
375 break;
376 }
377
378 case kind::STRING_STRIDOF:
379 {
380 const String& s = results[currNode[0]].d_str;
381 Integer s_len(s.size());
382 const String& x = results[currNode[1]].d_str;
383 Integer i = results[currNode[2]].d_rat.getNumerator();
384
385 if (i.strictlyNegative())
386 {
387 results[currNode] = EvalResult(Rational(-1));
388 }
389 else
390 {
391 size_t r = s.find(x, i.toUnsignedInt());
392 if (r == std::string::npos)
393 {
394 results[currNode] = EvalResult(Rational(-1));
395 }
396 else
397 {
398 results[currNode] = EvalResult(Rational(r));
399 }
400 }
401 break;
402 }
403
404 case kind::STRING_STRREPL:
405 {
406 const String& s = results[currNode[0]].d_str;
407 const String& x = results[currNode[1]].d_str;
408 const String& y = results[currNode[2]].d_str;
409 results[currNode] = EvalResult(s.replace(x, y));
410 break;
411 }
412
413 case kind::STRING_PREFIX:
414 {
415 const String& t = results[currNode[0]].d_str;
416 const String& s = results[currNode[1]].d_str;
417 if (s.size() < t.size())
418 {
419 results[currNode] = EvalResult(false);
420 }
421 else
422 {
423 results[currNode] = EvalResult(s.prefix(t.size()) == t);
424 }
425 break;
426 }
427
428 case kind::STRING_SUFFIX:
429 {
430 const String& t = results[currNode[0]].d_str;
431 const String& s = results[currNode[1]].d_str;
432 if (s.size() < t.size())
433 {
434 results[currNode] = EvalResult(false);
435 }
436 else
437 {
438 results[currNode] = EvalResult(s.suffix(t.size()) == t);
439 }
440 break;
441 }
442
443 case kind::STRING_ITOS:
444 {
445 Integer i = results[currNode[0]].d_rat.getNumerator();
446 if (i.strictlyNegative())
447 {
448 results[currNode] = EvalResult(String(""));
449 }
450 else
451 {
452 results[currNode] = EvalResult(String(i.toString()));
453 }
454 break;
455 }
456
457 case kind::STRING_STOI:
458 {
459 const String& s = results[currNode[0]].d_str;
460 if (s.isNumber())
461 {
462 results[currNode] = EvalResult(Rational(s.toNumber()));
463 }
464 else
465 {
466 results[currNode] = EvalResult(Rational(-1));
467 }
468 break;
469 }
470
471 case kind::STRING_CODE:
472 {
473 const String& s = results[currNode[0]].d_str;
474 if (s.size() == 1)
475 {
476 results[currNode] = EvalResult(
477 Rational(String::convertUnsignedIntToCode(s.getVec()[0])));
478 }
479 else
480 {
481 results[currNode] = EvalResult(Rational(-1));
482 }
483 break;
484 }
485
486 case kind::CONST_BITVECTOR:
487 results[currNode] = EvalResult(currNodeVal.getConst<BitVector>());
488 break;
489
490 case kind::BITVECTOR_NOT:
491 results[currNode] = EvalResult(~results[currNode[0]].d_bv);
492 break;
493
494 case kind::BITVECTOR_NEG:
495 results[currNode] = EvalResult(-results[currNode[0]].d_bv);
496 break;
497
498 case kind::BITVECTOR_EXTRACT:
499 {
500 unsigned lo = bv::utils::getExtractLow(currNodeVal);
501 unsigned hi = bv::utils::getExtractHigh(currNodeVal);
502 results[currNode] =
503 EvalResult(results[currNode[0]].d_bv.extract(hi, lo));
504 break;
505 }
506
507 case kind::BITVECTOR_CONCAT:
508 {
509 BitVector res = results[currNode[0]].d_bv;
510 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
511 {
512 res = res.concat(results[currNode[i]].d_bv);
513 }
514 results[currNode] = EvalResult(res);
515 break;
516 }
517
518 case kind::BITVECTOR_PLUS:
519 {
520 BitVector res = results[currNode[0]].d_bv;
521 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
522 {
523 res = res + results[currNode[i]].d_bv;
524 }
525 results[currNode] = EvalResult(res);
526 break;
527 }
528
529 case kind::BITVECTOR_MULT:
530 {
531 BitVector res = results[currNode[0]].d_bv;
532 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
533 {
534 res = res * results[currNode[i]].d_bv;
535 }
536 results[currNode] = EvalResult(res);
537 break;
538 }
539 case kind::BITVECTOR_AND:
540 {
541 BitVector res = results[currNode[0]].d_bv;
542 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
543 {
544 res = res & results[currNode[i]].d_bv;
545 }
546 results[currNode] = EvalResult(res);
547 break;
548 }
549
550 case kind::BITVECTOR_OR:
551 {
552 BitVector res = results[currNode[0]].d_bv;
553 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
554 {
555 res = res | results[currNode[i]].d_bv;
556 }
557 results[currNode] = EvalResult(res);
558 break;
559 }
560
561 case kind::BITVECTOR_XOR:
562 {
563 BitVector res = results[currNode[0]].d_bv;
564 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
565 {
566 res = res ^ results[currNode[i]].d_bv;
567 }
568 results[currNode] = EvalResult(res);
569 break;
570 }
571
572 case kind::EQUAL:
573 {
574 EvalResult lhs = results[currNode[0]];
575 EvalResult rhs = results[currNode[1]];
576
577 switch (lhs.d_tag)
578 {
579 case EvalResult::BOOL:
580 {
581 results[currNode] = EvalResult(lhs.d_bool == rhs.d_bool);
582 break;
583 }
584
585 case EvalResult::BITVECTOR:
586 {
587 results[currNode] = EvalResult(lhs.d_bv == rhs.d_bv);
588 break;
589 }
590
591 case EvalResult::RATIONAL:
592 {
593 results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
594 break;
595 }
596
597 case EvalResult::STRING:
598 {
599 results[currNode] = EvalResult(lhs.d_str == rhs.d_str);
600 break;
601 }
602
603 default:
604 {
605 Trace("evaluator") << "Theory " << Theory::theoryOf(currNode[0])
606 << " not supported" << std::endl;
607 return EvalResult();
608 break;
609 }
610 }
611
612 break;
613 }
614
615 case kind::ITE:
616 {
617 if (results[currNode[0]].d_bool)
618 {
619 results[currNode] = results[currNode[1]];
620 }
621 else
622 {
623 results[currNode] = results[currNode[2]];
624 }
625 break;
626 }
627
628 default:
629 {
630 Trace("evaluator") << "Kind " << currNodeVal.getKind()
631 << " not supported" << std::endl;
632 return EvalResult();
633 }
634 }
635 }
636 }
637
638 return results[n];
639 }
640
641 } // namespace theory
642 } // namespace CVC4