Use standard interface for sygus default grammar construction (#3466)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 16 Nov 2019 04:52:58 +0000 (22:52 -0600)
committerGitHub <noreply@github.com>
Sat, 16 Nov 2019 04:52:58 +0000 (22:52 -0600)
src/expr/sygus_datatype.cpp
src/expr/sygus_datatype.h
src/theory/quantifiers/sygus/sygus_grammar_cons.cpp
src/theory/quantifiers/sygus/sygus_grammar_cons.h
src/theory/quantifiers/sygus/sygus_grammar_norm.cpp

index be2b0440261b2810527104be94b9a5937c6da983..d8ee2e1eae7952a66d0a01fd6fc9310a37164d7c 100644 (file)
@@ -26,9 +26,9 @@ std::string SygusDatatype::getName() const { return d_dt.getName(); }
 
 void SygusDatatype::addConstructor(Node op,
                                    const std::string& name,
+                                   const std::vector<TypeNode>& consTypes,
                                    std::shared_ptr<SygusPrintCallback> spc,
-                                   int weight,
-                                   const std::vector<TypeNode>& consTypes)
+                                   int weight)
 {
   d_ops.push_back(op);
   d_cons_names.push_back(std::string(name));
@@ -50,14 +50,26 @@ void SygusDatatype::addAnyConstantConstructor(TypeNode tn)
   std::vector<TypeNode> builtinArg;
   builtinArg.push_back(tn);
   addConstructor(
-      av, cname, printer::SygusEmptyPrintCallback::getEmptyPC(), 0, builtinArg);
+      av, cname, builtinArg, printer::SygusEmptyPrintCallback::getEmptyPC(), 0);
+}
+void SygusDatatype::addConstructor(Kind k,
+                                   const std::vector<TypeNode>& consTypes,
+                                   std::shared_ptr<SygusPrintCallback> spc,
+                                   int weight)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  addConstructor(nm->operatorOf(k), kindToString(k), consTypes, spc, weight);
 }
 
+size_t SygusDatatype::getNumConstructors() const { return d_ops.size(); }
+
 void SygusDatatype::initializeDatatype(TypeNode sygusType,
                                        Node sygusVars,
                                        bool allowConst,
                                        bool allowAll)
 {
+  // should not have initialized (set sygus) yet
+  Assert(!d_dt.isSygus());
   /* Use the sygus type to not lose reference to the original types (Bool,
    * Int, etc) */
   d_dt.setSygus(sygusType.toType(), sygusVars.toExpr(), allowConst, allowAll);
@@ -76,6 +88,11 @@ void SygusDatatype::initializeDatatype(TypeNode sygusType,
   Trace("sygus-type-cons") << "...built datatype " << d_dt << " ";
 }
 
-const Datatype& SygusDatatype::getDatatype() const { return d_dt; }
+const Datatype& SygusDatatype::getDatatype() const
+{
+  // should have initialized by this point
+  Assert(d_dt.isSygus());
+  return d_dt;
+}
 
 }  // namespace CVC4
index d7b18644a7343960698da5899003a6fae2ce2aa2..132406c69b75156d9f877060046d9edfd9095a00 100644 (file)
@@ -61,18 +61,35 @@ class SygusDatatype
    * weight: the weight of this constructor,
    * consTypes: the argument types of the constructor (typically other sygus
    * datatype types).
+   *
+   * It should be the case that consTypes are sygus datatype types (possibly
+   * unresolved) that encode the arguments of the builtin operator. That is,
+   * if op is the builtin PLUS operator, then consTypes could contain 2+
+   * sygus datatype types that encode integer.
    */
   void addConstructor(Node op,
                       const std::string& name,
-                      std::shared_ptr<SygusPrintCallback> spc,
-                      int weight,
-                      const std::vector<TypeNode>& consTypes);
+                      const std::vector<TypeNode>& consTypes,
+                      std::shared_ptr<SygusPrintCallback> spc = nullptr,
+                      int weight = -1);
+  /**
+   * Add constructor that encodes an application of builtin kind k. Like above,
+   * the arguments consTypes should correspond to sygus datatypes that encode
+   * the types of the arguments of the kind.
+   */
+  void addConstructor(Kind k,
+                      const std::vector<TypeNode>& consTypes,
+                      std::shared_ptr<SygusPrintCallback> spc = nullptr,
+                      int weight = -1);
   /**
    * This adds a constructor that corresponds to the any constant constructor
    * for the given (builtin) type tn.
    */
   void addAnyConstantConstructor(TypeNode tn);
 
