1 /********************* */
2 /*! \file evaluator.cpp
4 ** Top contributors (to current version):
5 ** Andres Noetzli, Andrew Reynolds
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief The Evaluator class
14 ** The Evaluator class.
17 #include "theory/evaluator.h"
19 #include "theory/bv/theory_bv_utils.h"
20 #include "theory/theory.h"
21 #include "util/integer.h"
26 EvalResult::EvalResult(const EvalResult
& other
)
31 case BOOL
: d_bool
= other
.d_bool
; break;
33 new (&d_bv
) BitVector
;
37 new (&d_rat
) Rational
;
48 EvalResult
& EvalResult::operator=(const EvalResult
& other
)
55 case BOOL
: d_bool
= other
.d_bool
; break;
57 new (&d_bv
) BitVector
;
61 new (&d_rat
) Rational
;
74 EvalResult::~EvalResult()
98 Node
EvalResult::toNode() const
100 NodeManager
* nm
= NodeManager::currentNM();
103 case EvalResult::BOOL
: return nm
->mkConst(d_bool
);
104 case EvalResult::BITVECTOR
: return nm
->mkConst(d_bv
);
105 case EvalResult::RATIONAL
: return nm
->mkConst(d_rat
);
106 case EvalResult::STRING
: return nm
->mkConst(d_str
);
109 Trace("evaluator") << "Missing conversion from " << d_tag
<< " to node"
116 Node
Evaluator::eval(TNode n
,
117 const std::vector
<Node
>& args
,
118 const std::vector
<Node
>& vals
)
120 Trace("evaluator") << "Evaluating " << n
<< " under substitution " << args
121 << " " << vals
<< std::endl
;
122 return evalInternal(n
, args
, vals
).toNode();
125 EvalResult
Evaluator::evalInternal(TNode n
,
126 const std::vector
<Node
>& args
,
127 const std::vector
<Node
>& vals
)
129 std::unordered_map
<TNode
, EvalResult
, TNodeHashFunction
> results
;
130 std::vector
<TNode
> queue
;
131 queue
.emplace_back(n
);
133 while (queue
.size() != 0)
135 TNode currNode
= queue
.back();
137 if (results
.find(currNode
) != results
.end())
144 for (const auto& currNodeChild
: currNode
)
146 if (results
.find(currNodeChild
) == results
.end())
148 queue
.emplace_back(currNodeChild
);
157 Node currNodeVal
= currNode
;
158 if (currNode
.isVar())
160 const auto& it
= std::find(args
.begin(), args
.end(), currNode
);
162 if (it
== args
.end())
167 ptrdiff_t pos
= std::distance(args
.begin(), it
);
168 currNodeVal
= vals
[pos
];
170 else if (currNode
.getKind() == kind::APPLY_UF
171 && currNode
.getOperator().getKind() == kind::LAMBDA
)
173 // Create a copy of the current substitutions
174 std::vector
<Node
> lambdaArgs(args
);
175 std::vector
<Node
> lambdaVals(vals
);
177 // Add the values for the arguments of the lambda as substitutions at
178 // the beginning of the vector to shadow variables from outer scopes
179 // with the same name
180 Node op
= currNode
.getOperator();
181 for (const auto& lambdaArg
: op
[0])
183 lambdaArgs
.insert(lambdaArgs
.begin(), lambdaArg
);
186 for (const auto& lambdaVal
: currNode
)
188 lambdaVals
.insert(lambdaVals
.begin(), results
[lambdaVal
].toNode());
191 // Lambdas are evaluated in a recursive fashion because each evaluation
192 // requires different substitutions
193 results
[currNode
] = evalInternal(op
[1], lambdaArgs
, lambdaVals
);
194 if (results
[currNode
].d_tag
== EvalResult::INVALID
)
196 // evaluation was invalid, we fail
197 return results
[currNode
];
202 switch (currNodeVal
.getKind())
204 case kind::CONST_BOOLEAN
:
205 results
[currNode
] = EvalResult(currNodeVal
.getConst
<bool>());
210 results
[currNode
] = EvalResult(!(results
[currNode
[0]].d_bool
));
216 bool res
= results
[currNode
[0]].d_bool
;
217 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
219 res
= res
&& results
[currNode
[i
]].d_bool
;
221 results
[currNode
] = EvalResult(res
);
227 bool res
= results
[currNode
[0]].d_bool
;
228 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
230 res
= res
|| results
[currNode
[i
]].d_bool
;
232 results
[currNode
] = EvalResult(res
);
236 case kind::CONST_RATIONAL
:
238 const Rational
& r
= currNodeVal
.getConst
<Rational
>();
239 results
[currNode
] = EvalResult(r
);
245 Rational res
= results
[currNode
[0]].d_rat
;
246 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
248 res
= res
+ results
[currNode
[i
]].d_rat
;
250 results
[currNode
] = EvalResult(res
);
256 const Rational
& x
= results
[currNode
[0]].d_rat
;
257 const Rational
& y
= results
[currNode
[1]].d_rat
;
258 results
[currNode
] = EvalResult(x
- y
);
264 const Rational
& x
= results
[currNode
[0]].d_rat
;
265 results
[currNode
] = EvalResult(-x
);
270 Rational res
= results
[currNode
[0]].d_rat
;
271 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
273 res
= res
* results
[currNode
[i
]].d_rat
;
275 results
[currNode
] = EvalResult(res
);
281 const Rational
& x
= results
[currNode
[0]].d_rat
;
282 const Rational
& y
= results
[currNode
[1]].d_rat
;
283 results
[currNode
] = EvalResult(x
>= y
);
288 const Rational
& x
= results
[currNode
[0]].d_rat
;
289 const Rational
& y
= results
[currNode
[1]].d_rat
;
290 results
[currNode
] = EvalResult(x
<= y
);
295 const Rational
& x
= results
[currNode
[0]].d_rat
;
296 const Rational
& y
= results
[currNode
[1]].d_rat
;
297 results
[currNode
] = EvalResult(x
> y
);
302 const Rational
& x
= results
[currNode
[0]].d_rat
;
303 const Rational
& y
= results
[currNode
[1]].d_rat
;
304 results
[currNode
] = EvalResult(x
< y
);
308 case kind::CONST_STRING
:
309 results
[currNode
] = EvalResult(currNodeVal
.getConst
<String
>());
312 case kind::STRING_CONCAT
:
314 String res
= results
[currNode
[0]].d_str
;
315 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
317 res
= res
.concat(results
[currNode
[i
]].d_str
);
319 results
[currNode
] = EvalResult(res
);
323 case kind::STRING_LENGTH
:
325 const String
& s
= results
[currNode
[0]].d_str
;
326 results
[currNode
] = EvalResult(Rational(s
.size()));
330 case kind::STRING_SUBSTR
:
332 const String
& s
= results
[currNode
[0]].d_str
;
333 Integer
s_len(s
.size());
334 Integer i
= results
[currNode
[1]].d_rat
.getNumerator();
335 Integer j
= results
[currNode
[2]].d_rat
.getNumerator();
337 if (i
.strictlyNegative() || j
.strictlyNegative() || i
>= s_len
)
339 results
[currNode
] = EvalResult(String(""));
341 else if (i
+ j
> s_len
)
344 EvalResult(s
.suffix((s_len
- i
).toUnsignedInt()));
349 EvalResult(s
.substr(i
.toUnsignedInt(), j
.toUnsignedInt()));
354 case kind::STRING_CHARAT
:
356 const String
& s
= results
[currNode
[0]].d_str
;
357 Integer
s_len(s
.size());
358 Integer i
= results
[currNode
[1]].d_rat
.getNumerator();
359 if (i
.strictlyNegative() || i
>= s_len
)
361 results
[currNode
] = EvalResult(String(""));
365 results
[currNode
] = EvalResult(s
.substr(i
.toUnsignedInt(), 1));
370 case kind::STRING_STRCTN
:
372 const String
& s
= results
[currNode
[0]].d_str
;
373 const String
& t
= results
[currNode
[1]].d_str
;
374 results
[currNode
] = EvalResult(s
.find(t
) != std::string::npos
);
378 case kind::STRING_STRIDOF
:
380 const String
& s
= results
[currNode
[0]].d_str
;
381 Integer
s_len(s
.size());
382 const String
& x
= results
[currNode
[1]].d_str
;
383 Integer i
= results
[currNode
[2]].d_rat
.getNumerator();
385 if (i
.strictlyNegative())
387 results
[currNode
] = EvalResult(Rational(-1));
391 size_t r
= s
.find(x
, i
.toUnsignedInt());
392 if (r
== std::string::npos
)
394 results
[currNode
] = EvalResult(Rational(-1));
398 results
[currNode
] = EvalResult(Rational(r
));
404 case kind::STRING_STRREPL
:
406 const String
& s
= results
[currNode
[0]].d_str
;
407 const String
& x
= results
[currNode
[1]].d_str
;
408 const String
& y
= results
[currNode
[2]].d_str
;
409 results
[currNode
] = EvalResult(s
.replace(x
, y
));
413 case kind::STRING_PREFIX
:
415 const String
& t
= results
[currNode
[0]].d_str
;
416 const String
& s
= results
[currNode
[1]].d_str
;
417 if (s
.size() < t
.size())
419 results
[currNode
] = EvalResult(false);
423 results
[currNode
] = EvalResult(s
.prefix(t
.size()) == t
);
428 case kind::STRING_SUFFIX
:
430 const String
& t
= results
[currNode
[0]].d_str
;
431 const String
& s
= results
[currNode
[1]].d_str
;
432 if (s
.size() < t
.size())
434 results
[currNode
] = EvalResult(false);
438 results
[currNode
] = EvalResult(s
.suffix(t
.size()) == t
);
443 case kind::STRING_ITOS
:
445 Integer i
= results
[currNode
[0]].d_rat
.getNumerator();
446 if (i
.strictlyNegative())
448 results
[currNode
] = EvalResult(String(""));
452 results
[currNode
] = EvalResult(String(i
.toString()));
457 case kind::STRING_STOI
:
459 const String
& s
= results
[currNode
[0]].d_str
;
462 results
[currNode
] = EvalResult(Rational(s
.toNumber()));
466 results
[currNode
] = EvalResult(Rational(-1));
471 case kind::STRING_CODE
:
473 const String
& s
= results
[currNode
[0]].d_str
;
476 results
[currNode
] = EvalResult(
477 Rational(String::convertUnsignedIntToCode(s
.getVec()[0])));
481 results
[currNode
] = EvalResult(Rational(-1));
486 case kind::CONST_BITVECTOR
:
487 results
[currNode
] = EvalResult(currNodeVal
.getConst
<BitVector
>());
490 case kind::BITVECTOR_NOT
:
491 results
[currNode
] = EvalResult(~results
[currNode
[0]].d_bv
);
494 case kind::BITVECTOR_NEG
:
495 results
[currNode
] = EvalResult(-results
[currNode
[0]].d_bv
);
498 case kind::BITVECTOR_EXTRACT
:
500 unsigned lo
= bv::utils::getExtractLow(currNodeVal
);
501 unsigned hi
= bv::utils::getExtractHigh(currNodeVal
);
503 EvalResult(results
[currNode
[0]].d_bv
.extract(hi
, lo
));
507 case kind::BITVECTOR_CONCAT
:
509 BitVector res
= results
[currNode
[0]].d_bv
;
510 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
512 res
= res
.concat(results
[currNode
[i
]].d_bv
);
514 results
[currNode
] = EvalResult(res
);
518 case kind::BITVECTOR_PLUS
:
520 BitVector res
= results
[currNode
[0]].d_bv
;
521 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
523 res
= res
+ results
[currNode
[i
]].d_bv
;
525 results
[currNode
] = EvalResult(res
);
529 case kind::BITVECTOR_MULT
:
531 BitVector res
= results
[currNode
[0]].d_bv
;
532 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
534 res
= res
* results
[currNode
[i
]].d_bv
;
536 results
[currNode
] = EvalResult(res
);
539 case kind::BITVECTOR_AND
:
541 BitVector res
= results
[currNode
[0]].d_bv
;
542 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
544 res
= res
& results
[currNode
[i
]].d_bv
;
546 results
[currNode
] = EvalResult(res
);
550 case kind::BITVECTOR_OR
:
552 BitVector res
= results
[currNode
[0]].d_bv
;
553 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
555 res
= res
| results
[currNode
[i
]].d_bv
;
557 results
[currNode
] = EvalResult(res
);
561 case kind::BITVECTOR_XOR
:
563 BitVector res
= results
[currNode
[0]].d_bv
;
564 for (size_t i
= 1, end
= currNode
.getNumChildren(); i
< end
; i
++)
566 res
= res
^ results
[currNode
[i
]].d_bv
;
568 results
[currNode
] = EvalResult(res
);
574 EvalResult lhs
= results
[currNode
[0]];
575 EvalResult rhs
= results
[currNode
[1]];
579 case EvalResult::BOOL
:
581 results
[currNode
] = EvalResult(lhs
.d_bool
== rhs
.d_bool
);
585 case EvalResult::BITVECTOR
:
587 results
[currNode
] = EvalResult(lhs
.d_bv
== rhs
.d_bv
);
591 case EvalResult::RATIONAL
:
593 results
[currNode
] = EvalResult(lhs
.d_rat
== rhs
.d_rat
);
597 case EvalResult::STRING
:
599 results
[currNode
] = EvalResult(lhs
.d_str
== rhs
.d_str
);
605 Trace("evaluator") << "Theory " << Theory::theoryOf(currNode
[0])
606 << " not supported" << std::endl
;
617 if (results
[currNode
[0]].d_bool
)
619 results
[currNode
] = results
[currNode
[1]];
623 results
[currNode
] = results
[currNode
[2]];
630 Trace("evaluator") << "Kind " << currNodeVal
.getKind()
631 << " not supported" << std::endl
;
641 } // namespace theory