Optionally permit creation of non-flat function types (#6010)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 26 Feb 2021 17:44:23 +0000 (11:44 -0600)
committerGitHub <noreply@github.com>
Fri, 26 Feb 2021 17:44:23 +0000 (14:44 -0300)
This is required for creating the representation of closues in LFSC, which are of the form ((forall x T) P) where notice that forall has non-flat function type (-> Int Sort (-> Bool Bool)).

src/expr/node_manager.cpp
src/expr/node_manager.h

index 59afec4a6883a122b52e7c1f748593c7007650a9..883febd6f03a094adece3fe6a45cdbe3a7061f4c 100644 (file)
@@ -769,40 +769,44 @@ TypeNode NodeManager::RecTypeCache::getRecordType( NodeManager * nm, const Recor
       nm, rec, index + 1);
 }
 
-TypeNode NodeManager::mkFunctionType(const std::vector<TypeNode>& sorts)
+TypeNode NodeManager::mkFunctionType(const std::vector<TypeNode>& sorts,
+                                     bool reqFlat)
 {
   Assert(sorts.size() >= 2);
-  CheckArgument(!sorts[sorts.size() - 1].isFunction(),
+  CheckArgument(!reqFlat || !sorts[sorts.size() - 1].isFunction(),
                 sorts[sorts.size() - 1],
                 "must flatten function types");
   return mkTypeNode(kind::FUNCTION_TYPE, sorts);
 }
 
-TypeNode NodeManager::mkPredicateType(const std::vector<TypeNode>& sorts)
+TypeNode NodeManager::mkPredicateType(const std::vector<TypeNode>& sorts,
+                                      bool reqFlat)
 {
   Assert(sorts.size() >= 1);
   std::vector<TypeNode> sortNodes;
   sortNodes.insert(sortNodes.end(), sorts.begin(), sorts.end());
   sortNodes.push_back(booleanType());
-  return mkFunctionType(sortNodes);
+  return mkFunctionType(sortNodes, reqFlat);
 }
 
 TypeNode NodeManager::mkFunctionType(const TypeNode& domain,
-                                     const TypeNode& range)
+                                     const TypeNode& range,
+                                     bool reqFlat)
 {
   std::vector<TypeNode> sorts;
   sorts.push_back(domain);
   sorts.push_back(range);
-  return mkFunctionType(sorts);
+  return mkFunctionType(sorts, reqFlat);
 }
 
 TypeNode NodeManager::mkFunctionType(const std::vector<TypeNode>& argTypes,
-                                     const TypeNode& range)
+                                     const TypeNode& range,
+                                     bool reqFlat)
 {
   Assert(argTypes.size() >= 1);
   std::vector<TypeNode> sorts(argTypes);
   sorts.push_back(range);
-  return mkFunctionType(sorts);
+  return mkFunctionType(sorts, reqFlat);
 }
 
 TypeNode NodeManager::mkTupleType(const std::vector<TypeNode>& types) {
index 89cd61e09bafe135c84c005075faa5cbaee39fa7..076b6d164ac861245ce831f13be505861aceec4a 100644 (file)
@@ -886,9 +886,15 @@ class NodeManager {
    *
    * @param domain the domain type
    * @param range the range type
+   * @param reqFlat If true, we require flat function types, e.g. the
+   * range type cannot be a function. User-generated function types and those
+   * used in solving must be flat, although some use cases (e.g. LFSC proof
+   * conversion) require non-flat function types.
    * @returns the functional type domain -> range
    */
-  TypeNode mkFunctionType(const TypeNode& domain, const TypeNode& range);
+  TypeNode mkFunctionType(const TypeNode& domain,
+                          const TypeNode& range,
+                          bool reqFlat = true);
 
   /**
    * Make a function type with input types from
@@ -896,18 +902,25 @@ class NodeManager {
    *
    * @param argTypes the domain is a tuple (argTypes[0], ..., argTypes[n])
    * @param range the range type
+   * @param reqFlat Same as above
    * @returns the functional type (argTypes[0], ..., argTypes[n]) -> range
    */
   TypeNode mkFunctionType(const std::vector<TypeNode>& argTypes,
-                          const TypeNode& range);
+                          const TypeNode& range,
+                          bool reqFlat = true);
 
   /**
    * Make a function type with input types from
    * <code>sorts[0..sorts.size()-2]</code> and result type
    * <code>sorts[sorts.size()-1]</code>. <code>sorts</code> must have
    * at least 2 elements.
+   *
+   * @param sorts The argument and range sort of the function type, where the
+   * range type is the last in this vector.
+   * @param reqFlat Same as above
    */
-  TypeNode mkFunctionType(const std::vector<TypeNode>& sorts);
+  TypeNode mkFunctionType(const std::vector<TypeNode>& sorts,
+                          bool reqFlat = true);
 
   /**
    * Make a predicate type with input types from
@@ -915,7 +928,8 @@ class NodeManager {
    * <code>BOOLEAN</code>. <code>sorts</code> must have at least one
    * element.
    */
-  TypeNode mkPredicateType(const std::vector<TypeNode>& sorts);
+  TypeNode mkPredicateType(const std::vector<TypeNode>& sorts,
+                           bool reqFlat = true);
 
   /**
    * Make a tuple type with types from