TypeNode: Refactor get param types handling. (#8428)
authorAina Niemetz <aina.niemetz@gmail.com>
Tue, 29 Mar 2022 08:33:05 +0000 (01:33 -0700)
committerGitHub <noreply@github.com>
Tue, 29 Mar 2022 08:33:05 +0000 (08:33 +0000)
This introduces TypeNode::getInstantiatedParamTypes() and deletes
TypeNode::getParamTypes(). This is in preparation for adding
api::Sort::getInstantiatedParameters().

src/api/cpp/cvc5.cpp
src/expr/dtype_cons.cpp
src/expr/type_matcher.cpp
src/expr/type_node.cpp
src/expr/type_node.h

index 8b16ea8bc16f27ff0815ae8e79de86ff659edb4b..4879779632856e597c3944540f8e3749964535cd 100644 (file)
@@ -1727,7 +1727,7 @@ std::vector<Sort> Sort::getDatatypeParamSorts() const
   CVC5_API_CHECK(d_type->isParametricDatatype())
       << "Not a parametric datatype sort.";
   //////// all checks before this line
-  return typeNodeVectorToSorts(d_solver, d_type->getParamTypes());
+  return typeNodeVectorToSorts(d_solver, d_type->getDType().getParameters());
   ////////
   CVC5_API_TRY_CATCH_END;
 }
index 269f3ebca3b96bea43d58a2d65758b1841fa8a63..c9c70f7789319a8541393ca3f8cf37e87b6da0a5 100644 (file)
@@ -199,7 +199,7 @@ std::pair<CardinalityClass, bool> DTypeConstructor::computeCardinalityInfo(
   if (isParam)
   {
     paramTypes = t.getDType().getParameters();
-    instTypes = t.getParamTypes();
+    instTypes = t.getInstantiatedParamTypes();
   }
   for (unsigned i = 0, nargs = getNumArgs(); i < nargs; i++)
   {
@@ -328,7 +328,7 @@ Cardinality DTypeConstructor::computeCardinality(
   if (isParam)
   {
     paramTypes = t.getDType().getParameters();
-    instTypes = t.getParamTypes();
+    instTypes = t.getInstantiatedParamTypes();
   }
   for (size_t i = 0, nargs = d_args.size(); i < nargs; i++)
   {
@@ -390,7 +390,7 @@ Node DTypeConstructor::computeGroundTerm(TypeNode t,
   if (isParam)
   {
     paramTypes = t.getDType().getParameters();
-    instTypes = TypeNode(t).getParamTypes();
+    instTypes = TypeNode(t).getInstantiatedParamTypes();
   }
   for (size_t i = 0, nargs = getNumArgs(); i < nargs; i++)
   {
index cd244df148e6844b2f5e2984e89b8832728ac776..cc33494d8e66c4c8249d3e77fb04e8b022a1a8a5 100644 (file)
@@ -15,6 +15,8 @@
 
 #include "type_matcher.h"
 
+#include "expr/dtype.h"
+
 namespace cvc5 {
 
 TypeMatcher::TypeMatcher(TypeNode dt)
@@ -25,7 +27,15 @@ TypeMatcher::TypeMatcher(TypeNode dt)
 
 void TypeMatcher::addTypesFromDatatype(TypeNode dt)
 {
-  std::vector<TypeNode> argTypes = dt.getParamTypes();
+  std::vector<TypeNode> argTypes;
+  if (dt.isInstantiated())
+  {
+    argTypes = dt.getInstantiatedParamTypes();
+  }
+  else
+  {
+    argTypes = dt.getDType().getParameters();
+  }
   addTypes(argTypes);
   Trace("typecheck-idt") << "instantiating matcher for " << dt << std::endl;
   for (unsigned i = 0, narg = argTypes.size(); i < narg; ++i)
index a90efd055865557cc51b5dbcc7712347c4d30a7f..b784e8ce154b7571934abcfa38761dbbdb40676d 100644 (file)
@@ -360,10 +360,12 @@ std::vector<TypeNode> TypeNode::getArgTypes() const {
   return args;
 }
 
-std::vector<TypeNode> TypeNode::getParamTypes() const {
+std::vector<TypeNode> TypeNode::getInstantiatedParamTypes() const
+{
+  Assert(isInstantiated());
   vector<TypeNode> params;
-  Assert(isParametricDatatype());
-  for(unsigned i = 1, i_end = getNumChildren(); i < i_end; ++i) {
+  for (uint32_t i = 1, i_end = getNumChildren(); i < i_end; ++i)
+  {
     params.push_back((*this)[i]);
   }
   return params;
index 8f8985d8650b664ca62b094d6ac2a1c4b45a1027..60b24d57e6140990849a28ffd59f3065ad6a3fd8 100644 (file)
@@ -529,10 +529,16 @@ private:
   std::vector<TypeNode> getArgTypes() const;
 
   /**
-   * Get the paramater types of a parameterized datatype.  Fails an
-   * assertion if this type is not a parametric datatype.
+   * Get the types used to instantiate the type parameters of a parametric
+   * type (parametric datatype or uninterpreted sort constructor type,
+   * see TypeNode::instantiate(const std::vector<TypeNode>& const).
+   *
+   * Asserts that this type is an instantiated type.
+   *
+   * @return the types used to instantiate the type parameters of a
+   *         parametric type
    */
-  std::vector<TypeNode> getParamTypes() const;
+  std::vector<TypeNode> getInstantiatedParamTypes() const;
 
   /**
    * Get the range type (i.e., the type of the result) of a function,