Add methods for constructing datatype types from NodeManager (#4823)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sun, 2 Aug 2020 03:27:08 +0000 (22:27 -0500)
committerGitHub <noreply@github.com>
Sun, 2 Aug 2020 03:27:08 +0000 (22:27 -0500)
This is work towards eliminating the Expr-level datatype.

This PR implements the required methods for constructing datatype types from NodeManager.

In particular, this PR copies the "mkMutualDatatypeTypes" methods and converts them to Node-level.

The next PRs will be in preparation for using these methods instead of the Expr-level ones.

It also adds a flag d_isRecord to DType, which is required for supporting record printing in the cvc printer, which will be updated in another PR.

It also eliminates an interface for constructing constructor types via Expr-level DatatypeConstructor objects, which was unused.

src/expr/dtype.cpp
src/expr/dtype.h
src/expr/expr_manager_template.cpp
src/expr/expr_manager_template.h
src/expr/node_manager.cpp
src/expr/node_manager.h
src/smt/listeners.cpp
src/smt/listeners.h

index 220e12c420b1d005f32219946efe96671b484875..91ae9d158445e8226b6fe285f4d5c789d2c99845 100644 (file)
@@ -25,6 +25,7 @@ DType::DType(std::string name, bool isCo)
       d_params(),
       d_isCo(isCo),
       d_isTuple(false),
+      d_isRecord(false),
       d_constructors(),
       d_resolved(false),
       d_self(),
@@ -43,6 +44,7 @@ DType::DType(std::string name, const std::vector<TypeNode>& params, bool isCo)
       d_params(params),
       d_isCo(isCo),
       d_isTuple(false),
+      d_isRecord(false),
       d_constructors(),
       d_resolved(false),
       d_self(),
@@ -82,6 +84,8 @@ bool DType::isSygus() const { return !d_sygusType.isNull(); }
 
 bool DType::isTuple() const { return d_isTuple; }
 
+bool DType::isRecord() const { return d_isRecord; }
+
 bool DType::isResolved() const { return d_resolved; }
 
 const DType& DType::datatypeOf(Node item)
@@ -231,6 +235,12 @@ void DType::setTuple()
   d_isTuple = true;
 }
 
