Fix a few issues in the sygus sampler related to evaluation (#2215)
[cvc5.git] / src / theory / evaluator.cpp
1 /********************* */
2 /*! \file evaluator.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief The Evaluator class
13 **
14 ** The Evaluator class.
15 **/
16
17 #include "theory/evaluator.h"
18
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/theory.h"
21 #include "util/integer.h"
22
23 namespace CVC4 {
24 namespace theory {
25
26 EvalResult::EvalResult(const EvalResult& other)
27 {
28 d_tag = other.d_tag;
29 switch (d_tag)
30 {
31 case BOOL: d_bool = other.d_bool; break;
32 case BITVECTOR:
33 new (&d_bv) BitVector;
34 d_bv = other.d_bv;
35 break;
36 case RATIONAL:
37 new (&d_rat) Rational;
38 d_rat = other.d_rat;
39 break;
40 case STRING:
41 new (&d_str) String;
42 d_str = other.d_str;
43 break;
44 case INVALID: break;
45 }
46 }
47
48 EvalResult& EvalResult::operator=(const EvalResult& other)
49 {
50 if (this != &other)
51 {
52 d_tag = other.d_tag;
53 switch (d_tag)
54 {
55 case BOOL: d_bool = other.d_bool; break;
56 case BITVECTOR:
57 new (&d_bv) BitVector;
58 d_bv = other.d_bv;
59 break;
60 case RATIONAL:
61 new (&d_rat) Rational;
62 d_rat = other.d_rat;
63 break;
64 case STRING:
65 new (&d_str) String;
66 d_str = other.d_str;
67 break;
68 case INVALID: break;
69 }
70 }
71 return *this;
72 }
73
74 EvalResult::~EvalResult()
75 {
76 switch (d_tag)
77 {
78 case BITVECTOR:
79 {
80 d_bv.~BitVector();
81 break;
82 }
83 case RATIONAL:
84 {
85 d_rat.~Rational();
86 break;
87 }
88 case STRING:
89 {
90 d_str.~String();
91 break;
92
93 default: break;
94 }
95 }
96 }
97
98 Node EvalResult::toNode() const
99 {
100 NodeManager* nm = NodeManager::currentNM();
101 switch (d_tag)
102 {
103 case EvalResult::BOOL: return nm->mkConst(d_bool);
104 case EvalResult::BITVECTOR: return nm->mkConst(d_bv);
105 case EvalResult::RATIONAL: return nm->mkConst(d_rat);
106 case EvalResult::STRING: return nm->mkConst(d_str);
107 default:
108 {
109 Trace("evaluator") << "Missing conversion from " << d_tag << " to node"
110 << std::endl;
111 return Node();
112 }
113 }
114
115 return Node();
116 }
117
118 Node Evaluator::eval(TNode n,
119 const std::vector<Node>& args,
120 const std::vector<Node>& vals)
121 {
122 Trace("evaluator") << "Evaluating " << n << " under substitution " << args
123 << " " << vals << std::endl;
124 return evalInternal(n, args, vals).toNode();
125 }
126
127 EvalResult Evaluator::evalInternal(TNode n,
128 const std::vector<Node>& args,
129 const std::vector<Node>& vals)
130 {
131 std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
132 std::vector<TNode> queue;
133 queue.emplace_back(n);
134
135 while (queue.size() != 0)
136 {
137 TNode currNode = queue.back();
138
139 if (results.find(currNode) != results.end())
140 {
141 queue.pop_back();
142 continue;
143 }
144
145 bool doEval = true;
146 for (const auto& currNodeChild : currNode)
147 {
148 if (results.find(currNodeChild) == results.end())
149 {
150 queue.emplace_back(currNodeChild);
151 doEval = false;
152 }
153 }
154
155 if (doEval)
156 {
157 queue.pop_back();
158
159 Node currNodeVal = currNode;
160 if (currNode.isVar())
161 {
162 const auto& it = std::find(args.begin(), args.end(), currNode);
163
164 if (it == args.end())
165 {
166 return EvalResult();
167 }
168
169 ptrdiff_t pos = std::distance(args.begin(), it);
170 currNodeVal = vals[pos];
171 }
172 else if (currNode.getKind() == kind::APPLY_UF
173 && currNode.getOperator().getKind() == kind::LAMBDA)
174 {
175 // Create a copy of the current substitutions
176 std::vector<Node> lambdaArgs(args);
177 std::vector<Node> lambdaVals(vals);
178
179 // Add the values for the arguments of the lambda as substitutions at
180 // the beginning of the vector to shadow variables from outer scopes
181 // with the same name
182 Node op = currNode.getOperator();
183 for (const auto& lambdaArg : op[0])
184 {
185 lambdaArgs.insert(lambdaArgs.begin(), lambdaArg);
186 }
187
188 for (const auto& lambdaVal : currNode)
189 {
190 lambdaVals.insert(lambdaVals.begin(), results[lambdaVal].toNode());
191 }
192
193 // Lambdas are evaluated in a recursive fashion because each evaluation
194 // requires different substitutions
195 results[currNode] = evalInternal(op[1], lambdaArgs, lambdaVals);
196 if (results[currNode].d_tag == EvalResult::INVALID)
197 {
198 // evaluation was invalid, we fail
199 return results[currNode];
200 }
201 continue;
202 }
203
204 switch (currNodeVal.getKind())
205 {
206 case kind::CONST_BOOLEAN:
207 results[currNode] = EvalResult(currNodeVal.getConst<bool>());
208 break;
209
210 case kind::NOT:
211 {
212 results[currNode] = EvalResult(!(results[currNode[0]].d_bool));
213 break;
214 }
215
216 case kind::AND:
217 {
218 bool res = results[currNode[0]].d_bool;
219 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
220 {
221 res = res && results[currNode[i]].d_bool;
222 }
223 results[currNode] = EvalResult(res);
224 break;
225 }
226
227 case kind::OR:
228 {
229 bool res = results[currNode[0]].d_bool;
230 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
231 {
232 res = res || results[currNode[i]].d_bool;
233 }
234 results[currNode] = EvalResult(res);
235 break;
236 }
237
238 case kind::CONST_RATIONAL:
239 {
240 const Rational& r = currNodeVal.getConst<Rational>();
241 results[currNode] = EvalResult(r);
242 break;
243 }
244
245 case kind::PLUS:
246 {
247 Rational res = results[currNode[0]].d_rat;
248 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
249 {
250 res = res + results[currNode[i]].d_rat;
251 }
252 results[currNode] = EvalResult(res);
253 break;
254 }
255
256 case kind::MINUS:
257 {
258 const Rational& x = results[currNode[0]].d_rat;
259 const Rational& y = results[currNode[1]].d_rat;
260 results[currNode] = EvalResult(x - y);
261 break;
262 }
263
264 case kind::MULT:
265 {
266 Rational res = results[currNode[0]].d_rat;
267 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
268 {
269 res = res * results[currNode[i]].d_rat;
270 }
271 results[currNode] = EvalResult(res);
272 break;
273 }
274
275 case kind::GEQ:
276 {
277 const Rational& x = results[currNode[0]].d_rat;
278 const Rational& y = results[currNode[1]].d_rat;
279 results[currNode] = EvalResult(x >= y);
280 break;
281 }
282
283 case kind::CONST_STRING:
284 results[currNode] = EvalResult(currNodeVal.getConst<String>());
285 break;
286
287 case kind::STRING_CONCAT:
288 {
289 String res = results[currNode[0]].d_str;
290 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
291 {
292 res = res.concat(results[currNode[i]].d_str);
293 }
294 results[currNode] = EvalResult(res);
295 break;
296 }
297
298 case kind::STRING_LENGTH:
299 {
300 const String& s = results[currNode[0]].d_str;
301 results[currNode] = EvalResult(Rational(s.size()));
302 break;
303 }
304
305 case kind::STRING_SUBSTR:
306 {
307 const String& s = results[currNode[0]].d_str;
308 Integer s_len(s.size());
309 Integer i = results[currNode[1]].d_rat.getNumerator();
310 Integer j = results[currNode[2]].d_rat.getNumerator();
311
312 if (i.strictlyNegative() || j.strictlyNegative() || i >= s_len)
313 {
314 results[currNode] = EvalResult(String(""));
315 }
316 else if (i + j > s_len)
317 {
318 results[currNode] =
319 EvalResult(s.suffix((s_len - i).toUnsignedInt()));
320 }
321 else
322 {
323 results[currNode] =
324 EvalResult(s.substr(i.toUnsignedInt(), j.toUnsignedInt()));
325 }
326 break;
327 }
328
329 case kind::STRING_CHARAT:
330 {
331 const String& s = results[currNode[0]].d_str;
332 Integer s_len(s.size());
333 Integer i = results[currNode[1]].d_rat.getNumerator();
334 if (i.strictlyNegative() || i >= s_len)
335 {
336 results[currNode] = EvalResult(String(""));
337 }
338 else
339 {
340 results[currNode] = EvalResult(s.substr(i.toUnsignedInt(), 1));
341 }
342 break;
343 }
344
345 case kind::STRING_STRCTN:
346 {
347 const String& s = results[currNode[0]].d_str;
348 const String& t = results[currNode[1]].d_str;
349 results[currNode] = EvalResult(s.find(t) != std::string::npos);
350 break;
351 }
352
353 case kind::STRING_STRIDOF:
354 {
355 const String& s = results[currNode[0]].d_str;
356 Integer s_len(s.size());
357 const String& x = results[currNode[1]].d_str;
358 Integer i = results[currNode[2]].d_rat.getNumerator();
359
360 if (i.strictlyNegative() || i >= s_len)
361 {
362 results[currNode] = EvalResult(Rational(-1));
363 }
364 else
365 {
366 size_t r = s.find(x, i.toUnsignedInt());
367 if (r == std::string::npos)
368 {
369 results[currNode] = EvalResult(Rational(-1));
370 }
371 else
372 {
373 results[currNode] = EvalResult(Rational(r));
374 }
375 }
376 break;
377 }
378
379 case kind::STRING_STRREPL:
380 {
381 const String& s = results[currNode[0]].d_str;
382 const String& x = results[currNode[1]].d_str;
383 const String& y = results[currNode[2]].d_str;
384 results[currNode] = EvalResult(s.replace(x, y));
385 break;
386 }
387
388 case kind::STRING_PREFIX:
389 {
390 const String& t = results[currNode[0]].d_str;
391 const String& s = results[currNode[1]].d_str;
392 if (s.size() < t.size())
393 {
394 results[currNode] = EvalResult(false);
395 }
396 else
397 {
398 results[currNode] = EvalResult(s.prefix(t.size()) == t);
399 }
400 break;
401 }
402
403 case kind::STRING_SUFFIX:
404 {
405 const String& t = results[currNode[0]].d_str;
406 const String& s = results[currNode[1]].d_str;
407 if (s.size() < t.size())
408 {
409 results[currNode] = EvalResult(false);
410 }
411 else
412 {
413 results[currNode] = EvalResult(s.suffix(t.size()) == t);
414 }
415 break;
416 }
417
418 case kind::STRING_ITOS:
419 {
420 Integer i = results[currNode[0]].d_rat.getNumerator();
421 if (i.strictlyNegative())
422 {
423 results[currNode] = EvalResult(String(""));
424 }
425 else
426 {
427 results[currNode] = EvalResult(String(i.toString()));
428 }
429 break;
430 }
431
432 case kind::STRING_STOI:
433 {
434 const String& s = results[currNode[0]].d_str;
435 if (s.isNumber())
436 {
437 results[currNode] = EvalResult(Rational(-1));
438 }
439 else
440 {
441 results[currNode] = EvalResult(Rational(s.toNumber()));
442 }
443 break;
444 }
445
446 case kind::CONST_BITVECTOR:
447 results[currNode] = EvalResult(currNodeVal.getConst<BitVector>());
448 break;
449
450 case kind::BITVECTOR_NOT:
451 results[currNode] = EvalResult(~results[currNode[0]].d_bv);
452 break;
453
454 case kind::BITVECTOR_NEG:
455 results[currNode] = EvalResult(-results[currNode[0]].d_bv);
456 break;
457
458 case kind::BITVECTOR_EXTRACT:
459 {
460 unsigned lo = bv::utils::getExtractLow(currNodeVal);
461 unsigned hi = bv::utils::getExtractHigh(currNodeVal);
462 results[currNode] =
463 EvalResult(results[currNode[0]].d_bv.extract(hi, lo));
464 break;
465 }
466
467 case kind::BITVECTOR_CONCAT:
468 {
469 BitVector res = results[currNode[0]].d_bv;
470 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
471 {
472 res = res.concat(results[currNode[i]].d_bv);
473 }
474 results[currNode] = EvalResult(res);
475 break;
476 }
477
478 case kind::BITVECTOR_PLUS:
479 {
480 BitVector res = results[currNode[0]].d_bv;
481 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
482 {
483 res = res + results[currNode[i]].d_bv;
484 }
485 results[currNode] = EvalResult(res);
486 break;
487 }
488
489 case kind::BITVECTOR_MULT:
490 {
491 BitVector res = results[currNode[0]].d_bv;
492 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
493 {
494 res = res * results[currNode[i]].d_bv;
495 }
496 results[currNode] = EvalResult(res);
497 break;
498 }
499 case kind::BITVECTOR_AND:
500 {
501 BitVector res = results[currNode[0]].d_bv;
502 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
503 {
504 res = res & results[currNode[i]].d_bv;
505 }
506 results[currNode] = EvalResult(res);
507 break;
508 }
509
510 case kind::BITVECTOR_OR:
511 {
512 BitVector res = results[currNode[0]].d_bv;
513 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
514 {
515 res = res | results[currNode[i]].d_bv;
516 }
517 results[currNode] = EvalResult(res);
518 break;
519 }
520
521 case kind::BITVECTOR_XOR:
522 {
523 BitVector res = results[currNode[0]].d_bv;
524 for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++)
525 {
526 res = res ^ results[currNode[i]].d_bv;
527 }
528 results[currNode] = EvalResult(res);
529 break;
530 }
531
532 case kind::EQUAL:
533 {
534 EvalResult lhs = results[currNode[0]];
535 EvalResult rhs = results[currNode[1]];
536
537 switch (lhs.d_tag)
538 {
539 case EvalResult::BOOL:
540 {
541 results[currNode] = EvalResult(lhs.d_bool == rhs.d_bool);
542 break;
543 }
544
545 case EvalResult::BITVECTOR:
546 {
547 results[currNode] = EvalResult(lhs.d_bv == rhs.d_bv);
548 break;
549 }
550
551 case EvalResult::RATIONAL:
552 {
553 results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat);
554 break;
555 }
556
557 case EvalResult::STRING:
558 {
559 results[currNode] = EvalResult(lhs.d_str == rhs.d_str);
560 break;
561 }
562
563 default:
564 {
565 Trace("evaluator") << "Theory " << Theory::theoryOf(currNode[0])
566 << " not supported" << std::endl;
567 return EvalResult();
568 break;
569 }
570 }
571
572 break;
573 }
574
575 case kind::ITE:
576 {
577 if (results[currNode[0]].d_bool)
578 {
579 results[currNode] = results[currNode[1]];
580 }
581 else
582 {
583 results[currNode] = results[currNode[2]];
584 }
585 break;
586 }
587
588 default:
589 {
590 Trace("evaluator") << "Kind " << currNodeVal.getKind()
591 << " not supported" << std::endl;
592 return EvalResult();
593 }
594 }
595 }
596 }
597
598 return results[n];
599 }
600
601 } // namespace theory
602 } // namespace CVC4