From: Andrew Reynolds Date: Wed, 22 Sep 2021 20:05:44 +0000 (-0500) Subject: Towards standard usage of evaluator (#7189) X-Git-Tag: cvc5-1.0.0~1180 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=ba259d66be877de3cc77e4f62083905ace942c82;p=cvc5.git Towards standard usage of evaluator (#7189) This makes the evaluator accessible via EnvObj through the Rewriter. It furthermore removes Rewriter::rewrite from inside the evaluator itself. Construction of Evaluator utilities is now discouraged. The include dependencies were cleaned slightly in this PR, leading to more precise includes throughout. This is work towards having a configurable cardinality for strings, as well as eliminating SmtEngineScope. --- diff --git a/src/smt/env.cpp b/src/smt/env.cpp index f42a51dd0..c77b8cfba 100644 --- a/src/smt/env.cpp +++ b/src/smt/env.cpp @@ -24,6 +24,7 @@ #include "proof/conv_proof_generator.h" #include "smt/dump_manager.h" #include "smt/smt_engine_stats.h" +#include "theory/evaluator.h" #include "theory/rewriter.h" #include "theory/trust_substitutions.h" #include "util/resource_manager.h" @@ -39,6 +40,8 @@ Env::Env(NodeManager* nm, const Options* opts) d_nodeManager(nm), d_proofNodeManager(nullptr), d_rewriter(new theory::Rewriter()), + d_evalRew(new theory::Evaluator(d_rewriter.get())), + d_eval(new theory::Evaluator(nullptr)), d_topLevelSubs(new theory::TrustSubstitutionMap(d_userContext.get())), d_dumpManager(new DumpManager(d_userContext.get())), d_logic(), @@ -132,4 +135,55 @@ const Printer& Env::getPrinter() std::ostream& Env::getDumpOut() { return *d_options.base.out; } +Node Env::evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + bool useRewriter) const +{ + std::unordered_map visited; + return evaluate(n, args, vals, visited, useRewriter); +} + +Node Env::evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + const std::unordered_map& visited, + bool useRewriter) const +{ + if (useRewriter) + { + return d_evalRew->eval(n, args, vals, visited); + } + return d_eval->eval(n, args, vals, visited); +} + +Node Env::rewriteViaMethod(TNode n, MethodId idr) +{ + if (idr == MethodId::RW_REWRITE) + { + return d_rewriter->rewrite(n); + } + if (idr == MethodId::RW_EXT_REWRITE) + { + return d_rewriter->extendedRewrite(n); + } + if (idr == MethodId::RW_REWRITE_EQ_EXT) + { + return d_rewriter->rewriteEqualityExt(n); + } + if (idr == MethodId::RW_EVALUATE) + { + return evaluate(n, {}, {}, false); + } + if (idr == MethodId::RW_IDENTITY) + { + // does nothing + return n; + } + // unknown rewriter + Unhandled() << "Env::rewriteViaMethod: no rewriter for " << idr + << std::endl; + return n; +} + } // namespace cvc5 diff --git a/src/smt/env.h b/src/smt/env.h index d95e70226..2f2fe19ce 100644 --- a/src/smt/env.h +++ b/src/smt/env.h @@ -22,6 +22,7 @@ #include #include "options/options.h" +#include "proof/method_id.h" #include "theory/logic_info.h" #include "util/statistics_registry.h" @@ -44,6 +45,7 @@ class PfManager; } namespace theory { +class Evaluator; class Rewriter; class TrustSubstitutionMap; } @@ -137,6 +139,39 @@ class Env */ std::ostream& getDumpOut(); + /* Rewrite helpers--------------------------------------------------------- */ + /** + * Evaluate node n under the substitution args -> vals. For details, see + * theory/evaluator.h. + * + * @param n The node to evaluate + * @param args The domain of the substitution + * @param vals The range of the substitution + * @param useRewriter if true, we use this rewriter to rewrite subterms of + * n that cannot be evaluated to a constant. + * @return the rewritten, evaluated form of n under the given substitution. + */ + Node evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + bool useRewriter) const; + /** Same as above, with a visited cache. */ + Node evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + const std::unordered_map& visited, + bool useRewriter = true) const; + /** + * Apply rewrite on n via the rewrite method identifier idr (see method_id.h). + * This encapsulates the exact behavior of a REWRITE step in a proof. + * + * @param n The node to rewrite, + * @param idr The method identifier of the rewriter, by default RW_REWRITE + * specifying a call to rewrite. + * @return The rewritten form of n. + */ + Node rewriteViaMethod(TNode n, MethodId idr = MethodId::RW_REWRITE); + private: /* Private initialization ------------------------------------------------- */ @@ -173,6 +208,10 @@ class Env * specific to an SmtEngine/TheoryEngine instance. */ std::unique_ptr d_rewriter; + /** Evaluator that also invokes the rewriter */ + std::unique_ptr d_evalRew; + /** Evaluator that does not invoke the rewriter */ + std::unique_ptr d_eval; /** The top level substitutions */ std::unique_ptr d_topLevelSubs; /** The dump manager */ diff --git a/src/smt/env_obj.cpp b/src/smt/env_obj.cpp index fcbcc92d2..b9aebbe83 100644 --- a/src/smt/env_obj.cpp +++ b/src/smt/env_obj.cpp @@ -33,6 +33,21 @@ Node EnvObj::extendedRewrite(TNode node, bool aggr) const { return d_env.getRewriter()->extendedRewrite(node, aggr); } +Node EnvObj::evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + bool useRewriter) const +{ + return d_env.evaluate(n, args, vals, useRewriter); +} +Node EnvObj::evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + const std::unordered_map& visited, + bool useRewriter) const +{ + return d_env.evaluate(n, args, vals, visited, useRewriter); +} const LogicInfo& EnvObj::logicInfo() const { return d_env.getLogicInfo(); } diff --git a/src/smt/env_obj.h b/src/smt/env_obj.h index ef9a82b17..75b97fda9 100644 --- a/src/smt/env_obj.h +++ b/src/smt/env_obj.h @@ -55,6 +55,20 @@ class EnvObj * This is a wrapper around theory::Rewriter::extendedRewrite via Env. */ Node extendedRewrite(TNode node, bool aggr = true) const; + /** + * Evaluate node n under the substitution args -> vals. + * This is a wrapper about theory::Rewriter::evaluate via Env. + */ + Node evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + bool useRewriter = true) const; + /** Same as above, with a visited cache. */ + Node evaluate(TNode n, + const std::vector& args, + const std::vector& vals, + const std::unordered_map& visited, + bool useRewriter = true) const; /** Get the current logic information. */ const LogicInfo& logicInfo() const; diff --git a/src/smt/proof_post_processor.cpp b/src/smt/proof_post_processor.cpp index 3866c1e0e..f5db349e1 100644 --- a/src/smt/proof_post_processor.cpp +++ b/src/smt/proof_post_processor.cpp @@ -478,8 +478,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, rargs.push_back(args[3]); } } - Rewriter* rr = d_env.getRewriter(); - Node tr = rr->rewriteViaMethod(ts, idr); + Node tr = d_env.rewriteViaMethod(ts, idr); Trace("smt-proof-pp-debug") << "...eq intro rewrite equality is " << ts << " == " << tr << ", from " << idr << std::endl; @@ -954,7 +953,7 @@ Node ProofPostprocessCallback::expandMacros(PfRule id, getMethodId(args[1], idr); } Rewriter* rr = d_env.getRewriter(); - Node ret = rr->rewriteViaMethod(args[0], idr); + Node ret = d_env.rewriteViaMethod(args[0], idr); Node eq = args[0].eqNode(ret); if (idr == MethodId::RW_REWRITE || idr == MethodId::RW_REWRITE_EQ_EXT) { diff --git a/src/theory/builtin/proof_checker.cpp b/src/theory/builtin/proof_checker.cpp index d71b3635b..1309a05f9 100644 --- a/src/theory/builtin/proof_checker.cpp +++ b/src/theory/builtin/proof_checker.cpp @@ -18,10 +18,10 @@ #include "expr/skolem_manager.h" #include "smt/env.h" #include "smt/term_formula_removal.h" -#include "theory/evaluator.h" #include "theory/rewriter.h" #include "theory/substitutions.h" #include "theory/theory.h" +#include "util/rational.h" using namespace cvc5::kind; @@ -67,7 +67,7 @@ Node BuiltinProofRuleChecker::applySubstitutionRewrite( MethodId idr) { Node nks = applySubstitution(n, exp, ids, ida); - return d_env.getRewriter()->rewriteViaMethod(nks, idr); + return d_env.rewriteViaMethod(nks, idr); } bool BuiltinProofRuleChecker::getSubstitutionForLit(Node exp, @@ -249,7 +249,7 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, { return Node::null(); } - Node res = d_env.getRewriter()->rewriteViaMethod(args[0], idr); + Node res = d_env.rewriteViaMethod(args[0], idr); if (res.isNull()) { return Node::null(); @@ -260,7 +260,7 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, { Assert(children.empty()); Assert(args.size() == 1); - Node res = d_env.getRewriter()->rewriteViaMethod(args[0], MethodId::RW_EVALUATE); + Node res = d_env.rewriteViaMethod(args[0], MethodId::RW_EVALUATE); if (res.isNull()) { return Node::null(); @@ -302,7 +302,7 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, << SkolemManager::getOriginalForm(res) << std::endl; // **** NOTE: can rewrite the witness form here. This enables certain lemmas // to be provable, e.g. (= k t) where k is a purification Skolem for t. - res = Rewriter::rewrite(SkolemManager::getOriginalForm(res)); + res = d_env.getRewriter()->rewrite(SkolemManager::getOriginalForm(res)); if (!res.isConst() || !res.getConst()) { Trace("builtin-pfcheck") @@ -349,8 +349,8 @@ Node BuiltinProofRuleChecker::checkInternal(PfRule id, if (res1 != res2) { // can rewrite the witness forms - res1 = Rewriter::rewrite(SkolemManager::getOriginalForm(res1)); - res2 = Rewriter::rewrite(SkolemManager::getOriginalForm(res2)); + res1 = d_env.getRewriter()->rewrite(SkolemManager::getOriginalForm(res1)); + res2 = d_env.getRewriter()->rewrite(SkolemManager::getOriginalForm(res2)); if (res1.isNull() || res1 != res2) { Trace("builtin-pfcheck") << "Failed to match results" << std::endl; diff --git a/src/theory/datatypes/sygus_datatype_utils.cpp b/src/theory/datatypes/sygus_datatype_utils.cpp index f1f7b45a4..12c255f57 100644 --- a/src/theory/datatypes/sygus_datatype_utils.cpp +++ b/src/theory/datatypes/sygus_datatype_utils.cpp @@ -391,7 +391,7 @@ Node sygusToBuiltin(Node n, bool isExternal) Node sygusToBuiltinEval(Node n, const std::vector& args) { NodeManager* nm = NodeManager::currentNM(); - Evaluator eval; + Evaluator eval(nullptr); // constant arguments? bool constArgs = true; for (const Node& a : args) diff --git a/src/theory/datatypes/sygus_extension.cpp b/src/theory/datatypes/sygus_extension.cpp index 2411013b2..d666cdac5 100644 --- a/src/theory/datatypes/sygus_extension.cpp +++ b/src/theory/datatypes/sygus_extension.cpp @@ -35,6 +35,7 @@ #include "theory/rewriter.h" #include "theory/theory_model.h" #include "theory/theory_state.h" +#include "util/rational.h" using namespace cvc5; using namespace cvc5::kind; @@ -1101,16 +1102,20 @@ Node SygusExtension::registerSearchValue(Node a, if (bv != bvr) { // add to the sampler database object - std::map::iterator its = - d_sampler[a].find(tn); - if (its == d_sampler[a].end()) + std::map>& smap = + d_sampler[a]; + std::map>::iterator its = + smap.find(tn); + if (its == smap.end()) { - d_sampler[a][tn].initializeSygus( + smap[tn].reset(new quantifiers::SygusSampler(d_env)); + smap[tn]->initializeSygus( d_tds, nv, options::sygusSamples(), false); its = d_sampler[a].find(tn); } // check equivalent - its->second.checkEquivalent(bv, bvr, *options().base.out); + its->second->checkEquivalent(bv, bvr, *options().base.out); } } diff --git a/src/theory/datatypes/sygus_extension.h b/src/theory/datatypes/sygus_extension.h index c7a9e7893..2fd0110b4 100644 --- a/src/theory/datatypes/sygus_extension.h +++ b/src/theory/datatypes/sygus_extension.h @@ -289,7 +289,8 @@ private: * This is used for the sygusRewVerify() option to verify the correctness of * the rewriter. */ - std::map> d_sampler; + std::map>> + d_sampler; /** Assert tester internal * * This function is called when the tester with index tindex is asserted for diff --git a/src/theory/datatypes/sygus_simple_sym.cpp b/src/theory/datatypes/sygus_simple_sym.cpp index 36dfc710b..63e60a478 100644 --- a/src/theory/datatypes/sygus_simple_sym.cpp +++ b/src/theory/datatypes/sygus_simple_sym.cpp @@ -17,6 +17,7 @@ #include "expr/dtype_cons.h" #include "theory/quantifiers/term_util.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 4a8976876..427e0251f 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -22,6 +22,7 @@ #include "expr/dtype_cons.h" #include "expr/kind.h" #include "expr/skolem_manager.h" +#include "expr/uninterpreted_constant.h" #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "options/smt_options.h" @@ -38,6 +39,7 @@ #include "theory/theory_state.h" #include "theory/type_enumerator.h" #include "theory/valuation.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp index 2a274426f..75c878065 100644 --- a/src/theory/evaluator.cpp +++ b/src/theory/evaluator.cpp @@ -127,19 +127,22 @@ Node EvalResult::toNode() const } } +Evaluator::Evaluator(Rewriter* rr) + : d_rr(rr), d_alphaCard(strings::utils::getAlphabetCardinality()) +{ +} + Node Evaluator::eval(TNode n, const std::vector& args, - const std::vector& vals, - bool useRewriter) const + const std::vector& vals) const { std::unordered_map visited; - return eval(n, args, vals, visited, useRewriter); + return eval(n, args, vals, visited); } Node Evaluator::eval(TNode n, const std::vector& args, const std::vector& vals, - const std::unordered_map& visited, - bool useRewriter) const + const std::unordered_map& visited) const { Trace("evaluator") << "Evaluating " << n << " under substitution " << args << " " << vals << " with visited size = " << visited.size() @@ -150,36 +153,36 @@ Node Evaluator::eval(TNode n, for (const std::pair& p : visited) { Trace("evaluator") << "Add " << p.first << " == " << p.second << std::endl; - results[p.first] = evalInternal(p.second, args, vals, evalAsNode, results, useRewriter); + results[p.first] = evalInternal(p.second, args, vals, evalAsNode, results); if (results[p.first].d_tag == EvalResult::INVALID) { // could not evaluate, use the evalAsNode map std::unordered_map::iterator itn = evalAsNode.find(p.second); Assert(itn != evalAsNode.end()); Node val = itn->second; - if (useRewriter) + if (d_rr != nullptr) { - val = Rewriter::rewrite(val); + val = d_rr->rewrite(val); } evalAsNode[p.first] = val; } } Trace("evaluator") << "Run eval internal..." << std::endl; - Node ret = evalInternal(n, args, vals, evalAsNode, results, useRewriter).toNode(); + Node ret = evalInternal(n, args, vals, evalAsNode, results).toNode(); // if we failed to evaluate - if (ret.isNull() && useRewriter) + if (ret.isNull() && d_rr != nullptr) { // should be stored in the evaluation-as-node map std::unordered_map::iterator itn = evalAsNode.find(n); Assert(itn != evalAsNode.end()); - ret = Rewriter::rewrite(itn->second); + ret = d_rr->rewrite(itn->second); } // should be the same as substitution + rewriting, or possibly null if - // useRewriter is false - Assert((ret.isNull() && !useRewriter) + // d_rr is nullptr + Assert((ret.isNull() && d_rr == nullptr) || ret - == Rewriter::rewrite(n.substitute( - args.begin(), args.end(), vals.begin(), vals.end()))); + == d_rr->rewrite(n.substitute( + args.begin(), args.end(), vals.begin(), vals.end()))); return ret; } @@ -188,8 +191,7 @@ EvalResult Evaluator::evalInternal( const std::vector& args, const std::vector& vals, std::unordered_map& evalAsNode, - std::unordered_map& results, - bool useRewriter) const + std::unordered_map& results) const { std::vector queue; queue.emplace_back(n); @@ -290,11 +292,11 @@ EvalResult Evaluator::evalInternal( // successfully evaluated, and the children that did not. Trace("evaluator") << "Evaluator: collect arguments" << std::endl; currNodeVal = reconstruct(currNodeVal, results, evalAsNode); - if (useRewriter) + if (d_rr != nullptr) { // Rewrite the result now, if we use the rewriter. We will see below // if we are able to turn it into a valid EvalResult. - currNodeVal = Rewriter::rewrite(currNodeVal); + currNodeVal = d_rr->rewrite(currNodeVal); } } needsReconstruct = false; @@ -360,12 +362,8 @@ EvalResult Evaluator::evalInternal( // evalAsNodeC but favor avoiding this copy for performance reasons. std::unordered_map evalAsNodeC; std::unordered_map resultsC; - results[currNode] = evalInternal(op[1], - lambdaArgs, - lambdaVals, - evalAsNodeC, - resultsC, - useRewriter); + results[currNode] = evalInternal( + op[1], lambdaArgs, lambdaVals, evalAsNodeC, resultsC); Trace("evaluator") << "Evaluated via arguments to " << results[currNode].d_tag << std::endl; if (results[currNode].d_tag == EvalResult::INVALID) @@ -676,7 +674,7 @@ EvalResult Evaluator::evalInternal( case kind::STRING_FROM_CODE: { Integer i = results[currNode[0]].d_rat.getNumerator(); - if (i >= 0 && i < strings::utils::getAlphabetCardinality()) + if (i >= 0 && i < d_alphaCard) { std::vector svec = {i.toUnsignedInt()}; results[currNode] = EvalResult(String(svec)); diff --git a/src/theory/evaluator.h b/src/theory/evaluator.h index 42cc34749..2e96952b8 100644 --- a/src/theory/evaluator.h +++ b/src/theory/evaluator.h @@ -80,6 +80,8 @@ struct EvalResult Node toNode() const; }; +class Rewriter; + /** * The class that performs the actual evaluation of a term under a * substitution. Right now, the class does not cache anything between different @@ -88,6 +90,7 @@ struct EvalResult class Evaluator { public: + Evaluator(Rewriter* rr); /** * Evaluates node `n` under the substitution described by the variable names * `args` and the corresponding values `vals`. This method uses evaluation @@ -103,22 +106,20 @@ class Evaluator * The result of this call is either equivalent to: * (1) Rewriter::rewrite(n.substitute(args,vars)) * (2) Node::null(). - * If useRewriter is true, then we are always in the first case. If - * useRewriter is false, then we may be in case (2) if computing the + * If d_rr is non-null, then we are always in the first case. If + * useRewriter is null, then we may be in case (2) if computing the * rewritten, substituted form of n could not be determined by evaluation. */ Node eval(TNode n, const std::vector& args, - const std::vector& vals, - bool useRewriter = true) const; + const std::vector& vals) const; /** * Same as above, but with a precomputed visited map. */ Node eval(TNode n, const std::vector& args, const std::vector& vals, - const std::unordered_map& visited, - bool useRewriter = true) const; + const std::unordered_map& visited) const; private: /** @@ -141,8 +142,7 @@ class Evaluator const std::vector& args, const std::vector& vals, std::unordered_map& evalAsNode, - std::unordered_map& results, - bool useRewriter) const; + std::unordered_map& results) const; /** reconstruct * * This function reconstructs the result of evaluating n using a combination @@ -155,6 +155,10 @@ class Evaluator Node reconstruct(TNode n, std::unordered_map& eresults, std::unordered_map& evalAsNode) const; + /** The (optional) rewriter to be used */ + Rewriter* d_rr; + /** The cardinality of the alphabet of strings */ + uint32_t d_alphaCard; }; } // namespace theory diff --git a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp index 88da629a0..4b06589b3 100644 --- a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp @@ -15,15 +15,14 @@ #include "theory/quantifiers/cegqi/ceg_instantiator.h" -#include "theory/quantifiers/cegqi/ceg_arith_instantiator.h" -#include "theory/quantifiers/cegqi/ceg_bv_instantiator.h" -#include "theory/quantifiers/cegqi/ceg_dt_instantiator.h" - #include "expr/dtype.h" #include "expr/dtype_cons.h" #include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" +#include "theory/quantifiers/cegqi/ceg_arith_instantiator.h" +#include "theory/quantifiers/cegqi/ceg_bv_instantiator.h" +#include "theory/quantifiers/cegqi/ceg_dt_instantiator.h" #include "theory/quantifiers/cegqi/inst_strategy_cegqi.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/quantifiers_attributes.h" @@ -31,6 +30,7 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp index 1ccfd8ede..04fa1d2fe 100644 --- a/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp +++ b/src/theory/quantifiers/cegqi/inst_strategy_cegqi.cpp @@ -25,6 +25,7 @@ #include "theory/quantifiers/term_registry.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/quantifiers/ematching/inst_match_generator.cpp b/src/theory/quantifiers/ematching/inst_match_generator.cpp index 5380fc7d5..d8e3b7950 100644 --- a/src/theory/quantifiers/ematching/inst_match_generator.cpp +++ b/src/theory/quantifiers/ematching/inst_match_generator.cpp @@ -29,6 +29,7 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_registry.h" #include "theory/quantifiers/term_util.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/expr_miner_manager.cpp b/src/theory/quantifiers/expr_miner_manager.cpp index 8af456ea8..e53fd9424 100644 --- a/src/theory/quantifiers/expr_miner_manager.cpp +++ b/src/theory/quantifiers/expr_miner_manager.cpp @@ -16,6 +16,7 @@ #include "theory/quantifiers/expr_miner_manager.h" #include "options/quantifiers_options.h" +#include "smt/env.h" namespace cvc5 { namespace theory { @@ -33,7 +34,8 @@ ExpressionMinerManager::ExpressionMinerManager(Env& env) options::sygusRewSynthAccel(), false), d_qg(env), - d_sols(env) + d_sols(env), + d_sampler(env) { } diff --git a/src/theory/quantifiers/fmf/bounded_integers.cpp b/src/theory/quantifiers/fmf/bounded_integers.cpp index 4a3e13dd0..44352c6fe 100644 --- a/src/theory/quantifiers/fmf/bounded_integers.cpp +++ b/src/theory/quantifiers/fmf/bounded_integers.cpp @@ -29,6 +29,7 @@ #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5; using namespace std; diff --git a/src/theory/quantifiers/fun_def_evaluator.cpp b/src/theory/quantifiers/fun_def_evaluator.cpp index 36f557db8..78a09641b 100644 --- a/src/theory/quantifiers/fun_def_evaluator.cpp +++ b/src/theory/quantifiers/fun_def_evaluator.cpp @@ -26,7 +26,7 @@ namespace cvc5 { namespace theory { namespace quantifiers { -FunDefEvaluator::FunDefEvaluator() {} +FunDefEvaluator::FunDefEvaluator(Env& env) : EnvObj(env) {} void FunDefEvaluator::assertDefinition(Node q) { @@ -51,11 +51,11 @@ void FunDefEvaluator::assertDefinition(Node q) << fdi.d_args << " / " << fdi.d_body << std::endl; } -Node FunDefEvaluator::evaluate(Node n) const +Node FunDefEvaluator::evaluateDefinitions(Node n) const { // should do standard rewrite before this call Assert(Rewriter::rewrite(n) == n); - Trace("fd-eval") << "FunDefEvaluator: evaluate " << n << std::endl; + Trace("fd-eval") << "FunDefEvaluator: evaluateDefinitions " << n << std::endl; NodeManager* nm = NodeManager::currentNM(); std::unordered_map funDefCount; std::unordered_map::iterator itCount; @@ -185,7 +185,7 @@ Node FunDefEvaluator::evaluate(Node n) const if (!args.empty()) { // invoke it on arguments using the evaluator - sbody = d_eval.eval(sbody, args, children); + sbody = evaluate(sbody, args, children); if (Trace.isOn("fd-eval-debug2")) { Trace("fd-eval-debug2") diff --git a/src/theory/quantifiers/fun_def_evaluator.h b/src/theory/quantifiers/fun_def_evaluator.h index a3b79bec7..c8e811968 100644 --- a/src/theory/quantifiers/fun_def_evaluator.h +++ b/src/theory/quantifiers/fun_def_evaluator.h @@ -20,8 +20,9 @@ #include #include + #include "expr/node.h" -#include "theory/evaluator.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -30,10 +31,10 @@ namespace quantifiers { /** * Techniques for evaluating recursively defined functions. */ -class FunDefEvaluator +class FunDefEvaluator : protected EnvObj { public: - FunDefEvaluator(); + FunDefEvaluator(Env& env); ~FunDefEvaluator() {} /** * Assert definition of a (recursive) function definition given by quantified @@ -45,7 +46,7 @@ class FunDefEvaluator * class. If n cannot be simplified to a constant, then this method returns * null. */ - Node evaluate(Node n) const; + Node evaluateDefinitions(Node n) const; /** * Has a call to assertDefinition been made? If this returns false, then * the evaluate method is the same as calling the rewriter, and returning @@ -74,8 +75,6 @@ class FunDefEvaluator std::map d_funDefMap; /** list of all definitions */ std::vector d_funDefs; - /** evaluator utility */ - Evaluator d_eval; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/quant_bound_inference.cpp b/src/theory/quantifiers/quant_bound_inference.cpp index a78f66c51..83e48bf9c 100644 --- a/src/theory/quantifiers/quant_bound_inference.cpp +++ b/src/theory/quantifiers/quant_bound_inference.cpp @@ -17,6 +17,7 @@ #include "theory/quantifiers/fmf/bounded_integers.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/quant_conflict_find.cpp b/src/theory/quantifiers/quant_conflict_find.cpp index 1de60422f..b26b65018 100644 --- a/src/theory/quantifiers/quant_conflict_find.cpp +++ b/src/theory/quantifiers/quant_conflict_find.cpp @@ -28,6 +28,7 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5::kind; using namespace std; diff --git a/src/theory/quantifiers/quantifiers_attributes.cpp b/src/theory/quantifiers/quantifiers_attributes.cpp index deed1c761..1a0e03bfc 100644 --- a/src/theory/quantifiers/quantifiers_attributes.cpp +++ b/src/theory/quantifiers/quantifiers_attributes.cpp @@ -19,6 +19,8 @@ #include "theory/arith/arith_msum.h" #include "theory/quantifiers/sygus/synth_engine.h" #include "theory/quantifiers/term_util.h" +#include "util/rational.h" +#include "util/string.h" using namespace std; using namespace cvc5::kind; diff --git a/src/theory/quantifiers/skolemize.cpp b/src/theory/quantifiers/skolemize.cpp index bb0fa3899..a34547f45 100644 --- a/src/theory/quantifiers/skolemize.cpp +++ b/src/theory/quantifiers/skolemize.cpp @@ -28,6 +28,7 @@ #include "theory/quantifiers/term_util.h" #include "theory/rewriter.h" #include "theory/sort_inference.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/solution_filter.cpp b/src/theory/quantifiers/solution_filter.cpp index 8844950c7..19bfcab66 100644 --- a/src/theory/quantifiers/solution_filter.cpp +++ b/src/theory/quantifiers/solution_filter.cpp @@ -19,6 +19,7 @@ #include "options/base_options.h" #include "options/quantifiers_options.h" +#include "smt/env.h" #include "util/random.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/cegis.cpp b/src/theory/quantifiers/sygus/cegis.cpp index f5774c761..d9e4b61af 100644 --- a/src/theory/quantifiers/sygus/cegis.cpp +++ b/src/theory/quantifiers/sygus/cegis.cpp @@ -39,6 +39,7 @@ Cegis::Cegis(Env& env, SynthConjecture* p) : SygusModule(env, qs, qim, tds, p), d_eval_unfold(tds->getEvalUnfold()), + d_cegis_sampler(env), d_usingSymCons(false) { } @@ -594,7 +595,6 @@ bool Cegis::checkRefinementEvalLemmas(const std::vector& vs, } } - Evaluator* eval = d_tds->getEvaluator(); for (unsigned r = 0; r < 2; r++) { std::unordered_set& rlemmas = @@ -603,7 +603,7 @@ bool Cegis::checkRefinementEvalLemmas(const std::vector& vs, { // We may have computed the evaluation of some function applications // via example-based symmetry breaking, stored in evalVisited. - Node lemcsu = eval->eval(lem, vs, ms, evalVisited); + Node lemcsu = evaluate(lem, vs, ms, evalVisited); if (lemcsu.isConst() && !lemcsu.getConst()) { return true; diff --git a/src/theory/quantifiers/sygus/cegis.h b/src/theory/quantifiers/sygus/cegis.h index d72805950..8e0fffdd1 100644 --- a/src/theory/quantifiers/sygus/cegis.h +++ b/src/theory/quantifiers/sygus/cegis.h @@ -28,6 +28,8 @@ namespace cvc5 { namespace theory { namespace quantifiers { +class SygusEvalUnfold; + /** Cegis * * The default sygus module for synthesis, counterexample-guided inductive diff --git a/src/theory/quantifiers/sygus/cegis_core_connective.cpp b/src/theory/quantifiers/sygus/cegis_core_connective.cpp index a42323227..b9066b079 100644 --- a/src/theory/quantifiers/sygus/cegis_core_connective.cpp +++ b/src/theory/quantifiers/sygus/cegis_core_connective.cpp @@ -311,7 +311,7 @@ bool CegisCoreConnective::constructSolution( Assert(candidates.size() == 1 && candidates[0] == d_candidate); TNode cval = candidate_values[0]; Node ets = d_eterm.substitute(d_candidate, cval); - Node etsr = Rewriter::rewrite(ets); + Node etsr = rewrite(ets); Trace("sygus-ccore-debug") << "...predicate is: " << etsr << std::endl; NodeManager* nm = NodeManager::currentNM(); for (unsigned d = 0; d < 2; d++) @@ -476,7 +476,7 @@ Node CegisCoreConnective::Component::getRefinementPt( visited.insert(id); Trace("sygus-ccore-ref") << "...eval " << std::endl; // check if it is true - Node en = p->evaluate(n, id, ctx); + Node en = p->evaluatePt(n, id, ctx); if (en.isConst() && en.getConst()) { ss = ctx; @@ -553,7 +553,7 @@ bool CegisCoreConnective::Component::addToAsserts(CegisCoreConnective* p, for (unsigned i = currIndex, psize = passerts.size(); i < psize; i++) { Node cn = passerts[i]; - Node cne = p->evaluate(cn, mvId, mvs); + Node cne = p->evaluatePt(cn, mvId, mvs); if (cne.isConst() && !cne.getConst()) { n = cn; @@ -635,9 +635,9 @@ Result CegisCoreConnective::checkSat(Node n, std::vector& mvs) const return r; } -Node CegisCoreConnective::evaluate(Node n, - Node id, - const std::vector& mvs) +Node CegisCoreConnective::evaluatePt(Node n, + Node id, + const std::vector& mvs) { Kind nk = n.getKind(); if (nk == AND || nk == OR) @@ -647,7 +647,7 @@ Node CegisCoreConnective::evaluate(Node n, // split AND/OR for (const Node& nc : n) { - Node enc = evaluate(nc, id, mvs); + Node enc = evaluatePt(nc, id, mvs); Assert(enc.isConst()); if (enc.getConst() == expRes) { @@ -666,12 +666,8 @@ Node CegisCoreConnective::evaluate(Node n, } } // use evaluator - Node cn = d_eval.eval(n, d_vars, mvs); - if (cn.isNull()) - { - cn = n.substitute(d_vars.begin(), d_vars.end(), mvs.begin(), mvs.end()); - cn = Rewriter::rewrite(cn); - } + Node cn = evaluate(n, d_vars, mvs); + Assert(!cn.isNull()); if (!id.isNull()) { ec[id] = cn; @@ -844,7 +840,7 @@ Node CegisCoreConnective::constructSolutionFromPool(Component& ccheck, mvs.clear(); getModel(*checkSol, mvs); // should evaluate to true - Node ean = evaluate(an, Node::null(), mvs); + Node ean = evaluatePt(an, Node::null(), mvs); Assert(ean.isConst() && ean.getConst()); Trace("sygus-ccore") << "--- Add refinement point " << mvs << std::endl; // In terms of Variant #2, this is the line: diff --git a/src/theory/quantifiers/sygus/cegis_core_connective.h b/src/theory/quantifiers/sygus/cegis_core_connective.h index 80ba6f26e..ebcd871aa 100644 --- a/src/theory/quantifiers/sygus/cegis_core_connective.h +++ b/src/theory/quantifiers/sygus/cegis_core_connective.h @@ -23,7 +23,6 @@ #include "expr/node.h" #include "expr/node_trie.h" #include "smt/env_obj.h" -#include "theory/evaluator.h" #include "theory/quantifiers/sygus/cegis.h" #include "util/result.h" @@ -365,11 +364,9 @@ class CegisCoreConnective : public Cegis * If id is non-null, then id is a unique identifier for mvs, and we cache * the result of n for this point. */ - Node evaluate(Node n, Node id, const std::vector& mvs); + Node evaluatePt(Node n, Node id, const std::vector& mvs); /** A cache of the above function */ std::unordered_map> d_eval_cache; - /** The evaluator utility used for the above function */ - Evaluator d_eval; //-----------------------------------end for evaluation /** Construct solution from pool diff --git a/src/theory/quantifiers/sygus/cegis_unif.cpp b/src/theory/quantifiers/sygus/cegis_unif.cpp index 6b260bb81..42306383b 100644 --- a/src/theory/quantifiers/sygus/cegis_unif.cpp +++ b/src/theory/quantifiers/sygus/cegis_unif.cpp @@ -23,6 +23,7 @@ #include "theory/quantifiers/sygus/sygus_unif_rl.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/enum_value_manager.cpp b/src/theory/quantifiers/sygus/enum_value_manager.cpp index e7b3bbaa9..937537ce9 100644 --- a/src/theory/quantifiers/sygus/enum_value_manager.cpp +++ b/src/theory/quantifiers/sygus/enum_value_manager.cpp @@ -106,7 +106,7 @@ Node EnumValueManager::getEnumeratedValue(bool& activeIncomplete) std::ostream* out = nullptr; if (options::sygusRewVerify()) { - d_samplerRrV.reset(new SygusSampler); + d_samplerRrV.reset(new SygusSampler(d_env)); d_samplerRrV->initializeSygus( d_tds, e, options::sygusSamples(), false); // use the default output for the output of sygusRewVerify diff --git a/src/theory/quantifiers/sygus/rcons_type_info.cpp b/src/theory/quantifiers/sygus/rcons_type_info.cpp index 78f8d303c..72a8e6a56 100644 --- a/src/theory/quantifiers/sygus/rcons_type_info.cpp +++ b/src/theory/quantifiers/sygus/rcons_type_info.cpp @@ -16,8 +16,10 @@ #include "theory/quantifiers/sygus/rcons_type_info.h" #include "expr/skolem_manager.h" +#include "smt/env.h" #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/quantifiers/sygus/rcons_obligation.h" +#include "theory/quantifiers/sygus_sampler.h" namespace cvc5 { namespace theory { @@ -37,8 +39,9 @@ void RConsTypeInfo::initialize(Env& env, d_crd.reset(new CandidateRewriteDatabase(env, true, false, true, false)); // since initial samples are not always useful for equivalence checks, set // their number to 0 - d_sygusSampler.initialize(stn, builtinVars, 0); - d_crd->initialize(builtinVars, &d_sygusSampler); + d_sygusSampler.reset(new SygusSampler(env)); + d_sygusSampler->initialize(stn, builtinVars, 0); + d_crd->initialize(builtinVars, d_sygusSampler.get()); } Node RConsTypeInfo::nextEnum() diff --git a/src/theory/quantifiers/sygus/rcons_type_info.h b/src/theory/quantifiers/sygus/rcons_type_info.h index 294454fe2..5f68993ad 100644 --- a/src/theory/quantifiers/sygus/rcons_type_info.h +++ b/src/theory/quantifiers/sygus/rcons_type_info.h @@ -20,7 +20,6 @@ #include "theory/quantifiers/candidate_rewrite_database.h" #include "theory/quantifiers/sygus/sygus_enumerator.h" -#include "theory/quantifiers/sygus_sampler.h" namespace cvc5 { namespace theory { @@ -28,6 +27,7 @@ namespace quantifiers { class RConsObligation; class CandidateRewriteDatabase; +class SygusSampler; /** * A utility class for Sygus Reconstruct datatype types (grammar non-terminals). @@ -93,7 +93,7 @@ class RConsTypeInfo /** Candidate rewrite database for this class' sygus datatype type */ std::unique_ptr d_crd; /** Sygus sampler needed for initializing the candidate rewrite database */ - SygusSampler d_sygusSampler; + std::unique_ptr d_sygusSampler; /** A map from a builtin term to its obligation. * * Each sygus datatype type has its own version of this map because it is diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_enumerator.cpp index fca09c43d..959532d98 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator.cpp @@ -25,6 +25,7 @@ #include "theory/quantifiers/sygus/synth_engine.h" #include "theory/quantifiers/sygus/type_node_id_trie.h" #include "theory/rewriter.h" +#include "util/rational.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 7072b77e1..43c958ff9 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -32,6 +32,7 @@ #include "theory/rewriter.h" #include "theory/strings/word.h" #include "util/floatingpoint.h" +#include "util/string.h" using namespace cvc5::kind; diff --git a/src/theory/quantifiers/sygus/sygus_unif_io.cpp b/src/theory/quantifiers/sygus/sygus_unif_io.cpp index 3fb80f917..e703569d9 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_io.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_io.cpp @@ -17,7 +17,6 @@ #include "options/quantifiers_options.h" #include "theory/datatypes/sygus_datatype_utils.h" -#include "theory/evaluator.h" #include "theory/quantifiers/sygus/example_infer.h" #include "theory/quantifiers/sygus/synth_conjecture.h" #include "theory/quantifiers/sygus/term_database_sygus.h" diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 035db433e..2e528b213 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -55,8 +55,7 @@ TermDbSygus::TermDbSygus(Env& env, QuantifiersState& qs) : EnvObj(env), d_qstate(qs), d_syexp(new SygusExplain(this)), - d_eval(new Evaluator), - d_funDefEval(new FunDefEvaluator), + d_funDefEval(new FunDefEvaluator(env)), d_eval_unfold(new SygusEvalUnfold(this)) { d_true = NodeManager::currentNM()->mkConst( true ); @@ -759,7 +758,7 @@ Node TermDbSygus::rewriteNode(Node n) const { // If recursive functions are enabled, then we use the recursive function // evaluation utility. - Node fres = d_funDefEval->evaluate(res); + Node fres = d_funDefEval->evaluateDefinitions(res); if (!fres.isNull()) { return fres; @@ -996,7 +995,7 @@ Node TermDbSygus::evaluateBuiltin(TypeNode tn, // This may fail if there is a subterm of bn under the // substitution that is not constant, or if an operator in bn is not // supported by the evaluator - res = d_eval->eval(bn, varlist, args); + res = evaluate(bn, varlist, args); } if (res.isNull()) { diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 7b05c70e4..59e0f4776 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -22,7 +22,6 @@ #include "expr/dtype.h" #include "smt/env_obj.h" -#include "theory/evaluator.h" #include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/fun_def_evaluator.h" #include "theory/quantifiers/sygus/sygus_eval_unfold.h" @@ -80,8 +79,6 @@ class TermDbSygus : protected EnvObj //------------------------------utilities /** get the explanation utility */ SygusExplain* getExplain() { return d_syexp.get(); } - /** get the evaluator */ - Evaluator* getEvaluator() { return d_eval.get(); } /** (recursive) function evaluator utility */ FunDefEvaluator* getFunDefEvaluator() { return d_funDefEval.get(); } /** evaluation unfolding utility */ @@ -309,8 +306,6 @@ class TermDbSygus : protected EnvObj //------------------------------utilities /** sygus explanation */ std::unique_ptr d_syexp; - /** evaluator */ - std::unique_ptr d_eval; /** (recursive) function evaluator utility */ std::unique_ptr d_funDefEval; /** evaluation function unfolding utility */ diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index 0cbc4df5b..08fab59eb 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -24,17 +24,20 @@ #include "options/quantifiers_options.h" #include "printer/printer.h" #include "theory/quantifiers/lazy_trie.h" +#include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/rewriter.h" #include "util/bitvector.h" #include "util/random.h" +#include "util/rational.h" #include "util/sampler.h" +#include "util/string.h" namespace cvc5 { namespace theory { namespace quantifiers { -SygusSampler::SygusSampler() - : d_tds(nullptr), d_use_sygus_type(false), d_is_valid(false) +SygusSampler::SygusSampler(Env& env) + : d_env(env), d_tds(nullptr), d_use_sygus_type(false), d_is_valid(false) { } @@ -471,21 +474,11 @@ Node SygusSampler::evaluate(Node n, unsigned index) { Assert(index < d_samples.size()); // do beta-reductions in n first - n = Rewriter::rewrite(n); + n = d_env.getRewriter()->rewrite(n); // use efficient rewrite for substitution + rewrite - Node ev = d_eval.eval(n, d_vars, d_samples[index]); + Node ev = d_env.evaluate(n, d_vars, d_samples[index], true); + Assert(!ev.isNull()); Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> "; - if (!ev.isNull()) - { - Trace("sygus-sample-ev") << ev << std::endl; - return ev; - } - Trace("sygus-sample-ev") << "null" << std::endl; - Trace("sygus-sample-ev") << "Rewrite -> "; - // substitution + rewrite - std::vector& pt = d_samples[index]; - ev = n.substitute(d_vars.begin(), d_vars.end(), pt.begin(), pt.end()); - ev = Rewriter::rewrite(ev); Trace("sygus-sample-ev") << ev << std::endl; return ev; } @@ -617,7 +610,7 @@ Node SygusSampler::getRandomValue(TypeNode tn) // negative ret = nm->mkNode(kind::UMINUS, ret); } - ret = Rewriter::rewrite(ret); + ret = d_env.getRewriter()->rewrite(ret); Assert(ret.isConst()); return ret; } @@ -715,7 +708,7 @@ Node SygusSampler::getSygusRandomValue(TypeNode tn, Trace("sygus-sample-grammar") << "mkGeneric" << std::endl; Node ret = d_tds->mkGeneric(dt, cindex, pre); Trace("sygus-sample-grammar") << "...returned " << ret << std::endl; - ret = Rewriter::rewrite(ret); + ret = d_env.getRewriter()->rewrite(ret); Trace("sygus-sample-grammar") << "...after rewrite " << ret << std::endl; // A rare case where we generate a non-constant value from constant // leaves is (/ n 0). diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h index 85606adc6..3695270e1 100644 --- a/src/theory/quantifiers/sygus_sampler.h +++ b/src/theory/quantifiers/sygus_sampler.h @@ -19,15 +19,18 @@ #define CVC5__THEORY__QUANTIFIERS__SYGUS_SAMPLER_H #include -#include "theory/evaluator.h" #include "theory/quantifiers/lazy_trie.h" -#include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_enumeration.h" namespace cvc5 { + +class Env; + namespace theory { namespace quantifiers { +class TermDbSygus; + /** SygusSampler * * This class can be used to test whether two expressions are equivalent @@ -65,7 +68,7 @@ namespace quantifiers { class SygusSampler : public LazyTrieEvaluator { public: - SygusSampler(); + SygusSampler(Env& env); ~SygusSampler() override {} /** initialize @@ -178,14 +181,14 @@ class SygusSampler : public LazyTrieEvaluator void checkEquivalent(Node bv, Node bvr, std::ostream& out); protected: + /** The environment we are using to evaluate terms and samples */ + Env& d_env; /** sygus term database of d_qe */ TermDbSygus* d_tds; /** term enumerator object (used for random sampling) */ TermEnumeration d_tenum; /** samples */ std::vector > d_samples; - /** evaluator class */ - Evaluator d_eval; /** data structure to check duplication of sample points */ class PtTrie { diff --git a/src/theory/rewriter.cpp b/src/theory/rewriter.cpp index 460813084..4e571a66b 100644 --- a/src/theory/rewriter.cpp +++ b/src/theory/rewriter.cpp @@ -479,35 +479,5 @@ void Rewriter::clearCaches() clearCachesInternal(); } -Node Rewriter::rewriteViaMethod(TNode n, MethodId idr) -{ - if (idr == MethodId::RW_REWRITE) - { - return rewrite(n); - } - if (idr == MethodId::RW_EXT_REWRITE) - { - return extendedRewrite(n); - } - if (idr == MethodId::RW_REWRITE_EQ_EXT) - { - return rewriteEqualityExt(n); - } - if (idr == MethodId::RW_EVALUATE) - { - Evaluator eval; - return eval.eval(n, {}, {}, false); - } - if (idr == MethodId::RW_IDENTITY) - { - // does nothing - return n; - } - // unknown rewriter - Unhandled() << "Rewriter::rewriteViaMethod: no rewriter for " << idr - << std::endl; - return n; -} - } // namespace theory } // namespace cvc5 diff --git a/src/theory/rewriter.h b/src/theory/rewriter.h index d87043a67..697253e03 100644 --- a/src/theory/rewriter.h +++ b/src/theory/rewriter.h @@ -18,22 +18,24 @@ #pragma once #include "expr/node.h" -#include "proof/method_id.h" #include "theory/theory_rewriter.h" namespace cvc5 { +class Env; class TConvProofGenerator; class ProofNodeManager; class TrustNode; namespace theory { +class Evaluator; + /** * The main rewriter class. */ class Rewriter { - + friend class cvc5::Env; // to initialize the evaluators of this class public: Rewriter(); @@ -62,6 +64,9 @@ class Rewriter { Node rewriteEqualityExt(TNode node); /** + * !!! Temporary until static access to rewriter is eliminated. This method + * should be moved to same place as evaluate (currently in Env). + * * Extended rewrite of the given node. This method is implemented by a * custom ExtendRewriter class that wraps this class to perform custom * rewrites (usually those that are not useful for solving, but e.g. useful @@ -103,17 +108,6 @@ class Rewriter { /** Get the theory rewriter for the given id */ TheoryRewriter* getTheoryRewriter(theory::TheoryId theoryId); - /** - * Apply rewrite on n via the rewrite method identifier idr (see method_id.h). - * This encapsulates the exact behavior of a REWRITE step in a proof. - * - * @param n The node to rewrite, - * @param idr The method identifier of the rewriter, by default RW_REWRITE - * specifying a call to rewrite. - * @return The rewritten form of n. - */ - Node rewriteViaMethod(TNode n, MethodId idr = MethodId::RW_REWRITE); - private: /** * Get the rewriter associated with the SmtEngine in scope. diff --git a/src/theory/rewriter_tables_template.h b/src/theory/rewriter_tables_template.h index 36d320fb7..e86d748fd 100644 --- a/src/theory/rewriter_tables_template.h +++ b/src/theory/rewriter_tables_template.h @@ -80,10 +80,7 @@ ${post_rewrite_set_cache} } } -Rewriter::Rewriter() : d_tpg(nullptr) -{ - -} +Rewriter::Rewriter() : d_tpg(nullptr) {} void Rewriter::clearCachesInternal() { diff --git a/test/unit/theory/evaluator_white.cpp b/test/unit/theory/evaluator_white.cpp index a1f56eaba..c2c6cf77e 100644 --- a/test/unit/theory/evaluator_white.cpp +++ b/test/unit/theory/evaluator_white.cpp @@ -59,10 +59,11 @@ TEST_F(TestTheoryWhiteEvaluator, simple) std::vector args = {w, x, y, z}; std::vector vals = {c1, zero, one, c1}; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); Node r = eval.eval(t, args, vals); ASSERT_EQ(r, - Rewriter::rewrite(t.substitute( + rr->rewrite(t.substitute( args.begin(), args.end(), vals.begin(), vals.end()))); } @@ -90,10 +91,11 @@ TEST_F(TestTheoryWhiteEvaluator, loop) std::vector args = {x}; std::vector vals = {c}; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); Node r = eval.eval(t, args, vals); ASSERT_EQ(r, - Rewriter::rewrite(t.substitute( + rr->rewrite(t.substitute( args.begin(), args.end(), vals.begin(), vals.end()))); } @@ -106,30 +108,31 @@ TEST_F(TestTheoryWhiteEvaluator, strIdOf) std::vector args; std::vector vals; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, empty, one); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, a, one); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, empty, two); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } { Node n = d_nodeManager->mkNode(kind::STRING_INDEXOF, a, a, two); Node r = eval.eval(n, args, vals); - ASSERT_EQ(r, Rewriter::rewrite(n)); + ASSERT_EQ(r, rr->rewrite(n)); } } @@ -140,7 +143,8 @@ TEST_F(TestTheoryWhiteEvaluator, code) std::vector args; std::vector vals; - Evaluator eval; + Rewriter* rr = d_smtEngine->getRewriter(); + Evaluator eval(rr); // (str.code "A") ---> 65 {