+void DType::setRecord()
+{
+  Assert(!d_resolved);
+  d_isRecord = true;
+}
+
 Cardinality DType::getCardinality(TypeNode t) const
 {
   Trace("datatypes-init") << "DType::getCardinality " << std::endl;
index fea51cd366302ad269db9f29234d9750903bab43..1682614d0b1b9ec156cd357686b3a9c86e87491d 100644 (file)
@@ -216,6 +216,9 @@ class DType
   /** set that this datatype is a tuple */
   void setTuple();
 
+  /** set that this datatype is a record */
+  void setRecord();
+
   /** Get the name of this DType. */
   std::string getName() const;
 
@@ -243,6 +246,9 @@ class DType
   /** is this a tuple datatype? */
   bool isTuple() const;
 
+  /** is this a record datatype? */
+  bool isRecord() const;
+
   /**
    * Return the cardinality of this datatype.
    * The DType must be resolved.
@@ -553,6 +559,8 @@ class DType
   bool d_isCo;
   /** whether the datatype is a tuple */
   bool d_isTuple;
+  /** whether the datatype is a record */
+  bool d_isRecord;
   /** the constructors of this datatype */
   std::vector<std::shared_ptr<DTypeConstructor> > d_constructors;
   /** whether this datatype has been resolved */
index 445ca9ee77b209feede981c60bda8ad68b6c9c11..66824c07af555e205bad392764c408686cca119d 100644 (file)
@@ -690,7 +690,7 @@ std::vector<DatatypeType> ExprManager::mkMutualDatatypeTypes(
   for(std::vector<Datatype*>::iterator i = dt_copies.begin(), i_end = dt_copies.end(); i != i_end; ++i) {
     TypeNode* typeNode;
     // register datatype with the node manager
-    unsigned index = d_nodeManager->registerDatatype((*i)->d_internal);
+    size_t index = d_nodeManager->registerDatatype((*i)->d_internal);
     if( (*i)->getNumParameters() == 0 ) {
       typeNode = new TypeNode(d_nodeManager->mkTypeConst(DatatypeIndexConstant(index)));
       //typeNode = new TypeNode(d_nodeManager->mkTypeConst(*i));
@@ -760,6 +760,7 @@ std::vector<DatatypeType> ExprManager::mkMutualDatatypeTypes(
   }
 
   // Lastly, perform the final resolutions and checks.
+  std::vector<TypeNode> tns;
   for(std::vector<DatatypeType>::iterator i = dtts.begin(),
         i_end = dtts.end();
       i != i_end;
@@ -776,10 +777,11 @@ std::vector<DatatypeType> ExprManager::mkMutualDatatypeTypes(
     // Now run some checks, including a check to make sure that no
     // selector is function-valued.
     checkResolvedDatatype(*i);
+    tns.push_back(TypeNode::fromType(*i));
   }
 
   for(std::vector<NodeManagerListener*>::iterator i = d_nodeManager->d_listeners.begin(); i != d_nodeManager->d_listeners.end(); ++i) {
-    (*i)->nmNotifyNewDatatypes(dtts, flags);
+    (*i)->nmNotifyNewDatatypes(tns, flags);
   }
   
   return dtts;
@@ -825,11 +827,6 @@ void ExprManager::checkResolvedDatatype(DatatypeType dtt) const {
   }
 }
 
-ConstructorType ExprManager::mkConstructorType(const DatatypeConstructor& constructor, Type range) const {
-  NodeManagerScope nms(d_nodeManager);
-  return Type(d_nodeManager, new TypeNode(d_nodeManager->mkConstructorType(constructor, *range.d_typeNode)));
-}
-
 SelectorType ExprManager::mkSelectorType(Type domain, Type range) const {
   NodeManagerScope nms(d_nodeManager);
   return Type(d_nodeManager, new TypeNode(d_nodeManager->mkSelectorType(*domain.d_typeNode, *range.d_typeNode)));
index 3f180e95164eebbe1227191faa8f6d37f54b20d6..2b9a85acaf67601e26783410502bacef870069e6 100644 (file)
@@ -435,11 +435,6 @@ class CVC4_PUBLIC ExprManager {
       std::set<Type>& unresolvedTypes,
       uint32_t flags = DATATYPE_FLAG_NONE);
 
-  /**
-   * Make a type representing a constructor with the given parameterization.
-   */
-  ConstructorType mkConstructorType(const DatatypeConstructor& constructor, Type range) const;
-
   /** Make a type representing a selector with the given parameterization. */
   SelectorType mkSelectorType(Type domain, Type range) const;
 
index e9f56bf3fdb59ccbbfc7c6aa180ab1cb2abe376a..c72de95644987069d93b93f962d98afdc8b08159 100644 (file)
@@ -149,6 +149,8 @@ NodeManager::~NodeManager() {
   d_rt_cache.d_children.clear();
   d_rt_cache.d_data = dummy;
 
+  d_registeredDTypes.clear();
+  // clear the datatypes
   d_ownedDTypes.clear();
 
   Assert(!d_attrManager->inGarbageCollection());
@@ -198,15 +200,15 @@ NodeManager::~NodeManager() {
 
 size_t NodeManager::registerDatatype(std::shared_ptr<DType> dt)
 {
-  size_t sz = d_ownedDTypes.size();
-  d_ownedDTypes.push_back(dt);
+  size_t sz = d_registeredDTypes.size();
+  d_registeredDTypes.push_back(dt);
   return sz;
 }
 
-const DType& NodeManager::getDTypeForIndex(unsigned index) const
+const DType& NodeManager::getDTypeForIndex(size_t index) const
 {
-  Assert(index < d_ownedDTypes.size());
-  return *d_ownedDTypes[index];
+  Assert(index < d_registeredDTypes.size());
+  return *d_registeredDTypes[index];
 }
 
 void NodeManager::reclaimZombies() {
@@ -462,28 +464,132 @@ TypeNode NodeManager::mkSequenceType(TypeNode elementType)
   return mkTypeNode(kind::SEQUENCE_TYPE, elementType);
 }
 
-TypeNode NodeManager::mkConstructorType(const DatatypeConstructor& constructor,
-                                        TypeNode range) {
-  vector<TypeNode> sorts;
-  Debug("datatypes") << "ctor name: " << constructor.getName() << endl;
-  for(DatatypeConstructor::const_iterator i = constructor.begin();
-      i != constructor.end();
-      ++i) {
-    TypeNode selectorType = *(*i).getSelector().getType().d_typeNode;
-    Debug("datatypes") << selectorType << endl;
-    TypeNode sort = selectorType[1];
+TypeNode NodeManager::mkDatatypeType(DType& datatype, uint32_t flags)
+{
+  // Not worth a special implementation; this doesn't need to be fast
+  // code anyway.
+  std::vector<DType> datatypes;
+  datatypes.push_back(datatype);
+  std::vector<TypeNode> result = mkMutualDatatypeTypes(datatypes, flags);
+  Assert(result.size() == 1);
+  return result.front();
+}
+
+std::vector<TypeNode> NodeManager::mkMutualDatatypeTypes(
+    const std::vector<DType>& datatypes, uint32_t flags)
+{
+  std::set<TypeNode> unresolvedTypes;
+  return mkMutualDatatypeTypes(datatypes, unresolvedTypes, flags);
+}
 
-    // should be guaranteed here already, but just in case
-    Assert(!sort.isFunctionLike());
+std::vector<TypeNode> NodeManager::mkMutualDatatypeTypes(
+    const std::vector<DType>& datatypes,
+    const std::set<TypeNode>& unresolvedTypes,
+    uint32_t flags)
+{
+  NodeManagerScope nms(this);
+  std::map<std::string, TypeNode> nameResolutions;
+  std::vector<TypeNode> dtts;
 
-    Debug("datatypes") << "ctor sort: " << sort << endl;
-    sorts.push_back(sort);
+  // have to build deep copy so that datatypes will live in this class
+  std::vector<std::shared_ptr<DType> > dt_copies;
+  for (const DType& dt : datatypes)
+  {
+    d_ownedDTypes.push_back(std::unique_ptr<DType>(new DType(dt)));
+    dt_copies.push_back(std::move(d_ownedDTypes.back()));
   }
-  Debug("datatypes") << "ctor range: " << range << endl;
-  PrettyCheckArgument(!range.isFunctionLike(), range,
-                      "cannot create higher-order function types");
-  sorts.push_back(range);
-  return mkTypeNode(kind::CONSTRUCTOR_TYPE, sorts);
+
+  // First do some sanity checks, set up the final Type to be used for
+  // each datatype, and set up the "named resolutions" used to handle
+  // simple self- and mutual-recursion, for example in the definition
+  // "nat = succ(pred:nat) | zero", a named resolution can handle the
+  // pred selector.
+  for (const std::shared_ptr<DType>& dtc : dt_copies)
+  {
+    TypeNode typeNode;
+    // register datatype with the node manager
+    size_t index = registerDatatype(dtc);
+    if (dtc->getNumParameters() == 0)
+    {
+      typeNode = mkTypeConst(DatatypeIndexConstant(index));
+    }
+    else
+    {
+      TypeNode cons = mkTypeConst(DatatypeIndexConstant(index));
+      std::vector<TypeNode> params;
+      params.push_back(cons);
+      for (unsigned int ip = 0; ip < dtc->getNumParameters(); ++ip)
+      {
+        params.push_back(dtc->getParameter(ip));
+      }
+
+      typeNode = mkTypeNode(kind::PARAMETRIC_DATATYPE, params);
+    }
+    AlwaysAssert(nameResolutions.find(dtc->getName()) == nameResolutions.end())
+        << "cannot construct two datatypes at the same time "
+           "with the same name";
+    nameResolutions.insert(std::make_pair(dtc->getName(), typeNode));
+    dtts.push_back(typeNode);
+  }
+
+  // Second, set up the type substitution map for complex type
+  // resolution (e.g. if "list" is the type we're defining, and it has
+  // a selector of type "ARRAY INT OF list", this can't be taken care
+  // of using the named resolutions that we set up above.  A
+  // preliminary array type was set up, and now needs to have "list"
+  // substituted in it for the correct type.
+  //
+  // @TODO get rid of named resolutions altogether and handle
+  // everything with these resolutions?
+  std::vector<TypeNode> paramTypes;
+  std::vector<TypeNode> paramReplacements;
+  std::vector<TypeNode> placeholders;  // to hold the "unresolved placeholders"
+  std::vector<TypeNode> replacements;  // to hold our final, resolved types
+  for (const TypeNode& ut : unresolvedTypes)
+  {
+    std::string name = ut.getAttribute(expr::VarNameAttr());
+    std::map<std::string, TypeNode>::const_iterator resolver =
+        nameResolutions.find(name);
+    AlwaysAssert(resolver != nameResolutions.end())
+        << "cannot resolve type " + name
+               + "; it's not among the datatypes being defined";
+    // We will instruct the Datatype to substitute "ut" (the
+    // unresolved SortType used as a placeholder in complex types)
+    // with "(*resolver).second" (the TypeNode we created in the
+    // first step, above).
+    if (ut.isSort())
+    {
+      placeholders.push_back(ut);
+      replacements.push_back((*resolver).second);
+    }
+    else
+    {
+      Assert(ut.isSortConstructor());
+      paramTypes.push_back(ut);
+      paramReplacements.push_back((*resolver).second);
+    }
+  }
+
+  // Lastly, perform the final resolutions and checks.
+  for (const TypeNode& ut : dtts)
+  {
+    const DType& dt = ut.getDType();
+    if (!dt.isResolved())
+    {
+      const_cast<DType&>(dt).resolve(nameResolutions,
+                                     placeholders,
+                                     replacements,
+                                     paramTypes,
+                                     paramReplacements);
+    }
+  }
+
+  for (NodeManagerListener* nml : d_listeners)
+  {
+    nml->nmNotifyNewDatatypes(dtts, flags);
+  }
+
+  return dtts;
 }
 
 TypeNode NodeManager::mkConstructorType(const std::vector<TypeNode>& args,
index 84c4b44e08aa0e4c0c0d185a9a30b79434577c43..cfe771ca540d1a7e42da43e082a0d32ef6682cb1 100644 (file)
@@ -66,7 +66,7 @@ class NodeManagerListener {
   virtual void nmNotifyNewSortConstructor(TypeNode tn, uint32_t flags) {}
   virtual void nmNotifyInstantiateSortConstructor(TypeNode ctor, TypeNode sort,
                                                   uint32_t flags) {}
-  virtual void nmNotifyNewDatatypes(const std::vector<DatatypeType>& datatypes,
+  virtual void nmNotifyNewDatatypes(const std::vector<TypeNode>& datatypes,
                                     uint32_t flags)
   {
   }
@@ -172,8 +172,12 @@ class NodeManager {
    */
   std::vector<NodeManagerListener*> d_listeners;
 
-  /** A list of datatypes owned by this node manager. */
-  std::vector<std::shared_ptr<DType> > d_ownedDTypes;
+  /** A list of datatypes registered by its corresponding expr manager.
+   * !!! this member should be deleted when the Expr-layer is deleted.
+   */
+  std::vector<std::shared_ptr<DType> > d_registeredDTypes;
+  /** A list of datatypes owned by this node manager */
+  std::vector<std::unique_ptr<DType> > d_ownedDTypes;
 
   /**
    * A map of tuple and record types to their corresponding datatype.
@@ -407,8 +411,10 @@ public:
     Assert(elt != d_listeners.end()) << "listener not subscribed";
     d_listeners.erase(elt);
   }
-  
-  /** register datatype */
+
+  /** register that datatype dt was constructed by the expression manager
+   * !!! this interface should be deleted when the Expr-layer is deleted.
+   */
   size_t registerDatatype(std::shared_ptr<DType> dt);
   /**
    * Return the datatype at the given index owned by this class. Type nodes are
@@ -420,7 +426,7 @@ public:
    * would lead to memory leaks. Thus TypeNode are given a DatatypeIndexConstant
    * which is used as an index to retrieve the DType via this call.
    */
-  const DType& getDTypeForIndex(unsigned index) const;
+  const DType& getDTypeForIndex(size_t index) const;
 
   /** Get a Kind from an operator expression */
   static inline Kind operatorToKind(TNode n);
@@ -878,8 +884,61 @@ public:
   /** Make the type of sequences with the given parameterization */
   TypeNode mkSequenceType(TypeNode elementType);
 
-  /** Make a type representing a constructor with the given parameterization */
-  TypeNode mkConstructorType(const DatatypeConstructor& constructor, TypeNode range);
+  /** Bits for use in mkDatatypeType() flags.
+   *
+   * DATATYPE_FLAG_PLACEHOLDER indicates that the type should not be printed
+   * out as a definition, for example, in models or during dumping.
+   */
+  enum
+  {
+    DATATYPE_FLAG_NONE = 0,
+    DATATYPE_FLAG_PLACEHOLDER = 1
+  }; /* enum */
+
+  /** Make a type representing the given datatype. */
+  TypeNode mkDatatypeType(DType& datatype, uint32_t flags = DATATYPE_FLAG_NONE);
+
+  /**
+   * Make a set of types representing the given datatypes, which may be
+   * mutually recursive.
+   */
+  std::vector<TypeNode> mkMutualDatatypeTypes(
+      const std::vector<DType>& datatypes, uint32_t flags = DATATYPE_FLAG_NONE);
+
+  /**
+   * Make a set of types representing the given datatypes, which may
+   * be mutually recursive.  unresolvedTypes is a set of SortTypes
+   * that were used as placeholders in the Datatypes for the Datatypes
+   * of the same name.  This is just a more complicated version of the
+   * above mkMutualDatatypeTypes() function, but is required to handle
+   * complex types.
+   *
+   * For example, unresolvedTypes might contain the single sort "list"
+   * (with that name reported from SortType::getName()).  The
+   * datatypes list might have the single datatype
+   *
+   *   DATATYPE
+   *     list = cons(car:ARRAY INT OF list, cdr:list) | nil;
+   *   END;
+   *
+   * To represent the Type of the array, the user had to create a
+   * placeholder type (an uninterpreted sort) to stand for "list" in
+   * the type of "car".  It is this placeholder sort that should be
+   * passed in unresolvedTypes.  If the datatype was of the simpler
+   * form:
+   *
+   *   DATATYPE
+   *     list = cons(car:list, cdr:list) | nil;
+   *   END;
+   *
+   * then no complicated Type needs to be created, and the above,
+   * simpler form of mkMutualDatatypeTypes() is enough.
+   */
+  std::vector<TypeNode> mkMutualDatatypeTypes(
+      const std::vector<DType>& datatypes,
+      const std::set<TypeNode>& unresolvedTypes,
+      uint32_t flags = DATATYPE_FLAG_NONE);
+
   /**
    * Make a type representing a constructor with the given argument (subfield)
    * types and return type range.
index 5cfbfd1c5e108a78094aa65ab02d70c8b0bd9b06..452894f62bb9a758cd800fd7e2df6a110455d209 100644 (file)
@@ -59,11 +59,16 @@ void SmtNodeManagerListener::nmNotifyNewSortConstructor(TypeNode tn,
 }
 
 void SmtNodeManagerListener::nmNotifyNewDatatypes(
-    const std::vector<DatatypeType>& dtts, uint32_t flags)
+    const std::vector<TypeNode>& dtts, uint32_t flags)
 {
   if ((flags & ExprManager::DATATYPE_FLAG_PLACEHOLDER) == 0)
   {
-    std::vector<Type> types(dtts.begin(), dtts.end());
+    std::vector<Type> types;
+    for (const TypeNode& dt : dtts)
+    {
+      Assert(dt.isDatatype());
+      types.push_back(dt.toType());
+    }
     DatatypeDeclarationCommand c(types);
     d_smt.addToModelCommandAndDump(c);
   }
index 1eba4122828d62777af15f65921a0ff827b94a9a..6054d13af3a7fb6424c5110ec98a249b65822cb0 100644 (file)
@@ -53,7 +53,7 @@ class SmtNodeManagerListener : public NodeManagerListener
   /** Notify when new sort constructor is created */
   void nmNotifyNewSortConstructor(TypeNode tn, uint32_t flags) override;
   /** Notify when list of datatypes is created */
-  void nmNotifyNewDatatypes(const std::vector<DatatypeType>& dtts,
+  void nmNotifyNewDatatypes(const std::vector<TypeNode>& dtts,
                             uint32_t flags) override;
   /** Notify when new variable is created */
   void nmNotifyNewVar(TNode n, uint32_t flags) override;