From: Andrew Reynolds Date: Fri, 1 Oct 2021 04:49:08 +0000 (-0500) Subject: Use the proper evaluator for optimized SyGuS datatype rewriting (#7266) X-Git-Tag: cvc5-1.0.0~1142 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=bb0e6dcde2e7267e391a46b868b990d7cb7e42bd;p=cvc5.git Use the proper evaluator for optimized SyGuS datatype rewriting (#7266) This updates the datatypes rewriter to use the evaluator from Env instead of creating local copies of Evaluator. This makes all uses of the Evaluator dependent on the proper options (e.g. which will be based later on the cardinality of the alphabet for strings). This moves one utility method (sygusToBuiltinEval) to the datatypes rewriter, as it uses an Evaluator that will be dependent on options. Notice that this is another instance where it is important for us to make the cache for the rewriter local. The same issue occurs for other places where rewriting is dependent on options. This issue will be revisited when the option for strings alphabet cardinality is added. --- diff --git a/src/smt/env.cpp b/src/smt/env.cpp index 0ffe1c4b9..5c7836fb7 100644 --- a/src/smt/env.cpp +++ b/src/smt/env.cpp @@ -105,6 +105,11 @@ bool Env::isTheoryProofProducing() const theory::Rewriter* Env::getRewriter() { return d_rewriter.get(); } +theory::Evaluator* Env::getEvaluator(bool useRewriter) +{ + return useRewriter ? d_evalRew.get() : d_eval.get(); +} + theory::TrustSubstitutionMap& Env::getTopLevelSubstitutions() { return *d_topLevelSubs.get(); diff --git a/src/smt/env.h b/src/smt/env.h index e3a34cf4a..8d2b1636e 100644 --- a/src/smt/env.h +++ b/src/smt/env.h @@ -105,6 +105,14 @@ class Env /** Get a pointer to the Rewriter owned by this Env. */ theory::Rewriter* getRewriter(); + /** + * Get a pointer to the Evaluator owned by this Env. There are two variants + * of the evaluator, one that invokes the rewriter when evaluation is not + * applicable, and one that does not. The former evaluator is returned when + * useRewriter is true. + */ + theory::Evaluator* getEvaluator(bool useRewriter = false); + /** Get a reference to the top-level substitution map */ theory::TrustSubstitutionMap& getTopLevelSubstitutions(); diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index 33d143a36..c446504fd 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -35,6 +35,11 @@ namespace cvc5 { namespace theory { namespace datatypes { +DatatypesRewriter::DatatypesRewriter(Evaluator* sygusEval) + : d_sygusEval(sygusEval) +{ +} + RewriteResponse DatatypesRewriter::postRewrite(TNode in) { Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl; @@ -137,7 +142,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) { args.push_back(in[j]); } - Node ret = utils::sygusToBuiltinEval(ev, args); + Node ret = sygusToBuiltinEval(ev, args); Trace("dt-sygus-util") << "...got " << ret << "\n"; Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl; Assert(in.getType().isComparableTo(ret.getType())); @@ -920,6 +925,126 @@ TrustNode DatatypesRewriter::expandDefinition(Node n) return TrustNode::null(); } +Node DatatypesRewriter::sygusToBuiltinEval(Node n, + const std::vector& args) +{ + Assert(d_sygusEval != nullptr); + NodeManager* nm = NodeManager::currentNM(); + // constant arguments? + bool constArgs = true; + for (const Node& a : args) + { + if (!a.isConst()) + { + constArgs = false; + break; + } + } + std::vector eargs; + bool svarsInit = false; + std::vector svars; + std::unordered_map visited; + std::unordered_map::iterator it; + std::vector visit; + TNode cur; + unsigned index; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + if (it == visited.end()) + { + TypeNode tn = cur.getType(); + if (!tn.isDatatype() || !tn.getDType().isSygus()) + { + visited[cur] = cur; + } + else if (cur.isConst()) + { + // convert to builtin term + Node bt = utils::sygusToBuiltin(cur); + // run the evaluator if possible + if (!svarsInit) + { + svarsInit = true; + TypeNode type = cur.getType(); + Node varList = type.getDType().getSygusVarList(); + for (const Node& v : varList) + { + svars.push_back(v); + } + } + Assert(args.size() == svars.size()); + // try evaluation if we have constant arguments + Node ret = + constArgs ? d_sygusEval->eval(bt, svars, args) : Node::null(); + if (ret.isNull()) + { + // if evaluation was not available, use a substitution + ret = bt.substitute( + svars.begin(), svars.end(), args.begin(), args.end()); + } + visited[cur] = ret; + } + else + { + if (cur.getKind() == APPLY_CONSTRUCTOR) + { + visited[cur] = Node::null(); + visit.push_back(cur); + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + else + { + // it is the evaluation of this term on the arguments + if (eargs.empty()) + { + eargs.push_back(cur); + eargs.insert(eargs.end(), args.begin(), args.end()); + } + else + { + eargs[0] = cur; + } + visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs); + } + } + } + else if (it->second.isNull()) + { + Node ret = cur; + Assert(cur.getKind() == APPLY_CONSTRUCTOR); + const DType& dt = cur.getType().getDType(); + // non sygus-datatype terms are also themselves + if (dt.isSygus()) + { + std::vector children; + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + children.push_back(it->second); + } + index = utils::indexOf(cur.getOperator()); + // apply to children, which constructs the builtin term + ret = utils::mkSygusTerm(dt, index, children); + // now apply it to arguments in args + ret = utils::applySygusArgs(dt, dt[index].getSygusOp(), ret, args); + } + visited[cur] = ret; + } + } while (!visit.empty()); + Assert(visited.find(n) != visited.end()); + Assert(!visited.find(n)->second.isNull()); + return visited[n]; +} + } // namespace datatypes } // namespace theory } // namespace cvc5 diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 56dde76a0..31e2a1bef 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -18,6 +18,7 @@ #ifndef CVC5__THEORY__DATATYPES__DATATYPES_REWRITER_H #define CVC5__THEORY__DATATYPES__DATATYPES_REWRITER_H +#include "theory/evaluator.h" #include "theory/theory_rewriter.h" namespace cvc5 { @@ -37,6 +38,7 @@ namespace datatypes { class DatatypesRewriter : public TheoryRewriter { public: + DatatypesRewriter(Evaluator* sygusEval); RewriteResponse postRewrite(TNode in) override; RewriteResponse preRewrite(TNode in) override; @@ -164,7 +166,32 @@ class DatatypesRewriter : public TheoryRewriter Node orig, TypeNode orig_tn, unsigned depth); -}; /* class DatatypesRewriter */ + + /** Sygus to builtin eval + * + * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice + * that n does not necessarily need to be a constant. + * + * It does so by (1) converting constant subterms of n to builtin terms and + * evaluating them on the arguments args, (2) unfolding non-constant + * applications of sygus constructors in n with respect to args and (3) + * converting all other non-constant subterms of n to applications of + * DT_SYGUS_EVAL. + * + * For example, if + * n = C_+( C_*( C_x(), C_y() ), n' ), and args = { 3, 4 } + * where n' is a variable, then this method returns: + * 12 + (DT_SYGUS_EVAL n' 3 4) + * Notice that the subterm C_*( C_x(), C_y() ) is converted to its builtin + * equivalent x*y and evaluated under the substition { x -> 3, y -> 4 } giving + * 12. The subterm n' is non-constant and thus we return its evaluation under + * 3,4, giving the term (DT_SYGUS_EVAL n' 3 4). Since the top-level + * constructor is C_+, these terms are added together to give the result. + */ + Node sygusToBuiltinEval(Node n, const std::vector& args); + /** Pointer to the evaluator, used as an optimization for the above method */ + Evaluator* d_sygusEval; +}; } // namespace datatypes } // namespace theory diff --git a/src/theory/datatypes/sygus_datatype_utils.cpp b/src/theory/datatypes/sygus_datatype_utils.cpp index 12c255f57..c68b87d85 100644 --- a/src/theory/datatypes/sygus_datatype_utils.cpp +++ b/src/theory/datatypes/sygus_datatype_utils.cpp @@ -388,124 +388,6 @@ Node sygusToBuiltin(Node n, bool isExternal) return visited[n]; } -Node sygusToBuiltinEval(Node n, const std::vector& args) -{ - NodeManager* nm = NodeManager::currentNM(); - Evaluator eval(nullptr); - // constant arguments? - bool constArgs = true; - for (const Node& a : args) - { - if (!a.isConst()) - { - constArgs = false; - break; - } - } - std::vector eargs; - bool svarsInit = false; - std::vector svars; - std::unordered_map visited; - std::unordered_map::iterator it; - std::vector visit; - TNode cur; - unsigned index; - visit.push_back(n); - do - { - cur = visit.back(); - visit.pop_back(); - it = visited.find(cur); - if (it == visited.end()) - { - TypeNode tn = cur.getType(); - if (!tn.isDatatype() || !tn.getDType().isSygus()) - { - visited[cur] = cur; - } - else if (cur.isConst()) - { - // convert to builtin term - Node bt = sygusToBuiltin(cur); - // run the evaluator if possible - if (!svarsInit) - { - svarsInit = true; - TypeNode type = cur.getType(); - Node varList = type.getDType().getSygusVarList(); - for (const Node& v : varList) - { - svars.push_back(v); - } - } - Assert(args.size() == svars.size()); - // try evaluation if we have constant arguments - Node ret = constArgs ? eval.eval(bt, svars, args) : Node::null(); - if (ret.isNull()) - { - // if evaluation was not available, use a substitution - ret = bt.substitute( - svars.begin(), svars.end(), args.begin(), args.end()); - } - visited[cur] = ret; - } - else - { - if (cur.getKind() == APPLY_CONSTRUCTOR) - { - visited[cur] = Node::null(); - visit.push_back(cur); - for (const Node& cn : cur) - { - visit.push_back(cn); - } - } - else - { - // it is the evaluation of this term on the arguments - if (eargs.empty()) - { - eargs.push_back(cur); - eargs.insert(eargs.end(), args.begin(), args.end()); - } - else - { - eargs[0] = cur; - } - visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs); - } - } - } - else if (it->second.isNull()) - { - Node ret = cur; - Assert(cur.getKind() == APPLY_CONSTRUCTOR); - const DType& dt = cur.getType().getDType(); - // non sygus-datatype terms are also themselves - if (dt.isSygus()) - { - std::vector children; - for (const Node& cn : cur) - { - it = visited.find(cn); - Assert(it != visited.end()); - Assert(!it->second.isNull()); - children.push_back(it->second); - } - index = indexOf(cur.getOperator()); - // apply to children, which constructs the builtin term - ret = mkSygusTerm(dt, index, children); - // now apply it to arguments in args - ret = applySygusArgs(dt, dt[index].getSygusOp(), ret, args); - } - visited[cur] = ret; - } - } while (!visit.empty()); - Assert(visited.find(n) != visited.end()); - Assert(!visited.find(n)->second.isNull()); - return visited[n]; -} - Node builtinVarToSygus(Node v) { BuiltinVarToSygusAttribute bvtsa; diff --git a/src/theory/datatypes/sygus_datatype_utils.h b/src/theory/datatypes/sygus_datatype_utils.h index 5784fe34a..3ea6b62e9 100644 --- a/src/theory/datatypes/sygus_datatype_utils.h +++ b/src/theory/datatypes/sygus_datatype_utils.h @@ -165,29 +165,6 @@ Node sygusToBuiltin(Node n, bool isExternal = false); */ Node builtinVarToSygus(Node v); -/** Sygus to builtin eval - * - * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice that - * n does not necessarily need to be a constant. - * - * It does so by (1) converting constant subterms of n to builtin terms and - * evaluating them on the arguments args, (2) unfolding non-constant - * applications of sygus constructors in n with respect to args and (3) - * converting all other non-constant subterms of n to applications of - * DT_SYGUS_EVAL. - * - * For example, if - * n = C_+( C_*( C_x(), C_y() ), n' ), and args = { 3, 4 } - * where n' is a variable, then this method returns: - * 12 + (DT_SYGUS_EVAL n' 3 4) - * Notice that the subterm C_*( C_x(), C_y() ) is converted to its builtin - * equivalent x*y and evaluated under the substition { x -> 3, y -> 4 } giving - * 12. The subterm n' is non-constant and thus we return its evaluation under - * 3,4, giving the term (DT_SYGUS_EVAL n' 3 4). Since the top-level constructor - * is C_+, these terms are added together to give the result. - */ -Node sygusToBuiltinEval(Node n, const std::vector& args); - /** Get free symbols in a sygus datatype type * * Add the free symbols (expr::getSymbols) in terms that can be generated by diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index feb19b182..a1c6942a5 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -61,6 +61,7 @@ TheoryDatatypes::TheoryDatatypes(Env& env, d_functionTerms(context()), d_singleton_eq(userContext()), d_sygusExtension(nullptr), + d_rewriter(env.getEvaluator()), d_state(env, valuation), d_im(env, *this, d_state, d_pnm), d_notify(d_im, *this)