Incorporate rewriting on demand in the evaluator (#3549)
[cvc5.git] / src / theory / evaluator.cpp
1 /********************* */
2 /*! \file evaluator.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andres Noetzli, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief The Evaluator class
13 **
14 ** The Evaluator class.
15 **/
16
17 #include "theory/evaluator.h"
18
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/rewriter.h"
21 #include "theory/theory.h"
22 #include "util/integer.h"
23
24 namespace CVC4 {
25 namespace theory {
26
27 EvalResult::EvalResult(const EvalResult& other)
28 {
29 d_tag = other.d_tag;
30 switch (d_tag)
31 {
32 case BOOL: d_bool = other.d_bool; break;
33 case BITVECTOR:
34 new (&d_bv) BitVector;
35 d_bv = other.d_bv;
36 break;
37 case RATIONAL:
38 new (&d_rat) Rational;
39 d_rat = other.d_rat;
40 break;
41 case STRING:
42 new (&d_str) String;
43 d_str = other.d_str;
44 break;
45 case INVALID: break;
46 }
47 }
48
49 EvalResult& EvalResult::operator=(const EvalResult& other)
50 {
51 if (this != &other)
52 {
53 d_tag = other.d_tag;
54 switch (d_tag)
55 {
56 case BOOL: d_bool = other.d_bool; break;
57 case BITVECTOR:
58 new (&d_bv) BitVector;
59 d_bv = other.d_bv;
60 break;
61 case RATIONAL:
62 new (&d_rat) Rational;
63 d_rat = other.d_rat;
64 break;
65 case STRING:
66 new (&d_str) String;
67 d_str = other.d_str;
68 break;
69 case INVALID: break;
70 }
71 }
72 return *this;
73 }
74
75 EvalResult::~EvalResult()
76 {
77 switch (d_tag)
78 {
79 case BITVECTOR:
80 {
81 d_bv.~BitVector();
82 break;
83 }
84 case RATIONAL:
85 {
86 d_rat.~Rational();
87 break;
88 }
89 case STRING:
90 {
91 d_str.~String();
92 break;
93
94 default: break;
95 }
96 }
97 }
98
99 Node EvalResult::toNode() const
100 {
101 NodeManager* nm = NodeManager::currentNM();
102 switch (d_tag)
103 {
104 case EvalResult::BOOL: return nm->mkConst(d_bool);
105 case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
106 case EvalResult::RATIONAL: return nm->mkConst(d_rat);
107 case EvalResult::STRING: return nm->mkConst(d_str);
108 default:
109 {
110 Trace("evaluator") << "Missing conversion from " << d_tag << " to node"
111 << std::endl;
112 return Node();
113 }
114 }
115 }
116
117 Node Evaluator::eval(TNode n,
118 const std::vector<Node>& args,
119 const std::vector<Node>& vals)
120 {
121 Trace("evaluator") << "Evaluating " << n << " under substitution " << args
122 << " " << vals << std::endl;
123 std::unordered_map<TNode, Node, NodeHashFunction> evalAsNode;
124 Node ret = evalInternal(n, args, vals, evalAsNode).toNode();
125 if (!ret.isNull())
126 {
127 // maybe it was stored in the evaluation-as-node map
128 std::unordered_map<TNode, Node, NodeHashFunction>::iterator itn =
129 evalAsNode.find(n);
130 if (itn != evalAsNode.end())
131 {
132 return itn->second;
133 }
134 }
135 return ret;
136 }
137
138 EvalResult Evaluator::evalInternal(
139 TNode n,
140 const std::vector<Node>& args,
141 const std::vector<Node>& vals,
142 std::unordered_map<TNode, Node, NodeHashFunction>& evalAsNode)
143 {
144 std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
145 std::vector<TNode> queue;
146 queue.emplace_back(n);
147 std::unordered_map<TNode, Node, NodeHashFunction>::iterator itn;
148 std::unordered_map<TNode, EvalResult, TNodeHashFunction>::iterator itr;
149 NodeManager* nm = NodeManager::currentNM();
150
151 while (queue.size() != 0)
152 {
153 TNode currNode = queue.back();
154
155 if (results.find(currNode) != results.end())
156 {
157 queue.pop_back();
158 continue;
159 }
160
161 bool doProcess = true;
162 bool doEval = true;
163 for (const auto& currNodeChild : currNode)
164 {
165 itr = results.find(currNodeChild);
166 if (itr == results.end())
167 {
168 queue.emplace_back(currNodeChild);
169 doProcess = false;
170 }
171 else if (itr->second.d_tag == EvalResult::INVALID)
172 {
173 // we cannot evaluate since there was an invalid child
174 doEval = false;
175 }
176 }
177 Trace("evaluator") << "Evaluator: visit " << currNode
178 << ", process = " << doProcess
179 << ", evaluate = " << doEval << std::endl;
180
181 if (doProcess)
182 {
183 queue.pop_back();
184
185 Node currNodeVal = currNode;
186
187 // The code below should either:
188 // (1) store a valid EvalResult into results[currNode], or
189 // (2) store an invalid EvalResult into results[currNode] and
190 // store the result of substitution + rewriting currNode { args -> vals }
191 // into evalAsNode[currNode].
192
193 // If we did not successfully evaluate all children
194 if (!doEval)
195 {
196 // Reconstruct the node with a combination of the children that
197 // successfully evaluated, and the children that did not.
198 Trace("evaluator") << "Evaluator: collect arguments" << std::endl;
199 std::vector<Node> echildren;
200 if (currNode.getMetaKind() == kind::metakind::PARAMETERIZED)
201 {
202 echildren.push_back(currNode.getOperator());
203 }
204 for (const auto& currNodeChild : currNode)
205 {
206 itr = results.find(currNodeChild);
207 if (itr->second.d_tag == EvalResult::INVALID)
208 {
209 // could not evaluate this child, look in the node cache
210 itn = evalAsNode.find(currNodeChild);
211 Assert(itn != evalAsNode.end());
212 echildren.push_back(itn->second);
213 }
214 else
215 {
216 // otherwise, use the evaluation
217 echildren.push_back(itr->second.toNode());
218 }
219 }
220 // The value is the result of our (partially) successful evaluation
221 // of the children.
222 currNodeVal = nm->mkNode(currNode.getKind(), echildren);
223 Trace("evaluator") << "Evaluator: partially evaluated " << currNodeVal
224 << std::endl;
225 // Use rewriting. Notice we do not need to substitute here since
226 // all substitutions should already have been applied recursively.
227 currNodeVal = Rewriter::rewrite(currNodeVal);
228 Trace("evaluator") << "Evaluator: now after substitution + rewriting: "
229 << currNodeVal << std::endl;
230 if (currNodeVal.getNumChildren() > 0)
231 {
232 // We may continue with a valid EvalResult at this point only if
233 // we have no children. We must otherwise fail here since some of
234 // our children may not have successful evaluations.
235 results[currNode] = EvalResult();
236 evalAsNode[currNode] = currNodeVal;
237 continue;
238 }
239 // Otherwise, we may be able to turn the overall result into an
240 // valid EvalResult and continue. We fallthrough and continue with the
241 // block of code below.
242 }
243
244 if (currNode.isVar())
245 {
246 const auto& it = std::find(args.begin(), args.end(), currNode);
247 if (it == args.end())
248 {
249 evalAsNode[currNode] = currNode;
250 results[currNode] = EvalResult();
251 continue;
252 }
253 ptrdiff_t pos = std::distance(args.begin(), it);
254 currNodeVal = vals[pos];
255 }
256 else if (currNode.getKind() == kind::APPLY_UF
257 && currNode.getOperator().getKind() == kind::LAMBDA)
258 {
259 // Create a copy of the current substitutions
260 std::vector<Node> lambdaArgs(args);
261 std::vector<Node> lambdaVals(vals);
262
263 // Add the values for the arguments of the lambda as substitutions at
264 // the beginning of the vector to shadow variables from outer scopes
265 // with the same name
266 Node op = currNode.getOperator();
267 for (const auto& lambdaArg : op[0])
268 {
269 lambdaArgs.insert(lambdaArgs.begin(), lambdaArg);
270 }
271
272 for (const auto& lambdaVal : currNode)
273 {
274 lambdaVals.insert(lambdaVals.begin(), results[lambdaVal].toNode());
275 }
276
277 // Lambdas are evaluated in a recursive fashion because each evaluation
278 // requires different substitutions. We use a fresh cache since the
279 // evaluation of op[1] is under a new substitution and thus should not
280 // be cached. We could alternatively copy evalAsNode to evalAsNodeC but
281 // favor avoiding this copy for performance reasons.
282 std::unordered_map<TNode, Node, NodeHashFunction> evalAsNodeC;
283 results[currNode] =
284 evalInternal(op[1], lambdaArgs, lambdaVals, evalAsNodeC);
285 if (results[currNode].d_tag == EvalResult::INVALID)
286 {
287 // evaluation was invalid, we take the node of op[1] as the result
288 evalAsNode[currNode] = evalAsNodeC[op[1]];
289 }
290 continue;
291 }
292
293 switch (currNodeVal.getKind())
294 {
295 case kind::CONST_BOOLEAN:
296 results[currNode] = EvalResult(currNodeVal.getConst<bool>());
297 break;
298
299 case kind::NOT:
300 {
301 results[currNode] = EvalResult(!(results[currNode[0]].d_bool));
302 break;
303 }
304
305 case kind::AND:
306 {
307 bool res = results[currNode[0]].d_bool;
308 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
309 {
310 res = res && results[currNode[i]].d_bool;
311 }
312 results[currNode] = EvalResult(res);
313 break;
314 }
315
316 case kind::OR:
317 {
318 bool res = results[currNode[0]].d_bool;
319 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
320 {
321 res = res || results[currNode[i]].d_bool;
322 }
323 results[currNode] = EvalResult(res);
324 break;
325 }
326
327 case kind::CONST_RATIONAL:
328 {
329 const Rational& r = currNodeVal.getConst<Rational>();
330 results[currNode] = EvalResult(r);
331 break;
332 }
333
334 case kind::PLUS:
335 {
336 Rational res = results[currNode[0]].d_rat;
337 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
338 {
339 res = res + results[currNode[i]].d_rat;
340 }
341 results[currNode] = EvalResult(res);
342 break;
343 }
344
345 case kind::MINUS:
346 {
347 const Rational& x = results[currNode[0]].d_rat;
348 const Rational& y = results[currNode[1]].d_rat;
349 results[currNode] = EvalResult(x - y);
350 break;
351 }
352
353 case kind::UMINUS:
354 {
355 const Rational& x = results[currNode[0]].d_rat;
356 results[currNode] = EvalResult(-x);
357 break;
358 }
359 case kind::MULT:
360 {
361 Rational res = results[currNode[0]].d_rat;
362 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
363 {
364 res = res * results[currNode[i]].d_rat;
365 }
366 results[currNode] = EvalResult(res);
367 break;
368 }
369
370 case kind::GEQ:
371 {
372 const Rational& x = results[currNode[0]].d_rat;
373 const Rational& y = results[currNode[1]].d_rat;
374 results[currNode] = EvalResult(x >= y);
375 break;
376 }
377 case kind::LEQ:
378 {
379 const Rational& x = results[currNode[0]].d_rat;
380 const Rational& y = results[currNode[1]].d_rat;
381 results[currNode] = EvalResult(x <= y);
382 break;
383 }
384 case kind::GT:
385 {
386 const Rational& x = results[currNode[0]].d_rat;
387 const Rational& y = results[currNode[1]].d_rat;
388 results[currNode] = EvalResult(x > y);
389 break;
390 }
391 case kind::LT:
392 {
393 const Rational& x = results[currNode[0]].d_rat;
394 const Rational& y = results[currNode[1]].d_rat;
395 results[currNode] = EvalResult(x < y);
396 break;
397 }
398 case kind::ABS:
399 {
400 const Rational& x = results[currNode[0]].d_rat;
401 results[currNode] = EvalResult(x.abs());
402 break;
403 }
404 case kind::CONST_STRING:
405 results[currNode] = EvalResult(currNodeVal.getConst<String>());
406 break;
407
408 case kind::STRING_CONCAT:
409 {
410 String res = results[currNode[0]].d_str;
411 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
412 {
413 res = res.concat(results[currNode[i]].d_str);
414 }
415 results[currNode] = EvalResult(res);
416 break;
417 }
418
419 case kind::STRING_LENGTH:
420 {
421 const String& s = results[currNode[0]].d_str;
422 results[currNode] = EvalResult(Rational(s.size()));
423 break;
424 }
425
426 case kind::STRING_SUBSTR:
427 {
428 const String& s = results[currNode[0]].d_str;
429 Integer s_len(s.size());
430 Integer i = results[currNode[1]].d_rat.getNumerator();
431 Integer j = results[currNode[2]].d_rat.getNumerator();
432
433 if (i.strictlyNegative() || j.strictlyNegative() || i >= s_len)
434 {
435 results[currNode] = EvalResult(String(""));
436 }
437 else if (i + j > s_len)
438 {
439 results[currNode] =
440 EvalResult(s.suffix((s_len - i).toUnsignedInt()));
441 }
442 else
443 {
444 results[currNode] =
445 EvalResult(s.substr(i.toUnsignedInt(), j.toUnsignedInt()));
446 }
447 break;
448 }
449
450 case kind::STRING_CHARAT:
451 {
452 const String& s = results[currNode[0]].d_str;
453 Integer s_len(s.size());
454 Integer i = results[currNode[1]].d_rat.getNumerator();
455 if (i.strictlyNegative() || i >= s_len)
456 {
457 results[currNode] = EvalResult(String(""));
458 }
459 else
460 {
461 results[currNode] = EvalResult(s.substr(i.toUnsignedInt(), 1));
462 }
463 break;
464 }
465
466 case kind::STRING_STRCTN:
467 {
468 const String& s = results[currNode[0]].d_str;
469 const String& t = results[currNode[1]].d_str;
470 results[currNode] = EvalResult(s.find(t) != std::string::npos);
471 break;
472 }
473
474 case kind::STRING_STRIDOF:
475 {
476 const String& s = results[currNode[0]].d_str;
477 Integer s_len(s.size());
478 const String& x = results[currNode[1]].d_str;
479 Integer i = results[currNode[2]].d_rat.getNumerator();
480
481 if (i.strictlyNegative())
482 {
483 results[currNode] = EvalResult(Rational(-1));
484 }
485 else
486 {
487 size_t r = s.find(x, i.toUnsignedInt());
488 if (r == std::string::npos)
489 {
490 results[currNode] = EvalResult(Rational(-1));
491 }
492 else
493 {
494 results[currNode] = EvalResult(Rational(r));
495 }
496 }
497 break;
498 }
499
500 case kind::STRING_STRREPL:
501 {
502 const String& s = results[currNode[0]].d_str;
503 const String& x = results[currNode[1]].d_str;
504 const String& y = results[currNode[2]].d_str;
505 results[currNode] = EvalResult(s.replace(x, y));
506 break;
507 }
508
509 case kind::STRING_PREFIX:
510 {
511 const String& t = results[currNode[0]].d_str;
512 const String& s = results[currNode[1]].d_str;
513 if (s.size() < t.size())
514 {
515 results[currNode] = EvalResult(false);
516 }
517 else
518 {
519 results[currNode] = EvalResult(s.prefix(t.size()) == t);
520 }
521 break;
522 }
523
524 case kind::STRING_SUFFIX:
525 {
526 const String& t = results[currNode[0]].d_str;
527 const String& s = results[currNode[1]].d_str;
528 if (s.size() < t.size())
529 {
530 results[currNode] = EvalResult(false);
531 }
532 else
533 {
534 results[currNode] = EvalResult(s.suffix(t.size()) == t);
535 }
536 break;
537 }
538
539 case kind::STRING_ITOS:
540 {
541 Integer i = results[currNode[0]].d_rat.getNumerator();
542 if (i.strictlyNegative())
543 {
544 results[currNode] = EvalResult(String(""));
545 }
546 else
547 {
548 results[currNode] = EvalResult(String(i.toString()));
549 }
550 break;
551 }
552
553 case kind::STRING_STOI:
554 {
555 const String& s = results[currNode[0]].d_str;
556 if (s.isNumber())
557 {
558 results[currNode] = EvalResult(Rational(s.toNumber()));
559 }
560 else
561 {
562 results[currNode] = EvalResult(Rational(-1));
563 }
564 break;
565 }
566
567 case kind::STRING_CODE:
568 {
569 const String& s = results[currNode[0]].d_str;
570 if (s.size() == 1)
571 {
572 results[currNode] = EvalResult(
573 Rational(String::convertUnsignedIntToCode(s.getVec()[0])));
574 }
575 else
576 {
577 results[currNode] = EvalResult(Rational(-1));
578 }
579 break;
580 }
581
582 case kind::CONST_BITVECTOR:
583 results[currNode] = EvalResult(currNodeVal.getConst<BitVector>());
584 break;
585
586 case kind::BITVECTOR_NOT:
587 results[currNode] = EvalResult(~results[currNode[0]].d_bv);
588 break;
589
590 case kind::BITVECTOR_NEG:
591 results[currNode] = EvalResult(-results[currNode[0]].d_bv);
592 break;
593
594 case kind::BITVECTOR_EXTRACT:
595 {
596 unsigned lo = bv::utils::getExtractLow(currNodeVal);
597 unsigned hi = bv::utils::getExtractHigh(currNodeVal);
598 results[currNode] =
599 EvalResult(results[currNode[0]].d_bv.extract(hi, lo));
600 break;
601 }
602
603 case kind::BITVECTOR_CONCAT:
604 {
605 BitVector res = results[currNode[0]].d_bv;
606 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
607 {
608 res = res.concat(results[currNode[i]].d_bv);
609 }
610 results[currNode] = EvalResult(res);
611 break;
612 }
613
614 case kind::BITVECTOR_PLUS:
615 {
616 BitVector res = results[currNode[0]].d_bv;
617 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
618 {
619 res = res + results[currNode[i]].d_bv;
620 }
621 results[currNode] = EvalResult(res);
622 break;
623 }
624
625 case kind::BITVECTOR_MULT:
626 {
627 BitVector res = results[currNode[0]].d_bv;
628 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
629 {
630 res = res * results[currNode[i]].d_bv;
631 }
632 results[currNode] = EvalResult(res);
633 break;
634 }
635 case kind::BITVECTOR_AND:
636 {
637 BitVector res = results[currNode[0]].d_bv;
638 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
639 {
640 res = res & results[currNode[i]].d_bv;
641 }
642 results[currNode] = EvalResult(res);
643 break;
644 }
645
646 case kind::BITVECTOR_OR:
647 {
648 BitVector res = results[currNode[0]].d_bv;
649 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
650 {
651 res = res | results[currNode[i]].d_bv;
652 }
653 results[currNode] = EvalResult(res);
654 break;
655 }
656
657 case kind::BITVECTOR_XOR:
658 {
659 BitVector res = results[currNode[0]].d_bv;
660 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
661 {
662 res = res ^ results[currNode[i]].d_bv;
663 }
664 results[currNode] = EvalResult(res);
665 break;
666 }
667 case kind::BITVECTOR_UDIV:
668 case kind::BITVECTOR_UDIV_TOTAL:
669 {
670 if (currNodeVal.getKind() == kind::BITVECTOR_UDIV_TOTAL
671 || results[currNode[1]].d_bv.getValue() != 0)
672 {
673 BitVector res = results[currNode[0]].d_bv;
674 res = res.unsignedDivTotal(results[currNode[1]].d_bv);
675 results[currNode] = EvalResult(res);
676 }
677 else
678 {
679 results[currNode] = EvalResult();
680 evalAsNode[currNode] = currNodeVal;
681 }
682 break;
683 }
684 case kind::BITVECTOR_UREM:
685 case kind::BITVECTOR_UREM_TOTAL:
686 {
687 if (currNodeVal.getKind() == kind::BITVECTOR_UREM_TOTAL
688 || results[currNode[1]].d_bv.getValue() != 0)
689 {
690 BitVector res = results[currNode[0]].d_bv;
691 res = res.unsignedRemTotal(results[currNode[1]].d_bv);
692 results[currNode] = EvalResult(res);
693 }
694 else
695 {
696 results[currNode] = EvalResult();
697 evalAsNode[currNode] = currNodeVal;
698 }
699 break;
700 }
701
702 case kind::EQUAL:
703 {
704 EvalResult lhs = results[currNode[0]];
705 EvalResult rhs = results[currNode[1]];
706
707 switch (lhs.d_tag)
708 {
709 case EvalResult::BOOL:
710 {
711 results[currNode] = EvalResult(lhs.d_bool == rhs.d_bool);
712 break;
713 }
714
715 case EvalResult::BITVECTOR:
716 {
717 results[currNode] = EvalResult(lhs.d_bv == rhs.d_bv);
718 break;
719 }
720
721 case EvalResult::RATIONAL:
722 {
723 results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
724 break;
725 }
726
727 case EvalResult::STRING:
728 {
729 results[currNode] = EvalResult(lhs.d_str == rhs.d_str);
730 break;
731 }
732
733 default:
734 {
735 Trace("evaluator") << "Theory " << Theory::theoryOf(currNode[0])
736 << " not supported" << std::endl;
737 results[currNode] = EvalResult();
738 evalAsNode[currNode] = currNodeVal;
739 break;
740 }
741 }
742
743 break;
744 }
745
746 case kind::ITE:
747 {
748 if (results[currNode[0]].d_bool)
749 {
750 results[currNode] = results[currNode[1]];
751 }
752 else
753 {
754 results[currNode] = results[currNode[2]];
755 }
756 break;
757 }
758
759 default:
760 {
761 Trace("evaluator") << "Kind " << currNodeVal.getKind()
762 << " not supported" << std::endl;
763 results[currNode] = EvalResult();
764 evalAsNode[currNode] = currNodeVal;
765 }
766 }
767 }
768 }
769
770 return results[n];
771 }
772
773 } // namespace theory
774 } // namespace CVC4