Support for parametric datatype subtyping, so that e.g. (Pair Int Int) is a subtype...
authorMorgan Deters <mdeters@cs.nyu.edu>
Tue, 3 Dec 2013 00:32:14 +0000 (19:32 -0500)
committerMorgan Deters <mdeters@cs.nyu.edu>
Tue, 3 Dec 2013 01:47:48 +0000 (20:47 -0500)
src/expr/type_node.cpp
test/regress/regress0/Makefile.am
test/regress/regress0/bug541.smt2 [new file with mode: 0644]
test/unit/util/datatype_black.h

index 54fd2f3e855b9e04ed1a7263c840472dbef18927..335dd2b6d61e1e7dcfcfed00803997720c924a9b 100644 (file)
@@ -133,6 +133,19 @@ bool TypeNode::isSubtypeOf(TypeNode t) const {
            getArgTypes() == t.getArgTypes() &&
            getRangeType().isSubtypeOf(t.getRangeType());
   }
+  if(isParametricDatatype() && t.isParametricDatatype()) {
+    Assert(getKind() == kind::PARAMETRIC_DATATYPE);
+    Assert(t.getKind() == kind::PARAMETRIC_DATATYPE);
+    if((*this)[0] != t[0] || getNumChildren() != t.getNumChildren()) {
+      return false;
+    }
+    for(size_t i = 1; i < getNumChildren(); ++i) {
+      if(!((*this)[i].isSubtypeOf(t[i]))) {
+        return false;
+      }
+    }
+    return true;
+  }
   if(isPredicateSubtype()) {
     return getSubtypeParentType().isSubtypeOf(t);
   }
@@ -186,7 +199,20 @@ bool TypeNode::isComparableTo(TypeNode t) const {
   } else if(isDatatype() && (t.isTuple() || t.isRecord())) {
     Assert(!isTuple() && !isRecord());// should have been handled above
     return *this == NodeManager::currentNM()->getDatatypeForTupleRecord(t);
+  } else if(isParametricDatatype() && t.isParametricDatatype()) {
+    Assert(getKind() == kind::PARAMETRIC_DATATYPE);
+    Assert(t.getKind() == kind::PARAMETRIC_DATATYPE);
+    if((*this)[0] != t[0] || getNumChildren() != t.getNumChildren()) {
+      return false;
+    }
+    for(size_t i = 1; i < getNumChildren(); ++i) {
+      if(!((*this)[i].isComparableTo(t[i]))) {
+        return false;
+      }
+    }
+    return true;
   }
+
   if(isPredicateSubtype()) {
     return t.isComparableTo(getSubtypeParentType());
   }
@@ -211,6 +237,13 @@ TypeNode TypeNode::getBaseType() const {
     return NodeManager::currentNM()->getDatatypeForTupleRecord(*this);
   } else if (isPredicateSubtype()) {
     return getSubtypeParentType().getBaseType();
+  } else if (isParametricDatatype()) {
+    vector<Type> v;
+    for(size_t i = 1; i < getNumChildren(); ++i) {
+      v.push_back((*this)[i].getBaseType().toType());
+    }
+    TypeNode tn = TypeNode::fromType((*this)[0].getDatatype().getDatatypeType(v));
+    return tn;
   }
   return *this;
 }