+  /** Get the number of constructors added to this class so far */
+  size_t getNumConstructors() const;
+
   /** builds a datatype with the information in the type object
    *
    * sygusType: the builtin type that this datatype encodes,
index b8bf6c865efe3bcc47ba703676c9d4fca8386399..d00df38c5ae50939a92d9a4f66d17510f78ffb65 100644 (file)
@@ -48,8 +48,7 @@ bool CegGrammarConstructor::hasSyntaxRestrictions(Node q)
     if (!gv.isNull())
     {
       TypeNode tn = gv.getType();
-      if (tn.isDatatype()
-          && static_cast<DatatypeType>(tn.toType()).getDatatype().isSygus())
+      if (tn.isDatatype() && tn.getDatatype().isSygus())
       {
         return true;
       }
@@ -260,7 +259,7 @@ Node CegGrammarConstructor::process(Node q,
     }
     tds->registerSygusType(tn);
     Assert(tn.isDatatype());
-    const Datatype& dt = static_cast<DatatypeType>(tn.toType()).getDatatype();
+    const Datatype& dt = tn.getDatatype();
     Assert(dt.isSygus());
     if( !dt.getSygusAllowAll() ){
       d_is_syntax_restricted = true;
@@ -433,21 +432,16 @@ void CegGrammarConstructor::collectSygusGrammarTypesFor(
           for (unsigned j = 0, size_args = dt[i].getNumArgs(); j < size_args;
                ++j)
           {
-            collectSygusGrammarTypesFor(
-                TypeNode::fromType(static_cast<SelectorType>(dt[i][j].getType())
-                                       .getRangeType()),
-                types);
+            TypeNode tn = TypeNode::fromType(dt[i][j].getRangeType());
+            collectSygusGrammarTypesFor(tn, types);
           }
         }
       }
       else if (range.isArray())
       {
-        ArrayType arrayType = static_cast<ArrayType>(range.toType());
         // add index and constituent type
-        collectSygusGrammarTypesFor(
-            TypeNode::fromType(arrayType.getIndexType()), types);
-        collectSygusGrammarTypesFor(
-            TypeNode::fromType(arrayType.getConstituentType()), types);
+        collectSygusGrammarTypesFor(range.getArrayIndexType(), types);
+        collectSygusGrammarTypesFor(range.getArrayConstituentType(), types);
       }
       else if (range.isString() )
       {
@@ -487,11 +481,12 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
     Node bvl,
     const std::string& fun,
     std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons,
-    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& exc_cons,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+        exclude_cons,
     const std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
-        inc_cons,
+        include_cons,
     std::unordered_set<Node, NodeHashFunction>& term_irrelevant,
-    std::vector<CVC4::Datatype>& datatypes,
+    std::vector<SygusDatatypeGenerator>& sdts,
     std::set<Type>& unres)
 {
   NodeManager* nm = NodeManager::currentNM();
@@ -515,19 +510,9 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       }
     }
   }
-  // operators for each constructor in type
-  std::vector<std::vector<Expr>> ops;
-  // names for the operators
-  std::vector<std::vector<std::string>> cnames;
-  // argument types of operators
-  std::vector<std::vector<std::vector<Type>>> cargs;
-  // set of callbacks for each constructor
-  std::vector<std::vector<std::shared_ptr<SygusPrintCallback>>> pcs;
-  // weights for each constructor
-  std::vector<std::vector<int>> weights;
   // index of top datatype, i.e. the datatype for the range type
   int startIndex = -1;
-  std::map< Type, Type > sygus_to_builtin;
+  std::map<TypeNode, TypeNode> sygus_to_builtin;
 
   std::vector<TypeNode> types;
   // collect connected types for each of the variables
@@ -542,32 +527,39 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
   std::stringstream ssb;
   ssb << fun << "_Bool";
   std::string dbname = ssb.str();
-  Type unres_bt = mkUnresolvedType(ssb.str(), unres).toType();
+  TypeNode unres_bt = mkUnresolvedType(ssb.str(), unres);
 
   // create placeholders for collected types
-  std::vector< Type > unres_types;
-  std::map< TypeNode, Type > type_to_unres;
+  std::vector<TypeNode> unres_types;
+  std::map<TypeNode, TypeNode> type_to_unres;
+  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>::const_iterator
+      itc;
   for (unsigned i = 0, size = types.size(); i < size; ++i)
   {
     std::stringstream ss;
     ss << fun << "_" << types[i];
     std::string dname = ss.str();
-    datatypes.push_back(Datatype(dname));
-    ops.push_back(std::vector< Expr >());
-    cnames.push_back(std::vector<std::string>());
-    cargs.push_back(std::vector<std::vector<Type>>());
-    pcs.push_back(std::vector<std::shared_ptr<SygusPrintCallback>>());
-    weights.push_back(std::vector<int>());
+    sdts.push_back(SygusDatatypeGenerator(dname));
+    itc = exclude_cons.find(types[i]);
+    if (itc != exclude_cons.end())
+    {
+      sdts.back().d_exclude_cons = itc->second;
+    }
+    itc = include_cons.find(types[i]);
+    if (itc != include_cons.end())
+    {
+      sdts.back().d_include_cons = itc->second;
+    }
     //make unresolved type
-    Type unres_t = mkUnresolvedType(dname, unres).toType();
+    TypeNode unres_t = mkUnresolvedType(dname, unres);
     unres_types.push_back(unres_t);
     type_to_unres[types[i]] = unres_t;
-    sygus_to_builtin[unres_t] = types[i].toType();
+    sygus_to_builtin[unres_t] = types[i];
   }
   for (unsigned i = 0, size = types.size(); i < size; ++i)
   {
     Trace("sygus-grammar-def") << "Make grammar for " << types[i] << " " << unres_types[i] << std::endl;
-    Type unres_t = unres_types[i];
+    TypeNode unres_t = unres_types[i];
     //add variables
     for (const Node& sv : sygus_vars)
     {
@@ -578,18 +570,15 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
         ss << sv;
         Trace("sygus-grammar-def")
             << "...add for variable " << ss.str() << std::endl;
-        ops[i].push_back(sv.toExpr());
-        cnames[i].push_back(ss.str());
-        cargs[i].push_back(std::vector<Type>());
-        pcs[i].push_back(nullptr);
-        weights[i].push_back(-1);
+        std::vector<TypeNode> cargsEmpty;
+        sdts[i].addConstructor(sv, ss.str(), cargsEmpty);
       }
       else if (svt.isFunction() && svt.getRangeType() == types[i])
       {
         // We add an APPLY_UF for all function whose return type is this type
         // whose argument types are the other sygus types we are constructing.
         std::vector<TypeNode> argTypes = svt.getArgTypes();
-        std::vector<Type> stypes;
+        std::vector<TypeNode> stypes;
         for (unsigned k = 0, ntypes = argTypes.size(); k < ntypes; k++)
         {
           unsigned index =
@@ -599,11 +588,7 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
         }
         std::stringstream ss;
         ss << "apply_" << sv;
-        ops[i].push_back(sv.toExpr());
-        cnames[i].push_back(ss.str());
-        cargs[i].push_back(stypes);
-        pcs[i].push_back(nullptr);
-        weights[i].push_back(-1);
+        sdts[i].addConstructor(sv, ss.str(), stypes);
       }
     }
     //add constants
@@ -628,23 +613,17 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       std::stringstream ss;
       ss << consts[j];
       Trace("sygus-grammar-def") << "...add for constant " << ss.str() << std::endl;
-      ops[i].push_back( consts[j].toExpr() );
-      cnames[i].push_back(ss.str());
-      cargs[i].push_back(std::vector<Type>());
-      pcs[i].push_back(nullptr);
-      weights[i].push_back(-1);
+      std::vector<TypeNode> cargsEmpty;
+      sdts[i].addConstructor(consts[j], ss.str(), cargsEmpty);
     }
     // ITE
     Kind k = ITE;
     Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-    ops[i].push_back(nm->operatorOf(k).toExpr());
-    cnames[i].push_back(kindToString(k));
-    cargs[i].push_back(std::vector<Type>());
-    cargs[i].back().push_back(unres_bt);
-    cargs[i].back().push_back(unres_t);
-    cargs[i].back().push_back(unres_t);
-    pcs[i].push_back(nullptr);
-    weights[i].push_back(-1);
+    std::vector<TypeNode> cargsIte;
+    cargsIte.push_back(unres_bt);
+    cargsIte.push_back(unres_t);
+    cargsIte.push_back(unres_t);
+    sdts[i].addConstructor(k, cargsIte);
 
     if (types[i].isReal())
     {
@@ -653,13 +632,10 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       for (const Kind k : kinds)
       {
         Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-        ops[i].push_back(nm->operatorOf(k).toExpr());
-        cnames[i].push_back(kindToString(k));
-        cargs[i].push_back(std::vector<Type>());
-        cargs[i].back().push_back(unres_t);
-        cargs[i].back().push_back(unres_t);
-        pcs[i].push_back(nullptr);
-        weights[i].push_back(-1);
+        std::vector<TypeNode> cargsOp;
+        cargsOp.push_back(unres_t);
+        cargsOp.push_back(unres_t);
+        sdts[i].addConstructor(k, cargsOp);
       }
       if (!types[i].isInteger())
       {
@@ -670,61 +646,42 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
         ss << fun << "_PosInt";
         std::string pos_int_name = ss.str();
         // make unresolved type
-        Type unres_pos_int_t = mkUnresolvedType(pos_int_name, unres).toType();
-        // make data type
-        datatypes.push_back(Datatype(pos_int_name));
-        /* add placeholders */
-        std::vector<Expr> ops_pos_int;
-        std::vector<std::string> cnames_pos_int;
-        std::vector<std::vector<Type>> cargs_pos_int;
+        TypeNode unres_pos_int_t = mkUnresolvedType(pos_int_name, unres);
+        // make data type for positive constant integers
+        sdts.push_back(SygusDatatypeGenerator(pos_int_name));
         /* Add operator 1 */
         Trace("sygus-grammar-def") << "\t...add for 1 to Pos_Int\n";
