sygusComp2018: Add evaluator (#2090)
[cvc5.git] / src / theory / evaluator.cpp
1 /********************* */
2 /*! \file evaluator.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief The Evaluator class
13 **
14 ** The Evaluator class.
15 **/
16
17 #include "theory/evaluator.h"
18
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/theory.h"
21 #include "util/integer.h"
22
23 namespace CVC4 {
24 namespace theory {
25
26 EvalResult::EvalResult(const EvalResult& other)
27 {
28 d_tag = other.d_tag;
29 switch (d_tag)
30 {
31 case BOOL: d_bool = other.d_bool; break;
32 case BITVECTOR:
33 new (&d_bv) BitVector;
34 d_bv = other.d_bv;
35 break;
36 case RATIONAL:
37 new (&d_rat) Rational;
38 d_rat = other.d_rat;
39 break;
40 case STRING:
41 new (&d_str) String;
42 d_str = other.d_str;
43 break;
44 case INVALID: break;
45 }
46 }
47
48 EvalResult& EvalResult::operator=(const EvalResult& other)
49 {
50 if (this != &other)
51 {
52 d_tag = other.d_tag;
53 switch (d_tag)
54 {
55 case BOOL: d_bool = other.d_bool; break;
56 case BITVECTOR:
57 new (&d_bv) BitVector;
58 d_bv = other.d_bv;
59 break;
60 case RATIONAL:
61 new (&d_rat) Rational;
62 d_rat = other.d_rat;
63 break;
64 case STRING:
65 new (&d_str) String;
66 d_str = other.d_str;
67 break;
68 case INVALID: break;
69 }
70 }
71 return *this;
72 }
73
74 EvalResult::~EvalResult()
75 {
76 switch (d_tag)
77 {
78 case BITVECTOR:
79 {
80 d_bv.~BitVector();
81 break;
82 }
83 case RATIONAL:
84 {
85 d_rat.~Rational();
86 break;
87 }
88 case STRING:
89 {
90 d_str.~String();
91 break;
92
93 default: break;
94 }
95 }
96 }
97
98 Node EvalResult::toNode() const
99 {
100 NodeManager* nm = NodeManager::currentNM();
101 switch (d_tag)
102 {
103 case EvalResult::BOOL: return nm->mkConst(d_bool);
104 case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
105 case EvalResult::RATIONAL: return nm->mkConst(d_rat);
106 case EvalResult::STRING: return nm->mkConst(d_str);
107 default:
108 {
109 Trace("evaluator") << "Missing conversion from " << d_tag << " to node"
110 << std::endl;
111 return Node();
112 }
113 }
114
115 return Node();
116 }
117
118 Node Evaluator::eval(TNode n,
119 const std::vector<Node>& args,
120 const std::vector<Node>& vals)
121 {
122 Trace("evaluator") << "Evaluating " << n << " under substitution " << args
123 << " " << vals << std::endl;
124 return evalInternal(n, args, vals).toNode();
125 }
126
127 EvalResult Evaluator::evalInternal(TNode n,
128 const std::vector<Node>& args,
129 const std::vector<Node>& vals)
130 {
131 std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
132 std::vector<TNode> queue;
133 queue.emplace_back(n);
134
135 while (queue.size() != 0)
136 {
137 TNode currNode = queue.back();
138
139 if (results.find(currNode) != results.end())
140 {
141 queue.pop_back();
142 continue;
143 }
144
145 bool doEval = true;
146 for (const auto& currNodeChild : currNode)
147 {
148 if (results.find(currNodeChild) == results.end())
149 {
150 queue.emplace_back(currNodeChild);
151 doEval = false;
152 }
153 }
154
155 if (doEval)
156 {
157 queue.pop_back();
158
159 Node currNodeVal = currNode;
160 if (currNode.isVar())
161 {
162 const auto& it = std::find(args.begin(), args.end(), currNode);
163
164 if (it == args.end())
165 {
166 return EvalResult();
167 }
168
169 ptrdiff_t pos = std::distance(args.begin(), it);
170 currNodeVal = vals[pos];
171 }
172 else if (currNode.getKind() == kind::APPLY_UF
173 && currNode.getOperator().getKind() == kind::LAMBDA)
174 {
175 // Create a copy of the current substitutions
176 std::vector<Node> lambdaArgs(args);
177 std::vector<Node> lambdaVals(vals);
178
179 // Add the values for the arguments of the lambda as substitutions at
180 // the beginning of the vector to shadow variables from outer scopes
181 // with the same name
182 Node op = currNode.getOperator();
183 for (const auto& lambdaArg : op[0])
184 {
185 lambdaArgs.insert(lambdaArgs.begin(), lambdaArg);
186 }
187
188 for (const auto& lambdaVal : currNode)
189 {
190 lambdaVals.insert(lambdaVals.begin(), results[lambdaVal].toNode());
191 }
192
193 // Lambdas are evaluated in a recursive fashion because each evaluation
194 // requires different substitutions
195 results[currNode] = evalInternal(op[1], lambdaArgs, lambdaVals);
196 continue;
197 }
198
199 switch (currNodeVal.getKind())
200 {
201 case kind::CONST_BOOLEAN:
202 results[currNode] = EvalResult(currNodeVal.getConst<bool>());
203 break;
204
205 case kind::NOT:
206 {
207 results[currNode] = EvalResult(!(results[currNode[0]].d_bool));
208 break;
209 }
210
211 case kind::AND:
212 {
213 bool res = results[currNode[0]].d_bool;
214 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
215 {
216 res = res && results[currNode[i]].d_bool;
217 }
218 results[currNode] = EvalResult(res);
219 break;
220 }
221
222 case kind::OR:
223 {
224 bool res = results[currNode[0]].d_bool;
225 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
226 {
227 res = res || results[currNode[i]].d_bool;
228 }
229 results[currNode] = EvalResult(res);
230 break;
231 }
232
233 case kind::CONST_RATIONAL:
234 {
235 const Rational& r = currNodeVal.getConst<Rational>();
236 results[currNode] = EvalResult(r);
237 break;
238 }
239
240 case kind::PLUS:
241 {
242 Rational res = results[currNode[0]].d_rat;
243 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
244 {
245 res = res + results[currNode[i]].d_rat;
246 }
247 results[currNode] = EvalResult(res);
248 break;
249 }
250
251 case kind::MINUS:
252 {
253 const Rational& x = results[currNode[0]].d_rat;
254 const Rational& y = results[currNode[1]].d_rat;
255 results[currNode] = EvalResult(x - y);
256 break;
257 }
258
259 case kind::MULT:
260 {
261 Rational res = results[currNode[0]].d_rat;
262 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
263 {
264 res = res * results[currNode[i]].d_rat;
265 }
266 results[currNode] = EvalResult(res);
267 break;
268 }
269
270 case kind::GEQ:
271 {
272 const Rational& x = results[currNode[0]].d_rat;
273 const Rational& y = results[currNode[1]].d_rat;
274 results[currNode] = EvalResult(x >= y);
275 break;
276 }
277
278 case kind::CONST_STRING:
279 results[currNode] = EvalResult(currNodeVal.getConst<String>());
280 break;
281
282 case kind::STRING_CONCAT:
283 {
284 String res = results[currNode[0]].d_str;
285 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
286 {
287 res = res.concat(results[currNode[i]].d_str);
288 }
289 results[currNode] = EvalResult(res);
290 break;
291 }
292
293 case kind::STRING_LENGTH:
294 {
295 const String& s = results[currNode[0]].d_str;
296 results[currNode] = EvalResult(Rational(s.size()));
297 break;
298 }
299
300 case kind::STRING_SUBSTR:
301 {
302 const String& s = results[currNode[0]].d_str;
303 Integer s_len(s.size());
304 Integer i = results[currNode[1]].d_rat.getNumerator();
305 Integer j = results[currNode[2]].d_rat.getNumerator();
306
307 if (i.strictlyNegative() || j.strictlyNegative() || i >= s_len)
308 {
309 results[currNode] = EvalResult(String(""));
310 }
311 else if (i + j > s_len)
312 {
313 results[currNode] =
314 EvalResult(s.suffix((s_len - i).toUnsignedInt()));
315 }
316 else
317 {
318 results[currNode] =
319 EvalResult(s.substr(i.toUnsignedInt(), j.toUnsignedInt()));
320 }
321 break;
322 }
323
324 case kind::STRING_CHARAT:
325 {
326 const String& s = results[currNode[0]].d_str;
327 Integer s_len(s.size());
328 Integer i = results[currNode[1]].d_rat.getNumerator();
329 if (i.strictlyNegative() || i >= s_len)
330 {
331 results[currNode] = EvalResult(String(""));
332 }
333 else
334 {
335 results[currNode] = EvalResult(s.substr(i.toUnsignedInt(), 1));
336 }
337 break;
338 }
339
340 case kind::STRING_STRCTN:
341 {
342 const String& s = results[currNode[0]].d_str;
343 const String& t = results[currNode[1]].d_str;
344 results[currNode] = EvalResult(s.find(t) != std::string::npos);
345 break;
346 }
347
348 case kind::STRING_STRIDOF:
349 {
350 const String& s = results[currNode[0]].d_str;
351 Integer s_len(s.size());
352 const String& x = results[currNode[1]].d_str;
353 Integer i = results[currNode[2]].d_rat.getNumerator();
354
355 if (i.strictlyNegative() || i >= s_len)
356 {
357 results[currNode] = EvalResult(Rational(-1));
358 }
359 else
360 {
361 size_t r = s.find(x, i.toUnsignedInt());
362 if (r == std::string::npos)
363 {
364 results[currNode] = EvalResult(Rational(-1));
365 }
366 else
367 {
368 results[currNode] = EvalResult(Rational(r));
369 }
370 }
371 break;
372 }
373
374 case kind::STRING_STRREPL:
375 {
376 const String& s = results[currNode[0]].d_str;
377 const String& x = results[currNode[1]].d_str;
378 const String& y = results[currNode[2]].d_str;
379 results[currNode] = EvalResult(s.replace(x, y));
380 break;
381 }
382
383 case kind::STRING_PREFIX:
384 {
385 const String& t = results[currNode[0]].d_str;
386 const String& s = results[currNode[1]].d_str;
387 if (s.size() < t.size())
388 {
389 results[currNode] = EvalResult(false);
390 }
391 else
392 {
393 results[currNode] = EvalResult(s.prefix(t.size()) == t);
394 }
395 break;
396 }
397
398 case kind::STRING_SUFFIX:
399 {
400 const String& t = results[currNode[0]].d_str;
401 const String& s = results[currNode[1]].d_str;
402 if (s.size() < t.size())
403 {
404 results[currNode] = EvalResult(false);
405 }
406 else
407 {
408 results[currNode] = EvalResult(s.suffix(t.size()) == t);
409 }
410 break;
411 }
412
413 case kind::STRING_ITOS:
414 {
415 Integer i = results[currNode[0]].d_rat.getNumerator();
416 if (i.strictlyNegative())
417 {
418 results[currNode] = EvalResult(String(""));
419 }
420 else
421 {
422 results[currNode] = EvalResult(String(i.toString()));
423 }
424 break;
425 }
426
427 case kind::STRING_STOI:
428 {
429 const String& s = results[currNode[0]].d_str;
430 if (s.isNumber())
431 {
432 results[currNode] = EvalResult(Rational(-1));
433 }
434 else
435 {
436 results[currNode] = EvalResult(Rational(s.toNumber()));
437 }
438 break;
439 }
440
441 case kind::CONST_BITVECTOR:
442 results[currNode] = EvalResult(currNodeVal.getConst<BitVector>());
443 break;
444
445 case kind::BITVECTOR_NOT:
446 results[currNode] = EvalResult(~results[currNode[0]].d_bv);
447 break;
448
449 case kind::BITVECTOR_NEG:
450 results[currNode] = EvalResult(-results[currNode[0]].d_bv);
451 break;
452
453 case kind::BITVECTOR_EXTRACT:
454 {
455 unsigned lo = bv::utils::getExtractLow(currNodeVal);
456 unsigned hi = bv::utils::getExtractHigh(currNodeVal);
457 results[currNode] =
458 EvalResult(results[currNode[0]].d_bv.extract(hi, lo));
459 break;
460 }
461
462 case kind::BITVECTOR_CONCAT:
463 {
464 BitVector res = results[currNode[0]].d_bv;
465 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
466 {
467 res = res.concat(results[currNode[i]].d_bv);
468 }
469 results[currNode] = EvalResult(res);
470 break;
471 }
472
473 case kind::BITVECTOR_PLUS:
474 {
475 BitVector res = results[currNode[0]].d_bv;
476 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
477 {
478 res = res + results[currNode[i]].d_bv;
479 }
480 results[currNode] = EvalResult(res);
481 break;
482 }
483
484 case kind::BITVECTOR_MULT:
485 {
486 BitVector res = results[currNode[0]].d_bv;
487 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
488 {
489 res = res * results[currNode[i]].d_bv;
490 }
491 results[currNode] = EvalResult(res);
492 break;
493 }
494 case kind::BITVECTOR_AND:
495 {
496 BitVector res = results[currNode[0]].d_bv;
497 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
498 {
499 res = res & results[currNode[i]].d_bv;
500 }
501 results[currNode] = EvalResult(res);
502 break;
503 }
504
505 case kind::BITVECTOR_OR:
506 {
507 BitVector res = results[currNode[0]].d_bv;
508 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
509 {
510 res = res | results[currNode[i]].d_bv;
511 }
512 results[currNode] = EvalResult(res);
513 break;
514 }
515
516 case kind::BITVECTOR_XOR:
517 {
518 BitVector res = results[currNode[0]].d_bv;
519 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
520 {
521 res = res ^ results[currNode[i]].d_bv;
522 }
523 results[currNode] = EvalResult(res);
524 break;
525 }
526
527 case kind::EQUAL:
528 {
529 EvalResult lhs = results[currNode[0]];
530 EvalResult rhs = results[currNode[1]];
531
532 switch (lhs.d_tag)
533 {
534 case EvalResult::BOOL:
535 {
536 results[currNode] = EvalResult(lhs.d_bool == rhs.d_bool);
537 break;
538 }
539
540 case EvalResult::BITVECTOR:
541 {
542 results[currNode] = EvalResult(lhs.d_bv == rhs.d_bv);
543 break;
544 }
545
546 case EvalResult::RATIONAL:
547 {
548 results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
549 break;
550 }
551
552 case EvalResult::STRING:
553 {
554 results[currNode] = EvalResult(lhs.d_str == rhs.d_str);
555 break;
556 }
557
558 default:
559 {
560 Trace("evaluator") << "Theory " << Theory::theoryOf(currNode[0])
561 << " not supported" << std::endl;
562 return EvalResult();
563 break;
564 }
565 }
566
567 break;
568 }
569
570 case kind::ITE:
571 {
572 if (results[currNode[0]].d_bool)
573 {
574 results[currNode] = results[currNode[1]];
575 }
576 else
577 {
578 results[currNode] = results[currNode[2]];
579 }
580 break;
581 }
582
583 default:
584 {
585 Trace("evaluator") << "Kind " << currNodeVal.getKind()
586 << " not supported" << std::endl;
587 return EvalResult();
588 }
589 }
590 }
591 }
592
593 return results[n];
594 }
595
596 } // namespace theory
597 } // namespace CVC4