Add sygus datatype substitution utility method (#4390)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 4 Jun 2020 17:38:36 +0000 (12:38 -0500)
committerGitHub <noreply@github.com>
Thu, 4 Jun 2020 17:38:36 +0000 (12:38 -0500)
This makes the method for substiutiton and generalization of sygus datatypes a generic utility method. It updates the abduction method to use it. Interpolation is another target user of this utility.

src/theory/datatypes/theory_datatypes_utils.cpp
src/theory/datatypes/theory_datatypes_utils.h
src/theory/quantifiers/sygus/sygus_abduct.cpp

index ee0fd814ef38ba73f2ba11b7c5cf6b688c051218..ea67ab79dda332b978585919c618ccfb937d040a 100644 (file)
@@ -23,6 +23,7 @@
 #include "smt/smt_engine_scope.h"
 #include "theory/evaluator.h"
 #include "theory/rewriter.h"
+#include "printer/sygus_print_callback.h"
 
 using namespace CVC4;
 using namespace CVC4::kind;
@@ -711,6 +712,223 @@ Node sygusToBuiltinEval(Node n, const std::vector<Node>& args)
   return visited[n];
 }
 
+void getFreeSymbolsSygusType(TypeNode sdt,
+                             std::unordered_set<Node, NodeHashFunction>& syms)
+{
+  // datatype types we need to process
+  std::vector<TypeNode> typeToProcess;
+  // datatype types we have processed
+  std::map<TypeNode, TypeNode> typesProcessed;
+  typeToProcess.push_back(sdt);
+  while (!typeToProcess.empty())
+  {
+    std::vector<TypeNode> typeNextToProcess;
+    for (const TypeNode& curr : typeToProcess)
+    {
+      Assert(curr.isDatatype() && curr.getDType().isSygus());
+      const DType& dtc = curr.getDType();
+      for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
+      {
+        // collect the symbols from the operator
+        Node op = dtc[j].getSygusOp();
+        expr::getSymbols(op, syms);
+        // traverse the argument types
+        for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
+        {
+          TypeNode argt = dtc[j].getArgType(k);
+          if (!argt.isDatatype() || !argt.getDType().isSygus())
+          {
+            // not a sygus datatype
+            continue;
+          }
+          if (typesProcessed.find(argt) == typesProcessed.end())
+          {
+            typeNextToProcess.push_back(argt);
+          }
+        }
+      }
+    }
+    typeToProcess.clear();
+    typeToProcess.insert(typeToProcess.end(),
+                         typeNextToProcess.begin(),
+                         typeNextToProcess.end());
+  }
+}
+
+TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
+                                          const std::vector<Node>& syms,
+                                          const std::vector<Node>& vars)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  const DType& sdtd = sdt.getDType();
+  // compute the new formal argument list
+  std::vector<Node> formalVars;
+  Node prevVarList = sdtd.getSygusVarList();
+  if (!prevVarList.isNull())
+  {
+    for (const Node& v : prevVarList)
+    {
+      // if it is not being replaced
+      if (std::find(syms.begin(), syms.end(), v) != syms.end())
+      {
+        formalVars.push_back(v);
+      }
+    }
+  }
+  for (const Node& v : vars)
+  {
+    if (v.getKind() == BOUND_VARIABLE)
+    {
+      formalVars.push_back(v);
+    }
+  }
+  // make the sygus variable list for the formal argument list
+  Node abvl = nm->mkNode(BOUND_VAR_LIST, formalVars);
+  Trace("sygus-abduct-debug") << "...finish" << std::endl;
+
+  // must convert all constructors to version with variables in "vars"
+  std::vector<SygusDatatype> sdts;
+  std::set<Type> unres;
+
+  Trace("dtsygus-gen-debug") << "Process sygus type:" << std::endl;
+  Trace("dtsygus-gen-debug") << sdtd.getName() << std::endl;
+
+  // datatype types we need to process
+  std::vector<TypeNode> dtToProcess;
+  // datatype types we have processed
+  std::map<TypeNode, TypeNode> dtProcessed;
+  dtToProcess.push_back(sdt);
+  std::stringstream ssutn0;
+  ssutn0 << sdtd.getName() << "_s";
+  TypeNode abdTNew =
+      nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
+  unres.insert(abdTNew.toType());
+  dtProcessed[sdt] = abdTNew;
+
+  // We must convert all symbols in the sygus datatype type sdt to
+  // apply the substitution { syms -> vars }, where syms is the free
+  // variables of the input problem, and vars is the formal argument list
+  // of the function-to-synthesize.
+
+  // We are traversing over the subfield types of the datatype to convert
+  // them into the form described above.
+  while (!dtToProcess.empty())
+  {
+    std::vector<TypeNode> dtNextToProcess;
+    for (const TypeNode& curr : dtToProcess)
+    {
+      Assert(curr.isDatatype() && curr.getDType().isSygus());
+      const DType& dtc = curr.getDType();
+      std::stringstream ssdtn;
+      ssdtn << dtc.getName() << "_s";
+      sdts.push_back(SygusDatatype(ssdtn.str()));
+      Trace("dtsygus-gen-debug")
+          << "Process datatype " << sdts.back().getName() << "..." << std::endl;
+      for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
+      {
+        Node op = dtc[j].getSygusOp();
+        // apply the substitution to the argument
+        Node ops =
+            op.substitute(syms.begin(), syms.end(), vars.begin(), vars.end());
+        Trace("dtsygus-gen-debug") << "  Process constructor " << op << " / "
+                                   << ops << "..." << std::endl;
+        std::vector<TypeNode> cargs;
+        for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
+        {
+          TypeNode argt = dtc[j].getArgType(k);
+          std::map<TypeNode, TypeNode>::iterator itdp = dtProcessed.find(argt);
+          TypeNode argtNew;
+          if (itdp == dtProcessed.end())
+          {
+            std::stringstream ssutn;
+            ssutn << argt.getDType().getName() << "_s";
+            argtNew =
+                nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
+            Trace("dtsygus-gen-debug") << "    ...unresolved type " << argtNew
+                                       << " for " << argt << std::endl;
+            unres.insert(argtNew.toType());
+            dtProcessed[argt] = argtNew;
+            dtNextToProcess.push_back(argt);
+          }
+          else
+          {
+            argtNew = itdp->second;
+          }
+          Trace("dtsygus-gen-debug")
+              << "    Arg #" << k << ": " << argtNew << std::endl;
+          cargs.push_back(argtNew);
+        }
+        // callback prints as the expression
+        std::shared_ptr<SygusPrintCallback> spc;
+        std::vector<Expr> args;
+        if (op.getKind() == LAMBDA)
+        {
+          Node opBody = op[1];
+          for (const Node& v : op[0])
+          {
+            args.push_back(v.toExpr());
+          }
+          spc = std::make_shared<printer::SygusExprPrintCallback>(
+              opBody.toExpr(), args);
+        }
+        else if (cargs.empty())
+        {
+          spc = std::make_shared<printer::SygusExprPrintCallback>(op.toExpr(),
+                                                                  args);
+        }
+        std::stringstream ss;
+        ss << ops.getKind();
+        Trace("dtsygus-gen-debug") << "Add constructor : " << ops << std::endl;
+        sdts.back().addConstructor(ops, ss.str(), cargs, spc);
+      }
+      Trace("dtsygus-gen-debug")
+          << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
+      TypeNode stn = dtc.getSygusType();
+      sdts.back().initializeDatatype(
+          stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
+    }
+    dtToProcess.clear();
+    dtToProcess.insert(
+        dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
+  }
+  Trace("dtsygus-gen-debug")
+      << "Make " << sdts.size() << " datatype types..." << std::endl;
+  // extract the datatypes
+  std::vector<Datatype> datatypes;
+  for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+  {
+    datatypes.push_back(sdts[i].getDatatype());
+  }
+  // make the datatype types
+  std::vector<DatatypeType> datatypeTypes =
+      nm->toExprManager()->mkMutualDatatypeTypes(
+          datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
+  TypeNode sdtS = TypeNode::fromType(datatypeTypes[0]);
+  if (Trace.isOn("dtsygus-gen-debug"))
+  {
+    Trace("dtsygus-gen-debug") << "Made datatype types:" << std::endl;
+    for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
+    {
+      const DType& dtj = TypeNode::fromType(datatypeTypes[j]).getDType();
+      Trace("dtsygus-gen-debug") << "#" << j << ": " << dtj << std::endl;
+      for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
+      {
+        for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
+        {
+          if (!dtj[k].getArgType(l).isDatatype())
+          {
+            Trace("dtsygus-gen-debug")
+                << "Argument " << l << " of " << dtj[k]
+                << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
+            AlwaysAssert(false);
+          }
+        }
+      }
+    }
+  }
+  return sdtS;
+}
+
 }  // namespace utils
 }  // namespace datatypes
 }  // namespace theory
