From: yoni206 Date: Tue, 1 Jun 2021 21:24:43 +0000 (-0700) Subject: Some additions to the datatypes python API (#6640) X-Git-Tag: cvc5-1.0.0~1667 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=172573dba45f7d231ec06a3a3992f41cf794b75e;p=cvc5.git Some additions to the datatypes python API (#6640) This commit makes the following additions, in order to sync the python API with the cpp API. 1. adding `getName` functions to datatypes related classes 2. allowing `mkDatatypeSorts` with 1 or 2 arguments (previously allowed only 2). 3. In case there is a second argument to `mkDatatypeSorts`, we make sure it is a set. 4. Corresponding changes to the tests. --- diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd index 87a646666..fdcbfa997 100644 --- a/src/api/python/cvc5.pxd +++ b/src/api/python/cvc5.pxd @@ -49,6 +49,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": DatatypeConstructor getConstructor(const string& name) except + Term getConstructorTerm(const string& name) except + DatatypeSelector getSelector(const string& name) except + + string getName() except + size_t getNumConstructors() except + bint isParametric() except + bint isCodatatype() except + @@ -100,6 +101,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": size_t getNumConstructors() except + bint isParametric() except + string toString() except + + string getName() except + cdef cppclass DatatypeSelector: diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi index cd9e91e51..25ded76bb 100644 --- a/src/api/python/cvc5.pxi +++ b/src/api/python/cvc5.pxi @@ -1,4 +1,4 @@ -from collections import defaultdict +from collections import defaultdict, Set from fractions import Fraction import sys @@ -6,7 +6,7 @@ from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t from libc.stddef cimport wchar_t from libcpp.pair cimport pair -from libcpp.set cimport set +from libcpp.set cimport set as c_set from libcpp.string cimport string from libcpp.vector cimport vector @@ -123,6 +123,9 @@ cdef class Datatype: ds.cds = self.cd.getSelector(name.encode()) return ds + def getName(self): + return self.cd.getName().decode() + def getNumConstructors(self): """:return: number of constructors.""" return self.cd.getNumConstructors() @@ -259,6 +262,9 @@ cdef class DatatypeDecl: def isParametric(self): return self.cdd.isParametric() + def getName(self): + return self.cdd.getName().decode() + def __str__(self): return self.cdd.toString().decode() @@ -502,19 +508,24 @@ cdef class Solver: sort.csort = self.csolver.mkDatatypeSort(dtypedecl.cdd) return sort - def mkDatatypeSorts(self, list dtypedecls, unresolvedSorts): - sorts = [] + def mkDatatypeSorts(self, list dtypedecls, unresolvedSorts = None): + """:return: A list of datatype sorts that correspond to dtypedecls and unresolvedSorts""" + if unresolvedSorts == None: + unresolvedSorts = set([]) + else: + assert isinstance(unresolvedSorts, Set) + sorts = [] cdef vector[c_DatatypeDecl] decls for decl in dtypedecls: decls.push_back(( decl).cdd) - cdef set[c_Sort] usorts + cdef c_set[c_Sort] usorts for usort in unresolvedSorts: usorts.insert(( usort).csort) csorts = self.csolver.mkDatatypeSorts( - decls, usorts) + decls, usorts) for csort in csorts: sort = Sort(self) sort.csort = csort diff --git a/test/python/unit/api/test_datatype_api.py b/test/python/unit/api/test_datatype_api.py index f0c1c0ea9..708942e98 100644 --- a/test/python/unit/api/test_datatype_api.py +++ b/test/python/unit/api/test_datatype_api.py @@ -84,6 +84,7 @@ def test_mk_datatype_sorts(solver): for i in range(0, len(dtdecls)): assert dtsorts[i].isDatatype() assert not dtsorts[i].getDatatype().isFinite() + assert dtsorts[i].getDatatype().getName() == dtdecls[i].getName() # verify the resolution was correct dtTree = dtsorts[0].getDatatype() dtcTreeNode = dtTree[0] @@ -98,6 +99,8 @@ def test_mk_datatype_sorts(solver): dtdeclsBad = [] emptyD = solver.mkDatatypeDecl("emptyD") dtdeclsBad.append(emptyD) + with pytest.raises(RuntimeError): + solver.mkDatatypeSorts(dtdeclsBad) def test_datatype_structs(solver): @@ -177,6 +180,8 @@ def test_datatype_names(solver): # create datatype sort to test dtypeSpec = solver.mkDatatypeDecl("list") + dtypeSpec.getName() + assert dtypeSpec.getName() == "list" cons = solver.mkDatatypeConstructorDecl("cons") cons.addSelector("head", intSort) cons.addSelectorSelf("tail") @@ -185,6 +190,7 @@ def test_datatype_names(solver): dtypeSpec.addConstructor(nil) dtypeSort = solver.mkDatatypeSort(dtypeSpec) dt = dtypeSort.getDatatype() + assert dt.getName() == "list" dt.getConstructor("nil") dt["cons"] with pytest.raises(RuntimeError): @@ -209,6 +215,10 @@ def test_datatype_names(solver): with pytest.raises(RuntimeError): dt.getSelector("cons") + # possible to construct null datatype declarations if not using mkDatatypeDecl + with pytest.raises(RuntimeError): + DatatypeDecl(solver).getName() + def test_parametric_datatype(solver): v = [] diff --git a/test/python/unit/api/test_solver.py b/test/python/unit/api/test_solver.py index c7224022e..67174ad8e 100644 --- a/test/python/unit/api/test_solver.py +++ b/test/python/unit/api/test_solver.py @@ -143,19 +143,19 @@ def test_mk_datatype_sorts(solver): dtypeSpec2.addConstructor(nil2) decls = [dtypeSpec1, dtypeSpec2] - solver.mkDatatypeSorts(decls, []) + solver.mkDatatypeSorts(decls, set([])) with pytest.raises(RuntimeError): - slv.mkDatatypeSorts(decls, []) + slv.mkDatatypeSorts(decls, set([])) throwsDtypeSpec = solver.mkDatatypeDecl("list") throwsDecls = [throwsDtypeSpec] with pytest.raises(RuntimeError): - solver.mkDatatypeSorts(throwsDecls, []) + solver.mkDatatypeSorts(throwsDecls, set([])) # with unresolved sorts unresList = solver.mkUninterpretedSort("ulist") - unresSorts = [unresList] + unresSorts = set([unresList]) ulist = solver.mkDatatypeDecl("ulist") ucons = solver.mkDatatypeConstructorDecl("ucons") ucons.addSelector("car", unresList)