-        ops_pos_int.push_back(nm->mkConst(Rational(1)).toExpr());
-        ss.str("");
-        ss << "1";
-        cnames_pos_int.push_back(ss.str());
-        cargs_pos_int.push_back(std::vector<Type>());
+        std::vector<TypeNode> cargsEmpty;
+        sdts.back().addConstructor(nm->mkConst(Rational(1)), "1", cargsEmpty);
         /* Add operator PLUS */
         Kind k = PLUS;
         Trace("sygus-grammar-def") << "\t...add for PLUS to Pos_Int\n";
-        ops_pos_int.push_back(nm->operatorOf(k).toExpr());
-        cnames_pos_int.push_back(kindToString(k));
-        cargs_pos_int.push_back(std::vector<Type>());
-        cargs_pos_int.back().push_back(unres_pos_int_t);
-        cargs_pos_int.back().push_back(unres_pos_int_t);
-        datatypes.back().setSygus(types[i].toType(), bvl.toExpr(), true, true);
-        for (unsigned j = 0, size_j = ops_pos_int.size(); j < size_j; ++j)
-        {
-          datatypes.back().addSygusConstructor(
-              ops_pos_int[j], cnames_pos_int[j], cargs_pos_int[j]);
-        }
+        std::vector<TypeNode> cargsPlus;
+        cargsPlus.push_back(unres_pos_int_t);
+        cargsPlus.push_back(unres_pos_int_t);
+        sdts.back().addConstructor(k, cargsPlus);
+        sdts.back().d_sdt.initializeDatatype(types[i], bvl, true, true);
         Trace("sygus-grammar-def")
-            << "  ...built datatype " << datatypes.back() << " ";
+            << "  ...built datatype " << sdts.back().d_sdt.getDatatype() << " ";
         /* Adding division at root */
         k = DIVISION;
         Trace("sygus-grammar-def") << "\t...add for " << k << std::endl;
