Python API: bugfix + translating tests from cpp unit tests (#6559)
authoryoni206 <yoni206@users.noreply.github.com>
Fri, 28 May 2021 20:28:56 +0000 (13:28 -0700)
committerGitHub <noreply@github.com>
Fri, 28 May 2021 20:28:56 +0000 (20:28 +0000)
This PR fixes an issue in the python API for datatypes, and also introduces tests translated from https://github.com/cvc5/cvc5/blob/master/test/unit/api/datatype_api_black.cpp

The next PR will translate more tests and will also introduce missing functions in the python API for datatypes.

src/api/python/cvc5.pxi
test/api/python/CMakeLists.txt
test/api/python/test_datatype_api.py [deleted file]
test/python/CMakeLists.txt
test/python/unit/api/test_datatype_api.py [new file with mode: 0644]

index 2fac78552166ee1a4766d059991c600f7a1a9f88..b2942e0b35976794aece5077615b81e429566c2f 100644 (file)
@@ -178,7 +178,7 @@ cdef class DatatypeConstructor:
         if isinstance(index, int) and index >= 0:
             ds.cds = self.cdc[(<int?> index)]
         elif isinstance(index, str):
-            ds.cds = self.cdc[(<const string &> name.encode())]
+            ds.cds = self.cdc[(<const string &> index.encode())]
         else:
             raise ValueError("Expecting a non-negative integer or string")
         return ds
index 7f05bf1307efbd89986908aef1c3efa486ddb212..4e77c0eadf01f7018a41817eec609753d4395a62 100644 (file)
@@ -38,5 +38,4 @@ macro(cvc5_add_python_api_test name filename)
 
 endmacro()
 
-cvc5_add_python_api_test(pytest_datatype_api test_datatype_api.py)
 cvc5_add_python_api_test(pytest_to_python_obj test_to_python_obj.py)
diff --git a/test/api/python/test_datatype_api.py b/test/api/python/test_datatype_api.py
deleted file mode 100644 (file)
index 06f4e0f..0000000
+++ /dev/null
@@ -1,184 +0,0 @@
-###############################################################################
-# Top contributors (to current version):
-#   Andres Noetzli, Makai Mann
-#
-# This file is part of the cvc5 project.
-#
-# Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
-# in the top-level source directory and their institutional affiliations.
-# All rights reserved.  See the file COPYING in the top-level source
-# directory for licensing information.
-# #############################################################################
-##
-
-import pytest
-
-import pycvc5
-from pycvc5 import kinds
-
-
-def test_datatype_simply_rec():
-    solver = pycvc5.Solver()
-
-    # Create mutual datatypes corresponding to this definition block:
-    #
-    #   DATATYPE
-    #     wlist = leaf(data: list),
-    #     list = cons(car: wlist, cdr: list) | nil,
-    #     ns = elem(ndata: set(wlist)) | elemArray(ndata2: array(list, list))
-    #   END;
-
-    # Make unresolved types as placeholders
-    unres_wlist = solver.mkUninterpretedSort('wlist')
-    unres_list = solver.mkUninterpretedSort('list')
-    unres_ns = solver.mkUninterpretedSort('ns')
-    unres_types = set([unres_wlist, unres_list, unres_ns])
-
-    wlist = solver.mkDatatypeDecl('wlist')
-    leaf = solver.mkDatatypeConstructorDecl('leaf')
-    leaf.addSelector('data', unres_list)
-    wlist.addConstructor(leaf)
-
-    dlist = solver.mkDatatypeDecl('list')
-    cons = solver.mkDatatypeConstructorDecl('cons')
-    cons.addSelector('car', unres_wlist)
-    cons.addSelector('cdr', unres_list)
-    dlist.addConstructor(cons)
-    nil = solver.mkDatatypeConstructorDecl("nil")
-    dlist.addConstructor(nil)
-
-    ns = solver.mkDatatypeDecl('ns')
-    elem = solver.mkDatatypeConstructorDecl('elem')
-    elem.addSelector('ndata', solver.mkSetSort(unres_wlist))
-    ns.addConstructor(elem)
-    elem_array = solver.mkDatatypeConstructorDecl('elemArray')
-    elem_array.addSelector('ndata', solver.mkArraySort(unres_list, unres_list))
-    ns.addConstructor(elem_array)
-
-    # this is well-founded and has no nested recursion
-    dtdecls = [wlist, dlist, ns]
-    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
-    assert len(dtsorts) == 3
-    assert dtsorts[0].getDatatype().isWellFounded()
-    assert dtsorts[1].getDatatype().isWellFounded()
-    assert dtsorts[2].getDatatype().isWellFounded()
-    assert not dtsorts[0].getDatatype().hasNestedRecursion()
-    assert not dtsorts[1].getDatatype().hasNestedRecursion()
-    assert not dtsorts[2].getDatatype().hasNestedRecursion()
-
-    # Create mutual datatypes corresponding to this definition block:
-    #   DATATYPE
-    #     ns2 = elem2(ndata: array(int,ns2)) | nil2
-    #   END;
-    unres_ns2 = solver.mkUninterpretedSort('ns2')
-    unres_types = set([unres_ns2])
-
-    ns2 = solver.mkDatatypeDecl('ns2')
-    elem2 = solver.mkDatatypeConstructorDecl('elem2')
-    elem2.addSelector('ndata',
-                      solver.mkArraySort(solver.getIntegerSort(), unres_ns2))
-    ns2.addConstructor(elem2)
-    nil2 = solver.mkDatatypeConstructorDecl('nil2')
-    ns2.addConstructor(nil2)
-
-    # this is not well-founded due to non-simple recursion
-    dtdecls = [ns2]
-    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
-    assert len(dtsorts) == 1
-    assert dtsorts[0].getDatatype()[0][0].getRangeSort().isArray()
-    elem_sort = dtsorts[0].getDatatype()[0][0].getRangeSort().getArrayElementSort()
-    assert elem_sort == dtsorts[0]
-    assert dtsorts[0].getDatatype().isWellFounded()
-    assert dtsorts[0].getDatatype().hasNestedRecursion()
-
-    # Create mutual datatypes corresponding to this definition block:
-    #  DATATYPE
-    #    list3 = cons3(car: ns3, cdr: list3) | nil3,
-    #    ns3 = elem3(ndata: set(list3))
-    #  END
-    unres_ns3 = solver.mkUninterpretedSort('ns3')
-    unres_list3 = solver.mkUninterpretedSort('list3')
-    unres_types = set([unres_ns3, unres_list3])
-
-    list3 = solver.mkDatatypeDecl('list3')
-    cons3 = solver.mkDatatypeConstructorDecl('cons3')
-    cons3.addSelector('car', unres_ns3)
-    cons3.addSelector('cdr', unres_list3)
-    list3.addConstructor(cons3)
-    nil3 = solver.mkDatatypeConstructorDecl('nil3')
-    list3.addConstructor(nil3)
-
-    ns3 = solver.mkDatatypeDecl('ns3')
-    elem3 = solver.mkDatatypeConstructorDecl('elem3')
-    elem3.addSelector('ndata', solver.mkSetSort(unres_list3))
-    ns3.addConstructor(elem3)
-
-    # both are well-founded and have nested recursion
-    dtdecls = [list3, ns3]
-    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
-    assert len(dtsorts) == 2
-    assert dtsorts[0].getDatatype().isWellFounded()
-    assert dtsorts[1].getDatatype().isWellFounded()
-    assert dtsorts[0].getDatatype().hasNestedRecursion()
-    assert dtsorts[1].getDatatype().hasNestedRecursion()
-
-    # Create mutual datatypes corresponding to this definition block:
-    #  DATATYPE
-    #    list4 = cons(car: set(ns4), cdr: list4) | nil,
-    #    ns4 = elem(ndata: list4)
-    #  END
-    unres_ns4 = solver.mkUninterpretedSort('ns4')
-    unres_list4 = solver.mkUninterpretedSort('list4')
-    unres_types = set([unres_ns4, unres_list4])
-
-    list4 = solver.mkDatatypeDecl('list4')
-    cons4 = solver.mkDatatypeConstructorDecl('cons4')
-    cons4.addSelector('car', solver.mkSetSort(unres_ns4))
-    cons4.addSelector('cdr', unres_list4)
-    list4.addConstructor(cons4)
-    nil4 = solver.mkDatatypeConstructorDecl('nil4')
-    list4.addConstructor(nil4)
-
-    ns4 = solver.mkDatatypeDecl('ns4')
-    elem4 = solver.mkDatatypeConstructorDecl('elem3')
-    elem4.addSelector('ndata', unres_list4)
-    ns4.addConstructor(elem4)
-
-    # both are well-founded and have nested recursion
-    dtdecls = [list4, ns4]
-    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
-    assert len(dtsorts) == 2
-    assert dtsorts[0].getDatatype().isWellFounded()
-    assert dtsorts[1].getDatatype().isWellFounded()
-    assert dtsorts[0].getDatatype().hasNestedRecursion()
-    assert dtsorts[1].getDatatype().hasNestedRecursion()
-
-    # Create mutual datatypes corresponding to this definition block:
-    #  DATATYPE
-    #    list5[X] = cons(car: X, cdr: list5[list5[X]]) | nil
-    #  END
-    unres_list5 = solver.mkSortConstructorSort('list5', 1)
-    unres_types = set([unres_list5])
-
-    x = solver.mkParamSort('X')
-    v = [x]
-    list5 = solver.mkDatatypeDecl('list5', v)
-
-    args = [x]
-    ur_list_x = unres_list5.instantiate(args)
-    args = [ur_list_x]
-    ur_list_list_x = unres_list5.instantiate(args)
-
-    cons5 = solver.mkDatatypeConstructorDecl('cons5')
-    cons5.addSelector('car', x)
-    cons5.addSelector('cdr', ur_list_list_x)
-    list5.addConstructor(cons5)
-    nil5 = solver.mkDatatypeConstructorDecl('nil5')
-    list5.addConstructor(nil5)
-
-    # well-founded and has nested recursion
-    dtdecls = [list5]
-    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
-    assert len(dtsorts) == 1
-    assert dtsorts[0].getDatatype().isWellFounded()
-    assert dtsorts[0].getDatatype().hasNestedRecursion()
index 6091e327587eb0f81b6b86c4ac27dfe9c8ae5a79..88fd817f2eeba102297b8e62beb7c9b36a309b76 100644 (file)
@@ -30,4 +30,5 @@ endmacro()
 cvc5_add_python_api_test(pytest_solver unit/api/test_solver.py)
 cvc5_add_python_api_test(pytest_sort unit/api/test_sort.py)
 cvc5_add_python_api_test(pytest_term unit/api/test_term.py)
+cvc5_add_python_api_test(pytest_datatype_api unit/api/test_datatype_api.py)
 cvc5_add_python_api_test(pytest_grammar unit/api/test_grammar.py)
diff --git a/test/python/unit/api/test_datatype_api.py b/test/python/unit/api/test_datatype_api.py
new file mode 100644 (file)
index 0000000..f0c1c0e
--- /dev/null
@@ -0,0 +1,286 @@
+###############################################################################
+# Top contributors (to current version):
+#   Yoni Zohar
+#
+# This file is part of the cvc5 project.
+#
+# Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+# in the top-level source directory and their institutional affiliations.
+# All rights reserved.  See the file COPYING in the top-level source
+# directory for licensing information.
+# #############################################################################
+##
+
+import pytest
+import pycvc5
+from pycvc5 import kinds
+from pycvc5 import Sort, Term, DatatypeDecl
+
+
+@pytest.fixture
+def solver():
+    return pycvc5.Solver()
+
+
+def test_mk_datatype_sort(solver):
+    dtypeSpec = solver.mkDatatypeDecl("list")
+    cons = solver.mkDatatypeConstructorDecl("cons")
+    cons.addSelector("head", solver.getIntegerSort())
+    dtypeSpec.addConstructor(cons)
+    nil = solver.mkDatatypeConstructorDecl("nil")
+    dtypeSpec.addConstructor(nil)
+    listSort = solver.mkDatatypeSort(dtypeSpec)
+    d = listSort.getDatatype()
+    consConstr = d[0]
+    nilConstr = d[1]
+    with pytest.raises(RuntimeError):
+        d[2]
+    consConstr.getConstructorTerm()
+    nilConstr.getConstructorTerm()
+
+
+def test_mk_datatype_sorts(solver):
+    # Create two mutual datatypes corresponding to this definition
+    # block:
+    #
+    #   DATATYPE
+    #     tree = node(left: tree, right: tree) | leaf(data: list),
+    #     list = cons(car: tree, cdr: list) | nil
+    #   END
+    #
+
+    #Make unresolved types as placeholders
+    unresTypes = set([])
+    unresTree = solver.mkUninterpretedSort("tree")
+    unresList = solver.mkUninterpretedSort("list")
+    unresTypes.add(unresTree)
+    unresTypes.add(unresList)
+
+    tree = solver.mkDatatypeDecl("tree")
+    node = solver.mkDatatypeConstructorDecl("node")
+    node.addSelector("left", unresTree)
+    node.addSelector("right", unresTree)
+    tree.addConstructor(node)
+
+    leaf = solver.mkDatatypeConstructorDecl("leaf")
+    leaf.addSelector("data", unresList)
+    tree.addConstructor(leaf)
+
+    llist = solver.mkDatatypeDecl("list")
+    cons = solver.mkDatatypeConstructorDecl("cons")
+    cons.addSelector("car", unresTree)
+    cons.addSelector("cdr", unresTree)
+    llist.addConstructor(cons)
+
+    nil = solver.mkDatatypeConstructorDecl("nil")
+    llist.addConstructor(nil)
+
+    dtdecls = []
+    dtdecls.append(tree)
+    dtdecls.append(llist)
+    dtsorts = []
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+    assert len(dtsorts) == len(dtdecls)
+    for i in range(0, len(dtdecls)):
+        assert dtsorts[i].isDatatype()
+        assert not dtsorts[i].getDatatype().isFinite()
+    # verify the resolution was correct
+    dtTree = dtsorts[0].getDatatype()
+    dtcTreeNode = dtTree[0]
+    assert dtcTreeNode.getName() == "node"
+    dtsTreeNodeLeft = dtcTreeNode[0]
+    assert dtsTreeNodeLeft.getName() == "left"
+    # argument type should have resolved to be recursive
+    assert dtsTreeNodeLeft.getRangeSort().isDatatype()
+    assert dtsTreeNodeLeft.getRangeSort() == dtsorts[0]
+
+    # fails due to empty datatype
+    dtdeclsBad = []
+    emptyD = solver.mkDatatypeDecl("emptyD")
+    dtdeclsBad.append(emptyD)
+
+
+def test_datatype_structs(solver):
+    intSort = solver.getIntegerSort()
+    boolSort = solver.getBooleanSort()
+
+    # create datatype sort to test
+    dtypeSpec = solver.mkDatatypeDecl("list")
+    cons = solver.mkDatatypeConstructorDecl("cons")
+    cons.addSelector("head", intSort)
+    cons.addSelectorSelf("tail")
+    nullSort = Sort(solver)
+    with pytest.raises(RuntimeError):
+        cons.addSelector("null", nullSort)
+    dtypeSpec.addConstructor(cons)
+    nil = solver.mkDatatypeConstructorDecl("nil")
+    dtypeSpec.addConstructor(nil)
+    dtypeSort = solver.mkDatatypeSort(dtypeSpec)
+    dt = dtypeSort.getDatatype()
+    assert not dt.isCodatatype()
+    assert not dt.isTuple()
+    assert not dt.isRecord()
+    assert not dt.isFinite()
+    assert dt.isWellFounded()
+    # get constructor
+    dcons = dt[0]
+    consTerm = dcons.getConstructorTerm()
+    assert dcons.getNumSelectors() == 2
+
+    # create datatype sort to test
+    dtypeSpecEnum = solver.mkDatatypeDecl("enum")
+    ca = solver.mkDatatypeConstructorDecl("A")
+    dtypeSpecEnum.addConstructor(ca)
+    cb = solver.mkDatatypeConstructorDecl("B")
+    dtypeSpecEnum.addConstructor(cb)
+    cc = solver.mkDatatypeConstructorDecl("C")
+    dtypeSpecEnum.addConstructor(cc)
+    dtypeSortEnum = solver.mkDatatypeSort(dtypeSpecEnum)
+    dtEnum = dtypeSortEnum.getDatatype()
+    assert not dtEnum.isTuple()
+    assert dtEnum.isFinite()
+
+    # create codatatype
+    dtypeSpecStream = solver.mkDatatypeDecl("stream", True)
+    consStream = solver.mkDatatypeConstructorDecl("cons")
+    consStream.addSelector("head", intSort)
+    consStream.addSelectorSelf("tail")
+    dtypeSpecStream.addConstructor(consStream)
+    dtypeSortStream = solver.mkDatatypeSort(dtypeSpecStream)
+    dtStream = dtypeSortStream.getDatatype()
+    assert dtStream.isCodatatype()
+    assert not dtStream.isFinite()
+    # codatatypes may be well-founded
+    assert dtStream.isWellFounded()
+
+    # create tuple
+    tupSort = solver.mkTupleSort([boolSort])
+    dtTuple = tupSort.getDatatype()
+    assert dtTuple.isTuple()
+    assert not dtTuple.isRecord()
+    assert dtTuple.isFinite()
+    assert dtTuple.isWellFounded()
+
+    # create record
+    fields = [("b", boolSort), ("i", intSort)]
+    recSort = solver.mkRecordSort(fields)
+    assert recSort.isDatatype()
+    dtRecord = recSort.getDatatype()
+    assert not dtRecord.isTuple()
+    assert dtRecord.isRecord()
+    assert not dtRecord.isFinite()
+    assert dtRecord.isWellFounded()
+
+
+def test_datatype_names(solver):
+    intSort = solver.getIntegerSort()
+
+    # create datatype sort to test
+    dtypeSpec = solver.mkDatatypeDecl("list")
+    cons = solver.mkDatatypeConstructorDecl("cons")
+    cons.addSelector("head", intSort)
+    cons.addSelectorSelf("tail")
+    dtypeSpec.addConstructor(cons)
+    nil = solver.mkDatatypeConstructorDecl("nil")
+    dtypeSpec.addConstructor(nil)
+    dtypeSort = solver.mkDatatypeSort(dtypeSpec)
+    dt = dtypeSort.getDatatype()
+    dt.getConstructor("nil")
+    dt["cons"]
+    with pytest.raises(RuntimeError):
+        dt.getConstructor("head")
+    with pytest.raises(RuntimeError):
+        dt.getConstructor("")
+
+    dcons = dt[0]
+    assert dcons.getName() == "cons"
+    dcons.getSelector("head")
+    dcons["tail"]
+    with pytest.raises(RuntimeError):
+        dcons.getSelector("cons")
+
+    # get selector
+    dselTail = dcons[1]
+    assert dselTail.getName() == "tail"
+    assert dselTail.getRangeSort() == dtypeSort
+
+    # get selector from datatype
+    dt.getSelector("head")
+    with pytest.raises(RuntimeError):
+        dt.getSelector("cons")
+
+
+def test_parametric_datatype(solver):
+    v = []
+    t1 = solver.mkParamSort("T1")
+    t2 = solver.mkParamSort("T2")
+    v.append(t1)
+    v.append(t2)
+    pairSpec = solver.mkDatatypeDecl("pair", v)
+
+    mkpair = solver.mkDatatypeConstructorDecl("mk-pair")
+    mkpair.addSelector("first", t1)
+    mkpair.addSelector("second", t2)
+    pairSpec.addConstructor(mkpair)
+
+    pairType = solver.mkDatatypeSort(pairSpec)
+
+    assert pairType.getDatatype().isParametric()
+
+    v.clear()
+    v.append(solver.getIntegerSort())
+    v.append(solver.getIntegerSort())
+    pairIntInt = pairType.instantiate(v)
+    v.clear()
+    v.append(solver.getRealSort())
+    v.append(solver.getRealSort())
+    pairRealReal = pairType.instantiate(v)
+    v.clear()
+    v.append(solver.getRealSort())
+    v.append(solver.getIntegerSort())
+    pairRealInt = pairType.instantiate(v)
+    v.clear()
+    v.append(solver.getIntegerSort())
+    v.append(solver.getRealSort())
+    pairIntReal = pairType.instantiate(v)
+
+    assert pairIntInt != pairRealReal
+    assert pairIntReal != pairRealReal
+    assert pairRealInt != pairRealReal
+    assert pairIntInt != pairIntReal
+    assert pairIntInt != pairRealInt
+    assert pairIntReal != pairRealInt
+
+    assert pairRealReal.isComparableTo(pairRealReal)
+    assert not pairIntReal.isComparableTo(pairRealReal)
+    assert not pairRealInt.isComparableTo(pairRealReal)
+    assert not pairIntInt.isComparableTo(pairRealReal)
+    assert not pairRealReal.isComparableTo(pairRealInt)
+    assert not pairIntReal.isComparableTo(pairRealInt)
+    assert pairRealInt.isComparableTo(pairRealInt)
+    assert not pairIntInt.isComparableTo(pairRealInt)
+    assert not pairRealReal.isComparableTo(pairIntReal)
+    assert pairIntReal.isComparableTo(pairIntReal)
+    assert not pairRealInt.isComparableTo(pairIntReal)
+    assert not pairIntInt.isComparableTo(pairIntReal)
+    assert not pairRealReal.isComparableTo(pairIntInt)
+    assert not pairIntReal.isComparableTo(pairIntInt)
+    assert not pairRealInt.isComparableTo(pairIntInt)
+    assert pairIntInt.isComparableTo(pairIntInt)
+
+    assert pairRealReal.isSubsortOf(pairRealReal)
+    assert not pairIntReal.isSubsortOf(pairRealReal)
+    assert not pairRealInt.isSubsortOf(pairRealReal)
+    assert not pairIntInt.isSubsortOf(pairRealReal)
+    assert not pairRealReal.isSubsortOf(pairRealInt)
+    assert not pairIntReal.isSubsortOf(pairRealInt)
+    assert pairRealInt.isSubsortOf(pairRealInt)
+    assert not pairIntInt.isSubsortOf(pairRealInt)
+    assert not pairRealReal.isSubsortOf(pairIntReal)
+    assert pairIntReal.isSubsortOf(pairIntReal)
+    assert not pairRealInt.isSubsortOf(pairIntReal)
+    assert not pairIntInt.isSubsortOf(pairIntReal)
+    assert not pairRealReal.isSubsortOf(pairIntInt)
+    assert not pairIntReal.isSubsortOf(pairIntInt)
+    assert not pairRealInt.isSubsortOf(pairIntInt)
+    assert pairIntInt.isSubsortOf(pairIntInt)