@@ -339,7 +372,6 @@ TypeNode TypeNode::leastCommonTypeNode(TypeNode t0, TypeNode t1){
       case kind::ARRAY_TYPE:
       case kind::BITVECTOR_TYPE:
       case kind::SORT_TYPE:
-      case kind::PARAMETRIC_DATATYPE:
       case kind::CONSTRUCTOR_TYPE:
       case kind::SELECTOR_TYPE:
       case kind::TESTER_TYPE:
@@ -444,6 +476,22 @@ TypeNode TypeNode::leastCommonTypeNode(TypeNode t0, TypeNode t1){
         }
         // otherwise no common ancestor
         return TypeNode();
+      case kind::PARAMETRIC_DATATYPE: {
+        if(!t1.isParametricDatatype()) {
+          return TypeNode();
+        }
+        while(t1.getKind() != kind::PARAMETRIC_DATATYPE) {
+          t1 = t1.getSubtypeParentType();
+        }
+        if(t0[0] != t1[0] || t0.getNumChildren() != t1.getNumChildren()) {
+          return TypeNode();
+        }
+        vector<Type> v;
+        for(size_t i = 1; i < t0.getNumChildren(); ++i) {
+          v.push_back(leastCommonTypeNode(t0[i], t1[i]).toType());
+        }
+        return TypeNode::fromType(t0[0].getDatatype().getDatatypeType(v));
+      }
       default:
         Unimplemented("don't have a leastCommonType for types `%s' and `%s'", t0.toString().c_str(), t1.toString().c_str());
         return TypeNode();
index 8ae9b3ae2114c28fd67746259ad949b529dd69bb..4748ca5f9dbca79f5f5d45b91ccb6abbe5769348 100644 (file)
@@ -151,7 +151,8 @@ BUG_TESTS = \
        bug520.smt2 \
        bug521.smt2 \
        bug521.minimized.smt2 \
-       bug522.smt2
+       bug522.smt2 \
+       bug541.smt2
 
 TESTS =        $(SMT_TESTS) $(SMT2_TESTS) $(CVC_TESTS) $(TPTP_TESTS) $(BUG_TESTS)
 
diff --git a/test/regress/regress0/bug541.smt2 b/test/regress/regress0/bug541.smt2
new file mode 100644 (file)
index 0000000..4828239
--- /dev/null
@@ -0,0 +1,6 @@
+; EXPECT: unsat
+(set-logic ALL_SUPPORTED)
+(declare-datatypes (T1 T2) ((Pair (mk-pair (first T1) (second T2)))))
+(assert (= (mk-pair 0.0 0.0) (mk-pair 1.5 2.5)))
+(check-sat)
+(exit)
index d88d72b854efb31633df0495fdd2c64ad5a5e326..0bb98c3f266a46e0689c76845b8165849e0b2605 100644 (file)
@@ -21,6 +21,8 @@
 
 #include "expr/expr.h"
 #include "expr/expr_manager.h"
+#include "expr/expr_manager_scope.h"
+#include "expr/type_node.h"
 
 using namespace CVC4;
 using namespace std;
@@ -28,16 +30,19 @@ using namespace std;
 class DatatypeBlack : public CxxTest::TestSuite {
 
   ExprManager* d_em;
+  ExprManagerScope* d_scope;
 
 public:
 
   void setUp() {
     d_em = new ExprManager();
+    d_scope = new ExprManagerScope(*d_em);
     Debug.on("datatypes");
     Debug.on("groundterms");
   }
 
   void tearDown() {
+    delete d_scope;
     delete d_em;
   }
 
@@ -68,6 +73,7 @@ public:
     TS_ASSERT_THROWS(colorsDT["blue"].getSelector("foo"), IllegalArgumentException);
     TS_ASSERT_THROWS(colorsDT["blue"]["foo"], IllegalArgumentException);
 
+    TS_ASSERT(! colorsType.getDatatype().isParametric());
     TS_ASSERT(colorsType.getDatatype().isFinite());
     TS_ASSERT(colorsType.getDatatype().getCardinality().compare(4) == Cardinality::EQUAL);
     TS_ASSERT(ctor.getType().getCardinality().compare(1) == Cardinality::EQUAL);
@@ -105,6 +111,7 @@ public:
     Expr apply = d_em->mkExpr(kind::APPLY_CONSTRUCTOR, ctor);
     Debug("datatypes") << apply << std::endl;
 
+    TS_ASSERT(! natType.getDatatype().isParametric());
     TS_ASSERT(! natType.getDatatype().isFinite());
     TS_ASSERT(natType.getDatatype().getCardinality().compare(Cardinality::INTEGERS) == Cardinality::EQUAL);
     TS_ASSERT(natType.getDatatype().isWellFounded());
@@ -146,6 +153,7 @@ public:
     TS_ASSERT(treeType.getConstructor("leaf") == ctor);
     TS_ASSERT_THROWS(treeType.getConstructor("leff"), IllegalArgumentException);
 
+    TS_ASSERT(! treeType.getDatatype().isParametric());
     TS_ASSERT(! treeType.getDatatype().isFinite());
     TS_ASSERT(treeType.getDatatype().getCardinality().compare(Cardinality::INTEGERS) == Cardinality::EQUAL);
     TS_ASSERT(treeType.getDatatype().isWellFounded());
@@ -180,6 +188,7 @@ public:
     DatatypeType listType = d_em->mkDatatypeType(list);
     Debug("datatypes") << listType << std::endl;
 
+    TS_ASSERT(! listType.getDatatype().isParametric());
     TS_ASSERT(! listType.getDatatype().isFinite());
     TS_ASSERT(listType.getDatatype().getCardinality().compare(Cardinality::INTEGERS) == Cardinality::EQUAL);
     TS_ASSERT(listType.getDatatype().isWellFounded());
@@ -214,6 +223,7 @@ public:
     DatatypeType listType = d_em->mkDatatypeType(list);
     Debug("datatypes") << listType << std::endl;
 
+    TS_ASSERT(! listType.getDatatype().isParametric());
     TS_ASSERT(! listType.getDatatype().isFinite());
     TS_ASSERT(listType.getDatatype().getCardinality().compare(Cardinality::REALS) == Cardinality::EQUAL);
     TS_ASSERT(listType.getDatatype().isWellFounded());
@@ -378,6 +388,7 @@ public:
       TS_ASSERT((*i).mkGroundTerm( dtts2[0] ).getType() == dtts2[0]);
     }
 
+    TS_ASSERT(! dtts2[1].getDatatype().isParametric());
     TS_ASSERT(! dtts2[1].getDatatype().isFinite());
     TS_ASSERT(dtts2[1].getDatatype().getCardinality().compare(Cardinality::INTEGERS) == Cardinality::EQUAL);
     TS_ASSERT(dtts2[1].getDatatype().isWellFounded());
@@ -408,6 +419,7 @@ public:
     DatatypeType treeType = d_em->mkDatatypeType(tree);
     Debug("datatypes") << treeType << std::endl;
 
+    TS_ASSERT(! treeType.getDatatype().isParametric());
     TS_ASSERT(! treeType.getDatatype().isFinite());
     TS_ASSERT(treeType.getDatatype().getCardinality().compare(Cardinality::INTEGERS) == Cardinality::EQUAL);
     TS_ASSERT(! treeType.getDatatype().isWellFounded());
@@ -423,4 +435,99 @@ public:
     }
   }
 