-        ops[i].push_back(nm->operatorOf(k).toExpr());
-        cnames[i].push_back(kindToString(k));
-        cargs[i].push_back(std::vector<Type>());
-        cargs[i].back().push_back(unres_t);
-        cargs[i].back().push_back(unres_pos_int_t);
-        pcs[i].push_back(nullptr);
-        weights[i].push_back(-1);
+        std::vector<TypeNode> cargsDiv;
+        cargsDiv.push_back(unres_t);
+        cargsDiv.push_back(unres_pos_int_t);
+        sdts[i].addConstructor(k, cargsDiv);
       }
     }
     else if (types[i].isBitVector())
     {
       // unary apps
       std::vector<Kind> un_kinds = {BITVECTOR_NOT, BITVECTOR_NEG};
+      std::vector<TypeNode> cargsUnary;
+      cargsUnary.push_back(unres_t);
       for (const Kind k : un_kinds)
       {
         Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-        ops[i].push_back(nm->operatorOf(k).toExpr());
-        cnames[i].push_back(kindToString(k));
-        cargs[i].push_back(std::vector<Type>());
-        cargs[i].back().push_back(unres_t);
-        pcs[i].push_back(nullptr);
-        weights[i].push_back(-1);
+        sdts[i].addConstructor(k, cargsUnary);
       }
       // binary apps
       std::vector<Kind> bin_kinds = {BITVECTOR_AND,
@@ -740,28 +697,22 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
                                      BITVECTOR_SHL,
                                      BITVECTOR_LSHR,
                                      BITVECTOR_ASHR};
+      std::vector<TypeNode> cargsBinary;
+      cargsBinary.push_back(unres_t);
+      cargsBinary.push_back(unres_t);
       for (const Kind k : bin_kinds)
       {
         Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-        ops[i].push_back(nm->operatorOf(k).toExpr());
-        cnames[i].push_back(kindToString(k));
-        cargs[i].push_back(std::vector<Type>());
-        cargs[i].back().push_back(unres_t);
-        cargs[i].back().push_back(unres_t);
-        pcs[i].push_back(nullptr);
-        weights[i].push_back(-1);
+        sdts[i].addConstructor(k, cargsBinary);
       }
     }
     else if (types[i].isString())
     {
       // concatenation
-      ops[i].push_back(nm->operatorOf(STRING_CONCAT).toExpr());
-      cnames[i].push_back(kindToString(STRING_CONCAT));
-      cargs[i].push_back(std::vector<Type>());
-      cargs[i].back().push_back(unres_t);
-      cargs[i].back().push_back(unres_t);
-      pcs[i].push_back(nullptr);
-      weights[i].push_back(-1);
+      std::vector<TypeNode> cargsBinary;
+      cargsBinary.push_back(unres_t);
+      cargsBinary.push_back(unres_t);
+      sdts[i].addConstructor(STRING_CONCAT, cargsBinary);
       // length
       TypeNode intType = nm->integerType();
       Assert(std::find(types.begin(), types.end(), intType) != types.end());
@@ -770,63 +721,46 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
           std::find(types.begin(),
                     types.end(),
                     intType));
