From 3076c4e70ded49d4b54585738d2dfc1d4aed1b9c Mon Sep 17 00:00:00 2001 From: Aina Niemetz Date: Wed, 30 Mar 2022 08:07:13 -0700 Subject: [PATCH] TypeNode: Unify functions to instantiate parametric sorts. (#8449) This unifies `instantiateParametricDatatype()` and `instantiateSortConstructor()` into `instantiate()`. It further fixes how the API calls TypeNode instantation. --- src/api/cpp/cvc5.cpp | 7 +--- src/expr/dtype.cpp | 2 +- src/expr/dtype_cons.cpp | 5 +-- src/expr/symbol_table.cpp | 4 +- src/expr/type_node.cpp | 39 +++++++++---------- src/expr/type_node.h | 35 ++++++++--------- .../datatypes/theory_datatypes_type_rules.cpp | 2 +- 7 files changed, 41 insertions(+), 53 deletions(-) diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index a22923cc9..a0604537c 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -1424,12 +1424,7 @@ Sort Sort::instantiate(const std::vector& params) const << "Arity mismatch for instantiated sort constructor"; //////// all checks before this line std::vector tparams = sortVectorToTypeNodes(params); - if (d_type->isDatatype()) - { - return Sort(d_solver, d_type->instantiateParametricDatatype(tparams)); - } - Assert(d_type->isUninterpretedSortConstructor()); - return Sort(d_solver, d_solver->getNodeManager()->mkSort(*d_type, tparams)); + return Sort(d_solver, d_type->instantiate(tparams)); //////// CVC5_API_TRY_CATCH_END; } diff --git a/src/expr/dtype.cpp b/src/expr/dtype.cpp index 5fbceee64..fa332f2d4 100644 --- a/src/expr/dtype.cpp +++ b/src/expr/dtype.cpp @@ -833,7 +833,7 @@ TypeNode DType::getTypeNode(const std::vector& params) const { Assert(isResolved()); Assert(!d_self.isNull() && d_self.isParametricDatatype()); - return d_self.instantiateParametricDatatype(params); + return d_self.instantiate(params); } const DTypeConstructor& DType::operator[](size_t index) const diff --git a/src/expr/dtype_cons.cpp b/src/expr/dtype_cons.cpp index 8cde888d6..b19deea0e 100644 --- a/src/expr/dtype_cons.cpp +++ b/src/expr/dtype_cons.cpp @@ -648,11 +648,10 @@ TypeNode DTypeConstructor::doParametricSubstitution( if (paramTypes[i].getUninterpretedSortConstructorArity() == origChildren.size()) { - TypeNode tn = paramTypes[i].instantiateSortConstructor(origChildren); + TypeNode tn = paramTypes[i].instantiate(origChildren); if (range == tn) { - TypeNode tret = - paramReplacements[i].instantiateParametricDatatype(children); + TypeNode tret = paramReplacements[i].instantiate(children); return tret; } } diff --git a/src/expr/symbol_table.cpp b/src/expr/symbol_table.cpp index 7280c5902..f6153372a 100644 --- a/src/expr/symbol_table.cpp +++ b/src/expr/symbol_table.cpp @@ -511,8 +511,8 @@ cvc5::Sort SymbolTable::Implementation::lookupType( << "type is " << p.second << std::endl; } cvc5::Sort instantiation = isUninterpretedSortConstructor - ? p.second.instantiate(params) - : p.second.substitute(p.first, params); + ? p.second.instantiate(params) + : p.second.substitute(p.first, params); Trace("sort") << "instance is " << instantiation << std::endl; return instantiation; diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 2f92e9233..9b4fd46c1 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -342,7 +342,7 @@ TypeNode TypeNode::getBaseType() const { for(size_t i = 1; i < getNumChildren(); ++i) { v.push_back((*this)[i].getBaseType()); } - return (*this)[0].getDType().getTypeNode().instantiateParametricDatatype(v); + return (*this)[0].getDType().getTypeNode().instantiate(v); } return *this; } @@ -428,20 +428,24 @@ bool TypeNode::isInstantiated() const || (isUninterpretedSort() && getNumChildren() > 0); } -TypeNode TypeNode::instantiateParametricDatatype( - const std::vector& params) const +TypeNode TypeNode::instantiate(const std::vector& params) const { - AssertArgument(getKind() == kind::PARAMETRIC_DATATYPE, *this); - AssertArgument(params.size() == getNumChildren() - 1, *this); NodeManager* nm = NodeManager::currentNM(); - TypeNode cons = nm->mkTypeConst((*this)[0].getConst()); - std::vector paramsNodes; - paramsNodes.push_back(cons); - for (const TypeNode& t : params) + if (getKind() == kind::PARAMETRIC_DATATYPE) { - paramsNodes.push_back(t); + Assert(params.size() == getNumChildren() - 1); + TypeNode cons = + nm->mkTypeConst((*this)[0].getConst()); + std::vector paramsNodes; + paramsNodes.push_back(cons); + for (const TypeNode& t : params) + { + paramsNodes.push_back(t); + } + return nm->mkTypeNode(kind::PARAMETRIC_DATATYPE, paramsNodes); } - return nm->mkTypeNode(kind::PARAMETRIC_DATATYPE, paramsNodes); + Assert(isUninterpretedSortConstructor()); + return nm->mkSort(*this, params); } uint64_t TypeNode::getUninterpretedSortConstructorArity() const @@ -457,18 +461,11 @@ std::string TypeNode::getName() const return getAttribute(expr::VarNameAttr()); } -TypeNode TypeNode::instantiateSortConstructor( - const std::vector& params) const +bool TypeNode::isParameterInstantiatedDatatype(size_t n) const { - Assert(isUninterpretedSortConstructor()); - return NodeManager::currentNM()->mkSort(*this, params); -} - -/** Is this an instantiated datatype parameter */ -bool TypeNode::isParameterInstantiatedDatatype(unsigned n) const { - AssertArgument(getKind() == kind::PARAMETRIC_DATATYPE, *this); + Assert(getKind() == kind::PARAMETRIC_DATATYPE); const DType& dt = (*this)[0].getDType(); - AssertArgument(n < dt.getNumParameters(), *this); + Assert(n < dt.getNumParameters()); return dt.getParameter(n) != (*this)[n + 1]; } diff --git a/src/expr/type_node.h b/src/expr/type_node.h index 584c64554..495f5b383 100644 --- a/src/expr/type_node.h +++ b/src/expr/type_node.h @@ -613,17 +613,25 @@ private: bool isSygusDatatype() const; /** - * Get instantiated datatype type. The type on which this method is called - * should be a parametric datatype whose parameter list is the same size as - * argument params. This constructs the instantiated version of this - * parametric datatype, e.g. passing (par (A) (List A)), { Int } ) to this - * method returns (List Int). + * Instantiate parametric type (parametric datatype or uninterpreted sort + * constructor type). + * + * The parameter list of this type must be the same size as the list of + * argument parameters `params`. + * + * If this TypeNode is a parametric datatype, this constructs the + * instantiated version of this parametric datatype. For example, passing + * (par (A) (List A)), { Int } ) to this method returns (List Int). + * + * If this is an uninterpreted sort constructor type, this constructs the + * instantiated version of this sort constructor. For example, for a sort + * constructor declared via (declare-sort U 2), passing { Int, Int } will + * generate the instantiated sort (U Int Int). */ - TypeNode instantiateParametricDatatype( - const std::vector& params) const; + TypeNode instantiate(const std::vector& params) const; /** Is this an instantiated datatype parameter */ - bool isParameterInstantiatedDatatype(unsigned n) const; + bool isParameterInstantiatedDatatype(size_t n) const; /** Is this a constructor type */ bool isConstructor() const; @@ -663,17 +671,6 @@ private: */ std::string getName() const; - /** - * Instantiate a sort constructor type. The type on which this method is - * called should be a sort constructor type whose parameter list is the - * same size as argument params. This constructs the instantiated version of - * this sort constructor. For example, this is a sort constructor, e.g. - * declared via (declare-sort U 2), then calling this method with - * { Int, Int } will generate the instantiated sort (U Int Int). - */ - TypeNode instantiateSortConstructor( - const std::vector& params) const; - /** Get the most general base type of the type */ TypeNode getBaseType() const; diff --git a/src/theory/datatypes/theory_datatypes_type_rules.cpp b/src/theory/datatypes/theory_datatypes_type_rules.cpp index bc4861e40..21bfb46b5 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.cpp +++ b/src/theory/datatypes/theory_datatypes_type_rules.cpp @@ -67,7 +67,7 @@ TypeNode DatatypeConstructorTypeRule::computeType(NodeManager* nodeManager, } std::vector instTypes; m.getMatches(instTypes); - TypeNode range = t.instantiateParametricDatatype(instTypes); + TypeNode range = t.instantiate(instTypes); Trace("typecheck-idt") << "Return " << range << std::endl; return range; } -- 2.30.2