Use standard sygus interface for abduction and rewrite rule synthesis (#3471)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 18 Nov 2019 19:52:18 +0000 (13:52 -0600)
committerGitHub <noreply@github.com>
Mon, 18 Nov 2019 19:52:18 +0000 (13:52 -0600)
src/preprocessing/passes/synth_rew_rules.cpp
src/theory/quantifiers/sygus/sygus_abduct.cpp

index 6d6e8fb2740d47380ae36d141af640cca507ef2b..47e64b2e4d2f0f54b4c249f62c6013a04e057b13 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "preprocessing/passes/synth_rew_rules.h"
 
+#include "expr/sygus_datatype.h"
 #include "expr/term_canonize.h"
 #include "options/base_options.h"
 #include "options/quantifiers_options.h"
@@ -236,14 +237,13 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
   Trace("srs-input") << "...finished." << std::endl;
   // the sygus variable list
   Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars);
-  Expr sygusVarListE = sygusVarList.toExpr();
   Trace("srs-input") << "Have " << cterms.size() << " canonical subterms."
                      << std::endl;
 
   Trace("srs-input") << "Construct unresolved types..." << std::endl;
   // each canonical subterm corresponds to a grammar type
   std::set<Type> unres;
-  std::vector<Datatype> datatypes;
+  std::vector<SygusDatatype> sdts;
   // make unresolved types for each canonical term
   std::map<Node, TypeNode> cterm_to_utype;
   for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
@@ -255,11 +255,11 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
     TypeNode tnu = nm->mkSort(tname, ExprManager::SORT_FLAG_PLACEHOLDER);
     cterm_to_utype[ct] = tnu;
     unres.insert(tnu.toType());
-    datatypes.push_back(Datatype(tname));
+    sdts.push_back(SygusDatatype(tname));
   }
   Trace("srs-input") << "...finished." << std::endl;
 
-  Trace("srs-input") << "Construct datatypes..." << std::endl;
+  Trace("srs-input") << "Construct sygus datatypes..." << std::endl;
   for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
   {
     Node ct = cterms[i];
@@ -268,7 +268,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
     // add the variables for the type
     TypeNode ctt = ct.getType();
     Assert(tvars.find(ctt) != tvars.end());
-    std::vector<Type> argList;
+    std::vector<TypeNode> argList;
     // we add variable constructors if we are not Boolean, we are interested
     // in purely propositional rewrites (via the option), or this term is
     // a Boolean variable.
@@ -279,7 +279,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
       {
         std::stringstream ssc;
         ssc << "C_" << i << "_" << v;
-        datatypes[i].addSygusConstructor(v.toExpr(), ssc.str(), argList);
+        sdts[i].addConstructor(v, ssc.str(), argList);
       }
     }
     // add the constructor for the operator if it is not a variable
@@ -295,7 +295,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
         Node ctc = term_to_cterm[tc];
         Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end());
         // get the type
-        argList.push_back(cterm_to_utype[ctc].toType());
+        argList.push_back(cterm_to_utype[ctc]);
       }
       // check if we should chain
       bool do_chain = false;
@@ -305,12 +305,12 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
         do_chain = theory::quantifiers::TermUtil::isAssoc(k)
                    && theory::quantifiers::TermUtil::isComm(k);
         // eliminate duplicate child types
-        std::vector<Type> argListTmp = argList;
+        std::vector<TypeNode> argListTmp = argList;
         argList.clear();