-      ops[i_intType].push_back(nm->operatorOf(STRING_LENGTH).toExpr());
-      cnames[i_intType].push_back(kindToString(STRING_LENGTH));
-      cargs[i_intType].push_back(std::vector<Type>());
-      cargs[i_intType].back().push_back(unres_t);
-      pcs[i_intType].push_back(nullptr);
-      weights[i_intType].push_back(-1);
+      std::vector<TypeNode> cargsLen;
+      cargsLen.push_back(unres_t);
+      sdts[i_intType].addConstructor(STRING_LENGTH, cargsLen);
     }
     else if (types[i].isArray())
     {
-      ArrayType arrayType = static_cast<ArrayType>(types[i].toType());
-      Trace("sygus-grammar-def")
-          << "...building for array type " << arrayType << "\n";
       Trace("sygus-grammar-def")
-          << "......finding unres type for index type "
-          << arrayType.getIndexType() << " with typenode "
-          << TypeNode::fromType(arrayType.getIndexType()) << "\n";
+          << "...building for array type " << types[i] << "\n";
+      Trace("sygus-grammar-def") << "......finding unres type for index type "
+                                 << types[i].getArrayIndexType() << "\n";
       // retrieve index and constituent unresolved types
-      Assert(std::find(types.begin(),
-                       types.end(),
-                       TypeNode::fromType(arrayType.getIndexType()))
+      Assert(std::find(types.begin(), types.end(), types[i].getArrayIndexType())
              != types.end());
       unsigned i_indexType = std::distance(
           types.begin(),
-          std::find(types.begin(),
-                    types.end(),
-                    TypeNode::fromType(arrayType.getIndexType())));
-      Type unres_indexType = unres_types[i_indexType];
-      Assert(std::find(types.begin(),
-                       types.end(),
-                       TypeNode::fromType(arrayType.getConstituentType()))
+          std::find(types.begin(), types.end(), types[i].getArrayIndexType()));
+      TypeNode unres_indexType = unres_types[i_indexType];
+      Assert(std::find(
+                 types.begin(), types.end(), types[i].getArrayConstituentType())
              != types.end());
       unsigned i_constituentType = std::distance(
           types.begin(),
-          std::find(types.begin(),
-                    types.end(),
-                    TypeNode::fromType(arrayType.getConstituentType())));
-      Type unres_constituentType = unres_types[i_constituentType];
+          std::find(
+              types.begin(), types.end(), types[i].getArrayConstituentType()));
+      TypeNode unres_constituentType = unres_types[i_constituentType];
       // add (store ArrayType IndexType ConstituentType)
       Trace("sygus-grammar-def") << "...add for STORE\n";
-      ops[i].push_back(nm->operatorOf(STORE).toExpr());
-      cnames[i].push_back(kindToString(STORE));
-      cargs[i].push_back(std::vector<Type>());
-      cargs[i].back().push_back(unres_t);
-      cargs[i].back().push_back(unres_indexType);
-      cargs[i].back().push_back(unres_constituentType);
-      pcs[i].push_back(nullptr);
-      weights[i].push_back(-1);
+
+      std::vector<TypeNode> cargsStore;
+      cargsStore.push_back(unres_t);
+      cargsStore.push_back(unres_indexType);
+      cargsStore.push_back(unres_constituentType);
+      sdts[i].addConstructor(STORE, cargsStore);
       // add to constituent type : (select ArrayType IndexType)
       Trace("sygus-grammar-def") << "...add select for constituent type"
                                  << unres_constituentType << "\n";
-      ops[i_constituentType].push_back(nm->operatorOf(SELECT).toExpr());
-      cnames[i_constituentType].push_back(kindToString(SELECT));
-      cargs[i_constituentType].push_back(std::vector<Type>());
-      cargs[i_constituentType].back().push_back(unres_t);
-      cargs[i_constituentType].back().push_back(unres_indexType);
-      pcs[i_constituentType].push_back(nullptr);
-      weights[i_constituentType].push_back(-1);
+      std::vector<TypeNode> cargsSelect;
+      cargsSelect.push_back(unres_t);
+      cargsSelect.push_back(unres_indexType);
+      sdts[i_constituentType].addConstructor(SELECT, cargsSelect);
     }
     else if (types[i].isDatatype())
     {
@@ -835,42 +769,36 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       for (unsigned k = 0, size_k = dt.getNumConstructors(); k < size_k; ++k)
       {
         Trace("sygus-grammar-def") << "...for " << dt[k].getName() << std::endl;
-        Expr cop = dt[k].getConstructor();
+        Node cop = Node::fromExpr(dt[k].getConstructor());
         if (dt[k].getNumArgs() == 0)
         {
           // Nullary constructors are interpreted as terms, not operators.
           // Thus, we apply them to no arguments here.
-          cop = nm->mkNode(APPLY_CONSTRUCTOR, Node::fromExpr(cop)).toExpr();
+          cop = nm->mkNode(APPLY_CONSTRUCTOR, cop);
         }
-        ops[i].push_back(cop);
-        cnames[i].push_back(dt[k].getName());
-        cargs[i].push_back(std::vector<Type>());
+        std::vector<TypeNode> cargsCons;
         Trace("sygus-grammar-def") << "...add for selectors" << std::endl;
         for (unsigned j = 0, size_j = dt[k].getNumArgs(); j < size_j; ++j)
         {
           Trace("sygus-grammar-def")
               << "...for " << dt[k][j].getName() << std::endl;
-          TypeNode crange = TypeNode::fromType(
-              static_cast<SelectorType>(dt[k][j].getType()).getRangeType());
+          TypeNode crange = TypeNode::fromType(dt[k][j].getRangeType());
           Assert(type_to_unres.find(crange) != type_to_unres.end());
-          cargs[i].back().push_back(type_to_unres[crange]);
+          cargsCons.push_back(type_to_unres[crange]);
           // add to the selector type the selector operator
 
           Assert(std::find(types.begin(), types.end(), crange) != types.end());
           unsigned i_selType = std::distance(
               types.begin(), std::find(types.begin(), types.end(), crange));
-          TypeNode arg_type = TypeNode::fromType(
-              static_cast<SelectorType>(dt[k][j].getType()).getDomain());
-          ops[i_selType].push_back(dt[k][j].getSelector());
-          cnames[i_selType].push_back(dt[k][j].getName());
-          cargs[i_selType].push_back(std::vector<Type>());
+          TypeNode arg_type = TypeNode::fromType(dt[k][j].getType());
+          arg_type = arg_type.getSelectorDomainType();
           Assert(type_to_unres.find(arg_type) != type_to_unres.end());
-          cargs[i_selType].back().push_back(type_to_unres[arg_type]);
-          pcs[i_selType].push_back(nullptr);
-          weights[i_selType].push_back(-1);
+          std::vector<TypeNode> cargsSel;
+          cargsSel.push_back(type_to_unres[arg_type]);
+          Node sel = Node::fromExpr(dt[k][j].getSelector());
+          sdts[i_selType].addConstructor(sel, dt[k][j].getName(), cargsSel);
         }
-        pcs[i].push_back(nullptr);
-        weights[i].push_back(-1);
+        sdts[i].addConstructor(cop, dt[k].getName(), cargsCons);
       }
     }
     else if (types[i].isSort() || types[i].isFunction())
