From: Andrew Reynolds Date: Fri, 30 Oct 2020 02:51:18 +0000 (-0500) Subject: Update api::Sort to use TypeNode instead of Type (#5363) X-Git-Tag: cvc5-1.0.0~2644 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=21fd193bdaad1a952845326aa1c84654cfce1503;p=cvc5.git Update api::Sort to use TypeNode instead of Type (#5363) This is work towards removing the old API. This makes TypeNode the backend for Sort instead of Type. It also updates a unit test for methods isUninterpretedSortParameterized and getUninterpretedSortParamSorts whose implementation was previously buggy due to the implementation of Type-level SortType. --- diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 507e270bb..e16d8c519 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -49,6 +49,7 @@ #include "expr/node_manager.h" #include "expr/sequence.h" #include "expr/type.h" +#include "expr/type_node.h" #include "options/main_options.h" #include "options/options.h" #include "options/smt_options.h" @@ -945,13 +946,26 @@ std::ostream& operator<<(std::ostream& out, const Result& r) /* -------------------------------------------------------------------------- */ Sort::Sort(const Solver* slv, const CVC4::Type& t) - : d_solver(slv), d_type(new CVC4::Type(t)) + : d_solver(slv), d_type(new CVC4::TypeNode(TypeNode::fromType(t))) +{ +} +Sort::Sort(const Solver* slv, const CVC4::TypeNode& t) + : d_solver(slv), d_type(new CVC4::TypeNode(t)) { } -Sort::Sort() : d_solver(nullptr), d_type(new CVC4::Type()) {} +Sort::Sort() : d_solver(nullptr), d_type(new CVC4::TypeNode()) {} -Sort::~Sort() {} +Sort::~Sort() +{ + if (d_solver != nullptr) + { + // Ensure that the correct node manager is in scope when the node is + // destroyed. + NodeManagerScope scope(d_solver->getNodeManager()); + d_type.reset(); + } +} /* Helpers */ /* -------------------------------------------------------------------------- */ @@ -996,7 +1010,7 @@ bool Sort::isDatatype() const { return d_type->isDatatype(); } bool Sort::isParametricDatatype() const { if (!d_type->isDatatype()) return false; - return TypeNode::fromType(*d_type).isParametricDatatype(); + return d_type->isParametricDatatype(); } bool Sort::isConstructor() const { return d_type->isConstructor(); } @@ -1015,7 +1029,7 @@ bool Sort::isArray() const { return d_type->isArray(); } bool Sort::isSet() const { return d_type->isSet(); } -bool Sort::isBag() const { return TypeNode::fromType(*d_type).isBag(); } +bool Sort::isBag() const { return d_type->isBag(); } bool Sort::isSequence() const { return d_type->isSequence(); } @@ -1038,7 +1052,7 @@ Datatype Sort::getDatatype() const { NodeManagerScope scope(d_solver->getNodeManager()); CVC4_API_CHECK(isDatatype()) << "Expected datatype sort."; - return Datatype(d_solver, TypeNode::fromType(*d_type).getDType()); + return Datatype(d_solver, d_type->getDType()); } Sort Sort::instantiate(const std::vector& params) const @@ -1046,23 +1060,13 @@ Sort Sort::instantiate(const std::vector& params) const NodeManagerScope scope(d_solver->getNodeManager()); CVC4_API_CHECK(isParametricDatatype() || isSortConstructor()) << "Expected parametric datatype or sort constructor sort."; - std::vector tparams; - for (const Sort& s : params) - { - tparams.push_back(TypeNode::fromType(*s.d_type.get())); - } + std::vector tparams = sortVectorToTypeNodes(params); if (d_type->isDatatype()) { - return Sort(d_solver, - TypeNode::fromType(*d_type) - .instantiateParametricDatatype(tparams) - .toType()); + return Sort(d_solver, d_type->instantiateParametricDatatype(tparams)); } Assert(d_type->isSortConstructor()); - return Sort(d_solver, - d_solver->getNodeManager() - ->mkSort(TypeNode::fromType(*d_type), tparams) - .toType()); + return Sort(d_solver, d_solver->getNodeManager()->mkSort(*d_type, tparams)); } std::string Sort::toString() const @@ -1077,27 +1081,32 @@ std::string Sort::toString() const // !!! This is only temporarily available until the parser is fully migrated // to the new API. !!! -CVC4::Type Sort::getType(void) const { return *d_type; } +CVC4::Type Sort::getType(void) const +{ + if (d_type->isNull()) return Type(); + NodeManagerScope scope(d_solver->getNodeManager()); + return d_type->toType(); +} +const CVC4::TypeNode& Sort::getTypeNode(void) const { return *d_type; } /* Constructor sort ------------------------------------------------------- */ size_t Sort::getConstructorArity() const { CVC4_API_CHECK(isConstructor()) << "Not a constructor sort: " << (*this); - return ConstructorType(*d_type).getArity(); + return d_type->getNumChildren() - 1; } std::vector Sort::getConstructorDomainSorts() const { CVC4_API_CHECK(isConstructor()) << "Not a constructor sort: " << (*this); - std::vector types = ConstructorType(*d_type).getArgTypes(); - return typeVectorToSorts(d_solver, types); + return typeNodeVectorToSorts(d_solver, d_type->getArgTypes()); } Sort Sort::getConstructorCodomainSort() const { CVC4_API_CHECK(isConstructor()) << "Not a constructor sort: " << (*this); - return Sort(d_solver, ConstructorType(*d_type).getRangeType()); + return Sort(d_solver, d_type->getConstructorRangeType()); } /* Selector sort ------------------------------------------------------- */ @@ -1105,15 +1114,13 @@ Sort Sort::getConstructorCodomainSort() const Sort Sort::getSelectorDomainSort() const { CVC4_API_CHECK(isSelector()) << "Not a selector sort: " << (*this); - TypeNode typeNode = TypeNode::fromType(*d_type); - return Sort(d_solver, typeNode.getSelectorDomainType().toType()); + return Sort(d_solver, d_type->getSelectorDomainType()); } Sort Sort::getSelectorCodomainSort() const { CVC4_API_CHECK(isSelector()) << "Not a selector sort: " << (*this); - TypeNode typeNode = TypeNode::fromType(*d_type); - return Sort(d_solver, typeNode.getSelectorRangeType().toType()); + return Sort(d_solver, d_type->getSelectorRangeType()); } /* Tester sort ------------------------------------------------------- */ @@ -1121,8 +1128,7 @@ Sort Sort::getSelectorCodomainSort() const Sort Sort::getTesterDomainSort() const { CVC4_API_CHECK(isTester()) << "Not a tester sort: " << (*this); - TypeNode typeNode = TypeNode::fromType(*d_type); - return Sort(d_solver, typeNode.getTesterDomainType().toType()); + return Sort(d_solver, d_type->getTesterDomainType()); } Sort Sort::getTesterCodomainSort() const @@ -1136,20 +1142,19 @@ Sort Sort::getTesterCodomainSort() const size_t Sort::getFunctionArity() const { CVC4_API_CHECK(isFunction()) << "Not a function sort: " << (*this); - return FunctionType(*d_type).getArity(); + return d_type->getNumChildren() - 1; } std::vector Sort::getFunctionDomainSorts() const { CVC4_API_CHECK(isFunction()) << "Not a function sort: " << (*this); - std::vector types = FunctionType(*d_type).getArgTypes(); - return typeVectorToSorts(d_solver, types); + return typeNodeVectorToSorts(d_solver, d_type->getArgTypes()); } Sort Sort::getFunctionCodomainSort() const { CVC4_API_CHECK(isFunction()) << "Not a function sort" << (*this); - return Sort(d_solver, FunctionType(*d_type).getRangeType()); + return Sort(d_solver, d_type->getRangeType()); } /* Array sort ---------------------------------------------------------- */ @@ -1157,13 +1162,13 @@ Sort Sort::getFunctionCodomainSort() const Sort Sort::getArrayIndexSort() const { CVC4_API_CHECK(isArray()) << "Not an array sort."; - return Sort(d_solver, ArrayType(*d_type).getIndexType()); + return Sort(d_solver, d_type->getArrayIndexType()); } Sort Sort::getArrayElementSort() const { CVC4_API_CHECK(isArray()) << "Not an array sort."; - return Sort(d_solver, ArrayType(*d_type).getConstituentType()); + return Sort(d_solver, d_type->getArrayConstituentType()); } /* Set sort ------------------------------------------------------------ */ @@ -1171,7 +1176,7 @@ Sort Sort::getArrayElementSort() const Sort Sort::getSetElementSort() const { CVC4_API_CHECK(isSet()) << "Not a set sort."; - return Sort(d_solver, SetType(*d_type).getElementType()); + return Sort(d_solver, d_type->getSetElementType()); } /* Bag sort ------------------------------------------------------------ */ @@ -1179,9 +1184,7 @@ Sort Sort::getSetElementSort() const Sort Sort::getBagElementSort() const { CVC4_API_CHECK(isBag()) << "Not a bag sort."; - TypeNode typeNode = TypeNode::fromType(*d_type); - Type type = typeNode.getBagElementType().toType(); - return Sort(d_solver, type); + return Sort(d_solver, d_type->getBagElementType()); } /* Sequence sort ------------------------------------------------------- */ @@ -1189,7 +1192,7 @@ Sort Sort::getBagElementSort() const Sort Sort::getSequenceElementSort() const { CVC4_API_CHECK(isSequence()) << "Not a sequence sort."; - return Sort(d_solver, SequenceType(*d_type).getElementType()); + return Sort(d_solver, d_type->getSequenceElementType()); } /* Uninterpreted sort -------------------------------------------------- */ @@ -1197,20 +1200,28 @@ Sort Sort::getSequenceElementSort() const std::string Sort::getUninterpretedSortName() const { CVC4_API_CHECK(isUninterpretedSort()) << "Not an uninterpreted sort."; - return SortType(*d_type).getName(); + return d_type->getName(); } bool Sort::isUninterpretedSortParameterized() const { CVC4_API_CHECK(isUninterpretedSort()) << "Not an uninterpreted sort."; - return SortType(*d_type).isParameterized(); + // This method is not implemented in the NodeManager, since whether a + // uninterpreted sort is parametrized is irrelevant for solving. + return d_type->getNumChildren() > 0; } std::vector Sort::getUninterpretedSortParamSorts() const { CVC4_API_CHECK(isUninterpretedSort()) << "Not an uninterpreted sort."; - std::vector types = SortType(*d_type).getParamTypes(); - return typeVectorToSorts(d_solver, types); + // This method is not implemented in the NodeManager, since whether a + // uninterpreted sort is parametrized is irrelevant for solving. + std::vector params; + for (size_t i = 0, nchildren = d_type->getNumChildren(); i < nchildren; i++) + { + params.push_back((*d_type)[i]); + } + return typeNodeVectorToSorts(d_solver, params); } /* Sort constructor sort ----------------------------------------------- */ @@ -1218,13 +1229,13 @@ std::vector Sort::getUninterpretedSortParamSorts() const std::string Sort::getSortConstructorName() const { CVC4_API_CHECK(isSortConstructor()) << "Not a sort constructor sort."; - return SortConstructorType(*d_type).getName(); + return d_type->getName(); } size_t Sort::getSortConstructorArity() const { CVC4_API_CHECK(isSortConstructor()) << "Not a sort constructor sort."; - return SortConstructorType(*d_type).getArity(); + return d_type->getSortConstructorArity(); } /* Bit-vector sort ----------------------------------------------------- */ @@ -1232,7 +1243,7 @@ size_t Sort::getSortConstructorArity() const uint32_t Sort::getBVSize() const { CVC4_API_CHECK(isBitVector()) << "Not a bit-vector sort."; - return BitVectorType(*d_type).getSize(); + return d_type->getBitVectorSize(); } /* Floating-point sort ------------------------------------------------- */ @@ -1240,13 +1251,13 @@ uint32_t Sort::getBVSize() const uint32_t Sort::getFPExponentSize() const { CVC4_API_CHECK(isFloatingPoint()) << "Not a floating-point sort."; - return FloatingPointType(*d_type).getExponentSize(); + return d_type->getFloatingPointExponentSize(); } uint32_t Sort::getFPSignificandSize() const { CVC4_API_CHECK(isFloatingPoint()) << "Not a floating-point sort."; - return FloatingPointType(*d_type).getSignificandSize(); + return d_type->getFloatingPointSignificandSize(); } /* Datatype sort ------------------------------------------------------- */ @@ -1254,20 +1265,14 @@ uint32_t Sort::getFPSignificandSize() const std::vector Sort::getDatatypeParamSorts() const { CVC4_API_CHECK(isParametricDatatype()) << "Not a parametric datatype sort."; - std::vector typeNodes = - TypeNode::fromType(*d_type).getParamTypes(); - std::vector sorts; - for (size_t i = 0, tsize = typeNodes.size(); i < tsize; i++) - { - sorts.push_back(Sort(d_solver, typeNodes[i].toType())); - } - return sorts; + std::vector typeNodes = d_type->getParamTypes(); + return typeNodeVectorToSorts(d_solver, typeNodes); } size_t Sort::getDatatypeArity() const { CVC4_API_CHECK(isDatatype()) << "Not a datatype sort."; - return TypeNode::fromType(*d_type).getNumChildren() - 1; + return d_type->getNumChildren() - 1; } /* Tuple sort ---------------------------------------------------------- */ @@ -1275,20 +1280,14 @@ size_t Sort::getDatatypeArity() const size_t Sort::getTupleLength() const { CVC4_API_CHECK(isTuple()) << "Not a tuple sort."; - return TypeNode::fromType(*d_type).getTupleLength(); + return d_type->getTupleLength(); } std::vector Sort::getTupleSorts() const { CVC4_API_CHECK(isTuple()) << "Not a tuple sort."; - std::vector typeNodes = - TypeNode::fromType(*d_type).getTupleTypes(); - std::vector sorts; - for (size_t i = 0, tsize = typeNodes.size(); i < tsize; i++) - { - sorts.push_back(Sort(d_solver, typeNodes[i].toType())); - } - return sorts; + std::vector typeNodes = d_type->getTupleTypes(); + return typeNodeVectorToSorts(d_solver, typeNodes); } /* --------------------------------------------------------------------- */ @@ -1301,7 +1300,7 @@ std::ostream& operator<<(std::ostream& out, const Sort& s) size_t SortHashFunction::operator()(const Sort& s) const { - return TypeHashFunction()(*s.d_type); + return TypeNodeHashFunction()(*s.d_type); } /* -------------------------------------------------------------------------- */ @@ -1329,7 +1328,7 @@ Op::~Op() { if (d_solver != nullptr) { - // Ensure that the correct node manager is in scope when the node is + // Ensure that the correct node manager is in scope when the type node is // destroyed. NodeManagerScope scope(d_solver->getNodeManager()); d_node.reset(); @@ -1709,7 +1708,7 @@ Sort Term::getSort() const { CVC4_API_CHECK_NOT_NULL; NodeManagerScope scope(d_solver->getNodeManager()); - return Sort(d_solver, d_node->getType().toType()); + return Sort(d_solver, d_node->getType()); } Term Term::substitute(Term e, Term replacement) const @@ -2133,7 +2132,7 @@ void DatatypeConstructorDecl::addSelector(const std::string& name, Sort sort) NodeManagerScope scope(d_solver->getNodeManager()); CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null range sort for selector"; - d_ctor->addArg(name, TypeNode::fromType(*sort.d_type)); + d_ctor->addArg(name, *sort.d_type); } void DatatypeConstructorDecl::addSelectorSelf(const std::string& name) @@ -2188,9 +2187,7 @@ DatatypeDecl::DatatypeDecl(const Solver* slv, bool isCoDatatype) : d_solver(slv), d_dtype(new CVC4::DType( - name, - std::vector{TypeNode::fromType(*param.d_type)}, - isCoDatatype)) + name, std::vector{*param.d_type}, isCoDatatype)) { } @@ -2200,11 +2197,7 @@ DatatypeDecl::DatatypeDecl(const Solver* slv, bool isCoDatatype) : d_solver(slv) { - std::vector tparams; - for (const Sort& p : params) - { - tparams.push_back(TypeNode::fromType(*p.d_type)); - } + std::vector tparams = sortVectorToTypeNodes(params); d_dtype = std::shared_ptr( new CVC4::DType(name, tparams, isCoDatatype)); } @@ -2297,7 +2290,7 @@ Term DatatypeSelector::getSelectorTerm() const Sort DatatypeSelector::getRangeSort() const { - return Sort(d_solver, d_stor->getRangeType().toType()); + return Sort(d_solver, d_stor->getRangeType()); } std::string DatatypeSelector::toString() const @@ -2363,13 +2356,11 @@ Term DatatypeConstructor::getSpecializedConstructorTerm(Sort retSort) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; NodeManager* nm = d_solver->getNodeManager(); - Node ret = nm->mkNode( - kind::APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(d_ctor - ->getSpecializedConstructorType( - TypeNode::fromType(retSort.getType())) - .toType())), - d_ctor->getConstructor()); + Node ret = + nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, + nm->mkConst(AscriptionType( + d_ctor->getSpecializedConstructorType(*retSort.d_type))), + d_ctor->getConstructor()); (void)ret.getType(true); /* kick off type checking */ // apply type ascription to the operator Term sctor = api::Term(d_solver, ret); @@ -2902,7 +2893,7 @@ Sort Grammar::resolve() // make the unresolved type, used for referencing the final version of // the ntsymbol's datatype ntsToUnres[ntsymbol] = - Sort(d_solver, d_solver->getExprManager()->mkSort(ntsymbol.toString())); + Sort(d_solver, d_solver->getNodeManager()->mkSort(ntsymbol.toString())); } std::vector datatypes; @@ -2922,8 +2913,8 @@ Sort Grammar::resolve() if (d_allowVars.find(ntSym) != d_allowVars.cend()) { - addSygusConstructorVariables( - dtDecl, Sort(d_solver, ntSym.d_node->getType().toType())); + addSygusConstructorVariables(dtDecl, + Sort(d_solver, ntSym.d_node->getType())); } bool aci = d_allowConst.find(ntSym) != d_allowConst.end(); @@ -2938,7 +2929,7 @@ Sort Grammar::resolve() << " produced an empty rule list"; datatypes.push_back(*dtDecl.d_dtype); - unresTypes.insert(TypeNode::fromType(*ntsToUnres[ntSym].d_type)); + unresTypes.insert(*ntsToUnres[ntSym].d_type); } std::vector datatypeTypes = @@ -2946,7 +2937,7 @@ Sort Grammar::resolve() datatypes, unresTypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); // return is the first datatype - return Sort(d_solver, datatypeTypes[0].toType()); + return Sort(d_solver, datatypeTypes[0]); } void Grammar::addSygusConstructorTerm( @@ -2978,11 +2969,7 @@ void Grammar::addSygusConstructorTerm( d_solver->getExprManager()->mkExpr( CVC4::kind::LAMBDA, {lbvl.d_node->toExpr(), op.d_node->toExpr()})); } - std::vector cargst; - for (const Sort& s : cargs) - { - cargst.push_back(TypeNode::fromType(s.getType())); - } + std::vector cargst = sortVectorToTypeNodes(cargs); dt.d_dtype->addSygusConstructor(*op.d_node, ssCName.str(), cargst); } @@ -3044,7 +3031,7 @@ void Grammar::addSygusConstructorVariables(DatatypeDecl& dt, Sort sort) const for (unsigned i = 0, size = d_sygusVars.size(); i < size; i++) { Term v = d_sygusVars[i]; - if (v.d_node->getType().toType() == *sort.d_type) + if (v.d_node->getType() == *sort.d_type) { std::stringstream ss; ss << v; @@ -3320,19 +3307,11 @@ std::vector Solver::mkDatatypeSortsInternal( { CVC4_API_SOLVER_CHECK_SORT(sort); } - - std::set utypes; - for (const Sort& s : unresolvedSorts) - { - utypes.insert(TypeNode::fromType(s.getType())); - } + + std::set utypes = sortSetToTypeNodes(unresolvedSorts); std::vector dtypes = getNodeManager()->mkMutualDatatypeTypes(datatypes, utypes); - std::vector retTypes; - for (CVC4::TypeNode t : dtypes) - { - retTypes.push_back(Sort(this, t.toType())); - } + std::vector retTypes = typeNodeVectorToSorts(this, dtypes); return retTypes; CVC4_API_SOLVER_TRY_CATCH_END; @@ -3348,7 +3327,7 @@ std::vector Solver::sortVectorToTypes( for (const Sort& s : sorts) { CVC4_API_SOLVER_CHECK_SORT(s); - res.push_back(*s.d_type); + res.push_back(s.d_type->toType()); } return res; } @@ -3401,42 +3380,42 @@ bool Solver::supportsFloatingPoint() const Sort Solver::getNullSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, Type()); + return Sort(this, TypeNode()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getBooleanSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->booleanType()); + return Sort(this, getNodeManager()->booleanType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getIntegerSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->integerType()); + return Sort(this, getNodeManager()->integerType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getRealSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->realType()); + return Sort(this, getNodeManager()->realType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getRegExpSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->regExpType()); + return Sort(this, getNodeManager()->regExpType()); CVC4_API_SOLVER_TRY_CATCH_END; } Sort Solver::getStringSort(void) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->stringType()); + return Sort(this, getNodeManager()->stringType()); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3445,7 +3424,7 @@ Sort Solver::getRoundingModeSort(void) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_CHECK(Configuration::isBuiltWithSymFPU()) << "Expected CVC4 to be compiled with SymFPU support"; - return Sort(this, d_exprMgr->roundingModeType()); + return Sort(this, getNodeManager()->roundingModeType()); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3461,8 +3440,8 @@ Sort Solver::mkArraySort(Sort indexSort, Sort elemSort) const CVC4_API_SOLVER_CHECK_SORT(indexSort); CVC4_API_SOLVER_CHECK_SORT(elemSort); - return Sort(this, - d_exprMgr->mkArrayType(*indexSort.d_type, *elemSort.d_type)); + return Sort( + this, getNodeManager()->mkArrayType(*indexSort.d_type, *elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3472,7 +3451,7 @@ Sort Solver::mkBitVectorSort(uint32_t size) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(size > 0, size) << "size > 0"; - return Sort(this, d_exprMgr->mkBitVectorType(size)); + return Sort(this, getNodeManager()->mkBitVectorType(size)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3485,7 +3464,7 @@ Sort Solver::mkFloatingPointSort(uint32_t exp, uint32_t sig) const CVC4_API_ARG_CHECK_EXPECTED(exp > 0, exp) << "exponent size > 0"; CVC4_API_ARG_CHECK_EXPECTED(sig > 0, sig) << "significand size > 0"; - return Sort(this, d_exprMgr->mkFloatingPointType(exp, sig)); + return Sort(this, getNodeManager()->mkFloatingPointType(exp, sig)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3499,8 +3478,7 @@ Sort Solver::mkDatatypeSort(DatatypeDecl dtypedecl) const CVC4_API_ARG_CHECK_EXPECTED(dtypedecl.getNumConstructors() > 0, dtypedecl) << "a datatype declaration with at least one constructor"; - return Sort(this, - getNodeManager()->mkDatatypeType(*dtypedecl.d_dtype).toType()); + return Sort(this, getNodeManager()->mkDatatypeType(*dtypedecl.d_dtype)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3536,8 +3514,8 @@ Sort Solver::mkFunctionSort(Sort domain, Sort codomain) const << "first-class sort as codomain sort for function sort"; Assert(!codomain.isFunction()); /* A function sort is not first-class. */ - return Sort(this, - d_exprMgr->mkFunctionType(*domain.d_type, *codomain.d_type)); + return Sort( + this, getNodeManager()->mkFunctionType(*domain.d_type, *codomain.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3566,8 +3544,9 @@ Sort Solver::mkFunctionSort(const std::vector& sorts, Sort codomain) const << "first-class sort as codomain sort for function sort"; Assert(!codomain.isFunction()); /* A function sort is not first-class. */ - std::vector argTypes = sortVectorToTypes(sorts); - return Sort(this, d_exprMgr->mkFunctionType(argTypes, *codomain.d_type)); + std::vector argTypes = sortVectorToTypeNodes(sorts); + return Sort(this, + getNodeManager()->mkFunctionType(argTypes, *codomain.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3575,8 +3554,9 @@ Sort Solver::mkFunctionSort(const std::vector& sorts, Sort codomain) const Sort Solver::mkParamSort(const std::string& symbol) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, - d_exprMgr->mkSort(symbol, ExprManager::SORT_FLAG_PLACEHOLDER)); + return Sort( + this, + getNodeManager()->mkSort(symbol, ExprManager::SORT_FLAG_PLACEHOLDER)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3597,9 +3577,9 @@ Sort Solver::mkPredicateSort(const std::vector& sorts) const sorts[i].isFirstClass(), "parameter sort", sorts[i], i) << "first-class sort as parameter sort for predicate sort"; } - std::vector types = sortVectorToTypes(sorts); + std::vector types = sortVectorToTypeNodes(sorts); - return Sort(this, d_exprMgr->mkPredicateType(types)); + return Sort(this, getNodeManager()->mkPredicateType(types)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3620,10 +3600,10 @@ Sort Solver::mkRecordSort( this == p.second.d_solver, "parameter sort", p.second, i) << "sort associated to this solver object"; i += 1; - f.emplace_back(p.first, *p.second.d_type); + f.emplace_back(p.first, p.second.d_type->toType()); } - return Sort(this, getNodeManager()->mkRecordType(Record(f)).toType()); + return Sort(this, getNodeManager()->mkRecordType(Record(f))); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3635,7 +3615,7 @@ Sort Solver::mkSetSort(Sort elemSort) const << "non-null element sort"; CVC4_API_SOLVER_CHECK_SORT(elemSort); - return Sort(this, d_exprMgr->mkSetType(*elemSort.d_type)); + return Sort(this, getNodeManager()->mkSetType(*elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3647,9 +3627,7 @@ Sort Solver::mkBagSort(Sort elemSort) const << "non-null element sort"; CVC4_API_SOLVER_CHECK_SORT(elemSort); - TypeNode typeNode = TypeNode::fromType(*elemSort.d_type); - Type type = getNodeManager()->mkBagType(typeNode).toType(); - return Sort(this, type); + return Sort(this, getNodeManager()->mkBagType(*elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3661,7 +3639,7 @@ Sort Solver::mkSequenceSort(Sort elemSort) const << "non-null element sort"; CVC4_API_SOLVER_CHECK_SORT(elemSort); - return Sort(this, d_exprMgr->mkSequenceType(*elemSort.d_type)); + return Sort(this, getNodeManager()->mkSequenceType(*elemSort.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3669,7 +3647,7 @@ Sort Solver::mkSequenceSort(Sort elemSort) const Sort Solver::mkUninterpretedSort(const std::string& symbol) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - return Sort(this, d_exprMgr->mkSort(symbol)); + return Sort(this, getNodeManager()->mkSort(symbol)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3679,7 +3657,7 @@ Sort Solver::mkSortConstructorSort(const std::string& symbol, CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_EXPECTED(arity > 0, arity) << "an arity > 0"; - return Sort(this, d_exprMgr->mkSortConstructor(symbol, arity)); + return Sort(this, getNodeManager()->mkSortConstructor(symbol, arity)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3699,12 +3677,8 @@ Sort Solver::mkTupleSort(const std::vector& sorts) const !sorts[i].isFunctionLike(), "parameter sort", sorts[i], i) << "non-function-like sort as parameter sort for tuple sort"; } - std::vector typeNodes; - for (const Sort& s : sorts) - { - typeNodes.push_back(TypeNode::fromType(*s.d_type)); - } - return Sort(this, getNodeManager()->mkTupleType(typeNodes).toType()); + std::vector typeNodes = sortVectorToTypeNodes(sorts); + return Sort(this, getNodeManager()->mkTupleType(typeNodes)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3840,8 +3814,7 @@ Term Solver::mkEmptySet(Sort s) const CVC4_API_ARG_CHECK_EXPECTED(s.isNull() || this == s.d_solver, s) << "set sort associated to this solver object"; - return mkValHelper( - CVC4::EmptySet(TypeNode::fromType(*s.d_type))); + return mkValHelper(CVC4::EmptySet(*s.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3855,8 +3828,7 @@ Term Solver::mkSingleton(Sort s, Term t) const CVC4_API_SOLVER_CHECK_TERM(t); checkMkTerm(SINGLETON, 1); - TypeNode typeNode = TypeNode::fromType(*s.d_type); - Node res = getNodeManager()->mkSingleton(typeNode, *t.d_node); + Node res = getNodeManager()->mkSingleton(*s.d_type, *t.d_node); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -3872,8 +3844,7 @@ Term Solver::mkEmptyBag(Sort s) const CVC4_API_ARG_CHECK_EXPECTED(s.isNull() || this == s.d_solver, s) << "bag sort associated to this solver object"; - return mkValHelper( - CVC4::EmptyBag(TypeNode::fromType(*s.d_type))); + return mkValHelper(CVC4::EmptyBag(*s.d_type)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -3884,7 +3855,8 @@ Term Solver::mkSepNil(Sort sort) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = d_exprMgr->mkNullaryOperator(*sort.d_type, CVC4::kind::SEP_NIL); + Node res = + getNodeManager()->mkNullaryOperator(*sort.d_type, CVC4::kind::SEP_NIL); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -3926,8 +3898,7 @@ Term Solver::mkEmptySequence(Sort sort) const CVC4_API_SOLVER_CHECK_SORT(sort); std::vector seq; - Expr res = - d_exprMgr->mkConst(Sequence(TypeNode::fromType(*sort.d_type), seq)); + Expr res = d_exprMgr->mkConst(Sequence(*sort.d_type, seq)); return Term(this, res); CVC4_API_SOLVER_TRY_CATCH_END; @@ -3939,8 +3910,8 @@ Term Solver::mkUniverseSet(Sort sort) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = - d_exprMgr->mkNullaryOperator(*sort.d_type, CVC4::kind::UNIVERSE_SET); + Node res = getNodeManager()->mkNullaryOperator(*sort.d_type, + CVC4::kind::UNIVERSE_SET); // TODO(#2771): Reenable? // (void)res->getType(true); /* kick off type checking */ return Term(this, res); @@ -3990,7 +3961,7 @@ Term Solver::mkConstArray(Sort sort, Term val) const n = n[0]; } Term res = mkValHelper( - CVC4::ArrayStoreAll(TypeNode::fromType(*sort.d_type), n)); + CVC4::ArrayStoreAll(*sort.d_type, n)); return res; CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4071,7 +4042,7 @@ Term Solver::mkUninterpretedConst(Sort sort, int32_t index) const CVC4_API_SOLVER_CHECK_SORT(sort); return mkValHelper( - CVC4::UninterpretedConstant(TypeNode::fromType(*sort.d_type), index)); + CVC4::UninterpretedConstant(*sort.d_type, index)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4134,8 +4105,8 @@ Term Solver::mkConst(Sort sort, const std::string& symbol) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = symbol.empty() ? d_exprMgr->mkVar(*sort.d_type) - : d_exprMgr->mkVar(symbol, *sort.d_type); + Expr res = symbol.empty() ? d_exprMgr->mkVar(sort.d_type->toType()) + : d_exprMgr->mkVar(symbol, sort.d_type->toType()); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -4151,8 +4122,9 @@ Term Solver::mkVar(Sort sort, const std::string& symbol) const CVC4_API_ARG_CHECK_EXPECTED(!sort.isNull(), sort) << "non-null sort"; CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = symbol.empty() ? d_exprMgr->mkBoundVar(*sort.d_type) - : d_exprMgr->mkBoundVar(symbol, *sort.d_type); + Expr res = symbol.empty() + ? d_exprMgr->mkBoundVar(sort.d_type->toType()) + : d_exprMgr->mkBoundVar(symbol, sort.d_type->toType()); (void)res.getType(true); /* kick off type checking */ return Term(this, res); @@ -4776,7 +4748,7 @@ Sort Solver::declareDatatype( << "datatype constructor declaration associated to this solver object"; dtdecl.addConstructor(ctors[i]); } - return Sort(this, getNodeManager()->mkDatatypeType(*dtdecl.d_dtype).toType()); + return Sort(this, getNodeManager()->mkDatatypeType(*dtdecl.d_dtype)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4801,13 +4773,13 @@ Term Solver::declareFun(const std::string& symbol, << "first-class sort as function codomain sort"; CVC4_API_SOLVER_CHECK_SORT(sort); Assert(!sort.isFunction()); /* A function sort is not first-class. */ - Type type = *sort.d_type; + TypeNode type = *sort.d_type; if (!sorts.empty()) { - std::vector types = sortVectorToTypes(sorts); - type = d_exprMgr->mkFunctionType(types, type); + std::vector types = sortVectorToTypeNodes(sorts); + type = getNodeManager()->mkFunctionType(types, type); } - return Term(this, d_exprMgr->mkVar(symbol, type)); + return Term(this, d_exprMgr->mkVar(symbol, type.toType())); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4817,8 +4789,11 @@ Term Solver::declareFun(const std::string& symbol, Sort Solver::declareSort(const std::string& symbol, uint32_t arity) const { CVC4_API_SOLVER_TRY_CATCH_BEGIN; - if (arity == 0) return Sort(this, d_exprMgr->mkSort(symbol)); - return Sort(this, d_exprMgr->mkSortConstructor(symbol, arity)); + if (arity == 0) + { + return Sort(this, getNodeManager()->mkSort(symbol)); + } + return Sort(this, getNodeManager()->mkSortConstructor(symbol, arity)); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -4846,18 +4821,18 @@ Term Solver::defineFun(const std::string& symbol, bound_vars[i], i) << "a bound variable"; - CVC4::Type t = bound_vars[i].d_node->getType().toType(); + CVC4::TypeNode t = bound_vars[i].d_node->getType(); CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( t.isFirstClass(), "sort of parameter", bound_vars[i], i) << "first-class sort of parameter of defined function"; - domain_types.push_back(TypeNode::fromType(t)); + domain_types.push_back(t); } CVC4_API_SOLVER_CHECK_SORT(sort); CVC4_API_CHECK(sort == term.getSort()) << "Invalid sort of function body '" << term << "', expected '" << sort << "'"; NodeManager* nm = getNodeManager(); - TypeNode type = TypeNode::fromType(*sort.d_type); + TypeNode type = *sort.d_type; if (!domain_types.empty()) { type = nm->mkFunctionType(domain_types, type); @@ -4965,7 +4940,7 @@ Term Solver::defineFunRec(const std::string& symbol, << "'"; CVC4_API_SOLVER_CHECK_TERM(term); NodeManager* nm = getNodeManager(); - TypeNode type = TypeNode::fromType(*sort.d_type); + TypeNode type = *sort.d_type; if (!domain_types.empty()) { type = nm->mkFunctionType(domain_types, type); @@ -5354,8 +5329,8 @@ bool Solver::getInterpolant(Term conj, Grammar& g, Term& output) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); Node result; - bool success = d_smtEngine->getInterpol( - *conj.d_node, TypeNode::fromType(*g.resolve().d_type), result); + bool success = + d_smtEngine->getInterpol(*conj.d_node, *g.resolve().d_type, result); if (success) { output = Term(this, result); @@ -5383,8 +5358,8 @@ bool Solver::getAbduct(Term conj, Grammar& g, Term& output) const CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4::ExprManagerScope exmgrs(*(d_exprMgr.get())); Node result; - bool success = d_smtEngine->getAbduct( - *conj.d_node, TypeNode::fromType(*g.resolve().d_type), result); + bool success = + d_smtEngine->getAbduct(*conj.d_node, *g.resolve().d_type, result); if (success) { output = Term(this, result); @@ -5569,10 +5544,10 @@ Term Solver::mkSygusVar(Sort sort, const std::string& symbol) const CVC4_API_ARG_CHECK_NOT_NULL(sort); CVC4_API_SOLVER_CHECK_SORT(sort); - Expr res = d_exprMgr->mkBoundVar(symbol, *sort.d_type); + Node res = getNodeManager()->mkBoundVar(symbol, *sort.d_type); (void)res.getType(true); /* kick off type checking */ - d_smtEngine->declareSygusVar(symbol, res, TypeNode::fromType(*sort.d_type)); + d_smtEngine->declareSygusVar(symbol, res, *sort.d_type); return Term(this, res); @@ -5641,7 +5616,7 @@ Term Solver::synthInv(const std::string& symbol, const std::vector& boundVars) const { return synthFunHelper( - symbol, boundVars, Sort(this, d_exprMgr->booleanType()), true); + symbol, boundVars, Sort(this, getNodeManager()->booleanType()), true); } Term Solver::synthInv(const std::string& symbol, @@ -5649,7 +5624,7 @@ Term Solver::synthInv(const std::string& symbol, Grammar& g) const { return synthFunHelper( - symbol, boundVars, Sort(this, d_exprMgr->booleanType()), true, &g); + symbol, boundVars, Sort(this, getNodeManager()->booleanType()), true, &g); } Term Solver::synthFunHelper(const std::string& symbol, @@ -5661,7 +5636,7 @@ Term Solver::synthFunHelper(const std::string& symbol, CVC4_API_SOLVER_TRY_CATCH_BEGIN; CVC4_API_ARG_CHECK_NOT_NULL(sort); - std::vector varTypes; + std::vector varTypes; for (size_t i = 0, n = boundVars.size(); i < n; ++i) { CVC4_API_ARG_AT_INDEX_CHECK_EXPECTED( @@ -5676,36 +5651,28 @@ Term Solver::synthFunHelper(const std::string& symbol, boundVars[i], i) << "a bound variable"; - varTypes.push_back(boundVars[i].d_node->getType().toType()); + varTypes.push_back(boundVars[i].d_node->getType()); } CVC4_API_SOLVER_CHECK_SORT(sort); if (g != nullptr) { - CVC4_API_CHECK(g->d_ntSyms[0].d_node->getType().toType() == *sort.d_type) + CVC4_API_CHECK(g->d_ntSyms[0].d_node->getType() == *sort.d_type) << "Invalid Start symbol for Grammar g, Expected Start's sort to be " << *sort.d_type << " but found " << g->d_ntSyms[0].d_node->getType(); } - Type funType = varTypes.empty() - ? *sort.d_type - : d_exprMgr->mkFunctionType(varTypes, *sort.d_type); + TypeNode funType = varTypes.empty() ? *sort.d_type + : getNodeManager()->mkFunctionType( + varTypes, *sort.d_type); - Node fun = getNodeManager()->mkBoundVar(symbol, TypeNode::fromType(funType)); + Node fun = getNodeManager()->mkBoundVar(symbol, funType); (void)fun.getType(true); /* kick off type checking */ - std::vector bvns; - for (const Term& t : boundVars) - { - bvns.push_back(*t.d_node); - } + std::vector bvns = termVectorToNodes(boundVars); d_smtEngine->declareSynthFun( - symbol, - fun, - TypeNode::fromType(g == nullptr ? funType : *g->resolve().d_type), - isInv, - bvns); + symbol, fun, g == nullptr ? funType : *g->resolve().d_type, isInv, bvns); return Term(this, fun); @@ -5744,21 +5711,21 @@ void Solver::addSygusInvConstraint(Term inv, CVC4_API_ARG_CHECK_EXPECTED(inv.d_node->getType().isFunction(), inv) << "a function"; - FunctionType invType = inv.d_node->getType().toType(); + TypeNode invType = inv.d_node->getType(); CVC4_API_ARG_CHECK_EXPECTED(invType.getRangeType().isBoolean(), inv) << "boolean range"; - CVC4_API_CHECK(pre.d_node->getType().toType() == invType) + CVC4_API_CHECK(pre.d_node->getType() == invType) << "Expected inv and pre to have the same sort"; - CVC4_API_CHECK(post.d_node->getType().toType() == invType) + CVC4_API_CHECK(post.d_node->getType() == invType) << "Expected inv and post to have the same sort"; - const std::vector& invArgTypes = invType.getArgTypes(); + const std::vector& invArgTypes = invType.getArgTypes(); - std::vector expectedTypes; - expectedTypes.reserve(2 * invType.getArity() + 1); + std::vector expectedTypes; + expectedTypes.reserve(2 * invArgTypes.size() + 1); for (size_t i = 0, n = invArgTypes.size(); i < 2 * n; i += 2) { @@ -5767,15 +5734,13 @@ void Solver::addSygusInvConstraint(Term inv, } expectedTypes.push_back(invType.getRangeType()); - FunctionType expectedTransType = d_exprMgr->mkFunctionType(expectedTypes); + TypeNode expectedTransType = getNodeManager()->mkFunctionType(expectedTypes); - CVC4_API_CHECK(trans.d_node->toExpr().getType() == expectedTransType) + CVC4_API_CHECK(trans.d_node->getType() == expectedTransType) << "Expected trans's sort to be " << invType; - d_smtEngine->assertSygusInvConstraint(inv.d_node->toExpr(), - pre.d_node->toExpr(), - trans.d_node->toExpr(), - post.d_node->toExpr()); + d_smtEngine->assertSygusInvConstraint( + *inv.d_node, *pre.d_node, *trans.d_node, *post.d_node); CVC4_API_SOLVER_TRY_CATCH_END; } @@ -5907,7 +5872,7 @@ std::vector sortVectorToTypes(const std::vector& sorts) std::vector types; for (size_t i = 0, ssize = sorts.size(); i < ssize; i++) { - types.push_back(sorts[i].getType()); + types.push_back(sorts[i].getTypeNode().toType()); } return types; } @@ -5917,17 +5882,17 @@ std::vector sortVectorToTypeNodes(const std::vector& sorts) std::vector typeNodes; for (const Sort& sort : sorts) { - typeNodes.push_back(TypeNode::fromType(sort.getType())); + typeNodes.push_back(sort.getTypeNode()); } return typeNodes; } -std::set sortSetToTypes(const std::set& sorts) +std::set sortSetToTypeNodes(const std::set& sorts) { - std::set types; + std::set types; for (const Sort& s : sorts) { - types.insert(s.getType()); + types.insert(s.getTypeNode()); } return types; } @@ -5945,6 +5910,16 @@ std::vector exprVectorToTerms(const Solver* slv, std::vector typeVectorToSorts(const Solver* slv, const std::vector& types) +{ + std::vector sorts; + for (size_t i = 0, tsize = types.size(); i < tsize; i++) + { + sorts.push_back(Sort(slv, TypeNode::fromType(types[i]))); + } + return sorts; +} +std::vector typeNodeVectorToSorts(const Solver* slv, + const std::vector& types) { std::vector sorts; for (size_t i = 0, tsize = types.size(); i < tsize; i++) diff --git a/src/api/cvc4cpp.h b/src/api/cvc4cpp.h index c05390e42..33d87af91 100644 --- a/src/api/cvc4cpp.h +++ b/src/api/cvc4cpp.h @@ -206,6 +206,7 @@ class Datatype; */ class CVC4_PUBLIC Sort { + friend class DatatypeConstructor; friend class DatatypeConstructorDecl; friend class DatatypeDecl; friend class Op; @@ -224,6 +225,7 @@ class CVC4_PUBLIC Sort * @return the Sort */ Sort(const Solver* slv, const CVC4::Type& t); + Sort(const Solver* slv, const CVC4::TypeNode& t); /** * Constructor. @@ -488,6 +490,7 @@ class CVC4_PUBLIC Sort // !!! This is only temporarily available until the parser is fully migrated // to the new API. !!! CVC4::Type getType(void) const; + const CVC4::TypeNode& getTypeNode(void) const; /* Constructor sort ------------------------------------------------------- */ @@ -670,7 +673,7 @@ class CVC4_PUBLIC Sort * memory allocation (CVC4::Type is already ref counted, so this could be * a unique_ptr instead). */ - std::shared_ptr d_type; + std::shared_ptr d_type; }; /** @@ -3528,11 +3531,13 @@ std::vector termVectorToExprs(const std::vector& terms); std::vector termVectorToNodes(const std::vector& terms); std::vector sortVectorToTypes(const std::vector& sorts); std::vector sortVectorToTypeNodes(const std::vector& sorts); -std::set sortSetToTypes(const std::set& sorts); +std::set sortSetToTypeNodes(const std::set& sorts); std::vector exprVectorToTerms(const Solver* slv, const std::vector& terms); std::vector typeVectorToSorts(const Solver* slv, const std::vector& sorts); +std::vector typeNodeVectorToSorts(const Solver* slv, + const std::vector& types); } // namespace api diff --git a/src/expr/CMakeLists.txt b/src/expr/CMakeLists.txt index ccc575289..d7af52fec 100644 --- a/src/expr/CMakeLists.txt +++ b/src/expr/CMakeLists.txt @@ -12,6 +12,7 @@ libcvc4_add_sources( array.h array_store_all.cpp array_store_all.h + ascription_type.cpp ascription_type.h attribute.h attribute.cpp diff --git a/src/expr/ascription_type.cpp b/src/expr/ascription_type.cpp new file mode 100644 index 000000000..d9466fdbf --- /dev/null +++ b/src/expr/ascription_type.cpp @@ -0,0 +1,58 @@ +/********************* */ +/*! \file ascription_type.cpp + ** \verbatim + ** Top contributors (to current version): + ** Morgan Deters, Tim King, Mathias Preiner + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief A class representing a type ascription + **/ + +#include "expr/ascription_type.h" + +#include + +#include "expr/type_node.h" + +namespace CVC4 { + +AscriptionType::AscriptionType(TypeNode t) : d_type(new TypeNode(t)) {} + +AscriptionType::AscriptionType(const AscriptionType& at) + : d_type(new TypeNode(at.getType())) +{ +} + +AscriptionType& AscriptionType::operator=(const AscriptionType& at) +{ + (*d_type) = at.getType(); + return *this; +} + +AscriptionType::~AscriptionType() {} +TypeNode AscriptionType::getType() const { return *d_type.get(); } +bool AscriptionType::operator==(const AscriptionType& other) const +{ + return getType() == other.getType(); +} +bool AscriptionType::operator!=(const AscriptionType& other) const +{ + return getType() != other.getType(); +} + +size_t AscriptionTypeHashFunction::operator()(const AscriptionType& at) const +{ + return TypeNodeHashFunction()(at.getType()); +} + +std::ostream& operator<<(std::ostream& out, AscriptionType at) +{ + out << at.getType(); + return out; +} + +} // namespace CVC4 diff --git a/src/expr/ascription_type.h b/src/expr/ascription_type.h index 0ce3df88d..800f46e0a 100644 --- a/src/expr/ascription_type.h +++ b/src/expr/ascription_type.h @@ -19,10 +19,13 @@ #ifndef CVC4__ASCRIPTION_TYPE_H #define CVC4__ASCRIPTION_TYPE_H -#include "expr/type.h" +#include +#include namespace CVC4 { +class TypeNode; + /** * A class used to parameterize a type ascription. For example, * "nil :: list" is an expression of kind APPLY_TYPE_ASCRIPTION. @@ -31,35 +34,29 @@ namespace CVC4 { * coerce a Type into the expression tree.) */ class CVC4_PUBLIC AscriptionType { - Type d_type; - public: - AscriptionType(Type t) : d_type(t) {} - Type getType() const { return d_type; } - bool operator==(const AscriptionType& other) const - { - return d_type == other.d_type; - } - bool operator!=(const AscriptionType& other) const - { - return d_type != other.d_type; - } + AscriptionType(TypeNode t); + ~AscriptionType(); + AscriptionType(const AscriptionType& other); + AscriptionType& operator=(const AscriptionType& other); + TypeNode getType() const; + bool operator==(const AscriptionType& other) const; + bool operator!=(const AscriptionType& other) const; + + private: + /** The type */ + std::unique_ptr d_type; };/* class AscriptionType */ /** * A hash function for type ascription operators. */ struct CVC4_PUBLIC AscriptionTypeHashFunction { - inline size_t operator()(const AscriptionType& at) const { - return TypeHashFunction()(at.getType()); - } + size_t operator()(const AscriptionType& at) const; };/* struct AscriptionTypeHashFunction */ /** An output routine for AscriptionTypes */ -inline std::ostream& operator<<(std::ostream& out, AscriptionType at) { - out << at.getType(); - return out; -} +std::ostream& operator<<(std::ostream& out, AscriptionType at) CVC4_PUBLIC; }/* CVC4 namespace */ diff --git a/src/expr/dtype_cons.cpp b/src/expr/dtype_cons.cpp index 8e86ba49d..d63db28d5 100644 --- a/src/expr/dtype_cons.cpp +++ b/src/expr/dtype_cons.cpp @@ -438,7 +438,7 @@ Node DTypeConstructor::computeGroundTerm(TypeNode t, << ", ascribe to " << t << std::endl; groundTerms[0] = nm->mkNode( APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(getSpecializedConstructorType(t).toType())), + nm->mkConst(AscriptionType(getSpecializedConstructorType(t))), groundTerms[0]); groundTerm = nm->mkNode(APPLY_CONSTRUCTOR, groundTerms); } diff --git a/src/expr/type.h b/src/expr/type.h index 5fd68e89e..6867673f8 100644 --- a/src/expr/type.h +++ b/src/expr/type.h @@ -30,7 +30,7 @@ namespace CVC4 { class NodeManager; class CVC4_PUBLIC ExprManager; -class CVC4_PUBLIC Expr; +class Expr; class TypeNode; struct CVC4_PUBLIC ExprManagerMapCollection; diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index b9b72e0c5..ce3bd7438 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -475,6 +475,12 @@ uint64_t TypeNode::getSortConstructorArity() const return getAttribute(expr::SortArityAttr()); } +std::string TypeNode::getName() const +{ + Assert(isSort() || isSortConstructor()); + return getAttribute(expr::VarNameAttr()); +} + TypeNode TypeNode::instantiateSortConstructor( const std::vector& params) const { diff --git a/src/expr/type_node.h b/src/expr/type_node.h index 57cdfc43b..41adc1d3b 100644 --- a/src/expr/type_node.h +++ b/src/expr/type_node.h @@ -697,6 +697,11 @@ public: /** Get sort constructor arity */ uint64_t getSortConstructorArity() const; + /** + * Get name, for uninterpreted sorts and uninterpreted sort constructors. + */ + std::string getName() const; + /** * Instantiate a sort constructor type. The type on which this method is * called should be a sort constructor type whose parameter list is the diff --git a/src/parser/cvc/Cvc.g b/src/parser/cvc/Cvc.g index 0bb41b483..81319e25a 100644 --- a/src/parser/cvc/Cvc.g +++ b/src/parser/cvc/Cvc.g @@ -1150,7 +1150,7 @@ declareVariables[std::unique_ptr* cmd, CVC4::api::Sort& t, PARSER_STATE->checkDeclaration(*i, CHECK_UNDECLARED, SYM_VARIABLE); api::Term func = PARSER_STATE->mkVar( *i, - api::Sort(SOLVER, t.getType()), + t, ExprManager::VAR_FLAG_GLOBAL | ExprManager::VAR_FLAG_DEFINED); PARSER_STATE->defineVar(*i, f); Command* decl = new DefineFunctionCommand(*i, func, f, true); diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index e815d9024..6feb298c2 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -229,8 +229,7 @@ std::vector Parser::bindBoundVars( std::vector vars; for (std::pair& i : sortedVarNames) { - vars.push_back( - bindBoundVar(i.first, api::Sort(d_solver, i.second.getType()))); + vars.push_back(bindBoundVar(i.first, i.second)); } return vars; } @@ -244,7 +243,7 @@ api::Term Parser::mkAnonymousFunction(const std::string& prefix, } stringstream name; name << prefix << "_anon_" << ++d_anonymousFunctionCount; - return mkVar(name.str(), api::Sort(d_solver, type.getType()), flags); + return mkVar(name.str(), type, flags); } std::vector Parser::bindVars(const std::vector names, diff --git a/src/printer/cvc/cvc_printer.cpp b/src/printer/cvc/cvc_printer.cpp index 9ccf02301..2a55cb972 100644 --- a/src/printer/cvc/cvc_printer.cpp +++ b/src/printer/cvc/cvc_printer.cpp @@ -409,7 +409,7 @@ void CvcPrinter::toStream( case kind::APPLY_TYPE_ASCRIPTION: { toStream(out, n[0], depth, types, false); out << "::"; - TypeNode t = TypeNode::fromType(n.getOperator().getConst().getType()); + TypeNode t = n.getOperator().getConst().getType(); out << (t.isFunctionLike() ? t.getRangeType() : t); } return; diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index a0cd8cf9c..cdaa61295 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -399,8 +399,7 @@ void Smt2Printer::toStream(std::ostream& out, Node type_asc_arg; if (n.getKind() == kind::APPLY_TYPE_ASCRIPTION) { - force_nt = TypeNode::fromType( - n.getOperator().getConst().getType()); + force_nt = n.getOperator().getConst().getType(); type_asc_arg = n[0]; } else if (!force_nt.isNull() && n.getType() != force_nt) diff --git a/src/smt/abstract_values.cpp b/src/smt/abstract_values.cpp index 2d21e7a1b..7d3ff64a6 100644 --- a/src/smt/abstract_values.cpp +++ b/src/smt/abstract_values.cpp @@ -46,7 +46,7 @@ Node AbstractValues::mkAbstractValue(TNode n) d_abstractValueMap.addSubstitution(val, n); } // We are supposed to ascribe types to all abstract values that go out. - Node ascription = d_nm->mkConst(AscriptionType(n.getType().toType())); + Node ascription = d_nm->mkConst(AscriptionType(n.getType())); Node retval = d_nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, ascription, val); return retval; } diff --git a/src/smt/command.cpp b/src/smt/command.cpp index 9b0784831..8a5247dec 100644 --- a/src/smt/command.cpp +++ b/src/smt/command.cpp @@ -602,7 +602,7 @@ void DeclareSygusVarCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDeclareVar( - out, d_var.getNode(), TypeNode::fromType(d_sort.getType())); + out, d_var.getNode(), d_sort.getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -663,9 +663,9 @@ void SynthFunCommand::toStream(std::ostream& out, out, d_symbol, nodeVars, - TypeNode::fromType(d_sort.getType()), + d_sort.getTypeNode(), d_isInv, - TypeNode::fromType(d_grammar->resolve().getType())); + d_grammar->resolve().getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -1130,7 +1130,7 @@ void DeclareFunctionCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDeclareFunction( - out, d_func.toString(), TypeNode::fromType(d_sort.getType())); + out, d_func.toString(), d_sort.getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -1168,7 +1168,7 @@ void DeclareSortCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdDeclareType( - out, d_sort.toString(), d_arity, TypeNode::fromType(d_sort.getType())); + out, d_sort.toString(), d_arity, d_sort.getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -1215,7 +1215,7 @@ void DefineSortCommand::toStream(std::ostream& out, out, d_symbol, api::sortVectorToTypeNodes(d_params), - TypeNode::fromType(d_sort.getType())); + d_sort.getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -1337,7 +1337,7 @@ void DefineNamedFunctionCommand::toStream(std::ostream& out, out, d_func.toString(), api::termVectorToNodes(d_formals), - TypeNode::fromType(d_func.getSort().getFunctionCodomainSort().getType()), + d_func.getSort().getFunctionCodomainSort().getTypeNode(), d_formula.getNode()); } @@ -2100,10 +2100,7 @@ void GetInterpolCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdGetInterpol( - out, - d_name, - d_conj.getNode(), - TypeNode::fromType(d_sygus_grammar->resolve().getType())); + out, d_name, d_conj.getNode(), d_sygus_grammar->resolve().getTypeNode()); } /* -------------------------------------------------------------------------- */ @@ -2189,10 +2186,7 @@ void GetAbductCommand::toStream(std::ostream& out, OutputLanguage language) const { Printer::getPrinter(language)->toStreamCmdGetAbduct( - out, - d_name, - d_conj.getNode(), - TypeNode::fromType(d_sygus_grammar->resolve().getType())); + out, d_name, d_conj.getNode(), d_sygus_grammar->resolve().getTypeNode()); } /* -------------------------------------------------------------------------- */ diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index fa5f2fe33..747ed89b7 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -280,7 +280,7 @@ RewriteResponse DatatypesRewriter::preRewrite(TNode in) const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)]; // create ascribed constructor type Node tc = NodeManager::currentNM()->mkConst( - AscriptionType(dtc.getSpecializedConstructorType(tn).toType())); + AscriptionType(dtc.getSpecializedConstructorType(tn))); Node op_new = NodeManager::currentNM()->mkNode( kind::APPLY_TYPE_ASCRIPTION, tc, op); // make new node @@ -397,7 +397,7 @@ RewriteResponse DatatypesRewriter::rewriteSelector(TNode in) { gt = NodeManager::currentNM()->mkNode( kind::APPLY_TYPE_ASCRIPTION, - NodeManager::currentNM()->mkConst(AscriptionType(tn.toType())), + NodeManager::currentNM()->mkConst(AscriptionType(tn)), gt); } Trace("datatypes-rewrite") diff --git a/src/theory/datatypes/theory_datatypes_type_rules.h b/src/theory/datatypes/theory_datatypes_type_rules.h index 0e67d8b3d..2834b86ba 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.h +++ b/src/theory/datatypes/theory_datatypes_type_rules.h @@ -196,8 +196,7 @@ struct DatatypeAscriptionTypeRule { bool check) { Debug("typecheck-idt") << "typechecking ascription: " << n << std::endl; Assert(n.getKind() == kind::APPLY_TYPE_ASCRIPTION); - TypeNode t = TypeNode::fromType( - n.getOperator().getConst().getType()); + TypeNode t = n.getOperator().getConst().getType(); if (check) { TypeNode childType = n[0].getType(check); diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp index ca51242cc..c55b4a14f 100644 --- a/src/theory/datatypes/theory_datatypes_utils.cpp +++ b/src/theory/datatypes/theory_datatypes_utils.cpp @@ -55,7 +55,7 @@ Node getInstCons(Node n, const DType& dt, int index) Debug("datatypes-parametric") << "Type specification is " << tspec << std::endl; children[0] = nm->mkNode(APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(tspec.toType())), + nm->mkConst(AscriptionType(tspec)), children[0]); n_ic = nm->mkNode(APPLY_CONSTRUCTOR, children); Assert(n_ic.getType() == tn); diff --git a/src/theory/datatypes/type_enumerator.cpp b/src/theory/datatypes/type_enumerator.cpp index 2946070c7..079430342 100644 --- a/src/theory/datatypes/type_enumerator.cpp +++ b/src/theory/datatypes/type_enumerator.cpp @@ -143,7 +143,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ NodeManager* nm = NodeManager::currentNM(); TypeNode typ = ctor.getSpecializedConstructorType(d_type); b << nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(typ.toType())), + nm->mkConst(AscriptionType(typ)), ctor.getConstructor()); } else diff --git a/test/unit/api/sort_black.h b/test/unit/api/sort_black.h index 60b2dd299..e6d712191 100644 --- a/test/unit/api/sort_black.h +++ b/test/unit/api/sort_black.h @@ -243,7 +243,10 @@ void SortBlack::testGetUninterpretedSortName() void SortBlack::testIsUninterpretedSortParameterized() { Sort uSort = d_solver.mkUninterpretedSort("u"); - TS_ASSERT_THROWS_NOTHING(uSort.isUninterpretedSortParameterized()); + TS_ASSERT(!uSort.isUninterpretedSortParameterized()); + Sort sSort = d_solver.mkSortConstructorSort("s", 1); + Sort siSort = sSort.instantiate({uSort}); + TS_ASSERT(siSort.isUninterpretedSortParameterized()); Sort bvSort = d_solver.mkBitVectorSort(32); TS_ASSERT_THROWS(bvSort.isUninterpretedSortParameterized(), CVC4ApiException&); @@ -253,6 +256,9 @@ void SortBlack::testGetUninterpretedSortParamSorts() { Sort uSort = d_solver.mkUninterpretedSort("u"); TS_ASSERT_THROWS_NOTHING(uSort.getUninterpretedSortParamSorts()); + Sort sSort = d_solver.mkSortConstructorSort("s", 2); + Sort siSort = sSort.instantiate({uSort, uSort}); + TS_ASSERT(siSort.getUninterpretedSortParamSorts().size() == 2); Sort bvSort = d_solver.mkBitVectorSort(32); TS_ASSERT_THROWS(bvSort.getUninterpretedSortParamSorts(), CVC4ApiException&); }