From 4e310461b2e41f9ccf1426797b5d8b58e27bc1c7 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 10 Apr 2020 23:37:43 -0500 Subject: [PATCH] Ensure exported sygus solutions match grammar (#4270) Previously we were doing rewriting/expand definitions during grammar normalization, which overwrote the original sygus operators. The connection to the original grammar was maintained via the SygusPrintCallback utility, which ensured that a sygus term printed in a way that matched the grammar. We now have several use cases where solutions from SyGuS will be directly exported to the user, including the current use of get-abduct. This means that the terms must match the grammar, and we cannot simply rely on the print callback. This moves the code to normalize sygus operators to datatypes utils, where the conversion between sygus and builtin terms takes place. This allows a version of this function where isExternal = true, which constructs terms matching the original grammar. This PR enables the SyGuS API to have an accurate getSynthSolution method. It also will eliminate the need for SygusPrintCallback altogether, once the v1 parser is deleted. --- src/expr/dtype_cons.h | 9 +- .../datatypes/theory_datatypes_utils.cpp | 151 +++++++++++++++++- src/theory/datatypes/theory_datatypes_utils.h | 23 ++- .../quantifiers/sygus/sygus_grammar_norm.cpp | 120 +------------- .../quantifiers/sygus/sygus_grammar_norm.h | 10 -- .../quantifiers/sygus/synth_conjecture.cpp | 11 +- test/regress/CMakeLists.txt | 1 + .../regress/regress1/sygus/yoni-true-sol.smt2 | 20 +++ 8 files changed, 199 insertions(+), 146 deletions(-) create mode 100644 test/regress/regress1/sygus/yoni-true-sol.smt2 diff --git a/src/expr/dtype_cons.h b/src/expr/dtype_cons.h index d5d0013de..ca4806316 100644 --- a/src/expr/dtype_cons.h +++ b/src/expr/dtype_cons.h @@ -87,12 +87,9 @@ class DTypeConstructor void setSygus(Node op); /** get sygus op * - * This method returns the operator or - * term that this constructor represents - * in the sygus encoding. This may be a - * builtin operator, defined function, variable, - * or constant that this constructor encodes in this - * deep embedding. + * This method returns the operator or term that this constructor represents + * in the sygus encoding. This may be a builtin operator, defined function, + * variable, or constant that this constructor encodes in this deep embedding. */ Node getSygusOp() const; /** is this a sygus identity function? diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp index 13cc8fc19..ee0fd814e 100644 --- a/src/theory/datatypes/theory_datatypes_utils.cpp +++ b/src/theory/datatypes/theory_datatypes_utils.cpp @@ -19,7 +19,10 @@ #include "expr/dtype.h" #include "expr/node_algorithm.h" #include "expr/sygus_datatype.h" +#include "smt/smt_engine.h" +#include "smt/smt_engine_scope.h" #include "theory/evaluator.h" +#include "theory/rewriter.h" using namespace CVC4; using namespace CVC4::kind; @@ -117,10 +120,99 @@ Kind getOperatorKindForSygusBuiltin(Node op) return UNDEFINED_KIND; } +struct SygusOpRewrittenAttributeId +{ +}; +typedef expr::Attribute + SygusOpRewrittenAttribute; + +Kind getEliminateKind(Kind ok) +{ + Kind nk = ok; + // We also must ensure that builtin operators which are eliminated + // during expand definitions are replaced by the proper operator. + if (ok == BITVECTOR_UDIV) + { + nk = BITVECTOR_UDIV_TOTAL; + } + else if (ok == BITVECTOR_UREM) + { + nk = BITVECTOR_UREM_TOTAL; + } + else if (ok == DIVISION) + { + nk = DIVISION_TOTAL; + } + else if (ok == INTS_DIVISION) + { + nk = INTS_DIVISION_TOTAL; + } + else if (ok == INTS_MODULUS) + { + nk = INTS_MODULUS_TOTAL; + } + return nk; +} + +Node eliminatePartialOperators(Node n) +{ + NodeManager* nm = NodeManager::currentNM(); + std::unordered_map visited; + std::unordered_map::iterator it; + std::vector visit; + TNode cur; + visit.push_back(n); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + + if (it == visited.end()) + { + visited[cur] = Node::null(); + visit.push_back(cur); + for (const Node& cn : cur) + { + visit.push_back(cn); + } + } + else if (it->second.isNull()) + { + Node ret = cur; + bool childChanged = false; + std::vector children; + if (cur.getMetaKind() == metakind::PARAMETERIZED) + { + children.push_back(cur.getOperator()); + } + for (const Node& cn : cur) + { + it = visited.find(cn); + Assert(it != visited.end()); + Assert(!it->second.isNull()); + childChanged = childChanged || cn != it->second; + children.push_back(it->second); + } + Kind ok = cur.getKind(); + Kind nk = getEliminateKind(ok); + if (nk != ok || childChanged) + { + ret = nm->mkNode(nk, children); + } + visited[cur] = ret; + } + } while (!visit.empty()); + Assert(visited.find(n) != visited.end()); + Assert(!visited.find(n)->second.isNull()); + return visited[n]; +} + Node mkSygusTerm(const DType& dt, unsigned i, const std::vector& children, - bool doBetaReduction) + bool doBetaReduction, + bool isExternal) { Trace("dt-sygus-util") << "Make sygus term " << dt.getName() << "[" << i << "] with children: " << children << std::endl; @@ -128,7 +220,49 @@ Node mkSygusTerm(const DType& dt, Assert(dt.isSygus()); Assert(!dt[i].getSygusOp().isNull()); Node op = dt[i].getSygusOp(); - return mkSygusTerm(op, children, doBetaReduction); + Node opn = op; + if (!isExternal) + { + // Get the normalized version of the sygus operator. We do this by + // expanding definitions, rewriting it, and eliminating partial operators. + if (!op.hasAttribute(SygusOpRewrittenAttribute())) + { + if (op.isConst()) + { + // If it is a builtin operator, convert to total version if necessary. + // First, get the kind for the operator. + Kind ok = NodeManager::operatorToKind(op); + Trace("sygus-grammar-normalize-debug") + << "...builtin kind is " << ok << std::endl; + Kind nk = getEliminateKind(ok); + if (nk != ok) + { + Trace("sygus-grammar-normalize-debug") + << "...replace by builtin operator " << nk << std::endl; + opn = NodeManager::currentNM()->operatorOf(nk); + } + } + else + { + // Only expand definitions if the operator is not constant, since + // calling expandDefinitions on them should be a no-op. This check + // ensures we don't try to expand e.g. bitvector extract operators, + // whose type is undefined, and thus should not be passed to + // expandDefinitions. + opn = Node::fromExpr( + smt::currentSmtEngine()->expandDefinitions(op.toExpr())); + opn = Rewriter::rewrite(opn); + opn = eliminatePartialOperators(opn); + SygusOpRewrittenAttribute sora; + op.setAttribute(sora, opn); + } + } + else + { + opn = op.getAttribute(SygusOpRewrittenAttribute()); + } + } + return mkSygusTerm(opn, children, doBetaReduction); } Node mkSygusTerm(Node op, @@ -386,7 +520,7 @@ struct SygusToBuiltinTermAttributeId typedef expr::Attribute SygusToBuiltinTermAttribute; -Node sygusToBuiltin(Node n) +Node sygusToBuiltin(Node n, bool isExternal) { Assert(n.isConst()); std::unordered_map visited; @@ -404,7 +538,7 @@ Node sygusToBuiltin(Node n) { if (cur.getKind() == APPLY_CONSTRUCTOR) { - if (cur.hasAttribute(SygusToBuiltinTermAttribute())) + if (!isExternal && cur.hasAttribute(SygusToBuiltinTermAttribute())) { visited[cur] = cur.getAttribute(SygusToBuiltinTermAttribute()); } @@ -445,12 +579,15 @@ Node sygusToBuiltin(Node n) children.push_back(it->second); } index = indexOf(cur.getOperator()); - ret = mkSygusTerm(dt, index, children); + ret = mkSygusTerm(dt, index, children, true, isExternal); } visited[cur] = ret; // cache - SygusToBuiltinTermAttribute stbt; - cur.setAttribute(stbt, ret); + if (!isExternal) + { + SygusToBuiltinTermAttribute stbt; + cur.setAttribute(stbt, ret); + } } } while (!visit.empty()); Assert(visited.find(n) != visited.end()); diff --git a/src/theory/datatypes/theory_datatypes_utils.h b/src/theory/datatypes/theory_datatypes_utils.h index b23302276..58f719910 100644 --- a/src/theory/datatypes/theory_datatypes_utils.h +++ b/src/theory/datatypes/theory_datatypes_utils.h @@ -146,17 +146,31 @@ bool checkClash(Node n1, Node n2, std::vector& rew); * function mkSygusTerm. */ Kind getOperatorKindForSygusBuiltin(Node op); +/** + * Returns the total version of Kind k if it is a partial operator, or + * otherwise k itself. + */ +Kind getEliminateKind(Kind k); +/** + * Returns a version of n where all partial functions such as bvudiv + * have been replaced by their total versions like bvudiv_total. + */ +Node eliminatePartialOperators(Node n); /** make sygus term * * This function returns a builtin term f( children[0], ..., children[n] ) * where f is the builtin op that the i^th constructor of sygus datatype dt * encodes. If doBetaReduction is true, then lambdas are eagerly eliminated * via beta reduction. + * + * If isExternal is true, then the returned term respects the original grammar + * that was provided. This includes the use of defined functions. */ Node mkSygusTerm(const DType& dt, unsigned i, const std::vector& children, - bool doBetaReduction = true); + bool doBetaReduction = true, + bool isExternal = false); /** * Same as above, but we already have the sygus operator op. The above method * is syntax sugar for calling this method on dt[i].getSygusOp(). @@ -201,8 +215,13 @@ Node applySygusArgs(const DType& dt, * equivalent. For example, given input C_*( C_x(), C_y() ), this method returns * x*y, assuming C_+, C_x, and C_y have sygus operators *, x, and y * respectively. + * + * If isExternal is true, then the returned term respects the original grammar + * that was provided. This includes the use of defined functions. This argument + * should typically be false, unless we are e.g. exporting the value of the + * term as a final solution. */ -Node sygusToBuiltin(Node c); +Node sygusToBuiltin(Node c, bool isExternal = false); /** Sygus to builtin eval * * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice that diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp index f00fd0092..3b2c56974 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp @@ -81,87 +81,6 @@ SygusGrammarNorm::TypeObject::TypeObject(TypeNode src_tn, TypeNode unres_tn) d_sdt(unres_tn.getAttribute(expr::VarNameAttr())) { } -Kind SygusGrammarNorm::TypeObject::getEliminateKind(Kind ok) -{ - Kind nk = ok; - // We also must ensure that builtin operators which are eliminated - // during expand definitions are replaced by the proper operator. - if (ok == BITVECTOR_UDIV) - { - nk = BITVECTOR_UDIV_TOTAL; - } - else if (ok == BITVECTOR_UREM) - { - nk = BITVECTOR_UREM_TOTAL; - } - else if (ok == DIVISION) - { - nk = DIVISION_TOTAL; - } - else if (ok == INTS_DIVISION) - { - nk = INTS_DIVISION_TOTAL; - } - else if (ok == INTS_MODULUS) - { - nk = INTS_MODULUS_TOTAL; - } - return nk; -} - -Node SygusGrammarNorm::TypeObject::eliminatePartialOperators(Node n) -{ - NodeManager* nm = NodeManager::currentNM(); - std::unordered_map visited; - std::unordered_map::iterator it; - std::vector visit; - TNode cur; - visit.push_back(n); - do - { - cur = visit.back(); - visit.pop_back(); - it = visited.find(cur); - - if (it == visited.end()) - { - visited[cur] = Node::null(); - visit.push_back(cur); - for (const Node& cn : cur) - { - visit.push_back(cn); - } - } - else if (it->second.isNull()) - { - Node ret = cur; - bool childChanged = false; - std::vector children; - if (cur.getMetaKind() == metakind::PARAMETERIZED) - { - children.push_back(cur.getOperator()); - } - for (const Node& cn : cur) - { - it = visited.find(cn); - Assert(it != visited.end()); - Assert(!it->second.isNull()); - childChanged = childChanged || cn != it->second; - children.push_back(it->second); - } - Kind ok = cur.getKind(); - Kind nk = getEliminateKind(ok); - if (nk != ok || childChanged) - { - ret = nm->mkNode(nk, children); - } - visited[cur] = ret; - } - } while (!visit.empty()); - Assert(visited.find(n) != visited.end()); - Assert(!visited.find(n)->second.isNull()); - return visited[n]; -} void SygusGrammarNorm::TypeObject::addConsInfo( SygusGrammarNorm* sygus_norm, @@ -174,41 +93,6 @@ void SygusGrammarNorm::TypeObject::addConsInfo( Node sygus_op = cons.getSygusOp(); Trace("sygus-grammar-normalize-debug") << ".....operator is " << sygus_op << std::endl; - Node exp_sop_n = sygus_op; - if (exp_sop_n.isConst()) - { - // If it is a builtin operator, convert to total version if necessary. - // First, get the kind for the operator. - Kind ok = NodeManager::operatorToKind(exp_sop_n); - Trace("sygus-grammar-normalize-debug") - << "...builtin kind is " << ok << std::endl; - Kind nk = getEliminateKind(ok); - if (nk != ok) - { - Trace("sygus-grammar-normalize-debug") - << "...replace by builtin operator " << nk << std::endl; - exp_sop_n = NodeManager::currentNM()->operatorOf(nk); - } - } - else - { - // Only expand definitions if the operator is not constant, since calling - // expandDefinitions on them should be a no-op. This check ensures we don't - // try to expand e.g. bitvector extract operators, whose type is undefined, - // and thus should not be passed to expandDefinitions. - exp_sop_n = Node::fromExpr( - smt::currentSmtEngine()->expandDefinitions(sygus_op.toExpr())); - exp_sop_n = Rewriter::rewrite(exp_sop_n); - Trace("sygus-grammar-normalize-debug") - << ".....operator (post-rewrite) is " << exp_sop_n << std::endl; - // eliminate all partial operators from it - exp_sop_n = eliminatePartialOperators(exp_sop_n); - Trace("sygus-grammar-normalize-debug") - << ".....operator (eliminate partial operators) is " << exp_sop_n - << std::endl; - // rewrite again - exp_sop_n = Rewriter::rewrite(exp_sop_n); - } std::vector consTypes; const std::vector >& args = cons.getArgs(); @@ -222,10 +106,8 @@ void SygusGrammarNorm::TypeObject::addConsInfo( consTypes.push_back(atype); } - Trace("sygus-type-cons-defs") << "\tOriginal op: " << cons.getSygusOp() - << "\n\tExpanded one: " << exp_sop_n << "\n\n"; d_sdt.addConstructor( - exp_sop_n, cons.getName(), consTypes, spc, cons.getWeight()); + sygus_op, cons.getName(), consTypes, spc, cons.getWeight()); } void SygusGrammarNorm::TypeObject::initializeDatatype( diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.h b/src/theory/quantifiers/sygus/sygus_grammar_norm.h index 360762b38..956228f38 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.h +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.h @@ -200,16 +200,6 @@ class SygusGrammarNorm void addConsInfo(SygusGrammarNorm* sygus_norm, const DTypeConstructor& cons, std::shared_ptr spc); - /** - * Returns the total version of Kind k if it is a partial operator, or - * otherwise k itself. - */ - static Kind getEliminateKind(Kind k); - /** - * Returns a version of n where all partial functions such as bvudiv - * have been replaced by their total versions like bvudiv_total. - */ - static Node eliminatePartialOperators(Node n); /** initializes a datatype with the information in the type object * diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index 1596c30f0..e69d746fe 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -1178,8 +1178,10 @@ bool SynthConjecture::getSynthSolutions( NodeManager* nm = NodeManager::currentNM(); std::vector sols; std::vector statuses; + Trace("cegqi-debug") << "getSynthSolutions..." << std::endl; if (!getSynthSolutionsInternal(sols, statuses)) { + Trace("cegqi-debug") << "...failed internal" << std::endl; return false; } // we add it to the solution map, indexed by this conjecture @@ -1188,12 +1190,16 @@ bool SynthConjecture::getSynthSolutions( { Node sol = sols[i]; int status = statuses[i]; + Trace("cegqi-debug") << "...got " << i << ": " << sol + << ", status=" << status << std::endl; // get the builtin solution Node bsol = sol; if (status != 0) { - // convert sygus to builtin here - bsol = d_tds->sygusToBuiltin(sol, sol.getType()); + // Convert sygus to builtin here. + // We must use the external representation to ensure bsol matches the + // grammar. + bsol = datatypes::utils::sygusToBuiltin(sol, true); } // convert to lambda TypeNode tn = d_embed_quant[0][i].getType(); @@ -1214,6 +1220,7 @@ bool SynthConjecture::getSynthSolutions( } // store in map smc[fvar] = bsol; + Trace("cegqi-debug") << "...return " << bsol << std::endl; } return true; } diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 8aae1890d..06dc2d87c 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1951,6 +1951,7 @@ set(regress_1_tests regress1/sygus/unbdd_inv_gen_ex7.sy regress1/sygus/unbdd_inv_gen_winf1.sy regress1/sygus/univ_2-long-repeat.sy + regress1/sygus/yoni-true-sol.smt2 regress1/sym/q-constant.smt2 regress1/sym/q-function.smt2 regress1/sym/qf-function.smt2 diff --git a/test/regress/regress1/sygus/yoni-true-sol.smt2 b/test/regress/regress1/sygus/yoni-true-sol.smt2 new file mode 100644 index 000000000..464f7c729 --- /dev/null +++ b/test/regress/regress1/sygus/yoni-true-sol.smt2 @@ -0,0 +1,20 @@ +; COMMAND-LINE: --produce-abducts +; EXPECT: (define-fun A () Bool (>= j i)) +(set-logic QF_LIA) +(set-option :produce-abducts true) +(declare-fun n () Int) +(declare-fun m () Int) +(declare-fun i () Int) +(declare-fun j () Int) +(assert (and (>= n 0) (>= m 0))) +(assert (< n i)) +(assert (< (+ i j) m)) +(get-abduct A + (<= n m) + ((GA Bool) (GJ Int) (GI Int)) + ( + (GA Bool ((>= GJ GI))) + (GJ Int ( j)) + (GI Int ( i)) + ) +) -- 2.30.2