Some additions to the datatypes python API (#6640)
authoryoni206 <yoni206@users.noreply.github.com>
Tue, 1 Jun 2021 21:24:43 +0000 (14:24 -0700)
committerGitHub <noreply@github.com>
Tue, 1 Jun 2021 21:24:43 +0000 (21:24 +0000)
This commit makes the following additions, in order to sync the python API with the cpp API.

1. adding `getName` functions to datatypes related classes
2. allowing `mkDatatypeSorts` with 1 or 2 arguments (previously allowed only 2).
3. In case there is a second argument to `mkDatatypeSorts`, we make sure it is a set.
4. Corresponding changes to the tests.

src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/python/unit/api/test_datatype_api.py
test/python/unit/api/test_solver.py

index 87a646666464391eebd71798498242d9172ef602..fdcbfa9979750236762cb83dd86ed542f012f4e7 100644 (file)
@@ -49,6 +49,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         DatatypeConstructor getConstructor(const string& name) except +
         Term getConstructorTerm(const string& name) except +
         DatatypeSelector getSelector(const string& name) except +
+        string getName() except +
         size_t getNumConstructors() except +
         bint isParametric() except +
         bint isCodatatype() except +
@@ -100,6 +101,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         size_t getNumConstructors() except +
         bint isParametric() except +
         string toString() except +
+        string getName() except +
 
 
     cdef cppclass DatatypeSelector:
index cd9e91e51c720ac1346380db67b4ff61139b9f00..25ded76bbdb2b6a1659dc498cf8cae8994fa3d3b 100644 (file)
@@ -1,4 +1,4 @@
-from collections import defaultdict
+from collections import defaultdict, Set
 from fractions import Fraction
 import sys
 
@@ -6,7 +6,7 @@ from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t
 from libc.stddef cimport wchar_t
 
 from libcpp.pair cimport pair
-from libcpp.set cimport set
+from libcpp.set cimport set as c_set
 from libcpp.string cimport string
 from libcpp.vector cimport vector
 
@@ -123,6 +123,9 @@ cdef class Datatype:
         ds.cds = self.cd.getSelector(name.encode())
         return ds
 
+    def getName(self):
+        return self.cd.getName().decode()
+
     def getNumConstructors(self):
         """:return: number of constructors."""
         return self.cd.getNumConstructors()
@@ -259,6 +262,9 @@ cdef class DatatypeDecl:
     def isParametric(self):
         return self.cdd.isParametric()
 
+    def getName(self):
+        return self.cdd.getName().decode()
+
     def __str__(self):
         return self.cdd.toString().decode()
 
@@ -502,19 +508,24 @@ cdef class Solver:
         sort.csort = self.csolver.mkDatatypeSort(dtypedecl.cdd)
         return sort
 
-    def mkDatatypeSorts(self, list dtypedecls, unresolvedSorts):
-        sorts = []
+    def mkDatatypeSorts(self, list dtypedecls, unresolvedSorts = None):
+        """:return: A list of datatype sorts that correspond to dtypedecls and unresolvedSorts"""
+        if unresolvedSorts == None:
+            unresolvedSorts = set([])
+        else:
+            assert isinstance(unresolvedSorts, Set)
 
+        sorts = []
         cdef vector[c_DatatypeDecl] decls
         for decl in dtypedecls:
             decls.push_back((<DatatypeDecl?> decl).cdd)
 
-        cdef set[c_Sort] usorts
+        cdef c_set[c_Sort] usorts
         for usort in unresolvedSorts:
             usorts.insert((<Sort?> usort).csort)
 
         csorts = self.csolver.mkDatatypeSorts(
-            <const vector[c_DatatypeDecl]&> decls, <const set[c_Sort]&> usorts)
+            <const vector[c_DatatypeDecl]&> decls, <const c_set[c_Sort]&> usorts)
         for csort in csorts:
           sort = Sort(self)
           sort.csort = csort
index f0c1c0ea9adc0e1dfd5242a1989924f068720d54..708942e98b20bd5a02259be50aaa250cb65da194 100644 (file)
@@ -84,6 +84,7 @@ def test_mk_datatype_sorts(solver):
     for i in range(0, len(dtdecls)):
         assert dtsorts[i].isDatatype()
         assert not dtsorts[i].getDatatype().isFinite()
+        assert dtsorts[i].getDatatype().getName() == dtdecls[i].getName()
     # verify the resolution was correct
     dtTree = dtsorts[0].getDatatype()
     dtcTreeNode = dtTree[0]
@@ -98,6 +99,8 @@ def test_mk_datatype_sorts(solver):
     dtdeclsBad = []
     emptyD = solver.mkDatatypeDecl("emptyD")
     dtdeclsBad.append(emptyD)
+    with pytest.raises(RuntimeError):
+        solver.mkDatatypeSorts(dtdeclsBad)
 
 
 def test_datatype_structs(solver):
@@ -177,6 +180,8 @@ def test_datatype_names(solver):
 
     # create datatype sort to test
     dtypeSpec = solver.mkDatatypeDecl("list")
+    dtypeSpec.getName()
+    assert dtypeSpec.getName() == "list"
     cons = solver.mkDatatypeConstructorDecl("cons")
     cons.addSelector("head", intSort)
     cons.addSelectorSelf("tail")
@@ -185,6 +190,7 @@ def test_datatype_names(solver):
     dtypeSpec.addConstructor(nil)
     dtypeSort = solver.mkDatatypeSort(dtypeSpec)
     dt = dtypeSort.getDatatype()
+    assert dt.getName() == "list"
     dt.getConstructor("nil")
     dt["cons"]
     with pytest.raises(RuntimeError):
@@ -209,6 +215,10 @@ def test_datatype_names(solver):
     with pytest.raises(RuntimeError):
         dt.getSelector("cons")
 
+    # possible to construct null datatype declarations if not using mkDatatypeDecl
+    with pytest.raises(RuntimeError):
+        DatatypeDecl(solver).getName()
+
 
 def test_parametric_datatype(solver):
     v = []
index c7224022e1332011bf2318d53c8f7f735b173030..67174ad8e67fa3852929f32d9b480b31a0ec3ac7 100644 (file)
@@ -143,19 +143,19 @@ def test_mk_datatype_sorts(solver):
     dtypeSpec2.addConstructor(nil2)
 
     decls = [dtypeSpec1, dtypeSpec2]
-    solver.mkDatatypeSorts(decls, [])
+    solver.mkDatatypeSorts(decls, set([]))
 
     with pytest.raises(RuntimeError):
-        slv.mkDatatypeSorts(decls, [])
+        slv.mkDatatypeSorts(decls, set([]))
 
     throwsDtypeSpec = solver.mkDatatypeDecl("list")
     throwsDecls = [throwsDtypeSpec]
     with pytest.raises(RuntimeError):
-        solver.mkDatatypeSorts(throwsDecls, [])
+        solver.mkDatatypeSorts(throwsDecls, set([]))
 
     # with unresolved sorts
     unresList = solver.mkUninterpretedSort("ulist")
-    unresSorts = [unresList]
+    unresSorts = set([unresList])
     ulist = solver.mkDatatypeDecl("ulist")
     ucons = solver.mkDatatypeConstructorDecl("ucons")
     ucons.addSelector("car", unresList)