@@ -887,40 +815,9 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
   // make datatypes
   for (unsigned i = 0, size = types.size(); i < size; ++i)
   {
-    Trace("sygus-grammar-def") << "...make datatype " << datatypes[i] << std::endl;
-    datatypes[i].setSygus( types[i].toType(), bvl.toExpr(), true, true );
-    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>::iterator
-        itexc = exc_cons.find(types[i]);
-    std::map<TypeNode,
-             std::unordered_set<Node, NodeHashFunction>>::const_iterator itinc =
-        inc_cons.find(types[i]);
-    for (unsigned j = 0, size = ops[i].size(); j < size; ++j)
-    {
-      // add the constructor if it is not excluded,
-      // and it is in inc_cons, in case it is not empty
-      Node opn = Node::fromExpr(ops[i][j]);
-      Trace("sygus-grammar-def")
-          << "...considering " << opn.toString() << " of kind " << opn.getKind()
-          << " and of type " << opn.getType() << " and of kind of type "
-          << opn.getType().getKind() << " of metakind " << opn.getMetaKind()
-          << std::endl;
-      if (itexc == exc_cons.end()
-          || std::find(itexc->second.begin(), itexc->second.end(), opn)
-                 == itexc->second.end())
-      {
-        Trace("sygus-grammar-def") << "......not excluded " << std::endl;
-        if ((opn.isVar()) || (opn.getType().getKind() != kind::TYPE_CONSTANT)
-            || (itinc == inc_cons.end())
-            || (std::find(itinc->second.begin(), itinc->second.end(), opn)
-                != itinc->second.end()))
-        {
-          Trace("sygus-grammar-def") << "......included " << std::endl;
-          datatypes[i].addSygusConstructor(
-              ops[i][j], cnames[i][j], cargs[i][j], pcs[i][j], weights[i][j]);
-        }
-      }
-    }
-    Trace("sygus-grammar-def") << "...built datatype " << datatypes[i] << " ";
+    sdts[i].d_sdt.initializeDatatype(types[i], bvl, true, true);
+    Trace("sygus-grammar-def")
+        << "...built datatype " << sdts[i].d_sdt.getDatatype() << " ";
     //set start index if applicable
     if( types[i]==range ){
       startIndex = i;
@@ -929,13 +826,9 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
 
   //------ make Boolean type
   TypeNode btype = nm->booleanType();
-  datatypes.push_back(Datatype(dbname));
-  ops.push_back(std::vector<Expr>());
-  cnames.push_back(std::vector<std::string>());
-  cargs.push_back(std::vector<std::vector<Type>>());
-  pcs.push_back(std::vector<std::shared_ptr<SygusPrintCallback>>());
-  weights.push_back(std::vector<int>());
-  Trace("sygus-grammar-def") << "Make grammar for " << btype << " " << datatypes.back() << std::endl;
+  sdts.push_back(SygusDatatypeGenerator(dbname));
+  SygusDatatypeGenerator& sdtBool = sdts.back();
+  Trace("sygus-grammar-def") << "Make grammar for " << btype << std::endl;
   //add variables
   for (unsigned i = 0, size = sygus_vars.size(); i < size; ++i)
   {
@@ -943,12 +836,9 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       std::stringstream ss;
       ss << sygus_vars[i];
       Trace("sygus-grammar-def") << "...add for variable " << ss.str() << std::endl;
-      ops.back().push_back( sygus_vars[i].toExpr() );
-      cnames.back().push_back(ss.str());
-      cargs.back().push_back(std::vector<Type>());
-      pcs.back().push_back(nullptr);
+      std::vector<TypeNode> cargsEmpty;
       // make boolean variables weight as non-nullary constructors
-      weights.back().push_back(1);
+      sdtBool.addConstructor(sygus_vars[i], ss.str(), cargsEmpty, nullptr, 1);
     }
   }
   // add constants
@@ -960,11 +850,8 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
     ss << consts[i];
     Trace("sygus-grammar-def") << "...add for constant " << ss.str()
                                << std::endl;
-    ops.back().push_back(consts[i].toExpr());
-    cnames.back().push_back(ss.str());
-    cargs.back().push_back(std::vector<Type>());
-    pcs.back().push_back(nullptr);
-    weights.back().push_back(-1);
+    std::vector<TypeNode> cargsEmpty;
+    sdtBool.addConstructor(consts[i], ss.str(), cargsEmpty);
   }
   // add predicates for types
   for (unsigned i = 0, size = types.size(); i < size; ++i)
@@ -977,60 +864,43 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
     //add equality per type
     Kind k = EQUAL;
     Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-    ops.back().push_back(nm->operatorOf(k).toExpr());
     std::stringstream ss;
     ss << kindToString(k) << "_" << types[i];
-    cnames.back().push_back(ss.str());
-    cargs.back().push_back(std::vector<Type>());
-    cargs.back().back().push_back(unres_types[i]);
-    cargs.back().back().push_back(unres_types[i]);
-    pcs.back().push_back(nullptr);
-    weights.back().push_back(-1);
+    std::vector<TypeNode> cargsBinary;
+    cargsBinary.push_back(unres_types[i]);
+    cargsBinary.push_back(unres_types[i]);
+    sdtBool.addConstructor(nm->operatorOf(k), ss.str(), cargsBinary);
     // type specific predicates
     if (types[i].isReal())
     {
       Kind k = LEQ;
       Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-      ops.back().push_back(nm->operatorOf(k).toExpr());
-      cnames.back().push_back(kindToString(k));
-      cargs.back().push_back(std::vector<Type>());
-      cargs.back().back().push_back(unres_types[i]);
-      cargs.back().back().push_back(unres_types[i]);
-      pcs.back().push_back(nullptr);
-      weights.back().push_back(-1);
+      sdtBool.addConstructor(k, cargsBinary);
     }
     else if (types[i].isBitVector())
     {
       Kind k = BITVECTOR_ULT;
       Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-      ops.back().push_back(nm->operatorOf(k).toExpr());
-      cnames.back().push_back(kindToString(k));
-      cargs.back().push_back(std::vector<Type>());
-      cargs.back().back().push_back(unres_types[i]);
-      cargs.back().back().push_back(unres_types[i]);
-      pcs.back().push_back(nullptr);
-      weights.back().push_back(-1);
+      sdtBool.addConstructor(k, cargsBinary);
     }
     else if (types[i].isDatatype())
     {
       //add for testers
       Trace("sygus-grammar-def") << "...add for testers" << std::endl;
       const Datatype& dt = types[i].getDatatype();
+      std::vector<TypeNode> cargsTester;
+      cargsTester.push_back(unres_types[i]);
       for (unsigned k = 0, size_k = dt.getNumConstructors(); k < size_k; ++k)
       {
         Trace("sygus-grammar-def") << "...for " << dt[k].getTesterName() << std::endl;
-        ops.back().push_back(dt[k].getTester());
-        cnames.back().push_back(dt[k].getTesterName());
-        cargs.back().push_back(std::vector<Type>());
-        cargs.back().back().push_back(unres_types[i]);
-        pcs.back().push_back(nullptr);
-        weights.back().push_back(-1);
+        sdtBool.addConstructor(
+            dt[k].getTester(), dt[k].getTesterName(), cargsTester);
       }
     }
   }
   // add Boolean connectives, if not in a degenerate case of (recursively)
   // having only constant constructors
