This makes the method for substiutiton and generalization of sygus datatypes a generic utility method. It updates the abduction method to use it. Interpolation is another target user of this utility.
#include "smt/smt_engine_scope.h"
#include "theory/evaluator.h"
#include "theory/rewriter.h"
+#include "printer/sygus_print_callback.h"
using namespace CVC4;
using namespace CVC4::kind;
return visited[n];
}
+void getFreeSymbolsSygusType(TypeNode sdt,
+ std::unordered_set<Node, NodeHashFunction>& syms)
+{
+ // datatype types we need to process
+ std::vector<TypeNode> typeToProcess;
+ // datatype types we have processed
+ std::map<TypeNode, TypeNode> typesProcessed;
+ typeToProcess.push_back(sdt);
+ while (!typeToProcess.empty())
+ {
+ std::vector<TypeNode> typeNextToProcess;
+ for (const TypeNode& curr : typeToProcess)
+ {
+ Assert(curr.isDatatype() && curr.getDType().isSygus());
+ const DType& dtc = curr.getDType();
+ for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
+ {
+ // collect the symbols from the operator
+ Node op = dtc[j].getSygusOp();
+ expr::getSymbols(op, syms);
+ // traverse the argument types
+ for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
+ {
+ TypeNode argt = dtc[j].getArgType(k);
+ if (!argt.isDatatype() || !argt.getDType().isSygus())
+ {
+ // not a sygus datatype
+ continue;
+ }
+ if (typesProcessed.find(argt) == typesProcessed.end())
+ {
+ typeNextToProcess.push_back(argt);
+ }
+ }
+ }
+ }
+ typeToProcess.clear();
+ typeToProcess.insert(typeToProcess.end(),
+ typeNextToProcess.begin(),
+ typeNextToProcess.end());
+ }
+}
+
+TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
+ const std::vector<Node>& syms,
+ const std::vector<Node>& vars)
+{
+ NodeManager* nm = NodeManager::currentNM();
+ const DType& sdtd = sdt.getDType();
+ // compute the new formal argument list
+ std::vector<Node> formalVars;
+ Node prevVarList = sdtd.getSygusVarList();
+ if (!prevVarList.isNull())
+ {
+ for (const Node& v : prevVarList)
+ {
+ // if it is not being replaced
+ if (std::find(syms.begin(), syms.end(), v) != syms.end())
+ {
+ formalVars.push_back(v);
+ }
+ }
+ }
+ for (const Node& v : vars)
+ {
+ if (v.getKind() == BOUND_VARIABLE)
+ {
+ formalVars.push_back(v);
+ }
+ }
+ // make the sygus variable list for the formal argument list
+ Node abvl = nm->mkNode(BOUND_VAR_LIST, formalVars);
+ Trace("sygus-abduct-debug") << "...finish" << std::endl;
+
+ // must convert all constructors to version with variables in "vars"
+ std::vector<SygusDatatype> sdts;
+ std::set<Type> unres;
+
+ Trace("dtsygus-gen-debug") << "Process sygus type:" << std::endl;
+ Trace("dtsygus-gen-debug") << sdtd.getName() << std::endl;
+
+ // datatype types we need to process
+ std::vector<TypeNode> dtToProcess;
+ // datatype types we have processed
+ std::map<TypeNode, TypeNode> dtProcessed;
+ dtToProcess.push_back(sdt);
+ std::stringstream ssutn0;
+ ssutn0 << sdtd.getName() << "_s";
+ TypeNode abdTNew =
+ nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
+ unres.insert(abdTNew.toType());
+ dtProcessed[sdt] = abdTNew;
+
+ // We must convert all symbols in the sygus datatype type sdt to
+ // apply the substitution { syms -> vars }, where syms is the free
+ // variables of the input problem, and vars is the formal argument list
+ // of the function-to-synthesize.
+
+ // We are traversing over the subfield types of the datatype to convert
+ // them into the form described above.
+ while (!dtToProcess.empty())
+ {
+ std::vector<TypeNode> dtNextToProcess;
+ for (const TypeNode& curr : dtToProcess)
+ {
+ Assert(curr.isDatatype() && curr.getDType().isSygus());
+ const DType& dtc = curr.getDType();
+ std::stringstream ssdtn;
+ ssdtn << dtc.getName() << "_s";
+ sdts.push_back(SygusDatatype(ssdtn.str()));
+ Trace("dtsygus-gen-debug")
+ << "Process datatype " << sdts.back().getName() << "..." << std::endl;
+ for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
+ {
+ Node op = dtc[j].getSygusOp();
+ // apply the substitution to the argument
+ Node ops =
+ op.substitute(syms.begin(), syms.end(), vars.begin(), vars.end());
+ Trace("dtsygus-gen-debug") << " Process constructor " << op << " / "
+ << ops << "..." << std::endl;
+ std::vector<TypeNode> cargs;
+ for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
+ {
+ TypeNode argt = dtc[j].getArgType(k);
+ std::map<TypeNode, TypeNode>::iterator itdp = dtProcessed.find(argt);
+ TypeNode argtNew;
+ if (itdp == dtProcessed.end())
+ {
+ std::stringstream ssutn;
+ ssutn << argt.getDType().getName() << "_s";
+ argtNew =
+ nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
+ Trace("dtsygus-gen-debug") << " ...unresolved type " << argtNew
+ << " for " << argt << std::endl;
+ unres.insert(argtNew.toType());
+ dtProcessed[argt] = argtNew;
+ dtNextToProcess.push_back(argt);
+ }
+ else
+ {
+ argtNew = itdp->second;
+ }
+ Trace("dtsygus-gen-debug")
+ << " Arg #" << k << ": " << argtNew << std::endl;
+ cargs.push_back(argtNew);
+ }
+ // callback prints as the expression
+ std::shared_ptr<SygusPrintCallback> spc;
+ std::vector<Expr> args;
+ if (op.getKind() == LAMBDA)
+ {
+ Node opBody = op[1];
+ for (const Node& v : op[0])
+ {
+ args.push_back(v.toExpr());
+ }
+ spc = std::make_shared<printer::SygusExprPrintCallback>(
+ opBody.toExpr(), args);
+ }
+ else if (cargs.empty())
+ {
+ spc = std::make_shared<printer::SygusExprPrintCallback>(op.toExpr(),
+ args);
+ }
+ std::stringstream ss;
+ ss << ops.getKind();
+ Trace("dtsygus-gen-debug") << "Add constructor : " << ops << std::endl;
+ sdts.back().addConstructor(ops, ss.str(), cargs, spc);
+ }
+ Trace("dtsygus-gen-debug")
+ << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
+ TypeNode stn = dtc.getSygusType();
+ sdts.back().initializeDatatype(
+ stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
+ }
+ dtToProcess.clear();
+ dtToProcess.insert(
+ dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
+ }
+ Trace("dtsygus-gen-debug")
+ << "Make " << sdts.size() << " datatype types..." << std::endl;
+ // extract the datatypes
+ std::vector<Datatype> datatypes;
+ for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+ {
+ datatypes.push_back(sdts[i].getDatatype());
+ }
+ // make the datatype types
+ std::vector<DatatypeType> datatypeTypes =
+ nm->toExprManager()->mkMutualDatatypeTypes(
+ datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
+ TypeNode sdtS = TypeNode::fromType(datatypeTypes[0]);
+ if (Trace.isOn("dtsygus-gen-debug"))
+ {
+ Trace("dtsygus-gen-debug") << "Made datatype types:" << std::endl;
+ for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
+ {
+ const DType& dtj = TypeNode::fromType(datatypeTypes[j]).getDType();
+ Trace("dtsygus-gen-debug") << "#" << j << ": " << dtj << std::endl;
+ for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
+ {
+ for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
+ {
+ if (!dtj[k].getArgType(l).isDatatype())
+ {
+ Trace("dtsygus-gen-debug")
+ << "Argument " << l << " of " << dtj[k]
+ << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
+ AlwaysAssert(false);
+ }
+ }
+ }
+ }
+ }
+ return sdtS;
+}
+
} // namespace utils
} // namespace datatypes
} // namespace theory
*/
Node sygusToBuiltinEval(Node n, const std::vector<Node>& args);
+/** Get free symbols in a sygus datatype type
+ *
+ * Add the free symbols (expr::getSymbols) in terms that can be generated by
+ * sygus datatype sdt to the set syms. For example, given sdt encodes the
+ * grammar:
+ * G -> a | +( b, G ) | c | e
+ * We have that { a, b, c, e } are added to syms. Notice that expr::getSymbols
+ * excludes variables whose kind is BOUND_VARIABLE.
+ */
+void getFreeSymbolsSygusType(TypeNode sdt,
+ std::unordered_set<Node, NodeHashFunction>& syms);
+
+/** Substitute and generalize a sygus datatype type
+ *
+ * This transforms a sygus datatype sdt into another one sdt' that generates
+ * terms t such that t * { vars -> syms } is generated by sdt.
+ *
+ * The arguments syms and vars should be vectors of the same size and types.
+ * It is recommended that the arguments in syms and vars should be variables
+ * (return true for .isVar()) but this is not required.
+ *
+ * The variables in vars of type BOUND_VARIABLE are added to the
+ * formal argument list of t. Other symbols are not.
+ *
+ * For example, given sdt encodes the grammar:
+ * G -> a | +( b, G ) | c | e
+ * Let syms = { a, b, c } and vars = { x_a, x_b, d }, where x_a and x_b have
+ * type BOUND_VARIABLE and d does not.
+ * The returned type encodes the grammar:
+ * G' -> x_a | +( x_b, G' ) | d | e
+ * Additionally, x_a and x_b are treated as formal arguments of a function
+ * to synthesize whose syntax restrictions are specified by G'.
+ *
+ * This method traverses the type definition of the datatype corresponding to
+ * the argument sdt.
+ */
+TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
+ const std::vector<Node>& syms,
+ const std::vector<Node>& vars);
+
// ------------------------ end sygus utils
} // namespace utils
#include "expr/dtype.h"
#include "expr/node_algorithm.h"
#include "expr/sygus_datatype.h"
-#include "printer/sygus_print_callback.h"
#include "theory/datatypes/theory_datatypes_utils.h"
#include "theory/quantifiers/quantifiers_attributes.h"
#include "theory/quantifiers/quantifiers_rewriter.h"
SygusVarToTermAttribute sta;
vlv.setAttribute(sta, s);
}
- // make the sygus variable list
- Node abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
Trace("sygus-abduct-debug") << "...finish" << std::endl;
Trace("sygus-abduct-debug") << "Make abduction predicate..." << std::endl;
Node abd = nm->mkBoundVar(name.c_str(), abdType);
Trace("sygus-abduct-debug") << "...finish" << std::endl;
- // if provided, we will associate it with the function-to-synthesize
+ // the sygus variable list
+ Node abvl;
+ // if provided, we will associate the provide sygus datatype type with the
+ // function-to-synthesize. However, we must convert it so that its
+ // free symbols are universally quantified.
if (!abdGType.isNull())
{
Assert(abdGType.isDatatype() && abdGType.getDType().isSygus());
- // must convert all constructors to version with bound variables in "vars"
- std::vector<SygusDatatype> sdts;
- std::set<Type> unres;
-
Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl;
Trace("sygus-abduct-debug") << abdGType.getDType().getName() << std::endl;
- // datatype types we need to process
- std::vector<TypeNode> dtToProcess;
- // datatype types we have processed
- std::map<TypeNode, TypeNode> dtProcessed;
- dtToProcess.push_back(abdGType);
- std::stringstream ssutn0;
- ssutn0 << abdGType.getDType().getName() << "_s";
- TypeNode abdTNew =
- nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
- unres.insert(abdTNew.toType());
- dtProcessed[abdGType] = abdTNew;
-
- // We must convert all symbols in the sygus datatype type abdGType to
- // apply the substitution { syms -> varlist }, where syms is the free
- // variables of the input problem, and varlist is the formal argument list
- // of the abduct-to-synthesize. For example, given user-provided sygus
- // grammar:
- // G -> a | +( b, G )
- // we synthesize a abduct A with two arguments x_a and x_b corresponding to
- // a and b, and reconstruct the grammar:
- // G' -> x_a | +( x_b, G' )
- // In this way, x_a and x_b are treated as bound variables and handled as
- // arguments of the abduct-to-synthesize instead of as free variables with
- // no relation to A. We additionally require that x_a, when printed, prints
- // "a", which we do with a custom sygus callback below.
+ // substitute the free symbols of the grammar with variables corresponding
+ // to the formal argument list of the new sygus datatype type.
+ TypeNode abdGTypeS = datatypes::utils::substituteAndGeneralizeSygusType(
+ abdGType, syms, varlist);
- // We are traversing over the subfield types of the datatype to convert
- // them into the form described above.
- while (!dtToProcess.empty())
- {
- std::vector<TypeNode> dtNextToProcess;
- for (const TypeNode& curr : dtToProcess)
- {
- Assert(curr.isDatatype() && curr.getDType().isSygus());
- const DType& dtc = curr.getDType();
- std::stringstream ssdtn;
- ssdtn << dtc.getName() << "_s";
- sdts.push_back(SygusDatatype(ssdtn.str()));
- Trace("sygus-abduct-debug")
- << "Process datatype " << sdts.back().getName() << "..."
- << std::endl;
- for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
- {
- Node op = dtc[j].getSygusOp();
- // apply the substitution to the argument
- Node ops = op.substitute(
- syms.begin(), syms.end(), varlist.begin(), varlist.end());
- Trace("sygus-abduct-debug") << " Process constructor " << op << " / "
- << ops << "..." << std::endl;
- std::vector<TypeNode> cargs;
- for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
- {
- TypeNode argt = dtc[j].getArgType(k);
- std::map<TypeNode, TypeNode>::iterator itdp =
- dtProcessed.find(argt);
- TypeNode argtNew;
- if (itdp == dtProcessed.end())
- {
- std::stringstream ssutn;
- ssutn << argt.getDType().getName() << "_s";
- argtNew =
- nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
- Trace("sygus-abduct-debug")
- << " ...unresolved type " << argtNew << " for " << argt
- << std::endl;
- unres.insert(argtNew.toType());
- dtProcessed[argt] = argtNew;
- dtNextToProcess.push_back(argt);
- }
- else
- {
- argtNew = itdp->second;
- }
- Trace("sygus-abduct-debug")
- << " Arg #" << k << ": " << argtNew << std::endl;
- cargs.push_back(argtNew);
- }
- // callback prints as the expression
- std::shared_ptr<SygusPrintCallback> spc;
- std::vector<Expr> args;
- if (op.getKind() == LAMBDA)
- {
- Node opBody = op[1];
- for (const Node& v : op[0])
- {
- args.push_back(v.toExpr());
- }
- spc = std::make_shared<printer::SygusExprPrintCallback>(
- opBody.toExpr(), args);
- }
- else if (cargs.empty())
- {
- spc = std::make_shared<printer::SygusExprPrintCallback>(op.toExpr(),
- args);
- }
- std::stringstream ss;
- ss << ops.getKind();
- Trace("sygus-abduct-debug")
- << "Add constructor : " << ops << std::endl;
- sdts.back().addConstructor(ops, ss.str(), cargs, spc);
- }
- Trace("sygus-abduct-debug")
- << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
- TypeNode stn = dtc.getSygusType();
- sdts.back().initializeDatatype(
- stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
- }
- dtToProcess.clear();
- dtToProcess.insert(
- dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
- }
- Trace("sygus-abduct-debug")
- << "Make " << sdts.size() << " datatype types..." << std::endl;
- // extract the datatypes
- std::vector<Datatype> datatypes;
- for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
- {
- datatypes.push_back(sdts[i].getDatatype());
- }
- // make the datatype types
- std::vector<DatatypeType> datatypeTypes =
- nm->toExprManager()->mkMutualDatatypeTypes(
- datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
- TypeNode abdGTypeS = TypeNode::fromType(datatypeTypes[0]);
- if (Trace.isOn("sygus-abduct-debug"))
- {
- Trace("sygus-abduct-debug") << "Made datatype types:" << std::endl;
- for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
- {
- const DType& dtj = TypeNode::fromType(datatypeTypes[j]).getDType();
- Trace("sygus-abduct-debug") << "#" << j << ": " << dtj << std::endl;
- for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
- {
- for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
- {
- if (!dtj[k].getArgType(l).isDatatype())
- {
- Trace("sygus-abduct-debug")
- << "Argument " << l << " of " << dtj[k]
- << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
- AlwaysAssert(false);
- }
- }
- }
- }
- }
+ Assert(abdGTypeS.isDatatype() && abdGTypeS.getDType().isSygus());
Trace("sygus-abduct-debug")
<< "Make sygus grammar attribute..." << std::endl;
theory::SygusSynthGrammarAttribute ssg;
abd.setAttribute(ssg, sym);
Trace("sygus-abduct-debug") << "Finished setting up grammar." << std::endl;
+
+ // use the bound variable list from the new substituted grammar type
+ const DType& agtsd = abdGTypeS.getDType();
+ abvl = agtsd.getSygusVarList();
+ Assert(!abvl.isNull() && abvl.getKind() == BOUND_VAR_LIST);
+ }
+ else
+ {
+ // the bound variable list of the abduct-to-synthesize is determined by
+ // the variable list above
+ abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
+ // We do not set a grammar type for abd (SygusSynthGrammarAttribute).
+ // Its grammar will be constructed internally in the default way
}
Trace("sygus-abduct-debug") << "Make abduction predicate app..." << std::endl;