index 58f719910ab90fdaf6b5c6353acbb1062f50f801..038922f3771065eae422005c0c904d4fd2c1873a 100644 (file)
@@ -245,6 +245,46 @@ Node sygusToBuiltin(Node c, bool isExternal = false);
  */
 Node sygusToBuiltinEval(Node n, const std::vector<Node>& args);
 
+/** Get free symbols in a sygus datatype type
+ *
+ * Add the free symbols (expr::getSymbols) in terms that can be generated by
+ * sygus datatype sdt to the set syms. For example, given sdt encodes the
+ * grammar:
+ *   G -> a | +( b, G ) | c | e
+ * We have that { a, b, c, e } are added to syms. Notice that expr::getSymbols
+ * excludes variables whose kind is BOUND_VARIABLE.
+ */
+void getFreeSymbolsSygusType(TypeNode sdt,
+                             std::unordered_set<Node, NodeHashFunction>& syms);
+
+/** Substitute and generalize a sygus datatype type
+ *
+ * This transforms a sygus datatype sdt into another one sdt' that generates
+ * terms t such that t * { vars -> syms } is generated by sdt.
+ *
+ * The arguments syms and vars should be vectors of the same size and types.
+ * It is recommended that the arguments in syms and vars should be variables
+ * (return true for .isVar()) but this is not required.
+ *
+ * The variables in vars of type BOUND_VARIABLE are added to the
+ * formal argument list of t. Other symbols are not.
+ *
+ * For example, given sdt encodes the grammar:
+ *   G -> a | +( b, G ) | c | e
+ * Let syms = { a, b, c } and vars = { x_a, x_b, d }, where x_a and x_b have
+ * type BOUND_VARIABLE and d does not.
+ * The returned type encodes the grammar:
+ *   G' -> x_a | +( x_b, G' ) | d | e
+ * Additionally, x_a and x_b are treated as formal arguments of a function
+ * to synthesize whose syntax restrictions are specified by G'.
+ *
+ * This method traverses the type definition of the datatype corresponding to
+ * the argument sdt.
+ */
+TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
+                                          const std::vector<Node>& syms,
+                                          const std::vector<Node>& vars);
+
 // ------------------------ end sygus utils
 
 }  // namespace utils
