From ae789c1d976b21bac4217a83f5ad9615b8f5e0f5 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 4 Dec 2019 19:13:08 -0600 Subject: [PATCH] Fix the subtyping relation for functions (#3494) --- src/expr/type_node.cpp | 86 ++++++++++++++++----- test/regress/CMakeLists.txt | 1 + test/regress/regress0/ho/fun-subtyping.smt2 | 12 +++ 3 files changed, 80 insertions(+), 19 deletions(-) create mode 100644 test/regress/regress0/ho/fun-subtyping.smt2 diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 0bf86240b..1ef5030ce 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -314,6 +314,15 @@ bool TypeNode::isSubtypeOf(TypeNode t) const { if(isSet() && t.isSet()) { return getSetElementType().isSubtypeOf(t.getSetElementType()); } + if (isFunction() && t.isFunction()) + { + if (!isComparableTo(t)) + { + // incomparable, not subtype + return false; + } + return getRangeType().isSubtypeOf(t.getRangeType()); + } // this should only return true for types T1, T2 where we handle equalities between T1 and T2 // (more cases go here, if we want to support such cases) return false; @@ -329,6 +338,11 @@ bool TypeNode::isComparableTo(TypeNode t) const { if(isSet() && t.isSet()) { return getSetElementType().isComparableTo(t.getSetElementType()); } + if (isFunction() && t.isFunction()) + { + // comparable if they have a common type node + return !leastCommonTypeNode(*this, t).isNull(); + } return false; } @@ -514,26 +528,60 @@ TypeNode TypeNode::commonTypeNode(TypeNode t0, TypeNode t1, bool isLeast) { // t0.getKind() == kind::TYPE_CONSTANT && // t1.getKind() == kind::TYPE_CONSTANT switch(t0.getKind()) { - case kind::BITVECTOR_TYPE: - case kind::FLOATINGPOINT_TYPE: - case kind::SORT_TYPE: - case kind::CONSTRUCTOR_TYPE: - case kind::SELECTOR_TYPE: - case kind::TESTER_TYPE: - case kind::FUNCTION_TYPE: - case kind::ARRAY_TYPE: - case kind::DATATYPE_TYPE: - case kind::PARAMETRIC_DATATYPE: - return TypeNode(); - case kind::SET_TYPE: { - // take the least common subtype of element types - TypeNode elementType; - if(t1.isSet() && !(elementType = commonTypeNode(t0[0], t1[0], isLeast)).isNull() ) { - return NodeManager::currentNM()->mkSetType(elementType); - } else { - return TypeNode(); + case kind::FUNCTION_TYPE: + { + if (t1.getKind() != kind::FUNCTION_TYPE) + { + return TypeNode(); + } + // must have equal arguments + std::vector t0a = t0.getArgTypes(); + std::vector t1a = t1.getArgTypes(); + if (t0a.size() != t1a.size()) + { + // different arities + return TypeNode(); + } + for (unsigned i = 0, nargs = t0a.size(); i < nargs; i++) + { + if (t0a[i] != t1a[i]) + { + // an argument is different + return TypeNode(); + } + } + TypeNode t0r = t0.getRangeType(); + TypeNode t1r = t1.getRangeType(); + TypeNode tr = commonTypeNode(t0r, t1r, isLeast); + std::vector ftypes; + ftypes.insert(ftypes.end(), t0a.begin(), t0a.end()); + ftypes.push_back(tr); + return NodeManager::currentNM()->mkFunctionType(ftypes); + } + break; + case kind::BITVECTOR_TYPE: + case kind::FLOATINGPOINT_TYPE: + case kind::SORT_TYPE: + case kind::CONSTRUCTOR_TYPE: + case kind::SELECTOR_TYPE: + case kind::TESTER_TYPE: + case kind::ARRAY_TYPE: + case kind::DATATYPE_TYPE: + case kind::PARAMETRIC_DATATYPE: return TypeNode(); + case kind::SET_TYPE: + { + // take the least common subtype of element types + TypeNode elementType; + if (t1.isSet() + && !(elementType = commonTypeNode(t0[0], t1[0], isLeast)).isNull()) + { + return NodeManager::currentNM()->mkSetType(elementType); + } + else + { + return TypeNode(); + } } - } case kind::SEXPR_TYPE: Unimplemented() << "haven't implemented leastCommonType for symbolic expressions yet"; diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 7d3fb2d5c..911943c64 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -490,6 +490,7 @@ set(regress_0_tests regress0/ho/finite-fun-ext.smt2 regress0/ho/fta0144-alpha-eq.smt2 regress0/ho/fta0210.smt2 + regress0/ho/fun-subtyping.smt2 regress0/ho/ho-exponential-model.smt2 regress0/ho/ho-match-fun-suffix.smt2 regress0/ho/ho-matching-enum.smt2 diff --git a/test/regress/regress0/ho/fun-subtyping.smt2 b/test/regress/regress0/ho/fun-subtyping.smt2 new file mode 100644 index 000000000..8eae3d073 --- /dev/null +++ b/test/regress/regress0/ho/fun-subtyping.smt2 @@ -0,0 +1,12 @@ +; COMMAND-LINE: --uf-ho +; EXPECT: sat +(set-logic ALL) +(declare-fun g (Int) Real) +(declare-fun h (Int) Real) +(assert (not (= g h))) +; g will be given a model value of lambda x. 0.0, which is interpreted as +; a function Int -> Int; however since function types T -> U are subtypes of +; T -> U' where U is a subtype of U', this example works. +(assert (= (g 0) 0.0)) +(assert (= (h 0) 0.5)) +(check-sat) -- 2.30.2