-        std::map<Type, bool> hasArgType;
+        std::map<TypeNode, bool> hasArgType;
         for (unsigned j = 0, size = argListTmp.size(); j < size; j++)
         {
-          Type t = argListTmp[j];
+          TypeNode t = argListTmp[j];
           if (hasArgType.find(t) == hasArgType.end())
           {
             hasArgType[t] = true;
@@ -323,9 +323,9 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
         // we make one type per child
         // the operator of each constructor is a no-op
         Node tbv = nm->mkBoundVar(ctt);
-        Expr lambdaOp =
-            nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr();
-        std::vector<Type> argListc;
+        Node lambdaOp =
+            nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
+        std::vector<TypeNode> argListc;
         // the following construction admits any number of repeated factors,
         // so for instance, t1+t2+t3, we generate the grammar:
         // T_{t1+t2+t3} ->
@@ -341,44 +341,49 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
           std::stringstream sscs;
           sscs << "C_factor_" << i << "_" << j;
           // ID function is not printed and does not count towards weight
-          datatypes[i].addSygusConstructor(
-              lambdaOp,
-              sscs.str(),
-              argListc,
-              printer::SygusEmptyPrintCallback::getEmptyPC(),
-              0);
+          sdts[i].addConstructor(lambdaOp,
+                                 sscs.str(),
+                                 argListc,
+                                 printer::SygusEmptyPrintCallback::getEmptyPC(),
+                                 0);
         }
         // recursive apply
-        Type recType = cterm_to_utype[ct].toType();
+        TypeNode recType = cterm_to_utype[ct];
         argListc.clear();
         argListc.push_back(recType);
         argListc.push_back(recType);
         std::stringstream ssc;
         ssc << "C_" << i << "_rec_" << op;
-        datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argListc);
+        sdts[i].addConstructor(op, ssc.str(), argListc);
       }
       else
       {
         std::stringstream ssc;
         ssc << "C_" << i << "_" << op;
-        datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argList);
+        sdts[i].addConstructor(op, ssc.str(), argList);
       }
     }
-    Assert(datatypes[i].getNumConstructors() > 0);
-    datatypes[i].setSygus(ctt.toType(), sygusVarListE, false, false);
+    Assert(sdts[i].getNumConstructors() > 0);
+    sdts[i].initializeDatatype(ctt, sygusVarList, false, false);
   }
   Trace("srs-input") << "...finished." << std::endl;
 
   Trace("srs-input") << "Make mutual datatype types for subterms..."
                      << std::endl;
+  // extract the datatypes
+  std::vector<Datatype> datatypes;
+  for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+  {
+    datatypes.push_back(sdts[i].getDatatype());
+  }
   std::vector<DatatypeType> types = nm->toExprManager()->mkMutualDatatypeTypes(
       datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
   Trace("srs-input") << "...finished." << std::endl;
   Assert(types.size() == unres.size());
-  std::map<Node, DatatypeType> subtermTypes;
+  std::map<Node, TypeNode> subtermTypes;
   for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++)
   {
-    subtermTypes[cterms[i]] = types[i];
+    subtermTypes[cterms[i]] = TypeNode::fromType(types[i]);
   }
 
   Trace("srs-input") << "Construct the top-level types..." << std::endl;
@@ -389,34 +394,34 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal(
     TypeNode t = tcp.first;
     std::stringstream ss;
     ss << "T_" << t;
-    Datatype dttl(ss.str());
+    SygusDatatype sdttl(ss.str());
     Node tbv = nm->mkBoundVar(t);
     // the operator of each constructor is a no-op
-    Expr lambdaOp =
-        nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr();
+    Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv);
     Trace("srs-input") << "  We have " << tcp.second.size()
                        << " subterms of type " << t << std::endl;
     for (unsigned i = 0, size = tcp.second.size(); i < size; i++)
     {
       Node n = tcp.second[i];
       // add constructor that encodes abstractions of this subterm
-      std::vector<Type> argList;
+      std::vector<TypeNode> argList;
       Assert(subtermTypes.find(n) != subtermTypes.end());
       argList.push_back(subtermTypes[n]);
       std::stringstream ssc;
       ssc << "Ctl_" << i;
       // the no-op should not be printed, hence we pass an empty callback
-      dttl.addSygusConstructor(lambdaOp,
-                               ssc.str(),
-                               argList,
-                               printer::SygusEmptyPrintCallback::getEmptyPC(),
-                               0);
+      sdttl.addConstructor(lambdaOp,
+                           ssc.str(),
+                           argList,
+                           printer::SygusEmptyPrintCallback::getEmptyPC(),
+                           0);
       Trace("srs-input-debug")
           << "Grammar for subterm " << n << " is: " << std::endl;
       Trace("srs-input-debug") << subtermTypes[n].getDatatype() << std::endl;
     }
     // set that this is a sygus datatype
