From f6034c8ede6e9b81f4eb8729594301a8ff3982ff Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 29 Apr 2022 16:49:13 -0500 Subject: [PATCH] Properly represent Tuples in the TypeNode AST (#8648) This makes it so that Tuple types are properly represented in the AST. It also removes a spurious restriction that disallowed higher-order tuples (this was leftover from a very old sanity check in the old API). For example, a tuple type over (Int, Int) is now (TUPLE_TYPE INT INT) instead of a DATATYPE_TYPE constant. Tuple types behave exactly like datatypes; we can still retrieve their DType as before. This is in preparation for gradual types and symbolic tuple projections. --- src/expr/node_manager_attributes.h | 7 +++ src/expr/node_manager_template.cpp | 60 +++++++++++++------ src/expr/node_manager_template.h | 14 ++++- src/expr/type_node.cpp | 45 ++++++-------- src/printer/smt2/smt2_printer.cpp | 1 + src/proof/lfsc/lfsc_node_converter.cpp | 56 ++++++++--------- src/theory/datatypes/kinds | 13 ++++ .../datatypes/theory_datatypes_type_rules.cpp | 2 +- test/unit/api/cpp/solver_black.cpp | 3 +- test/unit/api/java/SolverTest.java | 3 +- test/unit/api/python/test_solver.py | 3 +- 11 files changed, 127 insertions(+), 80 deletions(-) diff --git a/src/expr/node_manager_attributes.h b/src/expr/node_manager_attributes.h index df5734062..eb284b304 100644 --- a/src/expr/node_manager_attributes.h +++ b/src/expr/node_manager_attributes.h @@ -33,6 +33,9 @@ namespace attr { struct UnresolvedDatatypeTag { }; + struct TupleDatatypeTag + { + }; } // namespace attr typedef Attribute VarNameAttr; @@ -44,5 +47,9 @@ typedef expr::Attribute TypeCheckedAttr; using UnresolvedDatatypeAttr = expr::Attribute; +/** Mapping tuples to their datatype type encoding */ +using TupleDatatypeAttr = + expr::Attribute; + } // namespace expr } // namespace cvc5::internal diff --git a/src/expr/node_manager_template.cpp b/src/expr/node_manager_template.cpp index 64f24f89a..24e1f01f3 100644 --- a/src/expr/node_manager_template.cpp +++ b/src/expr/node_manager_template.cpp @@ -302,6 +302,25 @@ NodeManager::~NodeManager() d_attrManager = NULL; } +const DType& NodeManager::getDTypeFor(TypeNode tn) const +{ + Kind k = tn.getKind(); + if (k == kind::DATATYPE_TYPE) + { + DatatypeIndexConstant dic = tn.getConst(); + return getDTypeForIndex(dic.getIndex()); + } + else if (k == kind::TUPLE_TYPE) + { + // lookup its datatype encoding + TypeNode dtt = getAttribute(tn, expr::TupleDatatypeAttr()); + Assert(!dtt.isNull()); + return getDTypeFor(dtt); + } + Assert(k == kind::PARAMETRIC_DATATYPE); + return getDTypeFor(tn[0]); +} + const DType& NodeManager::getDTypeForIndex(size_t index) const { // if this assertion fails, it is likely due to not managing datatypes @@ -599,6 +618,22 @@ std::vector NodeManager::mkMutualDatatypeTypesInternal( if (dtp->getNumParameters() == 0) { typeNode = mkTypeConst(DatatypeIndexConstant(index)); + // if the datatype is a tuple, the type will be (TUPLE_TYPE ...) + if (dt.isTuple()) + { + TypeNode dtt = typeNode; + const DTypeConstructor& dc = dt[0]; + std::vector tupleTypes; + for (size_t i = 0, nargs = dc.getNumArgs(); i < nargs; i++) + { + // selector should be initialized to the range type, it is not null + // or unresolved since tuples are not recursive + tupleTypes.push_back(dc[i].getType()); + } + // Set its datatype representation + typeNode = mkTypeNode(kind::TUPLE_TYPE, tupleTypes); + typeNode.setAttribute(expr::TupleDatatypeAttr(), dtt); + } } else { @@ -740,9 +775,8 @@ TypeNode NodeManager::mkDatatypeUpdateType(TypeNode domain, TypeNode range) return mkTypeNode(kind::UPDATER_TYPE, domain, range); } -TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm, - std::vector& types, - unsigned index) +TypeNode NodeManager::TupleTypeCache::getTupleType( + NodeManager* nm, const std::vector& types, unsigned index) { if (index == types.size()) { @@ -750,7 +784,8 @@ TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm, { std::stringstream sst; sst << "__cvc5_tuple"; - for (unsigned i = 0; i < types.size(); ++i) + size_t ntypes = types.size(); + for (size_t i = 0; i < ntypes; ++i) { sst << "_" << types[i]; } @@ -760,7 +795,7 @@ TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm, ssc << sst.str() << "_ctor"; std::shared_ptr c = std::make_shared(ssc.str()); - for (unsigned i = 0; i < types.size(); ++i) + for (size_t i = 0; i < ntypes; ++i) { std::stringstream ss; ss << sst.str() << "_stor_" << i; @@ -768,6 +803,7 @@ TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm, } dt.addConstructor(c); d_data = nm->mkDatatypeType(dt); + Assert(d_data.isTuple()); Trace("tuprec-debug") << "Return type : " << d_data << std::endl; } return d_data; @@ -804,6 +840,7 @@ TypeNode NodeManager::RecTypeCache::getRecordType(NodeManager* nm, } dt.addConstructor(c); d_data = nm->mkDatatypeType(dt); + Assert(d_data.isRecord()); Trace("tuprec-debug") << "Return type : " << d_data << std::endl; } return d_data; @@ -847,18 +884,7 @@ TypeNode NodeManager::mkFunctionType(const std::vector& argTypes, TypeNode NodeManager::mkTupleType(const std::vector& types) { - std::vector ts; - Trace("tuprec-debug") << "Make tuple type : "; - for (unsigned i = 0; i < types.size(); ++i) - { - CheckArgument(!types[i].isFunctionLike(), - types, - "cannot put function-like types in tuples"); - ts.push_back(types[i]); - Trace("tuprec-debug") << types[i] << " "; - } - Trace("tuprec-debug") << std::endl; - return d_tt_cache.getTupleType(this, ts); + return d_tt_cache.getTupleType(this, types); } TypeNode NodeManager::mkRecordType(const Record& rec) diff --git a/src/expr/node_manager_template.h b/src/expr/node_manager_template.h index fc72be009..678728c78 100644 --- a/src/expr/node_manager_template.h +++ b/src/expr/node_manager_template.h @@ -131,6 +131,14 @@ class NodeManager * which is used as an index to retrieve the DType via this call. */ const DType& getDTypeForIndex(size_t index) const; + /** + * Get the DType for a type. If tn is a datatype type, then we retrieve its + * internal index and use the above method to lookup its datatype. + * + * If it is a tuple, then we lookup its datatype representation and call + * this method on it. + */ + const DType& getDTypeFor(TypeNode tn) const; /** get the canonical bound variable list for function type tn */ Node getBoundVarListForFunctionType(TypeNode tn); @@ -793,7 +801,8 @@ class NodeManager }; /** - * A map of tuple and record types to their corresponding datatype. + * A map of tuple types to their corresponding datatype type, which are + * TypeNode of kind TUPLE_TYPE. */ class TupleTypeCache { @@ -801,9 +810,10 @@ class NodeManager std::map d_children; TypeNode d_data; TypeNode getTupleType(NodeManager* nm, - std::vector& types, + const std::vector& types, unsigned index = 0); }; + /** Same as above, for records */ class RecTypeCache { public: diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 59feac41f..d0801126c 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -376,10 +376,7 @@ std::vector TypeNode::getInstantiatedParamTypes() const return params; } -bool TypeNode::isTuple() const -{ - return (getKind() == kind::DATATYPE_TYPE && getDType().isTuple()); -} +bool TypeNode::isTuple() const { return getKind() == kind::TUPLE_TYPE; } bool TypeNode::isRecord() const { @@ -388,34 +385,35 @@ bool TypeNode::isRecord() const size_t TypeNode::getTupleLength() const { Assert(isTuple()); - const DType& dt = getDType(); - Assert(dt.getNumConstructors() == 1); - return dt[0].getNumArgs(); + return getNumChildren(); } vector TypeNode::getTupleTypes() const { Assert(isTuple()); - const DType& dt = getDType(); - Assert(dt.getNumConstructors() == 1); - vector types; - for(unsigned i = 0; i < dt[0].getNumArgs(); ++i) { - types.push_back(dt[0][i].getRangeType()); + std::vector args; + for (uint32_t i = 0, i_end = getNumChildren(); i < i_end; ++i) + { + args.push_back((*this)[i]); } - return types; + return args; } /** Is this an instantiated datatype type */ bool TypeNode::isInstantiatedDatatype() const { - if(getKind() == kind::DATATYPE_TYPE) { + Kind k = getKind(); + if (k == kind::DATATYPE_TYPE || k == kind::TUPLE_TYPE) + { return true; } - if(getKind() != kind::PARAMETRIC_DATATYPE) { + if (k != kind::PARAMETRIC_DATATYPE) + { return false; } const DType& dt = (*this)[0].getDType(); - unsigned n = dt.getNumParameters(); + size_t n = dt.getNumParameters(); Assert(n < getNumChildren()); - for(unsigned i = 0; i < n; ++i) { + for (size_t i = 0; i < n; ++i) + { if (dt.getParameter(i) == (*this)[i + 1]) { return false; @@ -534,8 +532,9 @@ bool TypeNode::isBitVector() const { return getKind() == kind::BITVECTOR_TYPE; } bool TypeNode::isDatatype() const { - return getKind() == kind::DATATYPE_TYPE - || getKind() == kind::PARAMETRIC_DATATYPE; + Kind k = getKind(); + return k == kind::DATATYPE_TYPE || k == kind::PARAMETRIC_DATATYPE + || k == kind::TUPLE_TYPE; } bool TypeNode::isParametricDatatype() const @@ -589,13 +588,7 @@ std::string TypeNode::toString() const { const DType& TypeNode::getDType() const { - if (getKind() == kind::DATATYPE_TYPE) - { - DatatypeIndexConstant dic = getConst(); - return NodeManager::currentNM()->getDTypeForIndex(dic.getIndex()); - } - Assert(getKind() == kind::PARAMETRIC_DATATYPE); - return (*this)[0].getDType(); + return NodeManager::currentNM()->getDTypeFor(*this); } bool TypeNode::isBag() const diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 5367c0e19..cb2c1d083 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1141,6 +1141,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) // datatypes theory case kind::APPLY_TESTER: return "is"; case kind::APPLY_UPDATER: return "update"; + case kind::TUPLE_TYPE: return "Tuple"; // set theory case kind::SET_UNION: return "set.union"; diff --git a/src/proof/lfsc/lfsc_node_converter.cpp b/src/proof/lfsc/lfsc_node_converter.cpp index 563e7809b..f8e79e014 100644 --- a/src/proof/lfsc/lfsc_node_converter.cpp +++ b/src/proof/lfsc/lfsc_node_converter.cpp @@ -537,39 +537,38 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn) Node s = nm->mkConstInt(Rational(tn.getFloatingPointSignificandSize())); tnn = nm->mkNode(APPLY_UF, tnn, e, s); } - else if (tn.getNumChildren() == 0) + else if (k == TUPLE_TYPE) { - // an uninterpreted sort, or an uninstantiatied (maybe parametric) datatype - d_declTypes.insert(tn); // special case: tuples must be distinguished by their arity - if (tn.isTuple()) + size_t nargs = tn.getNumChildren(); + if (nargs > 0) { - const DType& dt = tn.getDType(); - unsigned int nargs = dt[0].getNumArgs(); - if (nargs > 0) + std::vector types; + std::vector convTypes; + std::vector targs; + for (size_t i = 0; i < nargs; i++) { - std::vector types; - std::vector convTypes; - std::vector targs; - for (unsigned int i = 0; i < nargs; i++) - { - // it is not converted yet, convert here - TypeNode tnc = convertType(dt[0][i].getRangeType()); - types.push_back(d_sortType); - convTypes.push_back(tnc); - targs.push_back(typeAsNode(tnc)); - } - TypeNode ftype = nm->mkFunctionType(types, d_sortType); - // must distinguish by arity - std::stringstream ss; - ss << "Tuple_" << nargs; - targs.insert(targs.begin(), getSymbolInternal(k, ftype, ss.str())); - tnn = nm->mkNode(APPLY_UF, targs); - // we are changing its name, we must make a sort constructor - cur = nm->mkSortConstructor(ss.str(), nargs); - cur = nm->mkSort(cur, convTypes); + TypeNode tnc = tn[i]; + types.push_back(d_sortType); + convTypes.push_back(tnc); + targs.push_back(typeAsNode(tnc)); } + TypeNode ftype = nm->mkFunctionType(types, d_sortType); + // must distinguish by arity + std::stringstream ss; + ss << "Tuple_" << nargs; + targs.insert(targs.begin(), getSymbolInternal(k, ftype, ss.str())); + tnn = nm->mkNode(APPLY_UF, targs); + // we are changing its name, we must make a sort constructor + cur = nm->mkSortConstructor(ss.str(), nargs); + cur = nm->mkSort(cur, convTypes); } + } + else if (tn.getNumChildren() == 0) + { + Assert(!tn.isTuple()); + // an uninterpreted sort, or an uninstantiatied (maybe parametric) datatype + d_declTypes.insert(tn); if (tnn.isNull()) { std::stringstream ss; @@ -582,7 +581,7 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn) cur = nm->mkSortConstructor(s, tn.getUninterpretedSortConstructorArity()); } - else if (tn.isUninterpretedSort() || (tn.isDatatype() && !tn.isTuple())) + else if (tn.isUninterpretedSort() || tn.isDatatype()) { std::string s = getNameForUserNameOfInternal(tn.getId(), ss.str()); tnn = getSymbolInternal(k, d_sortType, s, false); @@ -590,6 +589,7 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn) } else { + // all other builtin type constants, e.g. Int tnn = getSymbolInternal(k, d_sortType, ss.str()); } } diff --git a/src/theory/datatypes/kinds b/src/theory/datatypes/kinds index b2c46cf17..2d7a30918 100644 --- a/src/theory/datatypes/kinds +++ b/src/theory/datatypes/kinds @@ -77,6 +77,19 @@ enumerator PARAMETRIC_DATATYPE \ "::cvc5::internal::theory::datatypes::DatatypesEnumerator" \ "theory/datatypes/type_enumerator.h" +operator TUPLE_TYPE 0: "tuple type" +cardinality TUPLE_TYPE \ + "%TYPE%.getDType().getCardinality(%TYPE%)" \ + "expr/dtype.h" +well-founded TUPLE_TYPE \ + "%TYPE%.getDType().isWellFounded()" \ + "%TYPE%.getDType().mkGroundTerm(%TYPE%)" \ + "expr/dtype.h" + +enumerator TUPLE_TYPE \ + "::cvc5::internal::theory::datatypes::DatatypesEnumerator" \ + "expr/dtype.h" + parameterized APPLY_TYPE_ASCRIPTION ASCRIPTION_TYPE 1 \ "type ascription, for datatype constructor applications; first parameter is an ASCRIPTION_TYPE, second is the datatype constructor application being ascribed" constant ASCRIPTION_TYPE \ diff --git a/src/theory/datatypes/theory_datatypes_type_rules.cpp b/src/theory/datatypes/theory_datatypes_type_rules.cpp index edf797da1..94aebbc63 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.cpp +++ b/src/theory/datatypes/theory_datatypes_type_rules.cpp @@ -267,7 +267,7 @@ TypeNode DatatypeAscriptionTypeRule::computeType(NodeManager* nodeManager, { m.addTypesFromDatatype(childType.getDatatypeConstructorRangeType()); } - else if (childType.getKind() == kind::DATATYPE_TYPE) + else if (childType.isDatatype()) { m.addTypesFromDatatype(childType); } diff --git a/test/unit/api/cpp/solver_black.cpp b/test/unit/api/cpp/solver_black.cpp index e65f5fa43..6fbc2f46f 100644 --- a/test/unit/api/cpp/solver_black.cpp +++ b/test/unit/api/cpp/solver_black.cpp @@ -385,8 +385,7 @@ TEST_F(TestApiBlackSolver, mkTupleSort) ASSERT_NO_THROW(d_solver.mkTupleSort({d_solver.getIntegerSort()})); Sort funSort = d_solver.mkFunctionSort({d_solver.mkUninterpretedSort("u")}, d_solver.getIntegerSort()); - ASSERT_THROW(d_solver.mkTupleSort({d_solver.getIntegerSort(), funSort}), - CVC5ApiException); + ASSERT_NO_THROW(d_solver.mkTupleSort({d_solver.getIntegerSort(), funSort})); Solver slv; ASSERT_THROW(slv.mkTupleSort({d_solver.getIntegerSort()}), CVC5ApiException); diff --git a/test/unit/api/java/SolverTest.java b/test/unit/api/java/SolverTest.java index a13c3c3ec..f967efe30 100644 --- a/test/unit/api/java/SolverTest.java +++ b/test/unit/api/java/SolverTest.java @@ -378,8 +378,7 @@ class SolverTest assertDoesNotThrow(() -> d_solver.mkTupleSort(new Sort[] {d_solver.getIntegerSort()})); Sort funSort = d_solver.mkFunctionSort(d_solver.mkUninterpretedSort("u"), d_solver.getIntegerSort()); - assertThrows(CVC5ApiException.class, - () -> d_solver.mkTupleSort(new Sort[] {d_solver.getIntegerSort(), funSort})); + assertDoesNotThrow(() -> d_solver.mkTupleSort(new Sort[] {d_solver.getIntegerSort(), funSort})); Solver slv = new Solver(); assertThrows( diff --git a/test/unit/api/python/test_solver.py b/test/unit/api/python/test_solver.py index 6f02bbe87..736111d96 100644 --- a/test/unit/api/python/test_solver.py +++ b/test/unit/api/python/test_solver.py @@ -308,8 +308,7 @@ def test_mk_tuple_sort(solver): solver.mkTupleSort(solver.getIntegerSort()) funSort = solver.mkFunctionSort(solver.mkUninterpretedSort("u"),\ solver.getIntegerSort()) - with pytest.raises(RuntimeError): - solver.mkTupleSort(solver.getIntegerSort(), funSort) + solver.mkTupleSort(solver.getIntegerSort(), funSort) slv = cvc5.Solver() with pytest.raises(RuntimeError): -- 2.30.2