Fix the subtyping relation for functions (#3494)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 5 Dec 2019 01:13:08 +0000 (19:13 -0600)
committerAina Niemetz <aina.niemetz@gmail.com>
Thu, 5 Dec 2019 01:13:08 +0000 (17:13 -0800)
src/expr/type_node.cpp
test/regress/CMakeLists.txt
test/regress/regress0/ho/fun-subtyping.smt2 [new file with mode: 0644]

index 0bf86240b42571c90b1b6bb319ec4a62b1e6cdf9..1ef5030ce33efd03a2462fdc60adfec19abceb4b 100644 (file)
@@ -314,6 +314,15 @@ bool TypeNode::isSubtypeOf(TypeNode t) const {
   if(isSet() && t.isSet()) {
     return getSetElementType().isSubtypeOf(t.getSetElementType());
   }
+  if (isFunction() && t.isFunction())
+  {
+    if (!isComparableTo(t))
+    {
+      // incomparable, not subtype
+      return false;
+    }
+    return getRangeType().isSubtypeOf(t.getRangeType());
+  }
   // this should only return true for types T1, T2 where we handle equalities between T1 and T2
   // (more cases go here, if we want to support such cases)
   return false;
@@ -329,6 +338,11 @@ bool TypeNode::isComparableTo(TypeNode t) const {
   if(isSet() && t.isSet()) {
     return getSetElementType().isComparableTo(t.getSetElementType());
   }
+  if (isFunction() && t.isFunction())
+  {
+    // comparable if they have a common type node
+    return !leastCommonTypeNode(*this, t).isNull();
+  }
   return false;
 }
 
@@ -514,26 +528,60 @@ TypeNode TypeNode::commonTypeNode(TypeNode t0, TypeNode t1, bool isLeast) {
   // t0.getKind() == kind::TYPE_CONSTANT &&
   // t1.getKind() == kind::TYPE_CONSTANT
   switch(t0.getKind()) {
-  case kind::BITVECTOR_TYPE:
-  case kind::FLOATINGPOINT_TYPE:
-  case kind::SORT_TYPE:
-  case kind::CONSTRUCTOR_TYPE:
-  case kind::SELECTOR_TYPE:
-  case kind::TESTER_TYPE:
-  case kind::FUNCTION_TYPE:
-  case kind::ARRAY_TYPE:
-  case kind::DATATYPE_TYPE:
-  case kind::PARAMETRIC_DATATYPE:
-    return TypeNode();
-  case kind::SET_TYPE: {
-    // take the least common subtype of element types
-    TypeNode elementType;
-    if(t1.isSet() && !(elementType = commonTypeNode(t0[0], t1[0], isLeast)).isNull() ) {
-      return NodeManager::currentNM()->mkSetType(elementType);
-    } else {
-      return TypeNode();
+    case kind::FUNCTION_TYPE:
+    {
+      if (t1.getKind() != kind::FUNCTION_TYPE)
+      {
+        return TypeNode();
+      }
+      // must have equal arguments
+      std::vector<TypeNode> t0a = t0.getArgTypes();
+      std::vector<TypeNode> t1a = t1.getArgTypes();
+      if (t0a.size() != t1a.size())
+      {
+        // different arities
+        return TypeNode();
+      }
+      for (unsigned i = 0, nargs = t0a.size(); i < nargs; i++)
+      {
+        if (t0a[i] != t1a[i])
+        {
+          // an argument is different
+          return TypeNode();
+        }
+      }
+      TypeNode t0r = t0.getRangeType();
+      TypeNode t1r = t1.getRangeType();
+      TypeNode tr = commonTypeNode(t0r, t1r, isLeast);
+      std::vector<TypeNode> ftypes;
+      ftypes.insert(ftypes.end(), t0a.begin(), t0a.end());
+      ftypes.push_back(tr);
+      return NodeManager::currentNM()->mkFunctionType(ftypes);
+    }
+    break;
+    case kind::BITVECTOR_TYPE:
+    case kind::FLOATINGPOINT_TYPE:
+    case kind::SORT_TYPE:
+    case kind::CONSTRUCTOR_TYPE:
+    case kind::SELECTOR_TYPE:
+    case kind::TESTER_TYPE:
+    case kind::ARRAY_TYPE:
+    case kind::DATATYPE_TYPE:
+    case kind::PARAMETRIC_DATATYPE: return TypeNode();
+    case kind::SET_TYPE:
+    {
+      // take the least common subtype of element types
+      TypeNode elementType;
+      if (t1.isSet()
+          && !(elementType = commonTypeNode(t0[0], t1[0], isLeast)).isNull())
+      {
+        return NodeManager::currentNM()->mkSetType(elementType);
+      }
+      else
+      {
+        return TypeNode();
+      }
     }
-  }
   case kind::SEXPR_TYPE:
     Unimplemented()
         << "haven't implemented leastCommonType for symbolic expressions yet";
index 7d3fb2d5cd1788fcde6af88adf3e81644bed3896..911943c64c85379cddb77a95ea14f3e45cbba067 100644 (file)
@@ -490,6 +490,7 @@ set(regress_0_tests
   regress0/ho/finite-fun-ext.smt2
   regress0/ho/fta0144-alpha-eq.smt2
   regress0/ho/fta0210.smt2
+  regress0/ho/fun-subtyping.smt2
   regress0/ho/ho-exponential-model.smt2
   regress0/ho/ho-match-fun-suffix.smt2
   regress0/ho/ho-matching-enum.smt2
diff --git a/test/regress/regress0/ho/fun-subtyping.smt2 b/test/regress/regress0/ho/fun-subtyping.smt2
new file mode 100644 (file)
index 0000000..8eae3d0
--- /dev/null
@@ -0,0 +1,12 @@
+; COMMAND-LINE: --uf-ho
+; EXPECT: sat
+(set-logic ALL)
+(declare-fun g (Int) Real)
+(declare-fun h (Int) Real)
+(assert (not (= g h)))
+; g will be given a model value of lambda x. 0.0, which is interpreted as
+; a function Int -> Int; however since function types T -> U are subtypes of
+; T -> U' where U is a subtype of U', this example works.
+(assert (= (g 0) 0.0))
+(assert (= (h 0) 0.5))
+(check-sat)