-    dttl.setSygus(t.toType(), sygusVarListE, false, false);
+    sdttl.initializeDatatype(t, sygusVarList, false, false);
+    Datatype dttl = sdttl.getDatatype();
     DatatypeType tlt = nm->toExprManager()->mkDatatypeType(
         dttl, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
     tlGrammarTypes[t] = TypeNode::fromType(tlt);
index 529ef037f4883b8fd5e5a836e194d796d26d08c7..0396aba86855f65b662289fe46d2d40ce0c903c8 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "expr/datatype.h"
 #include "expr/node_algorithm.h"
+#include "expr/sygus_datatype.h"
 #include "printer/sygus_print_callback.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/quantifiers/quantifiers_rewriter.h"
@@ -86,7 +87,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
   {
     Assert(abdGType.isDatatype() && abdGType.getDatatype().isSygus());
     // must convert all constructors to version with bound variables in "vars"
-    std::vector<Datatype> datatypes;
+    std::vector<SygusDatatype> sdts;
     std::set<Type> unres;
 
     Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl;
@@ -129,9 +130,9 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
         const Datatype& dtc = curr.getDatatype();
         std::stringstream ssdtn;
         ssdtn << dtc.getName() << "_s";
-        datatypes.push_back(Datatype(ssdtn.str()));
+        sdts.push_back(SygusDatatype(ssdtn.str()));
         Trace("sygus-abduct-debug")
-            << "Process datatype " << datatypes.back().getName() << "..."
+            << "Process datatype " << sdts.back().getName() << "..."
             << std::endl;
         for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++)
         {
@@ -141,7 +142,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
               syms.begin(), syms.end(), varlist.begin(), varlist.end());
           Trace("sygus-abduct-debug") << "  Process constructor " << op << " / "
                                       << ops << "..." << std::endl;
-          std::vector<Type> cargs;
+          std::vector<TypeNode> cargs;
           for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++)
           {
             TypeNode argt = TypeNode::fromType(dtc[j].getArgType(k));
@@ -167,7 +168,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
             }
             Trace("sygus-abduct-debug")
                 << "    Arg #" << k << ": " << argtNew << std::endl;
-            cargs.push_back(argtNew.toType());
+            cargs.push_back(argtNew);
           }
           // callback prints as the expression
           std::shared_ptr<SygusPrintCallback> spc;
@@ -191,22 +192,26 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name,
           ss << ops.getKind();
           Trace("sygus-abduct-debug")
               << "Add constructor : " << ops << std::endl;
-          datatypes.back().addSygusConstructor(
-              ops.toExpr(), ss.str(), cargs, spc);
+          sdts.back().addConstructor(ops, ss.str(), cargs, spc);
         }
         Trace("sygus-abduct-debug")
             << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl;
-        datatypes.back().setSygus(dtc.getSygusType(),
-                                  abvl.toExpr(),
-                                  dtc.getSygusAllowConst(),
-                                  dtc.getSygusAllowAll());
+        TypeNode stn = TypeNode::fromType(dtc.getSygusType());
+        sdts.back().initializeDatatype(
+            stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll());
       }
       dtToProcess.clear();
       dtToProcess.insert(
           dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end());
     }
     Trace("sygus-abduct-debug")
-        << "Make " << datatypes.size() << " datatype types..." << std::endl;
+        << "Make " << sdts.size() << " datatype types..." << std::endl;
+    // extract the datatypes
+    std::vector<Datatype> datatypes;
+    for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+    {
+      datatypes.push_back(sdts[i].getDatatype());
+    }
     // make the datatype types
     std::vector<DatatypeType> datatypeTypes =
         nm->toExprManager()->mkMutualDatatypeTypes(