From: Andrew Reynolds Date: Fri, 6 Dec 2019 19:12:12 +0000 (-0600) Subject: Optimize the rewriter for DT_SYGUS_EVAL (#3529) X-Git-Tag: cvc5-1.0.0~3788 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=ec865a83596fd1285e033426b80ddfc1c35085cd;p=cvc5.git Optimize the rewriter for DT_SYGUS_EVAL (#3529) This makes it so that we don't construct intermediate unfoldings of applications of DT_SYGUS_EVAL, which wastes time in node construction. It makes the sygusToBuiltin utility in TermDbSygus use this implementation. --- diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index be4226f69..080306d39 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -120,34 +120,16 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) if (ev.getKind() == APPLY_CONSTRUCTOR) { Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n"; - const Datatype& dt = ev.getType().getDatatype(); - unsigned i = utils::indexOf(ev.getOperator()); - Node op = Node::fromExpr(dt[i].getSygusOp()); - // if it is the "any constant" constructor, return its argument - if (op.getAttribute(SygusAnyConstAttribute())) - { - Assert(ev.getNumChildren() == 1); - Assert(ev[0].getType().isComparableTo(in.getType())); - return RewriteResponse(REWRITE_AGAIN_FULL, ev[0]); - } + Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl; std::vector args; for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++) { args.push_back(in[j]); } - Assert(!dt.isParametric()); - std::vector children; - for (const Node& evc : ev) - { - std::vector cc; - cc.push_back(evc); - cc.insert(cc.end(), args.begin(), args.end()); - children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc)); - } - Node ret = utils::mkSygusTerm(dt, i, children); - // apply the appropriate substitution - ret = utils::applySygusArgs(dt, op, ret, args); + Node ret = utils::sygusToBuiltinEval(ev, args); Trace("dt-sygus-util") << "...got " << ret << "\n"; + Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl; + Assert(in.getType().isComparableTo(ret.getType())); return RewriteResponse(REWRITE_AGAIN_FULL, ret); } } diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp index 43d23b523..d2833a852 100644 --- a/src/theory/datatypes/theory_datatypes_utils.cpp +++ b/src/theory/datatypes/theory_datatypes_utils.cpp @@ -18,6 +18,7 @@ #include "expr/node_algorithm.h" #include "expr/sygus_datatype.h" +#include "theory/evaluator.h" using namespace CVC4; using namespace CVC4::kind; @@ -384,6 +385,200 @@ bool checkClash(Node n1, Node n2, std::vector& rew) return false; } +struct SygusToBuiltinTermAttributeId +{ +}; +typedef expr::Attribute + SygusToBuiltinTermAttribute; + +Node sygusToBuiltin(Node n) +{ + Assert(n.isConst()); + std::unordered_map visited; + std::unordered_map::iterator it; + std::vector visit; + TNode cur; + unsigned index; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + if (it == visited.end()) + { + if (cur.getKind() == APPLY_CONSTRUCTOR) + { + if (cur.hasAttribute(SygusToBuiltinTermAttribute())) + { + visited[cur] = cur.getAttribute(SygusToBuiltinTermAttribute()); + } + else + { + visited[cur] = Node::null(); + visit.push_back(cur); + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + } + else + { + // non-datatypes are themselves + visited[cur] = cur; + } + } + else if (it->second.isNull()) + { + Node ret = cur; + Assert(cur.getKind() == APPLY_CONSTRUCTOR); + const Datatype& dt = cur.getType().getDatatype(); + // Non sygus-datatype terms are also themselves. Notice we treat the + // case of non-sygus datatypes this way since it avoids computing + // the type / datatype of the node in the pre-traversal above. The + // case of non-sygus datatypes is very rare, so the extra addition to + // visited is justified performance-wise. + if (dt.isSygus()) + { + std::vector children; + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + children.push_back(it->second); + } + index = indexOf(cur.getOperator()); + ret = mkSygusTerm(dt, index, children); + } + visited[cur] = ret; + // cache + SygusToBuiltinTermAttribute stbt; + cur.setAttribute(stbt, ret); + } + } while (!visit.empty()); + Assert(visited.find(n) != visited.end()); + Assert(!visited.find(n)->second.isNull()); + return visited[n]; +} + +Node sygusToBuiltinEval(Node n, const std::vector& args) +{ + NodeManager* nm = NodeManager::currentNM(); + Evaluator eval; + // constant arguments? + bool constArgs = true; + for (const Node& a : args) + { + if (!a.isConst()) + { + constArgs = false; + break; + } + } + std::vector eargs; + bool svarsInit = false; + std::vector svars; + std::unordered_map visited; + std::unordered_map::iterator it; + std::vector visit; + TNode cur; + unsigned index; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + if (it == visited.end()) + { + TypeNode tn = cur.getType(); + if (!tn.isDatatype() || !tn.getDatatype().isSygus()) + { + visited[cur] = cur; + } + else if (cur.isConst()) + { + // convert to builtin term + Node bt = sygusToBuiltin(cur); + // run the evaluator if possible + if (!svarsInit) + { + svarsInit = true; + TypeNode tn = cur.getType(); + Node varList = Node::fromExpr(tn.getDatatype().getSygusVarList()); + for (const Node& v : varList) + { + svars.push_back(v); + } + } + Assert(args.size() == svars.size()); + // try evaluation if we have constant arguments + Node ret = constArgs ? eval.eval(bt, svars, args) : Node::null(); + if (ret.isNull()) + { + // if evaluation was not available, use a substitution + ret = bt.substitute( + svars.begin(), svars.end(), args.begin(), args.end()); + } + visited[cur] = ret; + } + else + { + if (cur.getKind() == APPLY_CONSTRUCTOR) + { + visited[cur] = Node::null(); + visit.push_back(cur); + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + else + { + // it is the evaluation of this term on the arguments + if (eargs.empty()) + { + eargs.push_back(cur); + eargs.insert(eargs.end(), args.begin(), args.end()); + } + else + { + eargs[0] = cur; + } + visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs); + } + } + } + else if (it->second.isNull()) + { + Node ret = cur; + Assert(cur.getKind() == APPLY_CONSTRUCTOR); + const Datatype& dt = cur.getType().getDatatype(); + // non sygus-datatype terms are also themselves + if (dt.isSygus()) + { + std::vector children; + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + children.push_back(it->second); + } + index = indexOf(cur.getOperator()); + // apply to arguments + ret = mkSygusTerm(dt, index, children); + } + visited[cur] = ret; + } + } while (!visit.empty()); + Assert(visited.find(n) != visited.end()); + Assert(!visited.find(n)->second.isNull()); + return visited[n]; +} + } // namespace utils } // namespace datatypes } // namespace theory diff --git a/src/theory/datatypes/theory_datatypes_utils.h b/src/theory/datatypes/theory_datatypes_utils.h index 5f74a4bee..46a6d56be 100644 --- a/src/theory/datatypes/theory_datatypes_utils.h +++ b/src/theory/datatypes/theory_datatypes_utils.h @@ -185,12 +185,36 @@ Node applySygusArgs(const Datatype& dt, Node op, Node n, const std::vector& args); -/** - * Get the builtin sygus operator for constructor term n of sygus datatype - * type. For example, if n is the term C_+( d1, d2 ) where C_+ is a sygus - * constructor whose sygus op is the builtin operator +, this method returns +. +/** Sygus to builtin + * + * This method converts a constant term of SyGuS datatype type to its builtin + * equivalent. For example, given input C_*( C_x(), C_y() ), this method returns + * x*y, assuming C_+, C_x, and C_y have sygus operators *, x, and y + * respectively. + */ +Node sygusToBuiltin(Node c); +/** Sygus to builtin eval + * + * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice that + * n does not necessarily need to be a constant. + * + * It does so by (1) converting constant subterms of n to builtin terms and + * evaluating them on the arguments args, (2) unfolding non-constant + * applications of sygus constructors in n with respect to args and (3) + * converting all other non-constant subterms of n to applications of + * DT_SYGUS_EVAL. + * + * For example, if + * n = C_+( C_*( C_x(), C_y() ), n' ), and args = { 3, 4 } + * where n' is a variable, then this method returns: + * 12 + (DT_SYGUS_EVAL n' 3 4) + * Notice that the subterm C_*( C_x(), C_y() ) is converted to its builtin + * equivalent x*y and evaluated under the substition { x -> 3, x -> 4 } giving + * 12. The subterm n' is non-constant and thus we return its evaluation under + * 3,4, giving the term (DT_SYGUS_EVAL n' 3 4). Since the top-level constructor + * is C_+, these terms are added together to give the result. */ -Node getSygusOpForCTerm(Node n); +Node sygusToBuiltinEval(Node n, const std::vector& args); // ------------------------ end sygus utils diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index d664a462d..c5ea0f9f3 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -277,6 +277,11 @@ typedef expr::Attribute Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn) { + if (n.isConst()) + { + // if its a constant, we use the datatype utility version + return datatypes::utils::sygusToBuiltin(n); + } Assert(n.getType().isComparableTo(tn)); if (!tn.isDatatype()) {