Update sygus grammar normalization to use node-level datatype. (#3567)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 31 Jan 2020 16:43:36 +0000 (10:43 -0600)
committerGitHub <noreply@github.com>
Fri, 31 Jan 2020 16:43:36 +0000 (10:43 -0600)
src/theory/quantifiers/sygus/sygus_grammar_norm.cpp
src/theory/quantifiers/sygus/sygus_grammar_norm.h

index c7c1d820fa3be66dc565deae9d8e31deddda15cf..b2e7d268184765262bc2b5cfb803f29a6309ce5f 100644 (file)
@@ -162,13 +162,15 @@ Node SygusGrammarNorm::TypeObject::eliminatePartialOperators(Node n)
   return visited[n];
 }
 
-void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm,
-                                               const DatatypeConstructor& cons)
+void SygusGrammarNorm::TypeObject::addConsInfo(
+    SygusGrammarNorm* sygus_norm,
+    const DTypeConstructor& cons,
+    std::shared_ptr<SygusPrintCallback> spc)
 {
   Trace("sygus-grammar-normalize") << "...for " << cons.getName() << "\n";
   /* Recover the sygus operator to not lose reference to the original
    * operator (NOT, ITE, etc) */
-  Node sygus_op = Node::fromExpr(cons.getSygusOp());
+  Node sygus_op = cons.getSygusOp();
   Trace("sygus-grammar-normalize-debug")
       << ".....operator is " << sygus_op << std::endl;
   Node exp_sop_n = sygus_op;
@@ -208,11 +210,12 @@ void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm,
   }
 
   std::vector<TypeNode> consTypes;
