From: Andres Noetzli Date: Tue, 26 Jun 2018 23:09:03 +0000 (-0700) Subject: sygusComp2018: Add evaluator (#2090) X-Git-Tag: cvc5-1.0.0~4942 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=81bb4147ad681641dc99a62fc1a8605f99c05f2d;p=cvc5.git sygusComp2018: Add evaluator (#2090) This commit adds the Evaluator class that can be used to quickly evaluate terms under a given substitution without going through the rewriter. This has been shown to lead to significant performance improvements on SyGuS PBE problems with a large number of inputs because candidate solutions are evaluated on those inputs. With this commit, the evaluator only gets called from the SyGuS solver but there are potentially other places in the code that could profit from it. --- diff --git a/src/Makefile.am b/src/Makefile.am index ce9f74d9e..b36c453e1 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -179,6 +179,8 @@ libcvc4_la_SOURCES = \ theory/atom_requests.cpp \ theory/atom_requests.h \ theory/care_graph.h \ + theory/evaluator.cpp \ + theory/evaluator.h \ theory/interrupted.h \ theory/ite_utilities.cpp \ theory/ite_utilities.h \ diff --git a/src/options/quantifiers_options.toml b/src/options/quantifiers_options.toml index 69868ad8d..be4c66b27 100644 --- a/src/options/quantifiers_options.toml +++ b/src/options/quantifiers_options.toml @@ -1118,6 +1118,14 @@ header = "options/quantifiers_options.h" includes = ["options/quantifiers_modes.h"] help = "mode for using samples in the counterexample-guided inductive synthesis loop" +[[option]] + name = "sygusEvalOpt" + category = "regular" + long = "sygus-eval-opt" + type = "bool" + default = "true" + help = "use optimized approach for evaluation in sygus" + # Internal uses of sygus [[option]] diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp new file mode 100644 index 000000000..ca2140ed5 --- /dev/null +++ b/src/theory/evaluator.cpp @@ -0,0 +1,597 @@ +/********************* */ +/*! \file evaluator.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andres Noetzli + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The Evaluator class + ** + ** The Evaluator class. + **/ + +#include "theory/evaluator.h" + +#include "theory/bv/theory_bv_utils.h" +#include "theory/theory.h" +#include "util/integer.h" + +namespace CVC4 { +namespace theory { + +EvalResult::EvalResult(const EvalResult& other) +{ + d_tag = other.d_tag; + switch (d_tag) + { + case BOOL: d_bool = other.d_bool; break; + case BITVECTOR: + new (&d_bv) BitVector; + d_bv = other.d_bv; + break; + case RATIONAL: + new (&d_rat) Rational; + d_rat = other.d_rat; + break; + case STRING: + new (&d_str) String; + d_str = other.d_str; + break; + case INVALID: break; + } +} + +EvalResult& EvalResult::operator=(const EvalResult& other) +{ + if (this != &other) + { + d_tag = other.d_tag; + switch (d_tag) + { + case BOOL: d_bool = other.d_bool; break; + case BITVECTOR: + new (&d_bv) BitVector; + d_bv = other.d_bv; + break; + case RATIONAL: + new (&d_rat) Rational; + d_rat = other.d_rat; + break; + case STRING: + new (&d_str) String; + d_str = other.d_str; + break; + case INVALID: break; + } + } + return *this; +} + +EvalResult::~EvalResult() +{ + switch (d_tag) + { + case BITVECTOR: + { + d_bv.~BitVector(); + break; + } + case RATIONAL: + { + d_rat.~Rational(); + break; + } + case STRING: + { + d_str.~String(); + break; + + default: break; + } + } +} + +Node EvalResult::toNode() const +{ + NodeManager* nm = NodeManager::currentNM(); + switch (d_tag) + { + case EvalResult::BOOL: return nm->mkConst(d_bool); + case EvalResult::BITVECTOR: return nm->mkConst(d_bv); + case EvalResult::RATIONAL: return nm->mkConst(d_rat); + case EvalResult::STRING: return nm->mkConst(d_str); + default: + { + Trace("evaluator") << "Missing conversion from " << d_tag << " to node" + << std::endl; + return Node(); + } + } + + return Node(); +} + +Node Evaluator::eval(TNode n, + const std::vector& args, + const std::vector& vals) +{ + Trace("evaluator") << "Evaluating " << n << " under substitution " << args + << " " << vals << std::endl; + return evalInternal(n, args, vals).toNode(); +} + +EvalResult Evaluator::evalInternal(TNode n, + const std::vector& args, + const std::vector& vals) +{ + std::unordered_map results; + std::vector queue; + queue.emplace_back(n); + + while (queue.size() != 0) + { + TNode currNode = queue.back(); + + if (results.find(currNode) != results.end()) + { + queue.pop_back(); + continue; + } + + bool doEval = true; + for (const auto& currNodeChild : currNode) + { + if (results.find(currNodeChild) == results.end()) + { + queue.emplace_back(currNodeChild); + doEval = false; + } + } + + if (doEval) + { + queue.pop_back(); + + Node currNodeVal = currNode; + if (currNode.isVar()) + { + const auto& it = std::find(args.begin(), args.end(), currNode); + + if (it == args.end()) + { + return EvalResult(); + } + + ptrdiff_t pos = std::distance(args.begin(), it); + currNodeVal = vals[pos]; + } + else if (currNode.getKind() == kind::APPLY_UF + && currNode.getOperator().getKind() == kind::LAMBDA) + { + // Create a copy of the current substitutions + std::vector lambdaArgs(args); + std::vector lambdaVals(vals); + + // Add the values for the arguments of the lambda as substitutions at + // the beginning of the vector to shadow variables from outer scopes + // with the same name + Node op = currNode.getOperator(); + for (const auto& lambdaArg : op[0]) + { + lambdaArgs.insert(lambdaArgs.begin(), lambdaArg); + } + + for (const auto& lambdaVal : currNode) + { + lambdaVals.insert(lambdaVals.begin(), results[lambdaVal].toNode()); + } + + // Lambdas are evaluated in a recursive fashion because each evaluation + // requires different substitutions + results[currNode] = evalInternal(op[1], lambdaArgs, lambdaVals); + continue; + } + + switch (currNodeVal.getKind()) + { + case kind::CONST_BOOLEAN: + results[currNode] = EvalResult(currNodeVal.getConst()); + break; + + case kind::NOT: + { + results[currNode] = EvalResult(!(results[currNode[0]].d_bool)); + break; + } + + case kind::AND: + { + bool res = results[currNode[0]].d_bool; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res && results[currNode[i]].d_bool; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::OR: + { + bool res = results[currNode[0]].d_bool; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res || results[currNode[i]].d_bool; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::CONST_RATIONAL: + { + const Rational& r = currNodeVal.getConst(); + results[currNode] = EvalResult(r); + break; + } + + case kind::PLUS: + { + Rational res = results[currNode[0]].d_rat; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res + results[currNode[i]].d_rat; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::MINUS: + { + const Rational& x = results[currNode[0]].d_rat; + const Rational& y = results[currNode[1]].d_rat; + results[currNode] = EvalResult(x - y); + break; + } + + case kind::MULT: + { + Rational res = results[currNode[0]].d_rat; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res * results[currNode[i]].d_rat; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::GEQ: + { + const Rational& x = results[currNode[0]].d_rat; + const Rational& y = results[currNode[1]].d_rat; + results[currNode] = EvalResult(x >= y); + break; + } + + case kind::CONST_STRING: + results[currNode] = EvalResult(currNodeVal.getConst()); + break; + + case kind::STRING_CONCAT: + { + String res = results[currNode[0]].d_str; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res.concat(results[currNode[i]].d_str); + } + results[currNode] = EvalResult(res); + break; + } + + case kind::STRING_LENGTH: + { + const String& s = results[currNode[0]].d_str; + results[currNode] = EvalResult(Rational(s.size())); + break; + } + + case kind::STRING_SUBSTR: + { + const String& s = results[currNode[0]].d_str; + Integer s_len(s.size()); + Integer i = results[currNode[1]].d_rat.getNumerator(); + Integer j = results[currNode[2]].d_rat.getNumerator(); + + if (i.strictlyNegative() || j.strictlyNegative() || i >= s_len) + { + results[currNode] = EvalResult(String("")); + } + else if (i + j > s_len) + { + results[currNode] = + EvalResult(s.suffix((s_len - i).toUnsignedInt())); + } + else + { + results[currNode] = + EvalResult(s.substr(i.toUnsignedInt(), j.toUnsignedInt())); + } + break; + } + + case kind::STRING_CHARAT: + { + const String& s = results[currNode[0]].d_str; + Integer s_len(s.size()); + Integer i = results[currNode[1]].d_rat.getNumerator(); + if (i.strictlyNegative() || i >= s_len) + { + results[currNode] = EvalResult(String("")); + } + else + { + results[currNode] = EvalResult(s.substr(i.toUnsignedInt(), 1)); + } + break; + } + + case kind::STRING_STRCTN: + { + const String& s = results[currNode[0]].d_str; + const String& t = results[currNode[1]].d_str; + results[currNode] = EvalResult(s.find(t) != std::string::npos); + break; + } + + case kind::STRING_STRIDOF: + { + const String& s = results[currNode[0]].d_str; + Integer s_len(s.size()); + const String& x = results[currNode[1]].d_str; + Integer i = results[currNode[2]].d_rat.getNumerator(); + + if (i.strictlyNegative() || i >= s_len) + { + results[currNode] = EvalResult(Rational(-1)); + } + else + { + size_t r = s.find(x, i.toUnsignedInt()); + if (r == std::string::npos) + { + results[currNode] = EvalResult(Rational(-1)); + } + else + { + results[currNode] = EvalResult(Rational(r)); + } + } + break; + } + + case kind::STRING_STRREPL: + { + const String& s = results[currNode[0]].d_str; + const String& x = results[currNode[1]].d_str; + const String& y = results[currNode[2]].d_str; + results[currNode] = EvalResult(s.replace(x, y)); + break; + } + + case kind::STRING_PREFIX: + { + const String& t = results[currNode[0]].d_str; + const String& s = results[currNode[1]].d_str; + if (s.size() < t.size()) + { + results[currNode] = EvalResult(false); + } + else + { + results[currNode] = EvalResult(s.prefix(t.size()) == t); + } + break; + } + + case kind::STRING_SUFFIX: + { + const String& t = results[currNode[0]].d_str; + const String& s = results[currNode[1]].d_str; + if (s.size() < t.size()) + { + results[currNode] = EvalResult(false); + } + else + { + results[currNode] = EvalResult(s.suffix(t.size()) == t); + } + break; + } + + case kind::STRING_ITOS: + { + Integer i = results[currNode[0]].d_rat.getNumerator(); + if (i.strictlyNegative()) + { + results[currNode] = EvalResult(String("")); + } + else + { + results[currNode] = EvalResult(String(i.toString())); + } + break; + } + + case kind::STRING_STOI: + { + const String& s = results[currNode[0]].d_str; + if (s.isNumber()) + { + results[currNode] = EvalResult(Rational(-1)); + } + else + { + results[currNode] = EvalResult(Rational(s.toNumber())); + } + break; + } + + case kind::CONST_BITVECTOR: + results[currNode] = EvalResult(currNodeVal.getConst()); + break; + + case kind::BITVECTOR_NOT: + results[currNode] = EvalResult(~results[currNode[0]].d_bv); + break; + + case kind::BITVECTOR_NEG: + results[currNode] = EvalResult(-results[currNode[0]].d_bv); + break; + + case kind::BITVECTOR_EXTRACT: + { + unsigned lo = bv::utils::getExtractLow(currNodeVal); + unsigned hi = bv::utils::getExtractHigh(currNodeVal); + results[currNode] = + EvalResult(results[currNode[0]].d_bv.extract(hi, lo)); + break; + } + + case kind::BITVECTOR_CONCAT: + { + BitVector res = results[currNode[0]].d_bv; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res.concat(results[currNode[i]].d_bv); + } + results[currNode] = EvalResult(res); + break; + } + + case kind::BITVECTOR_PLUS: + { + BitVector res = results[currNode[0]].d_bv; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res + results[currNode[i]].d_bv; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::BITVECTOR_MULT: + { + BitVector res = results[currNode[0]].d_bv; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res * results[currNode[i]].d_bv; + } + results[currNode] = EvalResult(res); + break; + } + case kind::BITVECTOR_AND: + { + BitVector res = results[currNode[0]].d_bv; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res & results[currNode[i]].d_bv; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::BITVECTOR_OR: + { + BitVector res = results[currNode[0]].d_bv; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res | results[currNode[i]].d_bv; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::BITVECTOR_XOR: + { + BitVector res = results[currNode[0]].d_bv; + for (size_t i = 1, end = currNode.getNumChildren(); i < end; i++) + { + res = res ^ results[currNode[i]].d_bv; + } + results[currNode] = EvalResult(res); + break; + } + + case kind::EQUAL: + { + EvalResult lhs = results[currNode[0]]; + EvalResult rhs = results[currNode[1]]; + + switch (lhs.d_tag) + { + case EvalResult::BOOL: + { + results[currNode] = EvalResult(lhs.d_bool == rhs.d_bool); + break; + } + + case EvalResult::BITVECTOR: + { + results[currNode] = EvalResult(lhs.d_bv == rhs.d_bv); + break; + } + + case EvalResult::RATIONAL: + { + results[currNode] = EvalResult(lhs.d_rat == rhs.d_rat); + break; + } + + case EvalResult::STRING: + { + results[currNode] = EvalResult(lhs.d_str == rhs.d_str); + break; + } + + default: + { + Trace("evaluator") << "Theory " << Theory::theoryOf(currNode[0]) + << " not supported" << std::endl; + return EvalResult(); + break; + } + } + + break; + } + + case kind::ITE: + { + if (results[currNode[0]].d_bool) + { + results[currNode] = results[currNode[1]]; + } + else + { + results[currNode] = results[currNode[2]]; + } + break; + } + + default: + { + Trace("evaluator") << "Kind " << currNodeVal.getKind() + << " not supported" << std::endl; + return EvalResult(); + } + } + } + } + + return results[n]; +} + +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/evaluator.h b/src/theory/evaluator.h new file mode 100644 index 000000000..0d7ddbec8 --- /dev/null +++ b/src/theory/evaluator.h @@ -0,0 +1,113 @@ +/********************* */ +/*! \file evaluator.h + ** \verbatim + ** Top contributors (to current version): + ** Andres Noetzli + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The Evaluator class + ** + ** The Evaluator class can be used to evaluate terms with constant leaves + ** quickly, without going through the rewriter. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__EVALUATOR_H +#define __CVC4__THEORY__EVALUATOR_H + +#include +#include + +#include "base/output.h" +#include "expr/node.h" +#include "util/bitvector.h" +#include "util/rational.h" +#include "util/regexp.h" + +namespace CVC4 { +namespace theory { + +/** + * Struct that holds the result of an evaluation. The actual value is stored in + * a union to avoid the overhead of a class hierarchy with virtual methods. + */ +struct EvalResult +{ + /* Describes which type of result is being stored */ + enum + { + BOOL, + BITVECTOR, + RATIONAL, + STRING, + INVALID + } d_tag; + + /* Stores the actual result */ + union + { + bool d_bool; + BitVector d_bv; + Rational d_rat; + String d_str; + }; + + EvalResult(const EvalResult& other); + EvalResult() : d_tag(INVALID) {} + EvalResult(bool b) : d_tag(BOOL), d_bool(b) {} + EvalResult(const BitVector& bv) : d_tag(BITVECTOR), d_bv(bv) {} + EvalResult(const Rational& i) : d_tag(RATIONAL), d_rat(i) {} + EvalResult(const String& str) : d_tag(STRING), d_str(str) {} + + EvalResult& operator=(const EvalResult& other); + + ~EvalResult(); + + /** + * Converts the result to a Node. If the result is not valid, this function + * returns the null node. + */ + Node toNode() const; +}; + +/** + * The class that performs the actual evaluation of a term under a + * substitution. Right now, the class does not cache anything between different + * calls to `eval` but this might change in the future. + */ +class Evaluator +{ + public: + /** + * Evaluates node `n` under the substitution described by the variable names + * `args` and the corresponding values `vals`. The function returns a null + * node if there is a subterm that is not constant under the substitution or + * if an operator is not supported by the evaluator. + */ + Node eval(TNode n, + const std::vector& args, + const std::vector& vals); + + private: + /** + * Evaluates node `n` under the substitution described by the variable names + * `args` and the corresponding values `vals`. The internal version returns + * an EvalResult which has slightly less overhead for recursive calls. The + * function returns an invalid result if there is a subterm that is not + * constant under the substitution or if an operator is not supported by the + * evaluator. + */ + EvalResult evalInternal(TNode n, + const std::vector& args, + const std::vector& vals); +}; + +} // namespace theory +} // namespace CVC4 + +#endif /* __CVC4__THEORY__EVALUATOR_H */ diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index d4349745b..26f26a145 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -22,6 +22,8 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" #include "theory/quantifiers_engine.h" +#include "options/base_options.h" +#include "printer/printer.h" using namespace CVC4::kind; @@ -33,6 +35,7 @@ TermDbSygus::TermDbSygus(context::Context* c, QuantifiersEngine* qe) : d_quantEngine(qe), d_syexp(new SygusExplain(this)), d_ext_rw(new ExtendedRewriter(true)), + d_eval(new Evaluator), d_eval_unfold(new SygusEvalUnfold(this)) { d_true = NodeManager::currentNM()->mkConst( true ); @@ -1548,13 +1551,39 @@ Node TermDbSygus::getEagerUnfold( Node n, std::map< Node, Node >& visited ) { } } - -Node TermDbSygus::evaluateBuiltin( TypeNode tn, Node bn, std::vector< Node >& args ) { +Node TermDbSygus::evaluateBuiltin(TypeNode tn, + Node bn, + std::vector& args, + bool tryEval) +{ if( !args.empty() ){ std::map< TypeNode, std::vector< Node > >::iterator it = d_var_list.find( tn ); Assert( it!=d_var_list.end() ); Assert( it->second.size()==args.size() ); - return Rewriter::rewrite( bn.substitute( it->second.begin(), it->second.end(), args.begin(), args.end() ) ); + + Node res; + if (tryEval && options::sygusEvalOpt()) + { + // Try evaluating, which is much faster than substitution+rewriting. + // This may fail if there is a subterm of bn under the + // substitution that is not constant, or if an operator in bn is not + // supported by the evaluator + res = d_eval->eval(bn, it->second, args); + } + if (!res.isNull()) + { + Assert(res + == Rewriter::rewrite(bn.substitute(it->second.begin(), + it->second.end(), + args.begin(), + args.end()))); + return res; + } + else + { + return Rewriter::rewrite(bn.substitute( + it->second.begin(), it->second.end(), args.begin(), args.end())); + } }else{ return Rewriter::rewrite( bn ); } diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index be35d07f3..286533570 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -19,6 +19,7 @@ #include +#include "theory/evaluator.h" #include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/sygus/sygus_eval_unfold.h" #include "theory/quantifiers/sygus/sygus_explain.h" @@ -53,6 +54,8 @@ class TermDbSygus { SygusExplain* getExplain() { return d_syexp.get(); } /** get the extended rewrite utility */ ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); } + /** get the evaluator */ + Evaluator* getEvaluator() { return d_eval.get(); } /** evaluation unfolding utility */ SygusEvalUnfold* getEvalUnfold() { return d_eval_unfold.get(); } //------------------------------end utilities @@ -182,7 +185,8 @@ class TermDbSygus { * form of bn [ args / vars(tn) ], where vars(tn) is the sygus variable * list for type tn (see Datatype::getSygusVarList). */ - Node evaluateBuiltin(TypeNode tn, Node bn, std::vector& args); + Node evaluateBuiltin(TypeNode tn, Node bn, std::vector& args, +bool tryEval = true); /** evaluate with unfolding * * n is any term that may involve sygus evaluation functions. This function @@ -222,6 +226,8 @@ class TermDbSygus { std::unique_ptr d_syexp; /** extended rewriter */ std::unique_ptr d_ext_rw; + /** evaluator */ + std::unique_ptr d_eval; /** evaluation function unfolding utility */ std::unique_ptr d_eval_unfold; //------------------------------end utilities diff --git a/test/unit/Makefile.am b/test/unit/Makefile.am index cc9f6fb1b..0ab305039 100644 --- a/test/unit/Makefile.am +++ b/test/unit/Makefile.am @@ -6,6 +6,7 @@ UNIT_TESTS = \ util/cardinality_public if WHITE_AND_BLACK_TESTS UNIT_TESTS += \ + theory/evaluator_white \ theory/logic_info_white \ theory/theory_arith_white \ theory/theory_black \ diff --git a/test/unit/theory/evaluator_white.h b/test/unit/theory/evaluator_white.h new file mode 100644 index 000000000..4416ee00a --- /dev/null +++ b/test/unit/theory/evaluator_white.h @@ -0,0 +1,122 @@ +/********************* */ +/*! \file evaluator_white.h + ** \verbatim + ** Top contributors (to current version): + ** Andres Noetzli + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include +#include + +#include "expr/node.h" +#include "expr/node_manager.h" +#include "smt/smt_engine.h" +#include "smt/smt_engine_scope.h" +#include "theory/bv/theory_bv_utils.h" +#include "theory/evaluator.h" +#include "theory/rewriter.h" +#include "theory/theory_test_utils.h" + +using namespace CVC4; +using namespace CVC4::smt; +using namespace CVC4::theory; + +using namespace std; + +class TheoryEvaluatorWhite : public CxxTest::TestSuite +{ + ExprManager *d_em; + NodeManager *d_nm; + SmtEngine *d_smt; + SmtScope *d_scope; + + public: + TheoryEvaluatorWhite() {} + + void setUp() + { + Options opts; + opts.setOutputLanguage(language::output::LANG_SMTLIB_V2); + d_em = new ExprManager(opts); + d_nm = NodeManager::fromExprManager(d_em); + d_smt = new SmtEngine(d_em); + d_scope = new SmtScope(d_smt); + } + + void tearDown() + { + delete d_scope; + delete d_smt; + delete d_em; + } + + void testSimple() + { + TypeNode bv64Type = d_nm->mkBitVectorType(64); + + Node w = d_nm->mkVar("w", bv64Type); + Node x = d_nm->mkVar("x", bv64Type); + Node y = d_nm->mkVar("y", bv64Type); + Node z = d_nm->mkVar("z", bv64Type); + + Node zero = d_nm->mkConst(BitVector(64, (unsigned int)0)); + Node one = d_nm->mkConst(BitVector(64, (unsigned int)1)); + Node c1 = d_nm->mkConst(BitVector( + 64, + (unsigned int)0b0000000100000101001110111001101000101110011101011011110011100111)); + Node c2 = d_nm->mkConst(BitVector( + 64, + (unsigned int)0b0000000100000101001110111001101000101110011101011011110011100111)); + + Node t = d_nm->mkNode(kind::ITE, d_nm->mkNode(kind::EQUAL, y, one), x, w); + + std::vector args = {w, x, y, z}; + std::vector vals = {c1, zero, one, c1}; + + Evaluator eval; + Node r = eval.eval(t, args, vals); + TS_ASSERT_EQUALS(r, + Rewriter::rewrite(t.substitute( + args.begin(), args.end(), vals.begin(), vals.end()))); + } + + void testLoop() + { + TypeNode bv64Type = d_nm->mkBitVectorType(64); + + Node w = d_nm->mkBoundVar(bv64Type); + Node x = d_nm->mkVar("x", bv64Type); + + Node zero = d_nm->mkConst(BitVector(1, (unsigned int)0)); + Node one = d_nm->mkConst(BitVector(64, (unsigned int)1)); + Node c = d_nm->mkConst(BitVector( + 64, + (unsigned int)0b0001111000010111110000110110001101011110111001101100000101010100)); + + Node largs = d_nm->mkNode(kind::BOUND_VAR_LIST, w); + Node lbody = d_nm->mkNode( + kind::BITVECTOR_CONCAT, bv::utils::mkExtract(w, 62, 0), zero); + Node lambda = d_nm->mkNode(kind::LAMBDA, largs, lbody); + Node t = d_nm->mkNode(kind::BITVECTOR_AND, + d_nm->mkNode(kind::APPLY_UF, lambda, one), + d_nm->mkNode(kind::APPLY_UF, lambda, x)); + + std::vector args = {x}; + std::vector vals = {c}; + Evaluator eval; + Node r = eval.eval(t, args, vals); + TS_ASSERT_EQUALS(r, + Rewriter::rewrite(t.substitute( + args.begin(), args.end(), vals.begin(), vals.end()))); + } +};