[Python API] Add missing methods to Datatype/Term (#4998)
authorAndres Noetzli <andres.noetzli@gmail.com>
Wed, 2 Sep 2020 21:02:26 +0000 (14:02 -0700)
committerGitHub <noreply@github.com>
Wed, 2 Sep 2020 21:02:26 +0000 (16:02 -0500)
Fixes #4942. The Python API was missing some methods related to
datatypes. Most importantly, it was missing mkDatatypeSorts, which
meant that datatypes with unresolved placeholders could not be resolved.
This commit adds missing methods and ports the corresponding tests of
datatype_api_black.h to Python. The commit also adds support for
__getitem__ in Term.

src/api/cvc4cpp.cpp
src/api/cvc4cpp.h
src/api/python/cvc4.pxd
src/api/python/cvc4.pxi
test/unit/api/python/test_datatype_api.py [new file with mode: 0644]
test/unit/api/python/test_term.py

index 6c39bfb912faeeb0816c580c2aab7ec88fa63db8..5b33844392190ef0d70f417b392a294468b1bc58 100644 (file)
@@ -3148,8 +3148,8 @@ Term Solver::mkTermHelper(Kind kind, const std::vector<Term>& children) const
 }
 
 std::vector<Sort> Solver::mkDatatypeSortsInternal(
-    std::vector<DatatypeDecl>& dtypedecls,
-    std::set<Sort>& unresolvedSorts) const
+    const std::vector<DatatypeDecl>& dtypedecls,
+    const std::set<Sort>& unresolvedSorts) const
 {
   NodeManagerScope scope(getNodeManager());
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
@@ -3367,8 +3367,9 @@ std::vector<Sort> Solver::mkDatatypeSorts(
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
 
-std::vector<Sort> Solver::mkDatatypeSorts(std::vector<DatatypeDecl>& dtypedecls,
-                                          std::set<Sort>& unresolvedSorts) const
+std::vector<Sort> Solver::mkDatatypeSorts(
+    const std::vector<DatatypeDecl>& dtypedecls,
+    const std::set<Sort>& unresolvedSorts) const
 {
   CVC4_API_SOLVER_TRY_CATCH_BEGIN;
   return mkDatatypeSortsInternal(dtypedecls, unresolvedSorts);
index d92660920b95b0f20d7c35fc7f414fbddd86e262..acf34abf921216cccc600120a519115db80ce1e6 100644 (file)
@@ -2254,8 +2254,9 @@ class CVC4_PUBLIC Solver
    * @param unresolvedSorts the list of unresolved sorts
    * @return the datatype sorts
    */
-  std::vector<Sort> mkDatatypeSorts(std::vector<DatatypeDecl>& dtypedecls,
-                                    std::set<Sort>& unresolvedSorts) const;
+  std::vector<Sort> mkDatatypeSorts(
+      const std::vector<DatatypeDecl>& dtypedecls,
+      const std::set<Sort>& unresolvedSorts) const;
 
   /**
    * Create function sort.
@@ -3353,8 +3354,8 @@ class CVC4_PUBLIC Solver
    * @return the datatype sorts
    */
   std::vector<Sort> mkDatatypeSortsInternal(
-      std::vector<DatatypeDecl>& dtypedecls,
-      std::set<Sort>& unresolvedSorts) const;
+      const std::vector<DatatypeDecl>& dtypedecls,
+      const std::set<Sort>& unresolvedSorts) const;
 
   /**
    * Synthesize n-ary function following specified syntactic constraints.
index 841fbb44d489aa1582ffad40bbf054da46afd96f..76dcc53179928d68b53a5de51295bc63b466b6a7 100644 (file)
@@ -1,6 +1,7 @@
 # import dereference and increment operators
 from cython.operator cimport dereference as deref, preincrement as inc
 from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t
+from libcpp.set cimport set
 from libcpp.string cimport string
 from libcpp.vector cimport vector
 from libcpp.pair cimport pair
@@ -27,6 +28,12 @@ cdef extern from "api/cvc4cpp.h" namespace "CVC4::api":
         Term getConstructorTerm(const string& name) except +
         size_t getNumConstructors() except +
         bint isParametric() except +
+        bint isCodatatype() except +
+        bint isTuple() except +
+        bint isRecord() except +
+        bint isFinite() except +
+        bint isWellFounded() except +
+        bint hasNestedRecursion() except +
         string toString() except +
         cppclass const_iterator:
             const_iterator() except +
@@ -127,6 +134,8 @@ cdef extern from "api/cvc4cpp.h" namespace "CVC4::api":
         Sort mkBitVectorSort(uint32_t size) except +
         Sort mkFloatingPointSort(uint32_t exp, uint32_t sig) except +
         Sort mkDatatypeSort(DatatypeDecl dtypedecl) except +
+        vector[Sort] mkDatatypeSorts(const vector[DatatypeDecl]& dtypedecls,
+                                     const set[Sort]& unresolvedSorts) except +
         Sort mkFunctionSort(Sort domain, Sort codomain) except +
         Sort mkFunctionSort(const vector[Sort]& sorts, Sort codomain) except +
         Sort mkParamSort(const string& symbol) except +
@@ -313,6 +322,7 @@ cdef extern from "api/cvc4cpp.h" namespace "CVC4::api":
         Term()
         bint operator==(const Term&) except +
         bint operator!=(const Term&) except +
+        Term operator[](size_t idx) except +
         Kind getKind() except +
         Sort getSort() except +
         Term substitute(const vector[Term] es, const vector[Term] & reps) except +
index 3caead057f835431cead6c8dfca0a37edab1df85..8c4bfe5e51429cafeb4206801fe4850b900d5a0e 100644 (file)
@@ -3,6 +3,7 @@ import sys
 from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t
 
 from libcpp.pair cimport pair
+from libcpp.set cimport set
 from libcpp.string cimport string
 from libcpp.vector cimport vector
 
@@ -110,6 +111,24 @@ cdef class Datatype:
     def isParametric(self):
         return self.cd.isParametric()
 
+    def isCodatatype(self):
+        return self.cd.isCodatatype()
+
+    def isTuple(self):
+        return self.cd.isTuple()
+
+    def isRecord(self):
+        return self.cd.isRecord()
+
+    def isFinite(self):
+        return self.cd.isFinite()
+
+    def isWellFounded(self):
+        return self.cd.isWellFounded()
+
+    def hasNestedRecursion(self):
+        return self.cd.hasNestedRecursion()
+
     def __str__(self):
         return self.cd.toString().decode()
 
@@ -449,6 +468,26 @@ cdef class Solver:
         sort.csort = self.csolver.mkDatatypeSort(dtypedecl.cdd)
         return sort
 
+    def mkDatatypeSorts(self, list dtypedecls, unresolvedSorts):
+        sorts = []
+
+        cdef vector[c_DatatypeDecl] decls
+        for decl in dtypedecls:
+            decls.push_back((<DatatypeDecl?> decl).cdd)
+
+        cdef set[c_Sort] usorts
+        for usort in unresolvedSorts:
+            usorts.insert((<Sort?> usort).csort)
+
+        csorts = self.csolver.mkDatatypeSorts(
+            <const vector[c_DatatypeDecl]&> decls, <const set[c_Sort]&> usorts)
+        for csort in csorts:
+          sort = Sort(self)
+          sort.csort = csort
+          sorts.append(sort)
+
+        return sorts
+
     def mkFunctionSort(self, sorts, Sort codomain):
 
         cdef Sort sort = Sort(self)
@@ -1354,6 +1393,14 @@ cdef class Term:
     def __ne__(self, Term other):
         return self.cterm != other.cterm
 
+    def __getitem__(self, int index):
+        cdef Term term = Term(self.solver)
+        if index >= 0:
+            term.cterm = self.cterm[index]
+        else:
+            raise ValueError("Expecting a non-negative integer or string")
+        return term
+
     def __str__(self):
         return self.cterm.toString().decode()
 
diff --git a/test/unit/api/python/test_datatype_api.py b/test/unit/api/python/test_datatype_api.py
new file mode 100644 (file)
index 0000000..a5499ff
--- /dev/null
@@ -0,0 +1,171 @@
+import pytest
+
+import pycvc4
+from pycvc4 import kinds
+
+
+def test_datatype_simply_rec():
+    solver = pycvc4.Solver()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #
+    #   DATATYPE
+    #     wlist = leaf(data: list),
+    #     list = cons(car: wlist, cdr: list) | nil,
+    #     ns = elem(ndata: set(wlist)) | elemArray(ndata2: array(list, list))
+    #   END;
+
+    # Make unresolved types as placeholders
+    unres_wlist = solver.mkUninterpretedSort('wlist')
+    unres_list = solver.mkUninterpretedSort('list')
+    unres_ns = solver.mkUninterpretedSort('ns')
+    unres_types = set([unres_wlist, unres_list, unres_ns])
+
+    wlist = solver.mkDatatypeDecl('wlist')
+    leaf = solver.mkDatatypeConstructorDecl('leaf')
+    leaf.addSelector('data', unres_list)
+    wlist.addConstructor(leaf)
+
+    dlist = solver.mkDatatypeDecl('list')
+    cons = solver.mkDatatypeConstructorDecl('cons')
+    cons.addSelector('car', unres_wlist)
+    cons.addSelector('cdr', unres_list)
+    dlist.addConstructor(cons)
+    nil = solver.mkDatatypeConstructorDecl("nil")
+    dlist.addConstructor(nil)
+
+    ns = solver.mkDatatypeDecl('ns')
+    elem = solver.mkDatatypeConstructorDecl('elem')
+    elem.addSelector('ndata', solver.mkSetSort(unres_wlist))
+    ns.addConstructor(elem)
+    elem_array = solver.mkDatatypeConstructorDecl('elemArray')
+    elem_array.addSelector('ndata', solver.mkArraySort(unres_list, unres_list))
+    ns.addConstructor(elem_array)
+
+    # this is well-founded and has no nested recursion
+    dtdecls = [wlist, dlist, ns]
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
+    assert len(dtsorts) == 3
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[1].getDatatype().isWellFounded()
+    assert dtsorts[2].getDatatype().isWellFounded()
+    assert not dtsorts[0].getDatatype().hasNestedRecursion()
+    assert not dtsorts[1].getDatatype().hasNestedRecursion()
+    assert not dtsorts[2].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #   DATATYPE
+    #     ns2 = elem2(ndata: array(int,ns2)) | nil2
+    #   END;
+    unres_ns2 = solver.mkUninterpretedSort('ns2')
+    unres_types = set([unres_ns2])
+
+    ns2 = solver.mkDatatypeDecl('ns2')
+    elem2 = solver.mkDatatypeConstructorDecl('elem2')
+    elem2.addSelector('ndata',
+                      solver.mkArraySort(solver.getIntegerSort(), unres_ns2))
+    ns2.addConstructor(elem2)
+    nil2 = solver.mkDatatypeConstructorDecl('nil2')
+    ns2.addConstructor(nil2)
+
+    # this is not well-founded due to non-simple recursion
+    dtdecls = [ns2]
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
+    assert len(dtsorts) == 1
+    assert dtsorts[0].getDatatype()[0][0].getRangeSort().isArray()
+    elem_sort = dtsorts[0].getDatatype()[0][0].getRangeSort().getArrayElementSort()
+    assert elem_sort == dtsorts[0]
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #  DATATYPE
+    #    list3 = cons3(car: ns3, cdr: list3) | nil3,
+    #    ns3 = elem3(ndata: set(list3))
+    #  END
+    unres_ns3 = solver.mkUninterpretedSort('ns3')
+    unres_list3 = solver.mkUninterpretedSort('list3')
+    unres_types = set([unres_ns3, unres_list3])
+
+    list3 = solver.mkDatatypeDecl('list3')
+    cons3 = solver.mkDatatypeConstructorDecl('cons3')
+    cons3.addSelector('car', unres_ns3)
+    cons3.addSelector('cdr', unres_list3)
+    list3.addConstructor(cons3)
+    nil3 = solver.mkDatatypeConstructorDecl('nil3')
+    list3.addConstructor(nil3)
+
+    ns3 = solver.mkDatatypeDecl('ns3')
+    elem3 = solver.mkDatatypeConstructorDecl('elem3')
+    elem3.addSelector('ndata', solver.mkSetSort(unres_list3))
+    ns3.addConstructor(elem3)
+
+    # both are well-founded and have nested recursion
+    dtdecls = [list3, ns3]
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
+    assert len(dtsorts) == 2
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[1].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
+    assert dtsorts[1].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #  DATATYPE
+    #    list4 = cons(car: set(ns4), cdr: list4) | nil,
+    #    ns4 = elem(ndata: list4)
+    #  END
+    unres_ns4 = solver.mkUninterpretedSort('ns4')
+    unres_list4 = solver.mkUninterpretedSort('list4')
+    unres_types = set([unres_ns4, unres_list4])
+
+    list4 = solver.mkDatatypeDecl('list4')
+    cons4 = solver.mkDatatypeConstructorDecl('cons4')
+    cons4.addSelector('car', solver.mkSetSort(unres_ns4))
+    cons4.addSelector('cdr', unres_list4)
+    list4.addConstructor(cons4)
+    nil4 = solver.mkDatatypeConstructorDecl('nil4')
+    list4.addConstructor(nil4)
+
+    ns4 = solver.mkDatatypeDecl('ns4')
+    elem4 = solver.mkDatatypeConstructorDecl('elem3')
+    elem4.addSelector('ndata', unres_list4)
+    ns4.addConstructor(elem4)
+
+    # both are well-founded and have nested recursion
+    dtdecls = [list4, ns4]
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
+    assert len(dtsorts) == 2
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[1].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
+    assert dtsorts[1].getDatatype().hasNestedRecursion()
+
+    # Create mutual datatypes corresponding to this definition block:
+    #  DATATYPE
+    #    list5[X] = cons(car: X, cdr: list5[list5[X]]) | nil
+    #  END
+    unres_list5 = solver.mkSortConstructorSort('list5', 1)
+    unres_types = set([unres_list5])
+
+    x = solver.mkParamSort('X')
+    v = [x]
+    list5 = solver.mkDatatypeDecl('list5', v)
+
+    args = [x]
+    ur_list_x = unres_list5.instantiate(args)
+    args = [ur_list_x]
+    ur_list_list_x = unres_list5.instantiate(args)
+
+    cons5 = solver.mkDatatypeConstructorDecl('cons5')
+    cons5.addSelector('car', x)
+    cons5.addSelector('cdr', ur_list_list_x)
+    list5.addConstructor(cons5)
+    nil5 = solver.mkDatatypeConstructorDecl('nil5')
+    list5.addConstructor(nil5)
+
+    # well-founded and has nested recursion
+    dtdecls = [list5]
+    dtsorts = solver.mkDatatypeSorts(dtdecls, unres_types)
+    assert len(dtsorts) == 1
+    assert dtsorts[0].getDatatype().isWellFounded()
+    assert dtsorts[0].getDatatype().hasNestedRecursion()
index b135e451079e3276cfc6ad84f0c6d4e32f51bf45..ca8d4c74159c3fd3b285a09bf12fcf4696346c5b 100644 (file)
@@ -4,6 +4,17 @@ import pycvc4
 from pycvc4 import kinds
 
 
+def test_getitem():
+    solver = pycvc4.Solver()
+    intsort = solver.getIntegerSort()
+    x = solver.mkConst(intsort, 'x')
+    y = solver.mkConst(intsort, 'y')
+    xpy = solver.mkTerm(kinds.Plus, x, y)
+
+    assert xpy[0] == x
+    assert xpy[1] == y
+
+
 def test_get_kind():
     solver = pycvc4.Solver()
     intsort = solver.getIntegerSort()