-  for (const DatatypeConstructorArg& arg : cons)
+  const std::vector<std::shared_ptr<DTypeSelector> >& args = cons.getArgs();
+  for (const std::shared_ptr<DTypeSelector>& arg : args)
   {
     // Collect unresolved type nodes corresponding to the typenode of the
     // arguments.
-    TypeNode atype = TypeNode::fromType(arg.getRangeType());
+    TypeNode atype = arg->getRangeType();
     // normalize it recursively
     atype = sygus_norm->normalizeSygusRec(atype);
     consTypes.push_back(atype);
@@ -220,19 +223,16 @@ void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm,
 
   Trace("sygus-type-cons-defs") << "\tOriginal op: " << cons.getSygusOp()
                                 << "\n\tExpanded one: " << exp_sop_n << "\n\n";
-  d_sdt.addConstructor(exp_sop_n,
-                       cons.getName(),
-                       consTypes,
-                       cons.getSygusPrintCallback(),
-                       cons.getWeight());
+  d_sdt.addConstructor(
+      exp_sop_n, cons.getName(), consTypes, spc, cons.getWeight());
 }
 
 void SygusGrammarNorm::TypeObject::initializeDatatype(
-    SygusGrammarNorm* sygus_norm, const Datatype& dt)
+    SygusGrammarNorm* sygus_norm, const DType& dt)
 {
   /* Use the sygus type to not lose reference to the original types (Bool,
    * Int, etc) */
-  TypeNode sygusType = TypeNode::fromType(dt.getSygusType());
+  TypeNode sygusType = dt.getSygusType();
   d_sdt.initializeDatatype(sygusType,
                            sygus_norm->d_sygus_vars.toExpr(),
                            dt.getSygusAllowConst(),
@@ -247,7 +247,7 @@ void SygusGrammarNorm::TypeObject::initializeDatatype(
 
 void SygusGrammarNorm::TransfDrop::buildType(SygusGrammarNorm* sygus_norm,
                                              TypeObject& to,
-                                             const Datatype& dt,
+                                             const DType& dt,
                                              std::vector<unsigned>& op_pos)
 {
   std::vector<unsigned> difference;
@@ -287,7 +287,7 @@ bool SygusGrammarNorm::TransfChain::isId(TypeNode tn, Node op, Node n)
 
 void SygusGrammarNorm::TransfChain::buildType(SygusGrammarNorm* sygus_norm,
                                               TypeObject& to,
-                                              const Datatype& dt,
+                                              const DType& dt,
                                               std::vector<unsigned>& op_pos)
 {
   NodeManager* nm = NodeManager::currentNM();
@@ -324,8 +324,7 @@ void SygusGrammarNorm::TransfChain::buildType(SygusGrammarNorm* sygus_norm,
     Trace("sygus-grammar-normalize-chain") << "\n";
   }
   /* Build identity operator and empty callback */
-  Node iden_op =
-      SygusGrammarNorm::getIdOp(TypeNode::fromType(dt.getSygusType()));
+  Node iden_op = SygusGrammarNorm::getIdOp(dt.getSygusType());
   /* If all operators are claimed, create a monomial */
   if (nb_op_pos == d_elem_pos.size() + 1)
   {
@@ -398,10 +397,10 @@ std::map<TypeNode, Node> SygusGrammarNorm::d_tn_to_id = {};
  * returns true if collected anything
  */
 std::unique_ptr<SygusGrammarNorm::Transf> SygusGrammarNorm::inferTransf(
-    TypeNode tn, const Datatype& dt, const std::vector<unsigned>& op_pos)
+    TypeNode tn, const DType& dt, const std::vector<unsigned>& op_pos)
 {
   NodeManager* nm = NodeManager::currentNM();
-  TypeNode sygus_tn = TypeNode::fromType(dt.getSygusType());
+  TypeNode sygus_tn = dt.getSygusType();
   Trace("sygus-gnorm") << "Infer transf for " << dt.getName() << "..."
                        << std::endl;
   Trace("sygus-gnorm") << "  #cons = " << op_pos.size() << " / "
@@ -436,21 +435,20 @@ std::unique_ptr<SygusGrammarNorm::Transf> SygusGrammarNorm::inferTransf(
   for (unsigned i = 0, size = op_pos.size(); i < size; ++i)
   {
     Assert(op_pos[i] < dt.getNumConstructors());
-    Expr sop = dt[op_pos[i]].getSygusOp();
+    Node sop = dt[op_pos[i]].getSygusOp();
     /* Collects a chainable operator such as PLUS */
-    if (sop.getKind() == BUILTIN
-        && TransfChain::isChainable(sygus_tn, Node::fromExpr(sop)))
+    if (sop.getKind() == BUILTIN && TransfChain::isChainable(sygus_tn, sop))
     {
-      Assert(nm->operatorToKind(Node::fromExpr(sop)) == PLUS);
+      Assert(nm->operatorToKind(sop) == PLUS);
       /* TODO #1304: be robust for this case */
       /* For now only transforms applications whose arguments have the same type
        * as the root */
       bool same_type_plus = true;
-      for (const DatatypeConstructorArg& arg : dt[op_pos[i]])
+      const std::vector<std::shared_ptr<DTypeSelector> >& args =
+          dt[op_pos[i]].getArgs();
+      for (const std::shared_ptr<DTypeSelector>& arg : args)
       {
-        if (TypeNode::fromType(
-                static_cast<SelectorType>(arg.getType()).getRangeType())
-            != tn)
+        if (arg->getRangeType() != tn)
         {
           same_type_plus = false;
           break;
@@ -472,7 +470,7 @@ std::unique_ptr<SygusGrammarNorm::Transf> SygusGrammarNorm::inferTransf(
     }
     /* TODO #1304: check this for each operator */
     /* Collects elements that are not the identity (e.g. 0 is the id of PLUS) */
-    if (!TransfChain::isId(sygus_tn, nm->operatorOf(PLUS), Node::fromExpr(sop)))
+    if (!TransfChain::isId(sygus_tn, nm->operatorOf(PLUS), sop))
     {
       Trace("sygus-grammar-normalize-infer")
           << "\tCollecting for NON_ID_ELEMS the sop " << sop
@@ -492,7 +490,7 @@ std::unique_ptr<SygusGrammarNorm::Transf> SygusGrammarNorm::inferTransf(
 }
 
 TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn,
-                                             const Datatype& dt,
+                                             const DType& dt,
                                              std::vector<unsigned>& op_pos)
 {
   Assert(tn.isDatatype());
@@ -541,7 +539,7 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn,
 
   if (dt.getSygusAllowConst())
   {
-    TypeNode sygus_type = TypeNode::fromType(dt.getSygusType());
+    TypeNode sygus_type = dt.getSygusType();
     // must be handled by counterexample-guided instantiation
     // don't do it for Boolean (not worth the trouble, since it has only
     // minimal gain (1 any constant vs 2 constructors for true/false), and
@@ -551,7 +549,7 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn,
         && !sygus_type.isBoolean())
     {
       Trace("sygus-grammar-normalize") << "...add any constant constructor.\n";
-      TypeNode dtn = TypeNode::fromType(dt.getSygusType());
+      TypeNode dtn = dt.getSygusType();
       // we add this constructor first since we use left associative chains
       // and our symmetry breaking should group any constants together
       // beneath the same application
@@ -570,11 +568,15 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn,
     transformation->buildType(this, to, dt, op_pos);
   }
 
-  /* Remaining operators are rebuilt as they are */
+  // Remaining operators are rebuilt as they are.
+  // Notice that we must extract the Datatype here to get the (Expr-layer)
+  // sygus print callback.
+  const Datatype& dtt = DatatypeType(tn.toType()).getDatatype();
   for (unsigned i = 0, size = op_pos.size(); i < size; ++i)
   {
-    Assert(op_pos[i] < dt.getNumConstructors());
-    to.addConsInfo(this, dt[op_pos[i]]);
+    unsigned oi = op_pos[i];
+    Assert(oi < dt.getNumConstructors());
+    to.addConsInfo(this, dt[oi], dtt[oi].getSygusPrintCallback());
   }
   /* Build normalize datatype */
   if (Trace.isOn("sygus-grammar-normalize"))
@@ -599,7 +601,7 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn)
     return tn;
   }
   /* Collect all operators for normalization */
-  const Datatype& dt = DatatypeType(tn.toType()).getDatatype();
+  const DType& dt = tn.getDType();
   if (!dt.isSygus())
   {
     // datatype but not sygus datatype case
index f9c53c4ad94037b5e9685f414ff4d356aa10a913..360762b38f463a619b7b25bfea8ba73751e69b9f 100644 (file)
@@ -198,7 +198,8 @@ class SygusGrammarNorm
      * The types of the arguments of "cons" are recursively normalized
      */
     void addConsInfo(SygusGrammarNorm* sygus_norm,
-                     const DatatypeConstructor& cons);
+                     const DTypeConstructor& cons,
+                     std::shared_ptr<SygusPrintCallback> spc);
     /**
      * Returns the total version of Kind k if it is a partial operator, or
      * otherwise k itself.
@@ -219,7 +220,7 @@ class SygusGrammarNorm
      * The initialized datatype and its unresolved type are saved in the global
      * accumulators of "sygus_norm"
      */
-    void initializeDatatype(SygusGrammarNorm* sygus_norm, const Datatype& dt);
+    void initializeDatatype(SygusGrammarNorm* sygus_norm, const DType& dt);
 
     //---------- information stored from original type node
 
@@ -253,7 +254,7 @@ class SygusGrammarNorm
      */
     virtual void buildType(SygusGrammarNorm* sygus_norm,
                            TypeObject& to,
-                           const Datatype& dt,
+                           const DType& dt,
                            std::vector<unsigned>& op_pos) = 0;
   }; /* class Transf */
 
@@ -271,7 +272,7 @@ class SygusGrammarNorm
     /** build type */
     void buildType(SygusGrammarNorm* sygus_norm,
                    TypeObject& to,
-                   const Datatype& dt,
+                   const DType& dt,
                    std::vector<unsigned>& op_pos) override;
 
    private:
@@ -329,7 +330,7 @@ class SygusGrammarNorm
      */
     void buildType(SygusGrammarNorm* sygus_norm,
                    TypeObject& to,
-                   const Datatype& dt,
+                   const DType& dt,
                    std::vector<unsigned>& op_pos) override;
 
     /** Whether operator is chainable for the type (e.g. PLUS for Int)
@@ -421,7 +422,7 @@ class SygusGrammarNorm
    * recursion depth is limited by the height of the types, which is small
    */
   TypeNode normalizeSygusRec(TypeNode tn,
-                             const Datatype& dt,
+                             const DType& dt,
                              std::vector<unsigned>& op_pos);
 
   /** wrapper for the above function
@@ -436,7 +437,7 @@ class SygusGrammarNorm
    * TODO: #1304: Infer more complex transformations
    */
   std::unique_ptr<Transf> inferTransf(TypeNode tn,
-                                      const Datatype& dt,
+                                      const DType& dt,
                                       const std::vector<unsigned>& op_pos);
 }; /* class SygusGrammarNorm */