From 3c2099bc67595bc015eb3b491e1110b1e94c0d25 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Tue, 11 Jun 2019 16:47:13 -0500 Subject: [PATCH] Do not require sygus constructors to be flattened (#3049) --- src/parser/smt2/Smt2.g | 2 +- src/theory/datatypes/datatypes_rewriter.cpp | 98 +++++++++++++++---- src/theory/datatypes/datatypes_rewriter.h | 53 ++++++++++ .../quantifiers/sygus/sygus_eval_unfold.cpp | 13 ++- .../quantifiers/sygus/sygus_grammar_cons.cpp | 9 +- .../quantifiers/sygus/sygus_grammar_red.cpp | 1 - .../quantifiers/sygus/term_database_sygus.cpp | 27 +++-- .../quantifiers/sygus/term_database_sygus.h | 10 ++ test/regress/regress1/sygus/sygus-dt.sy | 2 +- 9 files changed, 181 insertions(+), 34 deletions(-) diff --git a/src/parser/smt2/Smt2.g b/src/parser/smt2/Smt2.g index 9ba7f4b2e..b224032e8 100644 --- a/src/parser/smt2/Smt2.g +++ b/src/parser/smt2/Smt2.g @@ -1004,7 +1004,7 @@ sygusGTerm[CVC4::SygusGTerm& sgt, std::string& fun] }else if( PARSER_STATE->isDeclared(name,SYM_VARIABLE) ){ Debug("parser-sygus") << "Sygus grammar " << fun << " : symbol " << name << std::endl; - sgt.d_expr = PARSER_STATE->getVariable(name); + sgt.d_expr = PARSER_STATE->getExpressionForName(name); sgt.d_name = name; sgt.d_gterm_type = SygusGTerm::gterm_op; }else{ diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index be87b7e8d..ac3bff21b 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -16,6 +16,8 @@ #include "theory/datatypes/datatypes_rewriter.h" +#include "expr/node_algorithm.h" + using namespace CVC4; using namespace CVC4::kind; @@ -115,8 +117,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) if (ev.getKind() == APPLY_CONSTRUCTOR) { Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n"; - const Datatype& dt = - static_cast(ev.getType().toType()).getDatatype(); + const Datatype& dt = ev.getType().getDatatype(); unsigned i = indexOf(ev.getOperator()); Node op = Node::fromExpr(dt[i].getSygusOp()); // if it is the "any constant" constructor, return its argument @@ -141,14 +142,8 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc)); } Node ret = mkSygusTerm(dt, i, children); - // if it is a variable, apply the substitution - if (ret.getKind() == BOUND_VARIABLE) - { - Assert(ret.hasAttribute(SygusVarNumAttribute())); - int vn = ret.getAttribute(SygusVarNumAttribute()); - Assert(Node::fromExpr(dt.getSygusVarList())[vn] == ret); - ret = args[vn]; - } + // apply the appropriate substitution + ret = applySygusArgs(dt, op, ret, args); Trace("dt-sygus-util") << "...got " << ret << "\n"; return RewriteResponse(REWRITE_AGAIN_FULL, ret); } @@ -186,6 +181,67 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) return RewriteResponse(REWRITE_DONE, in); } +Node DatatypesRewriter::applySygusArgs(const Datatype& dt, + Node op, + Node n, + const std::vector& args) +{ + if (n.getKind() == BOUND_VARIABLE) + { + Assert(n.hasAttribute(SygusVarNumAttribute())); + int vn = n.getAttribute(SygusVarNumAttribute()); + Assert(Node::fromExpr(dt.getSygusVarList())[vn] == n); + return args[vn]; + } + // n is an application of operator op. + // We must compute the free variables in op to determine if there are + // any substitutions we need to make to n. + TNode val; + if (!op.hasAttribute(SygusVarFreeAttribute())) + { + std::unordered_set fvs; + if (expr::getFreeVariables(op, fvs)) + { + if (fvs.size() == 1) + { + for (const Node& v : fvs) + { + val = v; + } + } + else + { + val = op; + } + } + Trace("dt-sygus-fv") << "Free var in " << op << " : " << val << std::endl; + op.setAttribute(SygusVarFreeAttribute(), val); + } + else + { + val = op.getAttribute(SygusVarFreeAttribute()); + } + if (val.isNull()) + { + return n; + } + if (val.getKind() == BOUND_VARIABLE) + { + // single substitution case + int vn = val.getAttribute(SygusVarNumAttribute()); + TNode sub = args[vn]; + return n.substitute(val, sub); + } + // do the full substitution + std::vector vars; + Node bvl = Node::fromExpr(dt.getSygusVarList()); + for (unsigned i = 0, nvars = bvl.getNumChildren(); i < nvars; i++) + { + vars.push_back(bvl[i]); + } + return n.substitute(vars.begin(), vars.end(), args.begin(), args.end()); +} + Kind DatatypesRewriter::getOperatorKindForSygusBuiltin(Node op) { Assert(op.getKind() != BUILTIN); @@ -224,6 +280,13 @@ Node DatatypesRewriter::mkSygusTerm(const Datatype& dt, Assert(!dt[i].getSygusOp().isNull()); std::vector schildren; Node op = Node::fromExpr(dt[i].getSygusOp()); + Trace("dt-sygus-util") << "Operator is " << op << std::endl; + if (children.empty()) + { + // no children, return immediately + Trace("dt-sygus-util") << "...return direct op" << std::endl; + return op; + } // if it is the any constant, we simply return the child if (op.getAttribute(SygusAnyConstAttribute())) { @@ -243,18 +306,13 @@ Node DatatypesRewriter::mkSygusTerm(const Datatype& dt, return ret; } Kind ok = NodeManager::operatorToKind(op); + Trace("dt-sygus-util") << "operator kind is " << ok << std::endl; if (ok != UNDEFINED_KIND) { - if (ok == APPLY_UF && schildren.size() == 1) - { - // This case is triggered for defined constant symbols. In this case, - // we return the operator itself instead of an APPLY_UF node. - ret = schildren[0]; - } - else - { - ret = NodeManager::currentNM()->mkNode(ok, schildren); - } + // If it is an APPLY_UF operator, we should have at least an operator and + // a child. + Assert(ok != APPLY_UF || schildren.size() != 1); + ret = NodeManager::currentNM()->mkNode(ok, schildren); Trace("dt-sygus-util") << "...return (op) " << ret << std::endl; return ret; } diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 6c1d64e5b..1a1735402 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -51,6 +51,28 @@ struct SygusSymBreakOkAttributeId typedef expr::Attribute SygusSymBreakOkAttribute; +/** sygus var free + * + * This attribute is used to mark whether sygus operators have free occurrences + * of variables from the formal argument list of the function-to-synthesize. + * + * We store three possible cases for sygus operators op: + * (1) op.getAttribute(SygusVarFreeAttribute())==Node::null() + * In this case, op has no free variables from the formal argument list of the + * function-to-synthesize. + * (2) op.getAttribute(SygusVarFreeAttribute())==v, where v is a bound variable. + * In this case, op has exactly one free variable, v. + * (3) op.getAttribute(SygusVarFreeAttribute())==op + * In this case, op has an arbitrary set (cardinality >1) of free variables from + * the formal argument list of the function to synthesize. + * + * This attribute is used to compute applySygusArgs below. + */ +struct SygusVarFreeAttributeId +{ +}; +typedef expr::Attribute SygusVarFreeAttribute; + namespace datatypes { class DatatypesRewriter { @@ -149,6 +171,37 @@ public: static Node mkSygusTerm(const Datatype& dt, unsigned i, const std::vector& children); + /** + * n is a builtin term that is an application of operator op. + * + * This returns an n' such that (eval n args) is n', where n' is a instance of + * n for the appropriate substitution. + * + * For example, given a function-to-synthesize with formal argument list (x,y), + * say we have grammar: + * A -> A+A | A+x | A+(x+y) | y + * These lead to constructors with sygus ops: + * C1 / (lambda w1 w2. w1+w2) + * C2 / (lambda w1. w1+x) + * C3 / (lambda w1. w1+(x+y)) + * C4 / y + * Examples of calling this function: + * applySygusArgs( dt, C1, (APPLY_UF (lambda w1 w2. w1+w2) t1 t2), { 3, 5 } ) + * ... returns (APPLY_UF (lambda w1 w2. w1+w2) t1 t2). + * applySygusArgs( dt, C2, (APPLY_UF (lambda w1. w1+x) t1), { 3, 5 } ) + * ... returns (APPLY_UF (lambda w1. w1+3) t1). + * applySygusArgs( dt, C3, (APPLY_UF (lambda w1. w1+(x+y)) t1), { 3, 5 } ) + * ... returns (APPLY_UF (lambda w1. w1+(3+5)) t1). + * applySygusArgs( dt, C4, y, { 3, 5 } ) + * ... returns 5. + * Notice the attribute SygusVarFreeAttribute is applied to C1, C2, C3, C4, + * to cache the results of whether the evaluation of this constructor needs + * a substitution over the formal argument list of the function-to-synthesize. + */ + static Node applySygusArgs(const Datatype& dt, + Node op, + Node n, + const std::vector& args); /** * Get the builtin sygus operator for constructor term n of sygus datatype * type. For example, if n is the term C_+( d1, d2 ) where C_+ is a sygus diff --git a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp index e44b604d0..7324add50 100644 --- a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp +++ b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp @@ -133,7 +133,18 @@ void SygusEvalUnfold::registerModelValue(Node a, bool do_unfold = false; if (options::sygusEvalUnfoldBool()) { - if (bTerm.getKind() == ITE || bTerm.getType().isBoolean()) + Node bTermUse = bTerm; + if (bTerm.getKind() == APPLY_UF) + { + // if the builtin term is non-beta-reduced application of lambda, + // we look at the body of the lambda. + Node bTermOp = bTerm.getOperator(); + if (bTermOp.getKind() == LAMBDA) + { + bTermUse = bTermOp[0]; + } + } + if (bTermUse.getKind() == ITE || bTermUse.getType().isBoolean()) { do_unfold = true; } diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 48da8e8e8..263c88d15 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -712,7 +712,14 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( for (unsigned k = 0, size_k = dt.getNumConstructors(); k < size_k; ++k) { Trace("sygus-grammar-def") << "...for " << dt[k].getName() << std::endl; - ops[i].push_back( dt[k].getConstructor() ); + Expr cop = dt[k].getConstructor(); + if (dt[k].getNumArgs() == 0) + { + // Nullary constructors are interpreted as terms, not operators. + // Thus, we apply them to no arguments here. + cop = nm->mkNode(APPLY_CONSTRUCTOR, Node::fromExpr(cop)).toExpr(); + } + ops[i].push_back(cop); cnames[i].push_back(dt[k].getName()); cargs[i].push_back(std::vector()); Trace("sygus-grammar-def") << "...add for selectors" << std::endl; diff --git a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp index 6ad590f28..2b2c87f38 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp @@ -43,7 +43,6 @@ void SygusRedundantCons::initialize(QuantifiersEngine* qe, TypeNode tn) std::map pre; Node g = tds->mkGeneric(dt, i, pre); Trace("sygus-red-debug") << " ...pre-rewrite : " << g << std::endl; - Assert(g.getNumChildren() == dt[i].getNumArgs()); d_gen_terms[i] = g; for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index af820b0fc..01d08dad8 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -411,6 +411,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ) { Trace("sygus-db") << ", kind = " << sk; d_kinds[tn][sk] = i; d_arg_kind[tn][i] = sk; + if (sk == ITE) + { + // mark that this type has an ITE + d_hasIte[tn] = true; + } } else if (sop.isConst() && dt[i].getNumArgs() == 0) { @@ -432,6 +437,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ) { << ", argument to a lambda constructor is not " << lat << std::endl; } + if (sop[0].getKind() == ITE) + { + // mark that this type has an ITE + d_hasIte[tn] = true; + } } // symbolic constructors if (n.getAttribute(SygusAnyConstAttribute())) @@ -602,7 +612,7 @@ void TermDbSygus::registerEnumerator(Node e, // solution" clauses. const Datatype& dt = et.getDatatype(); if (options::sygusStream() - || (!hasKind(et, ITE) && !dt.getSygusType().isBoolean())) + || (!hasIte(et) && !dt.getSygusType().isBoolean())) { isActiveGen = true; } @@ -1003,6 +1013,10 @@ int TermDbSygus::getOpConsNum( TypeNode tn, Node n ) { bool TermDbSygus::hasKind( TypeNode tn, Kind k ) { return getKindConsNum( tn, k )!=-1; } +bool TermDbSygus::hasIte(TypeNode tn) const +{ + return d_hasIte.find(tn) != d_hasIte.end(); +} bool TermDbSygus::hasConst( TypeNode tn, Node n ) { return getConstConsNum( tn, n )!=-1; } @@ -1502,14 +1516,9 @@ Node TermDbSygus::unfold( Node en, std::map< Node, Node >& vtm, std::vector< Nod pre[j] = nm->mkNode(DT_SYGUS_EVAL, cc); } Node ret = mkGeneric(dt, i, pre); - // if it is a variable, apply the substitution - if (ret.getKind() == kind::BOUND_VARIABLE) - { - Assert(ret.hasAttribute(SygusVarNumAttribute())); - int i = ret.getAttribute(SygusVarNumAttribute()); - Assert(Node::fromExpr(dt.getSygusVarList())[i] == ret); - return args[i]; - } + // apply the appropriate substitution to ret + ret = datatypes::DatatypesRewriter::applySygusArgs(dt, sop, ret, args); + // rewrite ret = Rewriter::rewrite(ret); return ret; } diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 0f3d650d3..2854ecab6 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -393,6 +393,11 @@ class TermDbSygus { std::map > d_var_list; std::map > d_arg_kind; std::map > d_kinds; + /** + * Whether this sygus type has a constructors whose sygus operator is ITE, + * or is a lambda whose body is ITE. + */ + std::map d_hasIte; std::map > d_arg_const; std::map > d_consts; std::map > d_ops; @@ -462,6 +467,11 @@ class TermDbSygus { int getConstConsNum( TypeNode tn, Node n ); int getOpConsNum( TypeNode tn, Node n ); bool hasKind( TypeNode tn, Kind k ); + /** + * Returns true if this sygus type has a constructors whose sygus operator is + * ITE, or is a lambda whose body is ITE. + */ + bool hasIte(TypeNode tn) const; bool hasConst( TypeNode tn, Node n ); bool hasOp( TypeNode tn, Node n ); Node getConsNumConst( TypeNode tn, int i ); diff --git a/test/regress/regress1/sygus/sygus-dt.sy b/test/regress/regress1/sygus/sygus-dt.sy index 2f3f4dbb9..336c59b27 100644 --- a/test/regress/regress1/sygus/sygus-dt.sy +++ b/test/regress/regress1/sygus/sygus-dt.sy @@ -7,7 +7,7 @@ (define-fun g ((x Int)) List (cons (+ x 1) nil)) (define-fun i () List (cons 3 nil)) -(synth-fun f ((x Int)) List ((Start List ((g StartInt) i (cons StartInt Start) (nil) (tail Start))) +(synth-fun f ((x Int)) List ((Start List ((g StartInt) i (cons StartInt Start) nil (tail Start))) (StartInt Int (x 0 1 (+ StartInt StartInt))))) (declare-var x Int) -- 2.30.2