TypeNode: Unify functions to instantiate parametric sorts. (#8449)
authorAina Niemetz <aina.niemetz@gmail.com>
Wed, 30 Mar 2022 15:07:13 +0000 (08:07 -0700)
committerGitHub <noreply@github.com>
Wed, 30 Mar 2022 15:07:13 +0000 (15:07 +0000)
This unifies `instantiateParametricDatatype()` and
`instantiateSortConstructor()` into `instantiate()`. It further fixes
how the API calls TypeNode instantation.

src/api/cpp/cvc5.cpp
src/expr/dtype.cpp
src/expr/dtype_cons.cpp
src/expr/symbol_table.cpp
src/expr/type_node.cpp
src/expr/type_node.h
src/theory/datatypes/theory_datatypes_type_rules.cpp

index a22923cc9d7bfb5db41215646e77aea5d248ab6d..a0604537c1b895b9a3bdfb04ed6800f286cbb6b4 100644 (file)
@@ -1424,12 +1424,7 @@ Sort Sort::instantiate(const std::vector<Sort>& params) const
       << "Arity mismatch for instantiated sort constructor";
   //////// all checks before this line
   std::vector<internal::TypeNode> tparams = sortVectorToTypeNodes(params);
-  if (d_type->isDatatype())
-  {
-    return Sort(d_solver, d_type->instantiateParametricDatatype(tparams));
-  }
-  Assert(d_type->isUninterpretedSortConstructor());
-  return Sort(d_solver, d_solver->getNodeManager()->mkSort(*d_type, tparams));
+  return Sort(d_solver, d_type->instantiate(tparams));
   ////////
   CVC5_API_TRY_CATCH_END;
 }
index 5fbceee64ec7dd4a20a288204cae0d61fd4b4b1e..fa332f2d4cea75b12278c921a9a5ed214e47abaf 100644 (file)
@@ -833,7 +833,7 @@ TypeNode DType::getTypeNode(const std::vector<TypeNode>& params) const
 {
   Assert(isResolved());
   Assert(!d_self.isNull() && d_self.isParametricDatatype());
-  return d_self.instantiateParametricDatatype(params);
+  return d_self.instantiate(params);
 }
 
 const DTypeConstructor& DType::operator[](size_t index) const
index 8cde888d6975e73c78526d11f02b026c0d1d0f01..b19deea0e619fe2faca5d8185e50f6125b1bc920 100644 (file)
@@ -648,11 +648,10 @@ TypeNode DTypeConstructor::doParametricSubstitution(
     if (paramTypes[i].getUninterpretedSortConstructorArity()
         == origChildren.size())
     {
-      TypeNode tn = paramTypes[i].instantiateSortConstructor(origChildren);
+      TypeNode tn = paramTypes[i].instantiate(origChildren);
       if (range == tn)
       {
-        TypeNode tret =
-            paramReplacements[i].instantiateParametricDatatype(children);
+        TypeNode tret = paramReplacements[i].instantiate(children);
         return tret;
       }
     }
index 7280c590290bd86fb6c8ea2ee5ec4d034e234eed..f6153372a03fbc6aab752544f59caff961946aff 100644 (file)
@@ -511,8 +511,8 @@ cvc5::Sort SymbolTable::Implementation::lookupType(
                   << "type is      " << p.second << std::endl;
   }
   cvc5::Sort instantiation = isUninterpretedSortConstructor
-                                ? p.second.instantiate(params)
-                                : p.second.substitute(p.first, params);
+                                 ? p.second.instantiate(params)
+                                 : p.second.substitute(p.first, params);
   Trace("sort") << "instance is  " << instantiation << std::endl;
 
   return instantiation;
index 2f92e92332b0f91acb12069bbd837ea982884fa0..9b4fd46c103198d08ba539a16c944cc0c4927b4a 100644 (file)
@@ -342,7 +342,7 @@ TypeNode TypeNode::getBaseType() const {
     for(size_t i = 1; i < getNumChildren(); ++i) {
       v.push_back((*this)[i].getBaseType());
     }
-    return (*this)[0].getDType().getTypeNode().instantiateParametricDatatype(v);
+    return (*this)[0].getDType().getTypeNode().instantiate(v);
   }
   return *this;
 }
@@ -428,20 +428,24 @@ bool TypeNode::isInstantiated() const
          || (isUninterpretedSort() && getNumChildren() > 0);
 }
 