-  if (ops.back().size() > consts.size())
+  if (sdtBool.d_sdt.getNumConstructors() > consts.size())
   {
     for (unsigned i = 0; i < 4; i++)
     {
@@ -1043,42 +913,31 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
         continue;
       }
       Trace("sygus-grammar-def") << "...add for " << k << std::endl;
-      ops.back().push_back(nm->operatorOf(k).toExpr());
-      cnames.back().push_back(kindToString(k));
-      cargs.back().push_back(std::vector<Type>());
-      cargs.back().back().push_back(unres_bt);
+      std::vector<TypeNode> cargs;
+      cargs.push_back(unres_bt);
       if (k != NOT)
       {
-        cargs.back().back().push_back(unres_bt);
+        cargs.push_back(unres_bt);
         if (k == ITE)
         {
-          cargs.back().back().push_back(unres_bt);
+          cargs.push_back(unres_bt);
         }
       }
-      pcs.back().push_back(nullptr);
-      weights.back().push_back(-1);
+      sdtBool.addConstructor(k, cargs);
     }
   }
   if( range==btype ){
-    startIndex = datatypes.size()-1;
-  }
-  Trace("sygus-grammar-def") << "...make datatype " << datatypes.back() << std::endl;
-  datatypes.back().setSygus( btype.toType(), bvl.toExpr(), true, true );
-  for (unsigned i = 0, size = ops.back().size(); i < size; ++i)
-  {
-    datatypes.back().addSygusConstructor(ops.back()[i],
-                                         cnames.back()[i],
-                                         cargs.back()[i],
-                                         pcs.back()[i],
-                                         weights.back()[i]);
+    startIndex = sdts.size() - 1;
   }
-  Trace("sygus-grammar-def") << "...built datatype " << datatypes.back() << " ";
+  sdtBool.d_sdt.initializeDatatype(btype, bvl, true, true);
+  Trace("sygus-grammar-def")
+      << "...built datatype for Bool " << sdtBool.d_sdt.getDatatype() << " ";
   Trace("sygus-grammar-def") << "...finished make default grammar for " << fun << " " << range << std::endl;
   // make first datatype be the top level datatype
   if( startIndex>0 ){
-    Datatype tmp_dt = datatypes[0];
-    datatypes[0] = datatypes[startIndex];
-    datatypes[startIndex] = tmp_dt;
+    SygusDatatypeGenerator tmp_dt = sdts[0];
+    sdts[0] = sdts[startIndex];
+    sdts[startIndex] = tmp_dt;
   }
 }
 
@@ -1102,7 +961,7 @@ TypeNode CegGrammarConstructor::mkSygusDefaultType(
     Trace("sygus-grammar-def") << "    ...using " << it->second.size() << " extra constants for " << it->first << std::endl;
   }
   std::set<Type> unres;
