Adding unit tests for the datatypes python API (#6658)
authoryoni206 <yoni206@users.noreply.github.com>
Thu, 3 Jun 2021 23:57:29 +0000 (16:57 -0700)
committerGitHub <noreply@github.com>
Thu, 3 Jun 2021 23:57:29 +0000 (16:57 -0700)
This commit adds unit tests that are translated from `datatype_api_black.cpp`. One API function is also added to the python API.
This is the last part of the python api unit tests for datatypes.

Co-authored-by: Andres Noetzli <andres.noetzli@gmail.com>
src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/python/unit/api/test_datatype_api.py

index 2ad8cef5c5eca68e4dc31d6adbdf09d3944a748e..fdc1872e7e058494bca9248848b506738c4b4c9a 100644 (file)
@@ -75,6 +75,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         DatatypeSelector operator[](const string& name) except +
         string getName() except +
         Term getConstructorTerm() except +
+        Term getSpecializedConstructorTerm(const Sort& retSort) except +
         Term getTesterTerm() except +
         size_t getNumSelectors() except +
         DatatypeSelector getSelector(const string& name) except +
index 8599a1cd1d47cbab2004ce64a2b7b8fb5b1d2071..3339543f3598d903ddf5ac86b23127e2a0563c0e 100644 (file)
@@ -196,6 +196,11 @@ cdef class DatatypeConstructor:
         term.cterm = self.cdc.getConstructorTerm()
         return term
 
+    def getSpecializedConstructorTerm(self, Sort retSort):
+        cdef Term term = Term(self.solver)
+        term.cterm = self.cdc.getSpecializedConstructorTerm(retSort.csort)
+        return term
+
     def getTesterTerm(self):
         cdef Term term = Term(self.solver)
         term.cterm = self.cdc.getTesterTerm()
index 708942e98b20bd5a02259be50aaa250cb65da194..24a47bd761883e738fb292290bd3cddaa06b743a 100644 (file)
@@ -294,3 +294,237 @@ def test_parametric_datatype(solver):
     assert not pairIntReal.isSubsortOf(pairIntInt)
     assert not pairRealInt.isSubsortOf(pairIntInt)
     assert pairIntInt.isSubsortOf(pairIntInt)
+
+
+def test_datatype_simply_rec(solver):
+    # Create mutual datatypes corresponding to this definition block:
+    #
+    #   DATATYPE
+    #     wlist = leaf(data: list),
+    #     list = cons(car: wlist, cdr: list) | nil,
+    #     ns = elem(ndata: set(wlist)) | elemArray(ndata2: array(list, list))
+    #   END
+
+    # Make unresolved types as placeholders
+    unresTypes = set([])
+    unresWList = solver.mkUninterpretedSort("wlist")
+    unresList = solver.mkUninterpretedSort("list")
+    unresNs = solver.mkUninterpretedSort("ns")
+    unresTypes.add(unresWList)
+    unresTypes.add(unresList)
+    unresTypes.add(unresNs)
+
+    wlist = solver.mkDatatypeDecl("wlist")
+    leaf = solver.mkDatatypeConstructorDecl("leaf")
+    leaf.addSelector("data", unresList)
+    wlist.addConstructor(leaf)
+
+    llist = solver.mkDatatypeDecl("list")
+    cons = solver.mkDatatypeConstructorDecl("cons")
+    cons.addSelector("car", unresWList)
+    cons.addSelector("cdr", unresList)
+    llist.addConstructor(cons)
+    nil = solver.mkDatatypeConstructorDecl("nil")
+    llist.addConstructor(nil)
+
+    ns = solver.mkDatatypeDecl("ns")
+    elem = solver.mkDatatypeConstructorDecl("elem")
+    elem.addSelector("ndata", solver.mkSetSort(unresWList))
+    ns.addConstructor(elem)
+    elemArray = solver.mkDatatypeConstructorDecl("elemArray")
+    elemArray.addSelector("ndata", solver.mkArraySort(unresList, unresList))
+    ns.addConstructor(elemArray)
+
+    dtdecls = [wlist, llist, ns]
+    # this is well-founded and has no nested recursion
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+    assert len(dtsorts) == 3
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[1].getDatatype().isWellFounded()
+    assert dtsorts[2].getDatatype().isWellFounded()
+    assert not dtsorts[0].getDatatype().hasNestedRecursion()
+    assert not dtsorts[1].getDatatype().hasNestedRecursion()
+    assert not dtsorts[2].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #   DATATYPE
+    #     ns2 = elem2(ndata: array(int,ns2)) | nil2
+    #   END
+
+    unresTypes.clear()
+    unresNs2 = solver.mkUninterpretedSort("ns2")
+    unresTypes.add(unresNs2)
+
+    ns2 = solver.mkDatatypeDecl("ns2")
+    elem2 = solver.mkDatatypeConstructorDecl("elem2")
+    elem2.addSelector("ndata",
+                      solver.mkArraySort(solver.getIntegerSort(), unresNs2))
+    ns2.addConstructor(elem2)
+    nil2 = solver.mkDatatypeConstructorDecl("nil2")
+    ns2.addConstructor(nil2)
+
+    dtdecls.clear()
+    dtdecls.append(ns2)
+
+    # this is not well-founded due to non-simple recursion
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+    assert len(dtsorts) == 1
+    assert dtsorts[0].getDatatype()[0][0].getRangeSort().isArray()
+    assert dtsorts[0].getDatatype()[0][0].getRangeSort().getArrayElementSort() \
+        == dtsorts[0]
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #   DATATYPE
+    #     list3 = cons3(car: ns3, cdr: list3) | nil3,
+    #     ns3 = elem3(ndata: set(list3))
+    #   END
+
+    unresTypes.clear()
+    unresNs3 = solver.mkUninterpretedSort("ns3")
+    unresTypes.add(unresNs3)
+    unresList3 = solver.mkUninterpretedSort("list3")
+    unresTypes.add(unresList3)
+
+    list3 = solver.mkDatatypeDecl("list3")
+    cons3 = solver.mkDatatypeConstructorDecl("cons3")
+    cons3.addSelector("car", unresNs3)
+    cons3.addSelector("cdr", unresList3)
+    list3.addConstructor(cons3)
+    nil3 = solver.mkDatatypeConstructorDecl("nil3")
+    list3.addConstructor(nil3)
+
+    ns3 = solver.mkDatatypeDecl("ns3")
+    elem3 = solver.mkDatatypeConstructorDecl("elem3")
+    elem3.addSelector("ndata", solver.mkSetSort(unresList3))
+    ns3.addConstructor(elem3)
+
+    dtdecls.clear()
+    dtdecls.append(list3)
+    dtdecls.append(ns3)
+
+    # both are well-founded and have nested recursion
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+    assert len(dtsorts) == 2
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[1].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
+    assert dtsorts[1].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #   DATATYPE
+    #     list4 = cons(car: set(ns4), cdr: list4) | nil,
+    #     ns4 = elem(ndata: list4)
+    #   END
+    unresTypes.clear()
+    unresNs4 = solver.mkUninterpretedSort("ns4")
+    unresTypes.add(unresNs4)
+    unresList4 = solver.mkUninterpretedSort("list4")
+    unresTypes.add(unresList4)
+
+    list4 = solver.mkDatatypeDecl("list4")
+    cons4 = solver.mkDatatypeConstructorDecl("cons4")
+    cons4.addSelector("car", solver.mkSetSort(unresNs4))
+    cons4.addSelector("cdr", unresList4)
+    list4.addConstructor(cons4)
+    nil4 = solver.mkDatatypeConstructorDecl("nil4")
+    list4.addConstructor(nil4)
+
+    ns4 = solver.mkDatatypeDecl("ns4")
+    elem4 = solver.mkDatatypeConstructorDecl("elem3")
+    elem4.addSelector("ndata", unresList4)
+    ns4.addConstructor(elem4)
+
+    dtdecls.clear()
+    dtdecls.append(list4)
+    dtdecls.append(ns4)
+
+    # both are well-founded and have nested recursion
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+    assert len(dtsorts) == 2
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[1].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
+    assert dtsorts[1].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #   DATATYPE
+    #     list5[X] = cons(car: X, cdr: list5[list5[X]]) | nil
+    #   END
+    unresTypes.clear()
+    unresList5 = solver.mkSortConstructorSort("list5", 1)
+    unresTypes.add(unresList5)
+
+    v = []
+    x = solver.mkParamSort("X")
+    v.append(x)
+    list5 = solver.mkDatatypeDecl("list5", v)
+
+    args = [x]
+    urListX = unresList5.instantiate(args)
+    args[0] = urListX
+    urListListX = unresList5.instantiate(args)
+
+    cons5 = solver.mkDatatypeConstructorDecl("cons5")
+    cons5.addSelector("car", x)
+    cons5.addSelector("cdr", urListListX)
+    list5.addConstructor(cons5)
+    nil5 = solver.mkDatatypeConstructorDecl("nil5")
+    list5.addConstructor(nil5)
+
+    dtdecls.clear()
+    dtdecls.append(list5)
+
+    # well-founded and has nested recursion
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+    assert len(dtsorts) == 1
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
+
+
+def test_datatype_specialized_cons(solver):
+    # Create mutual datatypes corresponding to this definition block:
+    #   DATATYPE
+    #     plist[X] = pcons(car: X, cdr: plist[X]) | pnil
+    #   END
+
+    # Make unresolved types as placeholders
+    unresTypes = set([])
+    unresList = solver.mkSortConstructorSort("plist", 1)
+    unresTypes.add(unresList)
+
+    v = []
+    x = solver.mkParamSort("X")
+    v.append(x)
+    plist = solver.mkDatatypeDecl("plist", v)
+
+    args = [x]
+    urListX = unresList.instantiate(args)
+
+    pcons = solver.mkDatatypeConstructorDecl("pcons")
+    pcons.addSelector("car", x)
+    pcons.addSelector("cdr", urListX)
+    plist.addConstructor(pcons)
+    nil5 = solver.mkDatatypeConstructorDecl("pnil")
+    plist.addConstructor(nil5)
+
+    dtdecls = [plist]
+
+    # make the datatype sorts
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unresTypes)
+    assert len(dtsorts) == 1
+    d = dtsorts[0].getDatatype()
+    nilc = d[0]
+
+    isort = solver.getIntegerSort()
+    iargs = [isort]
+    listInt = dtsorts[0].instantiate(iargs)
+
+    testConsTerm = Term(solver)
+    # get the specialized constructor term for list[Int]
+    testConsTerm = nilc.getSpecializedConstructorTerm(listInt)
+    assert testConsTerm != nilc.getConstructorTerm()
+    # error to get the specialized constructor term for Int
+    with pytest.raises(RuntimeError):
+        nilc.getSpecializedConstructorTerm(isort)