Properly represent Tuples in the TypeNode AST (#8648)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 29 Apr 2022 21:49:13 +0000 (16:49 -0500)
committerGitHub <noreply@github.com>
Fri, 29 Apr 2022 21:49:13 +0000 (21:49 +0000)
This makes it so that Tuple types are properly represented in the AST. It also removes a spurious restriction that disallowed higher-order tuples (this was leftover from a very old sanity check in the old API).

For example, a tuple type over (Int, Int) is now (TUPLE_TYPE INT INT) instead of a DATATYPE_TYPE constant.

Tuple types behave exactly like datatypes; we can still retrieve their DType as before.

This is in preparation for gradual types and symbolic tuple projections.

src/expr/node_manager_attributes.h
src/expr/node_manager_template.cpp
src/expr/node_manager_template.h
src/expr/type_node.cpp
src/printer/smt2/smt2_printer.cpp
src/proof/lfsc/lfsc_node_converter.cpp
src/theory/datatypes/kinds
src/theory/datatypes/theory_datatypes_type_rules.cpp
test/unit/api/cpp/solver_black.cpp
test/unit/api/java/SolverTest.java
test/unit/api/python/test_solver.py

index df5734062476aaada408c411df4af393e4ddc0e2..eb284b304cfd48d83b495356d65d9584fa343303 100644 (file)
@@ -33,6 +33,9 @@ namespace attr {
   struct UnresolvedDatatypeTag
   {
   };
+  struct TupleDatatypeTag
+  {
+  };
   }  // namespace attr
 
 typedef Attribute<attr::VarNameTag, std::string> VarNameAttr;
@@ -44,5 +47,9 @@ typedef expr::Attribute<expr::attr::TypeCheckedTag, bool> TypeCheckedAttr;
 using UnresolvedDatatypeAttr =
     expr::Attribute<expr::attr::UnresolvedDatatypeTag, bool>;
 
+/** Mapping tuples to their datatype type encoding */
+using TupleDatatypeAttr =
+    expr::Attribute<expr::attr::TupleDatatypeTag, TypeNode>;
+
 }  // namespace expr
 }  // namespace cvc5::internal
index 64f24f89a8f68fbd464374a950ced3b08911ef87..24e1f01f3a4842df59e6ed7505249e76f4ab8721 100644 (file)
@@ -302,6 +302,25 @@ NodeManager::~NodeManager()
   d_attrManager = NULL;
 }
 
