Input user grammar in sygus abduct (#3119)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 26 Jul 2019 02:24:26 +0000 (21:24 -0500)
committerGitHub <noreply@github.com>
Fri, 26 Jul 2019 02:24:26 +0000 (21:24 -0500)
src/preprocessing/passes/sygus_abduct.cpp
src/preprocessing/passes/sygus_abduct.h

index 346915b51547dfdd9506c9de0021368ed9c9ba81..b2dc872e497466f3c73d99452f1270f0bdf860e9 100644 (file)
@@ -15,7 +15,9 @@
 
 #include "preprocessing/passes/sygus_abduct.h"
 
+#include "expr/datatype.h"
 #include "expr/node_algorithm.h"
+#include "printer/sygus_print_callback.h"
 #include "smt/smt_engine.h"
 #include "smt/smt_engine_scope.h"
 #include "smt/smt_statistics_registry.h"
@@ -37,11 +39,9 @@ SygusAbduct::SygusAbduct(PreprocessingPassContext* preprocContext)
 PreprocessingPassResult SygusAbduct::applyInternal(
     AssertionPipeline* assertionsToPreprocess)
 {
-  NodeManager* nm = NodeManager::currentNM();
   Trace("sygus-abduct") << "Run sygus abduct..." << std::endl;
 
   Trace("sygus-abduct-debug") << "Collect symbols..." << std::endl;
-  std::unordered_set<Node, NodeHashFunction> symset;
   std::vector<Node>& asserts = assertionsToPreprocess->ref();
   // do we have any assumptions, e.g. via check-sat-assuming?
   bool usingAssumptions = (assertionsToPreprocess->getNumAssumptions() > 0);
@@ -52,15 +52,44 @@ PreprocessingPassResult SygusAbduct::applyInternal(
   // conjecture Fc, and
   // - The conjunction of all other assertions are the axioms Fa.
   std::vector<Node> axioms;
-  for (size_t i = 0, size = asserts.size(); i < size; i++)
+  if (usingAssumptions)
   {
-    expr::getSymbols(asserts[i], symset);
-    // if we are not an assumption, add it to the set of axioms
-    if (usingAssumptions && i < assertionsToPreprocess->getAssumptionsStart())
+    for (size_t i = 0, astart = assertionsToPreprocess->getAssumptionsStart();
+         i < astart;
+         i++)
     {
+      // if we are not an assumption, add it to the set of axioms
       axioms.push_back(asserts[i]);
     }
   }
+
+  // the abduction grammar type we are using (null for now, until a future
+  // commit)
+  TypeNode abdGType;
+
+  Node res = mkAbductionConjecture(asserts, axioms, abdGType);
+
+  Node trueNode = NodeManager::currentNM()->mkConst(true);
+
+  assertionsToPreprocess->replace(0, res);
+  for (size_t i = 1, size = assertionsToPreprocess->size(); i < size; ++i)
+  {
+    assertionsToPreprocess->replace(i, trueNode);
+  }
+
+  return PreprocessingPassResult::NO_CONFLICT;
+}
+
+Node SygusAbduct::mkAbductionConjecture(const std::vector<Node>& asserts,
+                                        const std::vector<Node>& axioms,
+                                        TypeNode abdGType)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  std::unordered_set<Node, NodeHashFunction> symset;
+  for (size_t i = 0, size = asserts.size(); i < size; i++)
+  {
+    expr::getSymbols(asserts[i], symset);
+  }
   Trace("sygus-abduct-debug")
       << "...finish, got " << symset.size() << " symbols." << std::endl;
 
@@ -84,6 +113,8 @@ PreprocessingPassResult SygusAbduct::applyInternal(
       varlistTypes.push_back(tn);
     }
   }
+  // make the sygus variable list
+  Node abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
   Trace("sygus-abduct-debug") << "...finish" << std::endl;
 
   Trace("sygus-abduct-debug") << "Make abduction predicate..." << std::endl;
@@ -93,6 +124,170 @@ PreprocessingPassResult SygusAbduct::applyInternal(
   Node abd = nm->mkBoundVar("A", abdType);
   Trace("sygus-abduct-debug") << "...finish" << std::endl;
 
+  // if provided, we will associate it with the function-to-synthesize
+  if (!abdGType.isNull())
+  {
+    Assert(abdGType.isDatatype() && abdGType.getDatatype().isSygus());
+    // must convert all constructors to version with bound variables in "vars"
+    std::vector<Datatype> datatypes;
+    std::set<Type> unres;
+
+    Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl;
+    Trace("sygus-abduct-debug") << abdGType.getDatatype() << std::endl;
+
+    // datatype types we need to process
+    std::vector<TypeNode> dtToProcess;
+    // datatype types we have processed
+    std::map<TypeNode, TypeNode> dtProcessed;
+    dtToProcess.push_back(abdGType);
+    std::stringstream ssutn0;
+    ssutn0 << abdGType.getDatatype().getName() << "_s";
+    TypeNode abdTNew =
+        nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
+    unres.insert(abdTNew.toType());
+    dtProcessed[abdGType] = abdTNew;
+
+    // We must convert all symbols in the sygus datatype type abdGType to
+    // apply the substitution { syms -> varlist }, where syms is the free
+    // variables of the input problem, and varlist is the formal argument list
+    // of the abduct-to-synthesize. For example, given user-provided sygus
+    // grammar:
+    //   G -> a | +( b, G )
+    // we synthesize a abduct A with two arguments x_a and x_b corresponding to
+    // a and b, and reconstruct the grammar:
+    //   G' -> x_a | +( x_b, G' )
+    // In this way, x_a and x_b are treated as bound variables and handled as
+    // arguments of the abduct-to-synthesize instead of as free variables with
+    // no relation to A. We additionally require that x_a, when printed, prints
+    // "a", which we do with a custom sygus callback below.
+
+    // We are traversing over the subfield types of the datatype to convert
+    // them into the form described above.
+    while (!dtToProcess.empty())
+    {
+      std::vector<TypeNode> dtNextToProcess;
+      for (const TypeNode& curr : dtToProcess)
+      {
+        Assert(curr.isDatatype() && curr.getDatatype().isSygus());
+        const Datatype& dtc = curr.getDatatype();
+        std::stringstream ssdtn;
+        ssdtn << dtc.getName() << "_s";
+        datatypes.push_back(Datatype(ssdtn.str()));
+        Trace("sygus-abduct-debug")
+            << "Process datatype " << datatypes.back().getName() << "..."
+            << std::endl;
+        for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
+        {
+          Node op = Node::fromExpr(dtc[j].getSygusOp());
+          // apply the substitution to the argument
+          Node ops = op.substitute(
+              syms.begin(), syms.end(), varlist.begin(), varlist.end());
+          Trace("sygus-abduct-debug") << "  Process constructor " << op << " / "
+                                      << ops << "..." << std::endl;
+          std::vector<Type> cargs;
+          for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
+          {
+            TypeNode argt = TypeNode::fromType(dtc[j].getArgType(k));
+            std::map<TypeNode, TypeNode>::iterator itdp =
+                dtProcessed.find(argt);
+            TypeNode argtNew;
+            if (itdp == dtProcessed.end())
+            {
+              std::stringstream ssutn;
+              ssutn << argt.getDatatype().getName() << "_s";
+              argtNew =
+                  nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER);
+              Trace("sygus-abduct-debug")
+                  << "    ...unresolved type " << argtNew << " for " << argt
+                  << std::endl;
+              unres.insert(argtNew.toType());
+              dtProcessed[argt] = argtNew;
+              dtNextToProcess.push_back(argt);
+            }
+            else
+            {
+              argtNew = itdp->second;
+            }
+            Trace("sygus-abduct-debug")
+                << "    Arg #" << k << ": " << argtNew << std::endl;
+            cargs.push_back(argtNew.toType());
+          }
+          // callback prints as the expression
+          std::shared_ptr<SygusPrintCallback> spc;
+          std::vector<Expr> args;
+          if (op.getKind() == LAMBDA)
+          {
+            Node opBody = op[1];
+            for (const Node& v : op[0])
+            {
+              args.push_back(v.toExpr());
+            }
+            spc = std::make_shared<printer::SygusExprPrintCallback>(
+                opBody.toExpr(), args);
+          }
+          else if (cargs.empty())
+          {
+            spc = std::make_shared<printer::SygusExprPrintCallback>(op.toExpr(),
+                                                                    args);
+          }
+          std::stringstream ss;
+          ss << ops.getKind();
+          Trace("sygus-abduct-debug")
+              << "Add constructor : " << ops << std::endl;
+          datatypes.back().addSygusConstructor(
+              ops.toExpr(), ss.str(), cargs, spc);
+        }
+        Trace("sygus-abduct-debug")
+            << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
+        datatypes.back().setSygus(dtc.getSygusType(),
+                                  abvl.toExpr(),
+                                  dtc.getSygusAllowConst(),
+                                  dtc.getSygusAllowAll());
+      }
+      dtToProcess.clear();
+      dtToProcess.insert(
+          dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
+    }
+    Trace("sygus-abduct-debug")
+        << "Make " << datatypes.size() << " datatype types..." << std::endl;
+    // make the datatype types
+    std::vector<DatatypeType> datatypeTypes =
+        nm->toExprManager()->mkMutualDatatypeTypes(
+            datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
+    TypeNode abdGTypeS = TypeNode::fromType(datatypeTypes[0]);
+    if (Trace.isOn("sygus-abduct-debug"))
+    {
+      Trace("sygus-abduct-debug") << "Made datatype types:" << std::endl;
+      for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++)
+      {
+        const Datatype& dtj = datatypeTypes[j].getDatatype();
+        Trace("sygus-abduct-debug") << "#" << j << ": " << dtj << std::endl;
+        for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++)
+        {
+          for (unsigned l = 0, nargs = dtj[k].getNumArgs(); l < nargs; l++)
+          {
+            if (!dtj[k].getArgType(l).isDatatype())
+            {
+              Trace("sygus-abduct-debug")
+                  << "Argument " << l << " of " << dtj[k]
+                  << " is not datatype : " << dtj[k].getArgType(l) << std::endl;
+              AlwaysAssert(false);
+            }
+          }
+        }
+      }
+    }
+
+    Trace("sygus-abduct-debug")
+        << "Make sygus grammar attribute..." << std::endl;
+    Node sym = nm->mkBoundVar("sfproxy_abduct", abdGTypeS);
+    // Set the sygus grammar attribute to indicate that abdGTypeS encodes the
+    // grammar for abd.
+    theory::SygusSynthGrammarAttribute ssg;
+    abd.setAttribute(ssg, sym);
+    Trace("sygus-abduct-debug") << "Finished setting up grammar." << std::endl;
+  }
+
   Trace("sygus-abduct-debug") << "Make abduction predicate app..." << std::endl;
   std::vector<Node> achildren;
   achildren.push_back(abd);
@@ -102,7 +297,6 @@ PreprocessingPassResult SygusAbduct::applyInternal(
 
   Trace("sygus-abduct-debug") << "Set attributes..." << std::endl;
   // set the sygus bound variable list
-  Node abvl = nm->mkNode(BOUND_VAR_LIST, varlist);
   abd.setAttribute(theory::SygusSynthFunVarListAttribute(), abvl);
   Trace("sygus-abduct-debug") << "...finish" << std::endl;
 
@@ -158,15 +352,7 @@ PreprocessingPassResult SygusAbduct::applyInternal(
 
   Trace("sygus-abduct") << "Generate: " << res << std::endl;
 
-  Node trueNode = nm->mkConst(true);
-
-  assertionsToPreprocess->replace(0, res);
-  for (size_t i = 1, size = assertionsToPreprocess->size(); i < size; ++i)
-  {
-    assertionsToPreprocess->replace(i, trueNode);
-  }
-
-  return PreprocessingPassResult::NO_CONFLICT;
+  return res;
 }
 
 }  // namespace passes
index 0e0868cda2af6eb21a2f7ddcb304eb631b305fb1..db40b9688f9561093b2fbd9dbcf53b101e5d9f69 100644 (file)
@@ -56,6 +56,18 @@ class SygusAbduct : public PreprocessingPass
  public:
   SygusAbduct(PreprocessingPassContext* preprocContext);
 
+  /**
+   * Returns the sygus conjecture corresponding to the abduction problem for
+   * input problem (F above) given by asserts, and axioms (Fa above) given by
+   * axioms. Note that axioms is expected to be a subset of asserts.
+   *
+   * The type abdGType (if non-null) is a sygus datatype type that encodes the
+   * grammar that should be used for solutions of the abduction conjecture.
+   */
+  static Node mkAbductionConjecture(const std::vector<Node>& asserts,
+                                    const std::vector<Node>& axioms,
+                                    TypeNode abdGType);
+
  protected:
   /**
    * Replaces the set of assertions by an abduction sygus problem described