From: Andrew Reynolds Date: Thu, 4 Jun 2020 17:38:36 +0000 (-0500) Subject: Add sygus datatype substitution utility method (#4390) X-Git-Tag: cvc5-1.0.0~3258 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=c5bf818456ebe2dee833fecd4a0970f0105919f0;p=cvc5.git Add sygus datatype substitution utility method (#4390) This makes the method for substiutiton and generalization of sygus datatypes a generic utility method. It updates the abduction method to use it. Interpolation is another target user of this utility. --- diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp index ee0fd814e..ea67ab79d 100644 --- a/src/theory/datatypes/theory_datatypes_utils.cpp +++ b/src/theory/datatypes/theory_datatypes_utils.cpp @@ -23,6 +23,7 @@ #include "smt/smt_engine_scope.h" #include "theory/evaluator.h" #include "theory/rewriter.h" +#include "printer/sygus_print_callback.h" using namespace CVC4; using namespace CVC4::kind; @@ -711,6 +712,223 @@ Node sygusToBuiltinEval(Node n, const std::vector& args) return visited[n]; } +void getFreeSymbolsSygusType(TypeNode sdt, + std::unordered_set& syms) +{ + // datatype types we need to process + std::vector typeToProcess; + // datatype types we have processed + std::map typesProcessed; + typeToProcess.push_back(sdt); + while (!typeToProcess.empty()) + { + std::vector typeNextToProcess; + for (const TypeNode& curr : typeToProcess) + { + Assert(curr.isDatatype() && curr.getDType().isSygus()); + const DType& dtc = curr.getDType(); + for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++) + { + // collect the symbols from the operator + Node op = dtc[j].getSygusOp(); + expr::getSymbols(op, syms); + // traverse the argument types + for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++) + { + TypeNode argt = dtc[j].getArgType(k); + if (!argt.isDatatype() || !argt.getDType().isSygus()) + { + // not a sygus datatype + continue; + } + if (typesProcessed.find(argt) == typesProcessed.end()) + { + typeNextToProcess.push_back(argt); + } + } + } + } + typeToProcess.clear(); + typeToProcess.insert(typeToProcess.end(), + typeNextToProcess.begin(), + typeNextToProcess.end()); + } +} + +TypeNode substituteAndGeneralizeSygusType(TypeNode sdt, + const std::vector& syms, + const std::vector& vars) +{ + NodeManager* nm = NodeManager::currentNM(); + const DType& sdtd = sdt.getDType(); + // compute the new formal argument list + std::vector formalVars; + Node prevVarList = sdtd.getSygusVarList(); + if (!prevVarList.isNull()) + { + for (const Node& v : prevVarList) + { + // if it is not being replaced + if (std::find(syms.begin(), syms.end(), v) != syms.end()) + { + formalVars.push_back(v); + } + } + } + for (const Node& v : vars) + { + if (v.getKind() == BOUND_VARIABLE) + { + formalVars.push_back(v); + } + } + // make the sygus variable list for the formal argument list + Node abvl = nm->mkNode(BOUND_VAR_LIST, formalVars); + Trace("sygus-abduct-debug") << "...finish" << std::endl; + + // must convert all constructors to version with variables in "vars" + std::vector sdts; + std::set unres; + + Trace("dtsygus-gen-debug") << "Process sygus type:" << std::endl; + Trace("dtsygus-gen-debug") << sdtd.getName() << std::endl; + + // datatype types we need to process + std::vector dtToProcess; + // datatype types we have processed + std::map dtProcessed; + dtToProcess.push_back(sdt); + std::stringstream ssutn0; + ssutn0 << sdtd.getName() << "_s"; + TypeNode abdTNew = + nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER); + unres.insert(abdTNew.toType()); + dtProcessed[sdt] = abdTNew; + + // We must convert all symbols in the sygus datatype type sdt to + // apply the substitution { syms -> vars }, where syms is the free + // variables of the input problem, and vars is the formal argument list + // of the function-to-synthesize. + + // We are traversing over the subfield types of the datatype to convert + // them into the form described above. + while (!dtToProcess.empty()) + { + std::vector dtNextToProcess; + for (const TypeNode& curr : dtToProcess) + { + Assert(curr.isDatatype() && curr.getDType().isSygus()); + const DType& dtc = curr.getDType(); + std::stringstream ssdtn; + ssdtn << dtc.getName() << "_s"; + sdts.push_back(SygusDatatype(ssdtn.str())); + Trace("dtsygus-gen-debug") + << "Process datatype " << sdts.back().getName() << "..." << std::endl; + for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++) + { + Node op = dtc[j].getSygusOp(); + // apply the substitution to the argument + Node ops = + op.substitute(syms.begin(), syms.end(), vars.begin(), vars.end()); + Trace("dtsygus-gen-debug") << " Process constructor " << op << " / " + << ops << "..." << std::endl; + std::vector cargs; + for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++) + { + TypeNode argt = dtc[j].getArgType(k); + std::map::iterator itdp = dtProcessed.find(argt); + TypeNode argtNew; + if (itdp == dtProcessed.end()) + { + std::stringstream ssutn; + ssutn << argt.getDType().getName() << "_s"; + argtNew = + nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER); + Trace("dtsygus-gen-debug") << " ...unresolved type " << argtNew + << " for " << argt << std::endl; + unres.insert(argtNew.toType()); + dtProcessed[argt] = argtNew; + dtNextToProcess.push_back(argt); + } + else + { + argtNew = itdp->second; + } + Trace("dtsygus-gen-debug") + << " Arg #" << k << ": " << argtNew << std::endl; + cargs.push_back(argtNew); + } + // callback prints as the expression + std::shared_ptr spc; + std::vector args; + if (op.getKind() == LAMBDA) + { + Node opBody = op[1]; + for (const Node& v : op[0]) + { + args.push_back(v.toExpr()); + } + spc = std::make_shared( + opBody.toExpr(), args); + } + else if (cargs.empty()) + { + spc = std::make_shared(op.toExpr(), + args); + } + std::stringstream ss; + ss << ops.getKind(); + Trace("dtsygus-gen-debug") << "Add constructor : " << ops << std::endl; + sdts.back().addConstructor(ops, ss.str(), cargs, spc); + } + Trace("dtsygus-gen-debug") + << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl; + TypeNode stn = dtc.getSygusType(); + sdts.back().initializeDatatype( + stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll()); + } + dtToProcess.clear(); + dtToProcess.insert( + dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end()); + } + Trace("dtsygus-gen-debug") + << "Make " << sdts.size() << " datatype types..." << std::endl; + // extract the datatypes + std::vector datatypes; + for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++) + { + datatypes.push_back(sdts[i].getDatatype()); + } + // make the datatype types + std::vector datatypeTypes = + nm->toExprManager()->mkMutualDatatypeTypes( + datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER); + TypeNode sdtS = TypeNode::fromType(datatypeTypes[0]); + if (Trace.isOn("dtsygus-gen-debug")) + { + Trace("dtsygus-gen-debug") << "Made datatype types:" << std::endl; + for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++) + { + const DType& dtj = TypeNode::fromType(datatypeTypes[j]).getDType(); + Trace("dtsygus-gen-debug") << "#" << j << ": " << dtj << std::endl; + for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++) + { + for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++) + { + if (!dtj[k].getArgType(l).isDatatype()) + { + Trace("dtsygus-gen-debug") + << "Argument " << l << " of " << dtj[k] + << " is not datatype : " << dtj[k].getArgType(l) << std::endl; + AlwaysAssert(false); + } + } + } + } + } + return sdtS; +} + } // namespace utils } // namespace datatypes } // namespace theory diff --git a/src/theory/datatypes/theory_datatypes_utils.h b/src/theory/datatypes/theory_datatypes_utils.h index 58f719910..038922f37 100644 --- a/src/theory/datatypes/theory_datatypes_utils.h +++ b/src/theory/datatypes/theory_datatypes_utils.h @@ -245,6 +245,46 @@ Node sygusToBuiltin(Node c, bool isExternal = false); */ Node sygusToBuiltinEval(Node n, const std::vector& args); +/** Get free symbols in a sygus datatype type + * + * Add the free symbols (expr::getSymbols) in terms that can be generated by + * sygus datatype sdt to the set syms. For example, given sdt encodes the + * grammar: + * G -> a | +( b, G ) | c | e + * We have that { a, b, c, e } are added to syms. Notice that expr::getSymbols + * excludes variables whose kind is BOUND_VARIABLE. + */ +void getFreeSymbolsSygusType(TypeNode sdt, + std::unordered_set& syms); + +/** Substitute and generalize a sygus datatype type + * + * This transforms a sygus datatype sdt into another one sdt' that generates + * terms t such that t * { vars -> syms } is generated by sdt. + * + * The arguments syms and vars should be vectors of the same size and types. + * It is recommended that the arguments in syms and vars should be variables + * (return true for .isVar()) but this is not required. + * + * The variables in vars of type BOUND_VARIABLE are added to the + * formal argument list of t. Other symbols are not. + * + * For example, given sdt encodes the grammar: + * G -> a | +( b, G ) | c | e + * Let syms = { a, b, c } and vars = { x_a, x_b, d }, where x_a and x_b have + * type BOUND_VARIABLE and d does not. + * The returned type encodes the grammar: + * G' -> x_a | +( x_b, G' ) | d | e + * Additionally, x_a and x_b are treated as formal arguments of a function + * to synthesize whose syntax restrictions are specified by G'. + * + * This method traverses the type definition of the datatype corresponding to + * the argument sdt. + */ +TypeNode substituteAndGeneralizeSygusType(TypeNode sdt, + const std::vector& syms, + const std::vector& vars); + // ------------------------ end sygus utils } // namespace utils diff --git a/src/theory/quantifiers/sygus/sygus_abduct.cpp b/src/theory/quantifiers/sygus/sygus_abduct.cpp index 22b56eb6b..ef2e7e445 100644 --- a/src/theory/quantifiers/sygus/sygus_abduct.cpp +++ b/src/theory/quantifiers/sygus/sygus_abduct.cpp @@ -19,7 +19,6 @@ #include "expr/dtype.h" #include "expr/node_algorithm.h" #include "expr/sygus_datatype.h" -#include "printer/sygus_print_callback.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/quantifiers_rewriter.h" @@ -79,8 +78,6 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, SygusVarToTermAttribute sta; vlv.setAttribute(sta, s); } - // make the sygus variable list - Node abvl = nm->mkNode(BOUND_VAR_LIST, varlist); Trace("sygus-abduct-debug") << "...finish" << std::endl; Trace("sygus-abduct-debug") << "Make abduction predicate..." << std::endl; @@ -90,163 +87,23 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, Node abd = nm->mkBoundVar(name.c_str(), abdType); Trace("sygus-abduct-debug") << "...finish" << std::endl; - // if provided, we will associate it with the function-to-synthesize + // the sygus variable list + Node abvl; + // if provided, we will associate the provide sygus datatype type with the + // function-to-synthesize. However, we must convert it so that its + // free symbols are universally quantified. if (!abdGType.isNull()) { Assert(abdGType.isDatatype() && abdGType.getDType().isSygus()); - // must convert all constructors to version with bound variables in "vars" - std::vector sdts; - std::set unres; - Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl; Trace("sygus-abduct-debug") << abdGType.getDType().getName() << std::endl; - // datatype types we need to process - std::vector dtToProcess; - // datatype types we have processed - std::map dtProcessed; - dtToProcess.push_back(abdGType); - std::stringstream ssutn0; - ssutn0 << abdGType.getDType().getName() << "_s"; - TypeNode abdTNew = - nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER); - unres.insert(abdTNew.toType()); - dtProcessed[abdGType] = abdTNew; - - // We must convert all symbols in the sygus datatype type abdGType to - // apply the substitution { syms -> varlist }, where syms is the free - // variables of the input problem, and varlist is the formal argument list - // of the abduct-to-synthesize. For example, given user-provided sygus - // grammar: - // G -> a | +( b, G ) - // we synthesize a abduct A with two arguments x_a and x_b corresponding to - // a and b, and reconstruct the grammar: - // G' -> x_a | +( x_b, G' ) - // In this way, x_a and x_b are treated as bound variables and handled as - // arguments of the abduct-to-synthesize instead of as free variables with - // no relation to A. We additionally require that x_a, when printed, prints - // "a", which we do with a custom sygus callback below. + // substitute the free symbols of the grammar with variables corresponding + // to the formal argument list of the new sygus datatype type. + TypeNode abdGTypeS = datatypes::utils::substituteAndGeneralizeSygusType( + abdGType, syms, varlist); - // We are traversing over the subfield types of the datatype to convert - // them into the form described above. - while (!dtToProcess.empty()) - { - std::vector dtNextToProcess; - for (const TypeNode& curr : dtToProcess) - { - Assert(curr.isDatatype() && curr.getDType().isSygus()); - const DType& dtc = curr.getDType(); - std::stringstream ssdtn; - ssdtn << dtc.getName() << "_s"; - sdts.push_back(SygusDatatype(ssdtn.str())); - Trace("sygus-abduct-debug") - << "Process datatype " << sdts.back().getName() << "..." - << std::endl; - for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++) - { - Node op = dtc[j].getSygusOp(); - // apply the substitution to the argument - Node ops = op.substitute( - syms.begin(), syms.end(), varlist.begin(), varlist.end()); - Trace("sygus-abduct-debug") << " Process constructor " << op << " / " - << ops << "..." << std::endl; - std::vector cargs; - for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++) - { - TypeNode argt = dtc[j].getArgType(k); - std::map::iterator itdp = - dtProcessed.find(argt); - TypeNode argtNew; - if (itdp == dtProcessed.end()) - { - std::stringstream ssutn; - ssutn << argt.getDType().getName() << "_s"; - argtNew = - nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER); - Trace("sygus-abduct-debug") - << " ...unresolved type " << argtNew << " for " << argt - << std::endl; - unres.insert(argtNew.toType()); - dtProcessed[argt] = argtNew; - dtNextToProcess.push_back(argt); - } - else - { - argtNew = itdp->second; - } - Trace("sygus-abduct-debug") - << " Arg #" << k << ": " << argtNew << std::endl; - cargs.push_back(argtNew); - } - // callback prints as the expression - std::shared_ptr spc; - std::vector args; - if (op.getKind() == LAMBDA) - { - Node opBody = op[1]; - for (const Node& v : op[0]) - { - args.push_back(v.toExpr()); - } - spc = std::make_shared( - opBody.toExpr(), args); - } - else if (cargs.empty()) - { - spc = std::make_shared(op.toExpr(), - args); - } - std::stringstream ss; - ss << ops.getKind(); - Trace("sygus-abduct-debug") - << "Add constructor : " << ops << std::endl; - sdts.back().addConstructor(ops, ss.str(), cargs, spc); - } - Trace("sygus-abduct-debug") - << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl; - TypeNode stn = dtc.getSygusType(); - sdts.back().initializeDatatype( - stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll()); - } - dtToProcess.clear(); - dtToProcess.insert( - dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end()); - } - Trace("sygus-abduct-debug") - << "Make " << sdts.size() << " datatype types..." << std::endl; - // extract the datatypes - std::vector datatypes; - for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++) - { - datatypes.push_back(sdts[i].getDatatype()); - } - // make the datatype types - std::vector datatypeTypes = - nm->toExprManager()->mkMutualDatatypeTypes( - datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER); - TypeNode abdGTypeS = TypeNode::fromType(datatypeTypes[0]); - if (Trace.isOn("sygus-abduct-debug")) - { - Trace("sygus-abduct-debug") << "Made datatype types:" << std::endl; - for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++) - { - const DType& dtj = TypeNode::fromType(datatypeTypes[j]).getDType(); - Trace("sygus-abduct-debug") << "#" << j << ": " << dtj << std::endl; - for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++) - { - for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++) - { - if (!dtj[k].getArgType(l).isDatatype()) - { - Trace("sygus-abduct-debug") - << "Argument " << l << " of " << dtj[k] - << " is not datatype : " << dtj[k].getArgType(l) << std::endl; - AlwaysAssert(false); - } - } - } - } - } + Assert(abdGTypeS.isDatatype() && abdGTypeS.getDType().isSygus()); Trace("sygus-abduct-debug") << "Make sygus grammar attribute..." << std::endl; @@ -256,6 +113,19 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, theory::SygusSynthGrammarAttribute ssg; abd.setAttribute(ssg, sym); Trace("sygus-abduct-debug") << "Finished setting up grammar." << std::endl; + + // use the bound variable list from the new substituted grammar type + const DType& agtsd = abdGTypeS.getDType(); + abvl = agtsd.getSygusVarList(); + Assert(!abvl.isNull() && abvl.getKind() == BOUND_VAR_LIST); + } + else + { + // the bound variable list of the abduct-to-synthesize is determined by + // the variable list above + abvl = nm->mkNode(BOUND_VAR_LIST, varlist); + // We do not set a grammar type for abd (SygusSynthGrammarAttribute). + // Its grammar will be constructed internally in the default way } Trace("sygus-abduct-debug") << "Make abduction predicate app..." << std::endl;