-TypeNode TypeNode::instantiateParametricDatatype(
-    const std::vector<TypeNode>& params) const
+TypeNode TypeNode::instantiate(const std::vector<TypeNode>& params) const
 {
-  AssertArgument(getKind() == kind::PARAMETRIC_DATATYPE, *this);
-  AssertArgument(params.size() == getNumChildren() - 1, *this);
   NodeManager* nm = NodeManager::currentNM();
-  TypeNode cons = nm->mkTypeConst((*this)[0].getConst<DatatypeIndexConstant>());
-  std::vector<TypeNode> paramsNodes;
-  paramsNodes.push_back(cons);
-  for (const TypeNode& t : params)
+  if (getKind() == kind::PARAMETRIC_DATATYPE)
   {
-    paramsNodes.push_back(t);
+    Assert(params.size() == getNumChildren() - 1);
+    TypeNode cons =
+        nm->mkTypeConst((*this)[0].getConst<DatatypeIndexConstant>());
+    std::vector<TypeNode> paramsNodes;
+    paramsNodes.push_back(cons);
+    for (const TypeNode& t : params)
+    {
+      paramsNodes.push_back(t);
+    }
+    return nm->mkTypeNode(kind::PARAMETRIC_DATATYPE, paramsNodes);
   }
-  return nm->mkTypeNode(kind::PARAMETRIC_DATATYPE, paramsNodes);
+  Assert(isUninterpretedSortConstructor());
+  return nm->mkSort(*this, params);
 }
 
 uint64_t TypeNode::getUninterpretedSortConstructorArity() const
@@ -457,18 +461,11 @@ std::string TypeNode::getName() const
   return getAttribute(expr::VarNameAttr());
 }
 
-TypeNode TypeNode::instantiateSortConstructor(
-    const std::vector<TypeNode>& params) const
+bool TypeNode::isParameterInstantiatedDatatype(size_t n) const
 {
-  Assert(isUninterpretedSortConstructor());
-  return NodeManager::currentNM()->mkSort(*this, params);
-}
-
-/** Is this an instantiated datatype parameter */
-bool TypeNode::isParameterInstantiatedDatatype(unsigned n) const {
-  AssertArgument(getKind() == kind::PARAMETRIC_DATATYPE, *this);
+  Assert(getKind() == kind::PARAMETRIC_DATATYPE);
   const DType& dt = (*this)[0].getDType();
-  AssertArgument(n < dt.getNumParameters(), *this);
+  Assert(n < dt.getNumParameters());
   return dt.getParameter(n) != (*this)[n + 1];
 }
 
index 584c64554d940b6f1bcb368fd099ee2289899212..495f5b383ade030a6dc107c18055b8279e37b1f0 100644 (file)
@@ -613,17 +613,25 @@ private:
   bool isSygusDatatype() const;
 
   /**
-   * Get instantiated datatype type. The type on which this method is called
-   * should be a parametric datatype whose parameter list is the same size as
-   * argument params. This constructs the instantiated version of this
-   * parametric datatype, e.g. passing (par (A) (List A)), { Int } ) to this
-   * method returns (List Int).
+   * Instantiate parametric type (parametric datatype or uninterpreted sort
+   * constructor type).
+   *
+   * The parameter list of this type must be the same size as the list of
+   * argument parameters `params`.
+   *
+   * If this TypeNode is a parametric datatype, this constructs the
+   * instantiated version of this parametric datatype. For example, passing
+   * (par (A) (List A)), { Int } ) to this method returns (List Int).
+   *
+   * If this is an uninterpreted sort constructor type, this constructs the
+   * instantiated version of this sort constructor. For example, for a sort
+   * constructor declared via (declare-sort U 2), passing { Int, Int } will
+   * generate the instantiated sort (U Int Int).
    */
-  TypeNode instantiateParametricDatatype(
-      const std::vector<TypeNode>& params) const;
+  TypeNode instantiate(const std::vector<TypeNode>& params) const;
 
   /** Is this an instantiated datatype parameter */
-  bool isParameterInstantiatedDatatype(unsigned n) const;
+  bool isParameterInstantiatedDatatype(size_t n) const;
 
   /** Is this a constructor type */
   bool isConstructor() const;
@@ -663,17 +671,6 @@ private:
    */
   std::string getName() const;
 
-  /**
-   * Instantiate a sort constructor type. The type on which this method is
-   * called should be a sort constructor type whose parameter list is the
-   * same size as argument params. This constructs the instantiated version of
-   * this sort constructor. For example, this is a sort constructor, e.g.
-   * declared via (declare-sort U 2), then calling this method with
-   * { Int, Int } will generate the instantiated sort (U Int Int).
-   */
-  TypeNode instantiateSortConstructor(
-      const std::vector<TypeNode>& params) const;
-
   /** Get the most general base type of the type */
   TypeNode getBaseType() const;
 
index bc4861e4090aa503beefedfeffca949045157e6d..21bfb46b555b2505aafb93e745aef85fcaa82ab3 100644 (file)
@@ -67,7 +67,7 @@ TypeNode DatatypeConstructorTypeRule::computeType(NodeManager* nodeManager,
     }
     std::vector<TypeNode> instTypes;
     m.getMatches(instTypes);
-    TypeNode range = t.instantiateParametricDatatype(instTypes);
+    TypeNode range = t.instantiate(instTypes);
     Trace("typecheck-idt") << "Return " << range << std::endl;
     return range;
   }