0a0176f25b5a28a1b8281b709206d02766783a07
[cvc5.git] / src / theory / evaluator.cpp
1 /********************* */
2 /*! \file evaluator.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief The Evaluator class
13 **
14 ** The Evaluator class.
15 **/
16
17 #include "theory/evaluator.h"
18
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/theory.h"
21 #include "util/integer.h"
22
23 namespace CVC4 {
24 namespace theory {
25
26 EvalResult::EvalResult(const EvalResult& other)
27 {
28 d_tag = other.d_tag;
29 switch (d_tag)
30 {
31 case BOOL: d_bool = other.d_bool; break;
32 case BITVECTOR:
33 new (&d_bv) BitVector;
34 d_bv = other.d_bv;
35 break;
36 case RATIONAL:
37 new (&d_rat) Rational;
38 d_rat = other.d_rat;
39 break;
40 case STRING:
41 new (&d_str) String;
42 d_str = other.d_str;
43 break;
44 case INVALID: break;
45 }
46 }
47
48 EvalResult& EvalResult::operator=(const EvalResult& other)
49 {
50 if (this != &other)
51 {
52 d_tag = other.d_tag;
53 switch (d_tag)
54 {
55 case BOOL: d_bool = other.d_bool; break;
56 case BITVECTOR:
57 new (&d_bv) BitVector;
58 d_bv = other.d_bv;
59 break;
60 case RATIONAL:
61 new (&d_rat) Rational;
62 d_rat = other.d_rat;
63 break;
64 case STRING:
65 new (&d_str) String;
66 d_str = other.d_str;
67 break;
68 case INVALID: break;
69 }
70 }
71 return *this;
72 }
73
74 EvalResult::~EvalResult()
75 {
76 switch (d_tag)
77 {
78 case BITVECTOR:
79 {
80 d_bv.~BitVector();
81 break;
82 }
83 case RATIONAL:
84 {
85 d_rat.~Rational();
86 break;
87 }
88 case STRING:
89 {
90 d_str.~String();
91 break;
92
93 default: break;
94 }
95 }
96 }
97
98 Node EvalResult::toNode() const
99 {
100 NodeManager* nm = NodeManager::currentNM();
101 switch (d_tag)
102 {
103 case EvalResult::BOOL: return nm->mkConst(d_bool);
104 case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
105 case EvalResult::RATIONAL: return nm->mkConst(d_rat);
106 case EvalResult::STRING: return nm->mkConst(d_str);
107 default:
108 {
109 Trace("evaluator") << "Missing conversion from " << d_tag << " to node"
110 << std::endl;
111 return Node();
112 }
113 }
114 }
115
116 Node Evaluator::eval(TNode n,
117 const std::vector<Node>& args,
118 const std::vector<Node>& vals)
119 {
120 Trace("evaluator") << "Evaluating " << n << " under substitution " << args
121 << " " << vals << std::endl;
122 return evalInternal(n, args, vals).toNode();
123 }
124
125 EvalResult Evaluator::evalInternal(TNode n,
126 const std::vector<Node>& args,
127 const std::vector<Node>& vals)
128 {
129 std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
130 std::vector<TNode> queue;
131 queue.emplace_back(n);
132
133 while (queue.size() != 0)
134 {
135 TNode currNode = queue.back();
136
137 if (results.find(currNode) != results.end())
138 {
139 queue.pop_back();
140 continue;
141 }
142
143 bool doEval = true;
144 for (const auto& currNodeChild : currNode)
145 {
146 if (results.find(currNodeChild) == results.end())
147 {
148 queue.emplace_back(currNodeChild);
149 doEval = false;
150 }
151 }
152
153 if (doEval)
154 {
155 queue.pop_back();
156
157 Node currNodeVal = currNode;
158 if (currNode.isVar())
159 {
160 const auto& it = std::find(args.begin(), args.end(), currNode);
161
162 if (it == args.end())
163 {
164 return EvalResult();
165 }
166
167 ptrdiff_t pos = std::distance(args.begin(), it);
168 currNodeVal = vals[pos];
169 }
170 else if (currNode.getKind() == kind::APPLY_UF
171 && currNode.getOperator().getKind() == kind::LAMBDA)
172 {
173 // Create a copy of the current substitutions
174 std::vector<Node> lambdaArgs(args);
175 std::vector<Node> lambdaVals(vals);
176
177 // Add the values for the arguments of the lambda as substitutions at
178 // the beginning of the vector to shadow variables from outer scopes
179 // with the same name
180 Node op = currNode.getOperator();
181 for (const auto& lambdaArg : op[0])
182 {
183 lambdaArgs.insert(lambdaArgs.begin(), lambdaArg);
184 }
185
186 for (const auto& lambdaVal : currNode)
187 {
188 lambdaVals.insert(lambdaVals.begin(), results[lambdaVal].toNode());
189 }
190
191 // Lambdas are evaluated in a recursive fashion because each evaluation
192 // requires different substitutions
193 results[currNode] = evalInternal(op[1], lambdaArgs, lambdaVals);
194 if (results[currNode].d_tag == EvalResult::INVALID)
195 {
196 // evaluation was invalid, we fail
197 return results[currNode];
198 }
199 continue;
200 }
201
202 switch (currNodeVal.getKind())
203 {
204 case kind::CONST_BOOLEAN:
205 results[currNode] = EvalResult(currNodeVal.getConst<bool>());
206 break;
207
208 case kind::NOT:
209 {
210 results[currNode] = EvalResult(!(results[currNode[0]].d_bool));
211 break;
212 }
213
214 case kind::AND:
215 {
216 bool res = results[currNode[0]].d_bool;
217 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
218 {
219 res = res && results[currNode[i]].d_bool;
220 }
221 results[currNode] = EvalResult(res);
222 break;
223 }
224
225 case kind::OR:
226 {
227 bool res = results[currNode[0]].d_bool;
228 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
229 {
230 res = res || results[currNode[i]].d_bool;
231 }
232 results[currNode] = EvalResult(res);
233 break;
234 }
235
236 case kind::CONST_RATIONAL:
237 {
238 const Rational& r = currNodeVal.getConst<Rational>();
239 results[currNode] = EvalResult(r);
240 break;
241 }
242
243 case kind::PLUS:
244 {
245 Rational res = results[currNode[0]].d_rat;
246 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
247 {
248 res = res + results[currNode[i]].d_rat;
249 }
250 results[currNode] = EvalResult(res);
251 break;
252 }
253
254 case kind::MINUS:
255 {
256 const Rational& x = results[currNode[0]].d_rat;
257 const Rational& y = results[currNode[1]].d_rat;
258 results[currNode] = EvalResult(x - y);
259 break;
260 }
261
262 case kind::MULT:
263 {
264 Rational res = results[currNode[0]].d_rat;
265 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
266 {
267 res = res * results[currNode[i]].d_rat;
268 }
269 results[currNode] = EvalResult(res);
270 break;
271 }
272
273 case kind::GEQ:
274 {
275 const Rational& x = results[currNode[0]].d_rat;
276 const Rational& y = results[currNode[1]].d_rat;
277 results[currNode] = EvalResult(x >= y);
278 break;
279 }
280
281 case kind::CONST_STRING:
282 results[currNode] = EvalResult(currNodeVal.getConst<String>());
283 break;
284
285 case kind::STRING_CONCAT:
286 {
287 String res = results[currNode[0]].d_str;
288 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
289 {
290 res = res.concat(results[currNode[i]].d_str);
291 }
292 results[currNode] = EvalResult(res);
293 break;
294 }
295
296 case kind::STRING_LENGTH:
297 {
298 const String& s = results[currNode[0]].d_str;
299 results[currNode] = EvalResult(Rational(s.size()));
300 break;
301 }
302
303 case kind::STRING_SUBSTR:
304 {
305 const String& s = results[currNode[0]].d_str;
306 Integer s_len(s.size());
307 Integer i = results[currNode[1]].d_rat.getNumerator();
308 Integer j = results[currNode[2]].d_rat.getNumerator();
309
310 if (i.strictlyNegative() || j.strictlyNegative() || i >= s_len)
311 {
312 results[currNode] = EvalResult(String(""));
313 }
314 else if (i + j > s_len)
315 {
316 results[currNode] =
317 EvalResult(s.suffix((s_len - i).toUnsignedInt()));
318 }
319 else
320 {
321 results[currNode] =
322 EvalResult(s.substr(i.toUnsignedInt(), j.toUnsignedInt()));
323 }
324 break;
325 }
326
327 case kind::STRING_CHARAT:
328 {
329 const String& s = results[currNode[0]].d_str;
330 Integer s_len(s.size());
331 Integer i = results[currNode[1]].d_rat.getNumerator();
332 if (i.strictlyNegative() || i >= s_len)
333 {
334 results[currNode] = EvalResult(String(""));
335 }
336 else
337 {
338 results[currNode] = EvalResult(s.substr(i.toUnsignedInt(), 1));
339 }
340 break;
341 }
342
343 case kind::STRING_STRCTN:
344 {
345 const String& s = results[currNode[0]].d_str;
346 const String& t = results[currNode[1]].d_str;
347 results[currNode] = EvalResult(s.find(t) != std::string::npos);
348 break;
349 }
350
351 case kind::STRING_STRIDOF:
352 {
353 const String& s = results[currNode[0]].d_str;
354 Integer s_len(s.size());
355 const String& x = results[currNode[1]].d_str;
356 Integer i = results[currNode[2]].d_rat.getNumerator();
357
358 if (i.strictlyNegative())
359 {
360 results[currNode] = EvalResult(Rational(-1));
361 }
362 else
363 {
364 size_t r = s.find(x, i.toUnsignedInt());
365 if (r == std::string::npos)
366 {
367 results[currNode] = EvalResult(Rational(-1));
368 }
369 else
370 {
371 results[currNode] = EvalResult(Rational(r));
372 }
373 }
374 break;
375 }
376
377 case kind::STRING_STRREPL:
378 {
379 const String& s = results[currNode[0]].d_str;
380 const String& x = results[currNode[1]].d_str;
381 const String& y = results[currNode[2]].d_str;
382 results[currNode] = EvalResult(s.replace(x, y));
383 break;
384 }
385
386 case kind::STRING_PREFIX:
387 {
388 const String& t = results[currNode[0]].d_str;
389 const String& s = results[currNode[1]].d_str;
390 if (s.size() < t.size())
391 {
392 results[currNode] = EvalResult(false);
393 }
394 else
395 {
396 results[currNode] = EvalResult(s.prefix(t.size()) == t);
397 }
398 break;
399 }
400
401 case kind::STRING_SUFFIX:
402 {
403 const String& t = results[currNode[0]].d_str;
404 const String& s = results[currNode[1]].d_str;
405 if (s.size() < t.size())
406 {
407 results[currNode] = EvalResult(false);
408 }
409 else
410 {
411 results[currNode] = EvalResult(s.suffix(t.size()) == t);
412 }
413 break;
414 }
415
416 case kind::STRING_ITOS:
417 {
418 Integer i = results[currNode[0]].d_rat.getNumerator();
419 if (i.strictlyNegative())
420 {
421 results[currNode] = EvalResult(String(""));
422 }
423 else
424 {
425 results[currNode] = EvalResult(String(i.toString()));
426 }
427 break;
428 }
429
430 case kind::STRING_STOI:
431 {
432 const String& s = results[currNode[0]].d_str;
433 if (s.isNumber())
434 {
435 results[currNode] = EvalResult(Rational(s.toNumber()));
436 }
437 else
438 {
439 results[currNode] = EvalResult(Rational(-1));
440 }
441 break;
442 }
443
444 case kind::STRING_CODE:
445 {
446 const String& s = results[currNode[0]].d_str;
447 if (s.size() == 1)
448 {
449 results[currNode] = EvalResult(
450 Rational(String::convertUnsignedIntToCode(s.getVec()[0])));
451 }
452 else
453 {
454 results[currNode] = EvalResult(Rational(-1));
455 }
456 break;
457 }
458
459 case kind::CONST_BITVECTOR:
460 results[currNode] = EvalResult(currNodeVal.getConst<BitVector>());
461 break;
462
463 case kind::BITVECTOR_NOT:
464 results[currNode] = EvalResult(~results[currNode[0]].d_bv);
465 break;
466
467 case kind::BITVECTOR_NEG:
468 results[currNode] = EvalResult(-results[currNode[0]].d_bv);
469 break;
470
471 case kind::BITVECTOR_EXTRACT:
472 {
473 unsigned lo = bv::utils::getExtractLow(currNodeVal);
474 unsigned hi = bv::utils::getExtractHigh(currNodeVal);
475 results[currNode] =
476 EvalResult(results[currNode[0]].d_bv.extract(hi, lo));
477 break;
478 }
479
480 case kind::BITVECTOR_CONCAT:
481 {
482 BitVector res = results[currNode[0]].d_bv;
483 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
484 {
485 res = res.concat(results[currNode[i]].d_bv);
486 }
487 results[currNode] = EvalResult(res);
488 break;
489 }
490
491 case kind::BITVECTOR_PLUS:
492 {
493 BitVector res = results[currNode[0]].d_bv;
494 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
495 {
496 res = res + results[currNode[i]].d_bv;
497 }
498 results[currNode] = EvalResult(res);
499 break;
500 }
501
502 case kind::BITVECTOR_MULT:
503 {
504 BitVector res = results[currNode[0]].d_bv;
505 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
506 {
507 res = res * results[currNode[i]].d_bv;
508 }
509 results[currNode] = EvalResult(res);
510 break;
511 }
512 case kind::BITVECTOR_AND:
513 {
514 BitVector res = results[currNode[0]].d_bv;
515 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
516 {
517 res = res & results[currNode[i]].d_bv;
518 }
519 results[currNode] = EvalResult(res);
520 break;
521 }
522
523 case kind::BITVECTOR_OR:
524 {
525 BitVector res = results[currNode[0]].d_bv;
526 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
527 {
528 res = res | results[currNode[i]].d_bv;
529 }
530 results[currNode] = EvalResult(res);
531 break;
532 }
533
534 case kind::BITVECTOR_XOR:
535 {
536 BitVector res = results[currNode[0]].d_bv;
537 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
538 {
539 res = res ^ results[currNode[i]].d_bv;
540 }
541 results[currNode] = EvalResult(res);
542 break;
543 }
544
545 case kind::EQUAL:
546 {
547 EvalResult lhs = results[currNode[0]];
548 EvalResult rhs = results[currNode[1]];
549
550 switch (lhs.d_tag)
551 {
552 case EvalResult::BOOL:
553 {
554 results[currNode] = EvalResult(lhs.d_bool == rhs.d_bool);
555 break;
556 }
557
558 case EvalResult::BITVECTOR:
559 {
560 results[currNode] = EvalResult(lhs.d_bv == rhs.d_bv);
561 break;
562 }
563
564 case EvalResult::RATIONAL:
565 {
566 results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
567 break;
568 }
569
570 case EvalResult::STRING:
571 {
572 results[currNode] = EvalResult(lhs.d_str == rhs.d_str);
573 break;
574 }
575
576 default:
577 {
578 Trace("evaluator") << "Theory " << Theory::theoryOf(currNode[0])
579 << " not supported" << std::endl;
580 return EvalResult();
581 break;
582 }
583 }
584
585 break;
586 }
587
588 case kind::ITE:
589 {
590 if (results[currNode[0]].d_bool)
591 {
592 results[currNode] = results[currNode[1]];
593 }
594 else
595 {
596 results[currNode] = results[currNode[2]];
597 }
598 break;
599 }
600
601 default:
602 {
603 Trace("evaluator") << "Kind " << currNodeVal.getKind()
604 << " not supported" << std::endl;
605 return EvalResult();
606 }
607 }
608 }
609 }
610
611 return results[n];
612 }
613
614 } // namespace theory
615 } // namespace CVC4