+  void testParametricDatatype() {
+    vector<Type> v;
+    Type t1, t2;
+    v.push_back(t1 = d_em->mkSort("T1"));
+    v.push_back(t2 = d_em->mkSort("T2"));
+    Datatype pair("pair", v);
+
+    DatatypeConstructor mkpair("mk-pair");
+    mkpair.addArg("first", t1);
+    mkpair.addArg("second", t2);
+    pair.addConstructor(mkpair);
+    DatatypeType pairType = d_em->mkDatatypeType(pair);
+
+    TS_ASSERT(pairType.getDatatype().isParametric());
+    v.clear();
+    v.push_back(d_em->integerType());
+    v.push_back(d_em->integerType());
+    DatatypeType pairIntInt = pairType.getDatatype().getDatatypeType(v);
+    v.clear();
+    v.push_back(d_em->realType());
+    v.push_back(d_em->realType());
+    DatatypeType pairRealReal = pairType.getDatatype().getDatatypeType(v);
+    v.clear();
+    v.push_back(d_em->realType());
+    v.push_back(d_em->integerType());
+    DatatypeType pairRealInt = pairType.getDatatype().getDatatypeType(v);
+    v.clear();
+    v.push_back(d_em->integerType());
+    v.push_back(d_em->realType());
+    DatatypeType pairIntReal = pairType.getDatatype().getDatatypeType(v);
+
+    TS_ASSERT_DIFFERS(pairIntInt, pairRealReal);
+    TS_ASSERT_DIFFERS(pairIntReal, pairRealReal);
+    TS_ASSERT_DIFFERS(pairRealInt, pairRealReal);
+    TS_ASSERT_DIFFERS(pairIntInt, pairIntReal);
+    TS_ASSERT_DIFFERS(pairIntInt, pairRealInt);
+    TS_ASSERT_DIFFERS(pairIntReal, pairRealInt);
+
+    TS_ASSERT_EQUALS(pairRealReal.getBaseType(), pairRealReal);
+    TS_ASSERT_EQUALS(pairRealInt.getBaseType(), pairRealReal);
+    TS_ASSERT_EQUALS(pairIntReal.getBaseType(), pairRealReal);
+    TS_ASSERT_EQUALS(pairIntInt.getBaseType(), pairRealReal);
+
+    TS_ASSERT(pairRealReal.isComparableTo(pairRealReal));
+    TS_ASSERT(pairIntReal.isComparableTo(pairRealReal));
+    TS_ASSERT(pairRealInt.isComparableTo(pairRealReal));
+    TS_ASSERT(pairIntInt.isComparableTo(pairRealReal));
+    TS_ASSERT(pairRealReal.isComparableTo(pairRealInt));
+    TS_ASSERT(pairIntReal.isComparableTo(pairRealInt));
+    TS_ASSERT(pairRealInt.isComparableTo(pairRealInt));
+    TS_ASSERT(pairIntInt.isComparableTo(pairRealInt));
+    TS_ASSERT(pairRealReal.isComparableTo(pairIntReal));
+    TS_ASSERT(pairIntReal.isComparableTo(pairIntReal));
+    TS_ASSERT(pairRealInt.isComparableTo(pairIntReal));
+    TS_ASSERT(pairIntInt.isComparableTo(pairIntReal));
+    TS_ASSERT(pairRealReal.isComparableTo(pairIntInt));
+    TS_ASSERT(pairIntReal.isComparableTo(pairIntInt));
+    TS_ASSERT(pairRealInt.isComparableTo(pairIntInt));
+    TS_ASSERT(pairIntInt.isComparableTo(pairIntInt));
+
+    TS_ASSERT(pairRealReal.isSubtypeOf(pairRealReal));
+    TS_ASSERT(pairIntReal.isSubtypeOf(pairRealReal));
+    TS_ASSERT(pairRealInt.isSubtypeOf(pairRealReal));
+    TS_ASSERT(pairIntInt.isSubtypeOf(pairRealReal));
+    TS_ASSERT(!pairRealReal.isSubtypeOf(pairRealInt));
+    TS_ASSERT(!pairIntReal.isSubtypeOf(pairRealInt));
+    TS_ASSERT(pairRealInt.isSubtypeOf(pairRealInt));
+    TS_ASSERT(pairIntInt.isSubtypeOf(pairRealInt));
+    TS_ASSERT(!pairRealReal.isSubtypeOf(pairIntReal));
+    TS_ASSERT(pairIntReal.isSubtypeOf(pairIntReal));
+    TS_ASSERT(!pairRealInt.isSubtypeOf(pairIntReal));
+    TS_ASSERT(pairIntInt.isSubtypeOf(pairIntReal));
+    TS_ASSERT(!pairRealReal.isSubtypeOf(pairIntInt));
+    TS_ASSERT(!pairIntReal.isSubtypeOf(pairIntInt));
+    TS_ASSERT(!pairRealInt.isSubtypeOf(pairIntInt));
+    TS_ASSERT(pairIntInt.isSubtypeOf(pairIntInt));
+
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealReal), TypeNode::fromType(pairRealReal)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntReal), TypeNode::fromType(pairRealReal)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealInt), TypeNode::fromType(pairRealReal)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntInt), TypeNode::fromType(pairRealReal)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealReal), TypeNode::fromType(pairRealInt)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntReal), TypeNode::fromType(pairRealInt)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealInt), TypeNode::fromType(pairRealInt)).toType(), pairRealInt);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntInt), TypeNode::fromType(pairRealInt)).toType(), pairRealInt);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealReal), TypeNode::fromType(pairIntReal)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntReal), TypeNode::fromType(pairIntReal)).toType(), pairIntReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealInt), TypeNode::fromType(pairIntReal)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntInt), TypeNode::fromType(pairIntReal)).toType(), pairIntReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealReal), TypeNode::fromType(pairIntInt)).toType(), pairRealReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntReal), TypeNode::fromType(pairIntInt)).toType(), pairIntReal);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairRealInt), TypeNode::fromType(pairIntInt)).toType(), pairRealInt);
+    TS_ASSERT_EQUALS(TypeNode::leastCommonTypeNode(TypeNode::fromType(pairIntInt), TypeNode::fromType(pairIntInt)).toType(), pairIntInt);
+  }
+
 };/* class DatatypeBlack */