+const DType& NodeManager::getDTypeFor(TypeNode tn) const
+{
+  Kind k = tn.getKind();
+  if (k == kind::DATATYPE_TYPE)
+  {
+    DatatypeIndexConstant dic = tn.getConst<DatatypeIndexConstant>();
+    return getDTypeForIndex(dic.getIndex());
+  }
+  else if (k == kind::TUPLE_TYPE)
+  {
+    // lookup its datatype encoding
+    TypeNode dtt = getAttribute(tn, expr::TupleDatatypeAttr());
+    Assert(!dtt.isNull());
+    return getDTypeFor(dtt);
+  }
+  Assert(k == kind::PARAMETRIC_DATATYPE);
+  return getDTypeFor(tn[0]);
+}
+
 const DType& NodeManager::getDTypeForIndex(size_t index) const
 {
   // if this assertion fails, it is likely due to not managing datatypes
@@ -599,6 +618,22 @@ std::vector<TypeNode> NodeManager::mkMutualDatatypeTypesInternal(
     if (dtp->getNumParameters() == 0)
     {
       typeNode = mkTypeConst(DatatypeIndexConstant(index));
+      // if the datatype is a tuple, the type will be (TUPLE_TYPE ...)
+      if (dt.isTuple())
+      {
+        TypeNode dtt = typeNode;
+        const DTypeConstructor& dc = dt[0];
+        std::vector<TypeNode> tupleTypes;
+        for (size_t i = 0, nargs = dc.getNumArgs(); i < nargs; i++)
+        {
+          // selector should be initialized to the range type, it is not null
+          // or unresolved since tuples are not recursive
+          tupleTypes.push_back(dc[i].getType());
+        }
+        // Set its datatype representation
+        typeNode = mkTypeNode(kind::TUPLE_TYPE, tupleTypes);
+        typeNode.setAttribute(expr::TupleDatatypeAttr(), dtt);
+      }
     }
     else
     {
@@ -740,9 +775,8 @@ TypeNode NodeManager::mkDatatypeUpdateType(TypeNode domain, TypeNode range)
   return mkTypeNode(kind::UPDATER_TYPE, domain, range);
 }
 
-TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm,
-                                                   std::vector<TypeNode>& types,
-                                                   unsigned index)
+TypeNode NodeManager::TupleTypeCache::getTupleType(
+    NodeManager* nm, const std::vector<TypeNode>& types, unsigned index)
 {
   if (index == types.size())
   {
@@ -750,7 +784,8 @@ TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm,
     {
       std::stringstream sst;
       sst << "__cvc5_tuple";
-      for (unsigned i = 0; i < types.size(); ++i)
+      size_t ntypes = types.size();
+      for (size_t i = 0; i < ntypes; ++i)
       {
         sst << "_" << types[i];
       }
@@ -760,7 +795,7 @@ TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm,
       ssc << sst.str() << "_ctor";
       std::shared_ptr<DTypeConstructor> c =
           std::make_shared<DTypeConstructor>(ssc.str());
-      for (unsigned i = 0; i < types.size(); ++i)
+      for (size_t i = 0; i < ntypes; ++i)
       {
         std::stringstream ss;
         ss << sst.str() << "_stor_" << i;
@@ -768,6 +803,7 @@ TypeNode NodeManager::TupleTypeCache::getTupleType(NodeManager* nm,
       }
       dt.addConstructor(c);
       d_data = nm->mkDatatypeType(dt);
+      Assert(d_data.isTuple());
       Trace("tuprec-debug") << "Return type : " << d_data << std::endl;
     }
     return d_data;
@@ -804,6 +840,7 @@ TypeNode NodeManager::RecTypeCache::getRecordType(NodeManager* nm,
       }
       dt.addConstructor(c);
       d_data = nm->mkDatatypeType(dt);
+      Assert(d_data.isRecord());
       Trace("tuprec-debug") << "Return type : " << d_data << std::endl;
     }
     return d_data;
@@ -847,18 +884,7 @@ TypeNode NodeManager::mkFunctionType(const std::vector<TypeNode>& argTypes,
 
 TypeNode NodeManager::mkTupleType(const std::vector<TypeNode>& types)
 {
-  std::vector<TypeNode> ts;
-  Trace("tuprec-debug") << "Make tuple type : ";
-  for (unsigned i = 0; i < types.size(); ++i)
-  {
-    CheckArgument(!types[i].isFunctionLike(),
-                  types,
-                  "cannot put function-like types in tuples");
-    ts.push_back(types[i]);
-    Trace("tuprec-debug") << types[i] << " ";
-  }
-  Trace("tuprec-debug") << std::endl;
-  return d_tt_cache.getTupleType(this, ts);
+  return d_tt_cache.getTupleType(this, types);
 }
 
 TypeNode NodeManager::mkRecordType(const Record& rec)
index fc72be0092555b1b699d3dbfa5f89c29f3ec369e..678728c780b667f4c8920ff494e33b2e67868b62 100644 (file)
@@ -131,6 +131,14 @@ class NodeManager
    * which is used as an index to retrieve the DType via this call.
    */
   const DType& getDTypeForIndex(size_t index) const;
+  /**
+   * Get the DType for a type. If tn is a datatype type, then we retrieve its
+   * internal index and use the above method to lookup its datatype.
+   *
+   * If it is a tuple, then we lookup its datatype representation and call
+   * this method on it.
+   */
+  const DType& getDTypeFor(TypeNode tn) const;
 
   /** get the canonical bound variable list for function type tn */
   Node getBoundVarListForFunctionType(TypeNode tn);
@@ -793,7 +801,8 @@ class NodeManager
   };
 
   /**
-   * A map of tuple and record types to their corresponding datatype.
+   * A map of tuple types to their corresponding datatype type, which are
+   * TypeNode of kind TUPLE_TYPE.
    */
   class TupleTypeCache
   {
@@ -801,9 +810,10 @@ class NodeManager
     std::map<TypeNode, TupleTypeCache> d_children;
     TypeNode d_data;
     TypeNode getTupleType(NodeManager* nm,
-                          std::vector<TypeNode>& types,
+                          const std::vector<TypeNode>& types,
                           unsigned index = 0);
   };
+  /** Same as above, for records */
   class RecTypeCache
   {
    public:
index 59feac41f844fea1227602f45445c5199560c268..d0801126c5b49444dd6a3ce1da4be6a356ace9a1 100644 (file)
@@ -376,10 +376,7 @@ std::vector<TypeNode> TypeNode::getInstantiatedParamTypes() const
   return params;
 }
 
-bool TypeNode::isTuple() const
-{
-  return (getKind() == kind::DATATYPE_TYPE && getDType().isTuple());
-}
+bool TypeNode::isTuple() const { return getKind() == kind::TUPLE_TYPE; }
 
 bool TypeNode::isRecord() const
 {
@@ -388,34 +385,35 @@ bool TypeNode::isRecord() const
 
 size_t TypeNode::getTupleLength() const {
   Assert(isTuple());
-  const DType& dt = getDType();
-  Assert(dt.getNumConstructors() == 1);
-  return dt[0].getNumArgs();
+  return getNumChildren();
 }
 
 vector<TypeNode> TypeNode::getTupleTypes() const {
   Assert(isTuple());
-  const DType& dt = getDType();
-  Assert(dt.getNumConstructors() == 1);
-  vector<TypeNode> types;
-  for(unsigned i = 0; i < dt[0].getNumArgs(); ++i) {
-    types.push_back(dt[0][i].getRangeType());
+  std::vector<TypeNode> args;
+  for (uint32_t i = 0, i_end = getNumChildren(); i < i_end; ++i)
+  {
+    args.push_back((*this)[i]);
   }
-  return types;
+  return args;
 }
 
 /** Is this an instantiated datatype type */
 bool TypeNode::isInstantiatedDatatype() const {
-  if(getKind() == kind::DATATYPE_TYPE) {
+  Kind k = getKind();
+  if (k == kind::DATATYPE_TYPE || k == kind::TUPLE_TYPE)
+  {
     return true;
   }
-  if(getKind() != kind::PARAMETRIC_DATATYPE) {
+  if (k != kind::PARAMETRIC_DATATYPE)
+  {
     return false;
   }
   const DType& dt = (*this)[0].getDType();
-  unsigned n = dt.getNumParameters();
+  size_t n = dt.getNumParameters();
   Assert(n < getNumChildren());
-  for(unsigned i = 0; i < n; ++i) {
+  for (size_t i = 0; i < n; ++i)
+  {
     if (dt.getParameter(i) == (*this)[i + 1])
     {
       return false;
@@ -534,8 +532,9 @@ bool TypeNode::isBitVector() const { return getKind() == kind::BITVECTOR_TYPE; }
 
 bool TypeNode::isDatatype() const
 {
-  return getKind() == kind::DATATYPE_TYPE
-         || getKind() == kind::PARAMETRIC_DATATYPE;
+  Kind k = getKind();
+  return k == kind::DATATYPE_TYPE || k == kind::PARAMETRIC_DATATYPE
+         || k == kind::TUPLE_TYPE;
 }
 
 bool TypeNode::isParametricDatatype() const
@@ -589,13 +588,7 @@ std::string TypeNode::toString() const {
 
 const DType& TypeNode::getDType() const
 {
-  if (getKind() == kind::DATATYPE_TYPE)
-  {
-    DatatypeIndexConstant dic = getConst<DatatypeIndexConstant>();
-    return NodeManager::currentNM()->getDTypeForIndex(dic.getIndex());
-  }
-  Assert(getKind() == kind::PARAMETRIC_DATATYPE);
-  return (*this)[0].getDType();
+  return NodeManager::currentNM()->getDTypeFor(*this);
 }
 
 bool TypeNode::isBag() const
index 5367c0e1959d8cf6e6baace04676551ccffe4c10..cb2c1d08305d36b47854623d0888867606218441 100644 (file)
@@ -1141,6 +1141,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   // datatypes theory
   case kind::APPLY_TESTER: return "is";
   case kind::APPLY_UPDATER: return "update";
+  case kind::TUPLE_TYPE: return "Tuple";
 
   // set theory
   case kind::SET_UNION: return "set.union";
index 563e7809b168d582105d3e97a79286d4feea1694..f8e79e014146d73133c08ac8e6f34c0726276629 100644 (file)
@@ -537,39 +537,38 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn)
     Node s = nm->mkConstInt(Rational(tn.getFloatingPointSignificandSize()));
     tnn = nm->mkNode(APPLY_UF, tnn, e, s);
   }
-  else if (tn.getNumChildren() == 0)
+  else if (k == TUPLE_TYPE)
   {
-    // an uninterpreted sort, or an uninstantiatied (maybe parametric) datatype
-    d_declTypes.insert(tn);
     // special case: tuples must be distinguished by their arity
-    if (tn.isTuple())
+    size_t nargs = tn.getNumChildren();
+    if (nargs > 0)
     {
-      const DType& dt = tn.getDType();
-      unsigned int nargs = dt[0].getNumArgs();
-      if (nargs > 0)
+      std::vector<TypeNode> types;
+      std::vector<TypeNode> convTypes;
+      std::vector<Node> targs;
+      for (size_t i = 0; i < nargs; i++)
       {
-        std::vector<TypeNode> types;
-        std::vector<TypeNode> convTypes;
-        std::vector<Node> targs;
-        for (unsigned int i = 0; i < nargs; i++)
-        {
-          // it is not converted yet, convert here
-          TypeNode tnc = convertType(dt[0][i].getRangeType());
-          types.push_back(d_sortType);
-          convTypes.push_back(tnc);
-          targs.push_back(typeAsNode(tnc));
-        }
-        TypeNode ftype = nm->mkFunctionType(types, d_sortType);
-        // must distinguish by arity
-        std::stringstream ss;
-        ss << "Tuple_" << nargs;
-        targs.insert(targs.begin(), getSymbolInternal(k, ftype, ss.str()));
-        tnn = nm->mkNode(APPLY_UF, targs);
-        // we are changing its name, we must make a sort constructor
-        cur = nm->mkSortConstructor(ss.str(), nargs);
-        cur = nm->mkSort(cur, convTypes);
+        TypeNode tnc = tn[i];
+        types.push_back(d_sortType);
+        convTypes.push_back(tnc);
+        targs.push_back(typeAsNode(tnc));
       }
+      TypeNode ftype = nm->mkFunctionType(types, d_sortType);
+      // must distinguish by arity
+      std::stringstream ss;
+      ss << "Tuple_" << nargs;
+      targs.insert(targs.begin(), getSymbolInternal(k, ftype, ss.str()));
+      tnn = nm->mkNode(APPLY_UF, targs);
+      // we are changing its name, we must make a sort constructor
+      cur = nm->mkSortConstructor(ss.str(), nargs);
+      cur = nm->mkSort(cur, convTypes);
     }
+  }
+  else if (tn.getNumChildren() == 0)
+  {
+    Assert(!tn.isTuple());
+    // an uninterpreted sort, or an uninstantiatied (maybe parametric) datatype
+    d_declTypes.insert(tn);
     if (tnn.isNull())
     {
       std::stringstream ss;
@@ -582,7 +581,7 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn)
         cur =
             nm->mkSortConstructor(s, tn.getUninterpretedSortConstructorArity());
       }
-      else if (tn.isUninterpretedSort() || (tn.isDatatype() && !tn.isTuple()))
+      else if (tn.isUninterpretedSort() || tn.isDatatype())
       {
         std::string s = getNameForUserNameOfInternal(tn.getId(), ss.str());
         tnn = getSymbolInternal(k, d_sortType, s, false);
@@ -590,6 +589,7 @@ TypeNode LfscNodeConverter::postConvertType(TypeNode tn)
       }
       else
       {
+        // all other builtin type constants, e.g. Int
         tnn = getSymbolInternal(k, d_sortType, ss.str());
       }
     }
index b2c46cf17db45f31f39fccc76e271be79d635d2b..2d7a3091841ea788bf354e7cdc5501174dfd0bd0 100644 (file)
@@ -77,6 +77,19 @@ enumerator PARAMETRIC_DATATYPE \
     "::cvc5::internal::theory::datatypes::DatatypesEnumerator" \
     "theory/datatypes/type_enumerator.h"
 
+operator TUPLE_TYPE 0: "tuple type"
+cardinality TUPLE_TYPE \
+    "%TYPE%.getDType().getCardinality(%TYPE%)" \
+    "expr/dtype.h"
+well-founded TUPLE_TYPE \
+    "%TYPE%.getDType().isWellFounded()" \
+    "%TYPE%.getDType().mkGroundTerm(%TYPE%)" \
+    "expr/dtype.h"
+
+enumerator TUPLE_TYPE \
+    "::cvc5::internal::theory::datatypes::DatatypesEnumerator" \
+    "expr/dtype.h"
+
 parameterized APPLY_TYPE_ASCRIPTION ASCRIPTION_TYPE 1 \
     "type ascription, for datatype constructor applications; first parameter is an ASCRIPTION_TYPE, second is the datatype constructor application being ascribed"
 constant ASCRIPTION_TYPE \
index edf797da199353ee385226c6161235dba2f6b518..94aebbc63eef660256b3638b6b15f333d7da6bde 100644 (file)
@@ -267,7 +267,7 @@ TypeNode DatatypeAscriptionTypeRule::computeType(NodeManager* nodeManager,
     {
       m.addTypesFromDatatype(childType.getDatatypeConstructorRangeType());
     }
-    else if (childType.getKind() == kind::DATATYPE_TYPE)
+    else if (childType.isDatatype())
     {
       m.addTypesFromDatatype(childType);
     }
index e65f5fa43cfab4641a575f494e35be2e047ac5d0..6fbc2f46f5b433ba26311d046ae0ac2526ceba1f 100644 (file)
@@ -385,8 +385,7 @@ TEST_F(TestApiBlackSolver, mkTupleSort)
   ASSERT_NO_THROW(d_solver.mkTupleSort({d_solver.getIntegerSort()}));
   Sort funSort = d_solver.mkFunctionSort({d_solver.mkUninterpretedSort("u")},
                                          d_solver.getIntegerSort());
-  ASSERT_THROW(d_solver.mkTupleSort({d_solver.getIntegerSort(), funSort}),
-               CVC5ApiException);
+  ASSERT_NO_THROW(d_solver.mkTupleSort({d_solver.getIntegerSort(), funSort}));
 
   Solver slv;
   ASSERT_THROW(slv.mkTupleSort({d_solver.getIntegerSort()}), CVC5ApiException);
index a13c3c3ec7eee5ba61ea7657f32c3cb4a56b586d..f967efe3024cc270a3a696c11e9c3cefd9ef1ffb 100644 (file)
@@ -378,8 +378,7 @@ class SolverTest
     assertDoesNotThrow(() -> d_solver.mkTupleSort(new Sort[] {d_solver.getIntegerSort()}));
     Sort funSort =
         d_solver.mkFunctionSort(d_solver.mkUninterpretedSort("u"), d_solver.getIntegerSort());
-    assertThrows(CVC5ApiException.class,
-        () -> d_solver.mkTupleSort(new Sort[] {d_solver.getIntegerSort(), funSort}));
+    assertDoesNotThrow(() -> d_solver.mkTupleSort(new Sort[] {d_solver.getIntegerSort(), funSort}));
 
     Solver slv = new Solver();
     assertThrows(
index 6f02bbe8795da88078b6d4269e68b4d8d186658a..736111d96fc09a9b828abfba03a2b866d03e4fd8 100644 (file)
@@ -308,8 +308,7 @@ def test_mk_tuple_sort(solver):
     solver.mkTupleSort(solver.getIntegerSort())
     funSort = solver.mkFunctionSort(solver.mkUninterpretedSort("u"),\
                                     solver.getIntegerSort())
-    with pytest.raises(RuntimeError):
-        solver.mkTupleSort(solver.getIntegerSort(), funSort)
+    solver.mkTupleSort(solver.getIntegerSort(), funSort)
 
     slv = cvc5.Solver()
     with pytest.raises(RuntimeError):