-  std::vector< CVC4::Datatype > datatypes;
+  std::vector<SygusDatatypeGenerator> sdts;
   mkSygusDefaultGrammar(range,
                         bvl,
                         fun,
@@ -1110,8 +969,14 @@ TypeNode CegGrammarConstructor::mkSygusDefaultType(
                         exclude_cons,
                         include_cons,
                         term_irrelevant,
-                        datatypes,
+                        sdts,
                         unres);
+  // extract the datatypes from the sygus datatype generator objects
+  std::vector<Datatype> datatypes;
+  for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+  {
+    datatypes.push_back(sdts[i].d_sdt.getDatatype());
+  }
   Trace("sygus-grammar-def")  << "...made " << datatypes.size() << " datatypes, now make mutual datatype types..." << std::endl;
   Assert(!datatypes.empty());
   std::vector<DatatypeType> types =
@@ -1130,13 +995,13 @@ TypeNode CegGrammarConstructor::mkSygusTemplateTypeRec( Node templ, Node templ_a
   }else{
     tcount++;
     std::set<Type> unres;
-    std::vector< CVC4::Datatype > datatypes;
+    std::vector<SygusDatatype> sdts;
     std::stringstream ssd;
     ssd << fun << "_templ_" << tcount;
     std::string dbname = ssd.str();
-    datatypes.push_back(Datatype(dbname));
+    sdts.push_back(SygusDatatype(dbname));
     Node op;
-    std::vector< Type > argTypes;
+    std::vector<TypeNode> argTypes;
     if( templ.getNumChildren()==0 ){
       // TODO : can short circuit to this case when !TermUtil::containsTerm( templ, templ_arg )
       op = templ;
@@ -1147,15 +1012,20 @@ TypeNode CegGrammarConstructor::mkSygusTemplateTypeRec( Node templ, Node templ_a
       for( unsigned i=0; i<templ.getNumChildren(); i++ ){
         //recursion depth bound by the depth of SyGuS template expressions (low)
         TypeNode tnc = mkSygusTemplateTypeRec( templ[i], templ_arg, templ_arg_sygus_type, bvl, fun, tcount );
-        argTypes.push_back( tnc.toType() );
+        argTypes.push_back(tnc);
       }
     }
     std::stringstream ssdc;
     ssdc << fun << "_templ_cons_" << tcount;
-    std::string cname = ssdc.str();
     // we have a single sygus constructor that encodes the template
-    datatypes.back().addSygusConstructor( op.toExpr(), cname, argTypes );
-    datatypes.back().setSygus( templ.getType().toType(), bvl.toExpr(), true, true );
+    sdts.back().addConstructor(op, ssdc.str(), argTypes);
+    sdts.back().initializeDatatype(templ.getType(), bvl, true, true);
+    // extract the datatypes from the sygus datatype objects
+    std::vector<Datatype> datatypes;
+    for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++)
+    {
+      datatypes.push_back(sdts[i].getDatatype());
+    }
     std::vector<DatatypeType> types =
         NodeManager::currentNM()->toExprManager()->mkMutualDatatypeTypes(
             datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER);
@@ -1191,6 +1061,52 @@ Node CegGrammarConstructor::getSygusVarList(Node f)
   return sfvl;
 }
 
+CegGrammarConstructor::SygusDatatypeGenerator::SygusDatatypeGenerator(
+    const std::string& name)
+    : d_sdt(name)
+{
+}
+void CegGrammarConstructor::SygusDatatypeGenerator::addConstructor(
+    Node op,
+    const std::string& name,
+    const std::vector<TypeNode>& consTypes,
+    std::shared_ptr<SygusPrintCallback> spc,
+    int weight)
+{
+  if (shouldInclude(op))
+  {
+    d_sdt.addConstructor(op, name, consTypes, spc, weight);
+  }
+}
+void CegGrammarConstructor::SygusDatatypeGenerator::addConstructor(
+    Kind k,
+    const std::vector<TypeNode>& consTypes,
+    std::shared_ptr<SygusPrintCallback> spc,
+    int weight)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  addConstructor(nm->operatorOf(k), kindToString(k), consTypes, spc, weight);
+}
+bool CegGrammarConstructor::SygusDatatypeGenerator::shouldInclude(Node op) const
+{
+  if (d_exclude_cons.find(op) != d_exclude_cons.end())
+  {
+    return false;
+  }
+  if (!d_include_cons.empty())
+  {
+    // special case, variables and terms of certain types are always included
+    if (!op.isVar() && op.getType().getKind() == TYPE_CONSTANT)
+    {
+      if (d_include_cons.find(op) == d_include_cons.end())
+      {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 }/* namespace CVC4::theory::quantifiers */
 }/* namespace CVC4::theory */
 }/* namespace CVC4 */
index 85efddada44e26d999610dad8c27cf283e9f5072..00e9d45fba3024f7efbed770b2e0601b94284d1e 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "expr/attribute.h"
 #include "expr/node.h"
+#include "expr/sygus_datatype.h"
 
 namespace CVC4 {
 namespace theory {
@@ -186,6 +187,42 @@ public:
       Node n,
       std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& consts);
   //---------------- grammar construction
+  /** A class for generating sygus datatypes */
+  class SygusDatatypeGenerator
+  {
+   public:
+    SygusDatatypeGenerator(const std::string& name);
+    ~SygusDatatypeGenerator() {}
+    /**
+     * Possibly add a constructor to d_sdt, based on the criteria mentioned
+     * in this class below. For details see expr/sygus_datatype.h.
+     */
+    void addConstructor(Node op,
+                        const std::string& name,
+                        const std::vector<TypeNode>& consTypes,
+                        std::shared_ptr<SygusPrintCallback> spc = nullptr,
+                        int weight = -1);
+    /**
+     * Possibly add a constructor to d_sdt, based on the criteria mentioned
+     * in this class below.
+     */
+    void addConstructor(Kind k,
+                        const std::vector<TypeNode>& consTypes,
+                        std::shared_ptr<SygusPrintCallback> spc = nullptr,
+                        int weight = -1);
+    /** Should we include constructor with operator op? */
+    bool shouldInclude(Node op) const;
+    /** The constructors that should be excluded. */
+    std::unordered_set<Node, NodeHashFunction> d_exclude_cons;
+    /**
+     * If this set is non-empty, then only include variables and constructors
+     * from it.
+     */
+    std::unordered_set<Node, NodeHashFunction> d_include_cons;
+    /** The sygus datatype we are generating. */
+    SygusDatatype d_sdt;
+  };
+
   // helper for mkSygusDefaultGrammar (makes unresolved type for mutually recursive datatype construction)
   static TypeNode mkUnresolvedType(const std::string& name, std::set<Type>& unres);
   // collect the list of types that depend on type range
@@ -207,7 +244,7 @@ public:
       const std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
           include_cons,
       std::unordered_set<Node, NodeHashFunction>& term_irrelevant,
-      std::vector<CVC4::Datatype>& datatypes,
+      std::vector<SygusDatatypeGenerator>& sdts,
       std::set<Type>& unres);
 
   // helper function for mkSygusTemplateType
index 68445fca09db8cb6c57ab054486e1fa7ab83b5cc..40046ef15f1b07e832854d253ce43694c70c4fd0 100644 (file)
@@ -222,9 +222,9 @@ void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm,
                                 << "\n\tExpanded one: " << exp_sop_n << "\n\n";
   d_sdt.addConstructor(exp_sop_n,
                        cons.getName(),
+                       consTypes,
                        cons.getSygusPrintCallback(),
-                       cons.getWeight(),
-                       consTypes);
+                       cons.getWeight());
 }
 
 void SygusGrammarNorm::TypeObject::initializeDatatype(
@@ -342,17 +342,16 @@ void SygusGrammarNorm::TransfChain::buildType(SygusGrammarNorm* sygus_norm,
     ctypes.push_back(t);
     to.d_sdt.addConstructor(iden_op,
                             "id",
+                            ctypes,
                             printer::SygusEmptyPrintCallback::getEmptyPC(),
-                            0,
-                            ctypes);
+                            0);
     Trace("sygus-grammar-normalize-chain")
         << "\tAdding  " << t << " to " << to.d_unres_tn << "\n";
     /* adds to Root: "type + Root" */
     std::vector<TypeNode> ctypesp;
     ctypesp.push_back(t);
     ctypesp.push_back(to.d_unres_tn);
-    to.d_sdt.addConstructor(
-        nm->operatorOf(PLUS), kindToString(PLUS), nullptr, -1, ctypesp);
+    to.d_sdt.addConstructor(nm->operatorOf(PLUS), kindToString(PLUS), ctypesp);
     Trace("sygus-grammar-normalize-chain")
         << "\tAdding PLUS to " << to.d_unres_tn << " with arg types "
         << to.d_unres_tn << " and " << t << "\n";
@@ -385,9 +384,9 @@ void SygusGrammarNorm::TransfChain::buildType(SygusGrammarNorm* sygus_norm,
   ctypes.push_back(sygus_norm->normalizeSygusRec(to.d_tn, dt, d_elem_pos));
   to.d_sdt.addConstructor(iden_op,
                           "id_next",
+                          ctypes,
                           printer::SygusEmptyPrintCallback::getEmptyPC(),
-                          0,
-                          ctypes);
+                          0);
 }
 
 std::map<TypeNode, Node> SygusGrammarNorm::d_tn_to_id = {};