index 22b56eb6b21a87b7fc80ba28abe3f922c29400dd..ef2e7e445e448454ce61c1b07b683298374f726c 100644 (file)
@@ -19,7 +19,6 @@
 #include "expr/dtype.h"
 #include "expr/node_algorithm.h"
 #include "expr/sygus_datatype.h"
-#include "printer/sygus_print_callback.h"
 #include "theory/datatypes/theory_datatypes_utils.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/quantifiers/quantifiers_rewriter.h"
@@ -79,8 +78,6 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
     SygusVarToTermAttribute sta;
     vlv.setAttribute(sta, s);
   }
-  // make the sygus variable list
-  Node abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
   Trace("sygus-abduct-debug") << "...finish" << std::endl;
 
   Trace("sygus-abduct-debug") << "Make abduction predicate..." << std::endl;
@@ -90,163 +87,23 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
   Node abd = nm->mkBoundVar(name.c_str(), abdType);
   Trace("sygus-abduct-debug") << "...finish" << std::endl;
 
-  // if provided, we will associate it with the function-to-synthesize
+  // the sygus variable list
+  Node abvl;
+  // if provided, we will associate the provide sygus datatype type with the
+  // function-to-synthesize. However, we must convert it so that its
+  // free symbols are universally quantified.
   if (!abdGType.isNull())
   {
     Assert(abdGType.isDatatype() && abdGType.getDType().isSygus());
-    // must convert all constructors to version with bound variables in "vars"
-    std::vector<SygusDatatype> sdts;
-    std::set<Type> unres;
-
     Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl;
     Trace("sygus-abduct-debug") << abdGType.getDType().getName() << std::endl;
 
-    // datatype types we need to process
-    std::vector<TypeNode> dtToProcess;
-    // datatype types we have processed
-    std::map<TypeNode, TypeNode> dtProcessed;
-    dtToProcess.push_back(abdGType);
-    std::stringstream ssutn0;
-    ssutn0 << abdGType.getDType().getName() << "_s";
-    TypeNode abdTNew =
-        nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
-    unres.insert(abdTNew.toType());
-    dtProcessed[abdGType] = abdTNew;
-
-    // We must convert all symbols in the sygus datatype type abdGType to
-    // apply the substitution { syms -> varlist }, where syms is the free
-    // variables of the input problem, and varlist is the formal argument list
-    // of the abduct-to-synthesize. For example, given user-provided sygus
-    // grammar:
-    //   G -> a | +( b, G )
-    // we synthesize a abduct A with two arguments x_a and x_b corresponding to
-    // a and b, and reconstruct the grammar:
-    //   G' -> x_a | +( x_b, G' )
-    // In this way, x_a and x_b are treated as bound variables and handled as
-    // arguments of the abduct-to-synthesize instead of as free variables with
-    // no relation to A. We additionally require that x_a, when printed, prints
-    // "a", which we do with a custom sygus callback below.
+    // substitute the free symbols of the grammar with variables corresponding
+    // to the formal argument list of the new sygus datatype type.
+    TypeNode abdGTypeS = datatypes::utils::substituteAndGeneralizeSygusType(
+        abdGType, syms, varlist);
 
-    // We are traversing over the subfield types of the datatype to convert
-    // them into the form described above.
-    while (!dtToProcess.empty())
-    {
-      std::vector<TypeNode> dtNextToProcess;
-      for (const TypeNode& curr : dtToProcess)
-      {
-        Assert(curr.isDatatype() && curr.getDType().isSygus());
-        const DType& dtc = curr.getDType();
-        std::stringstream ssdtn;
-        ssdtn << dtc.getName() << "_s";
-        sdts.push_back(SygusDatatype(ssdtn.str()));
-        Trace("sygus-abduct-debug")
-            << "Process datatype " << sdts.back().getName() << "..."
-            << std::endl;
-        for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
-        {
-          Node op = dtc[j].getSygusOp();
-          // apply the substitution to the argument
-          Node ops = op.substitute(
-              syms.begin(), syms.end(), varlist.begin(), varlist.end());
-          Trace("sygus-abduct-debug") << "  Process constructor " << op << " / "
-                                      << ops << "..." << std::endl;
-          std::vector<TypeNode> cargs;
-          for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
-          {
-            TypeNode argt = dtc[j].getArgType(k);
-            std::map<TypeNode, TypeNode>::iterator itdp =
-                dtProcessed.find(argt);
-            TypeNode argtNew;
-            if (itdp == dtProcessed.end())
-            {
-              std::stringstream ssutn;
-              ssutn << argt.getDType().getName() << "_s";
-              argtNew =
-                  nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
-              Trace("sygus-abduct-debug")
-                  << "    ...unresolved type " << argtNew << " for " << argt
-                  << std::endl;
-              unres.insert(argtNew.toType());
-              dtProcessed[argt] = argtNew;
-              dtNextToProcess.push_back(argt);
-            }
-            else
-            {
-              argtNew = itdp->second;
-            }
-            Trace("sygus-abduct-debug")
-                << "    Arg #" << k << ": " << argtNew << std::endl;
-            cargs.push_back(argtNew);
-          }
-          // callback prints as the expression
-          std::shared_ptr<SygusPrintCallback> spc;
-          std::vector<Expr> args;
-          if (op.getKind() == LAMBDA)
-          {
-            Node opBody = op[1];
-            for (const Node& v : op[0])
-            {
-              args.push_back(v.toExpr());
-            }
-            spc = std::make_shared<printer::SygusExprPrintCallback>(
-                opBody.toExpr(), args);
-          }
-          else if (cargs.empty())
-          {
-            spc = std::make_shared<printer::SygusExprPrintCallback>(op.toExpr(),
-                                                                    args);
-          }
-          std::stringstream ss;
-          ss << ops.getKind();
-          Trace("sygus-abduct-debug")
-              << "Add constructor : " << ops << std::endl;
-          sdts.back().addConstructor(ops, ss.str(), cargs, spc);
-        }
-        Trace("sygus-abduct-debug")
-            << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
-        TypeNode stn = dtc.getSygusType();
-        sdts.back().initializeDatatype(
-            stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
-      }
-      dtToProcess.clear();
-      dtToProcess.insert(
-          dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
-    }
-    Trace("sygus-abduct-debug")
-        << "Make " << sdts.size() << " datatype types..." << std::endl;
-    // extract the datatypes
-    std::vector<Datatype> datatypes;
-    for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
-    {
-      datatypes.push_back(sdts[i].getDatatype());
-    }
-    // make the datatype types
-    std::vector<DatatypeType> datatypeTypes =
-        nm->toExprManager()->mkMutualDatatypeTypes(
-            datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
-    TypeNode abdGTypeS = TypeNode::fromType(datatypeTypes[0]);
-    if (Trace.isOn("sygus-abduct-debug"))
-    {
-      Trace("sygus-abduct-debug") << "Made datatype types:" << std::endl;
-      for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
-      {
-        const DType& dtj = TypeNode::fromType(datatypeTypes[j]).getDType();
-        Trace("sygus-abduct-debug") << "#" << j << ": " << dtj << std::endl;
-        for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
-        {
-          for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
-          {
-            if (!dtj[k].getArgType(l).isDatatype())
-            {
-              Trace("sygus-abduct-debug")
-                  << "Argument " << l << " of " << dtj[k]
-                  << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
-              AlwaysAssert(false);
-            }
-          }
-        }
-      }
-    }
+    Assert(abdGTypeS.isDatatype() && abdGTypeS.getDType().isSygus());
 
     Trace("sygus-abduct-debug")
         << "Make sygus grammar attribute..." << std::endl;
@@ -256,6 +113,19 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
     theory::SygusSynthGrammarAttribute ssg;
     abd.setAttribute(ssg, sym);
     Trace("sygus-abduct-debug") << "Finished setting up grammar." << std::endl;
+
+    // use the bound variable list from the new substituted grammar type
+    const DType& agtsd = abdGTypeS.getDType();
+    abvl = agtsd.getSygusVarList();
+    Assert(!abvl.isNull() && abvl.getKind() == BOUND_VAR_LIST);
+  }
+  else
+  {
+    // the bound variable list of the abduct-to-synthesize is determined by
+    // the variable list above
+    abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
+    // We do not set a grammar type for abd (SygusSynthGrammarAttribute).
+    // Its grammar will be constructed internally in the default way
   }
 
   Trace("sygus-abduct-debug") << "Make abduction predicate app..." << std::endl;