Adding functions to the python API and testing them -- part 2 (#6517)
authoryoni206 <yoni206@users.noreply.github.com>
Thu, 13 May 2021 05:32:55 +0000 (22:32 -0700)
committerGitHub <noreply@github.com>
Thu, 13 May 2021 05:32:55 +0000 (05:32 +0000)
This PR adds some functions that are missing in the python API, along with unit tests for them.
The unit tests are translated from https://github.com/cvc5/cvc5/blob/master/test/unit/api/term_black.cpp

src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/python/unit/api/test_term.py

index b91a9e9c557f388731f52b2aed15c009f9e302ea..a044c79f5ddfbcb2e8e1f981225936975a6e819e 100644 (file)
@@ -341,6 +341,11 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         Term()
         bint operator==(const Term&) except +
         bint operator!=(const Term&) except +
+        bint operator<(const Term&) except +
+        bint operator>(const Term&) except +
+        bint operator<=(const Term&) except +
+        bint operator>=(const Term&) except +
+        size_t getNumChildren() except +
         Term operator[](size_t idx) except +
         uint64_t getId() except +
         Kind getKind() except +
@@ -367,6 +372,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
             Term operator*() except +
         const_iterator begin() except +
         const_iterator end() except +
+        bint isInteger() except +
 
     cdef cppclass TermHashFunction:
         TermHashFunction() except +
index fc14c1fcaf70300ed48d8b8693e8c705251742d6..a512a17a8521586db31e1f7b773413a058cfa249 100644 (file)
@@ -1460,6 +1460,18 @@ cdef class Term:
     def __ne__(self, Term other):
         return self.cterm != other.cterm
 
+    def __lt__(self, Term other):
+        return self.cterm < other.cterm
+
+    def __gt__(self, Term other):
+        return self.cterm > other.cterm
+
+    def __le__(self, Term other):
+        return self.cterm <= other.cterm
+
+    def __ge__(self, Term other):
+        return self.cterm >= other.cterm
+
     def __getitem__(self, int index):
         cdef Term term = Term(self.solver)
         if index >= 0:
@@ -1483,6 +1495,8 @@ cdef class Term:
     def __hash__(self):
         return ctermhash(self.cterm)
 
+    def getNumChildren(self):
+        return self.cterm.getNumChildren()
 
     def getId(self):
         return self.cterm.getId()
@@ -1584,6 +1598,9 @@ cdef class Term:
         term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm)
         return term
 
+    def isInteger(self):
+        return self.cterm.isInteger()
+    
     def toPythonObj(self):
         '''
         Converts a constant value Term to a Python object.
index 91424c905b7cf9a790e8ff38380299cce8fb137c..70263480785ab5b895731ed76933ac4f21b2dbc3 100644 (file)
@@ -897,6 +897,83 @@ def test_substitute(solver):
         xpx.substitute(es, rs)
 
 
+def test_term_compare(solver):
+  t1 = solver.mkInteger(1)
+  t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
+  t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
+  assert t2 >= t3
+  assert t2 <= t3
+  assert (t1 > t2) != (t1 < t2)
+  assert (t1 > t2 or t1 == t2) == (t1 >= t2)
+
+def test_term_children(solver):
+  # simple term 2+3
+  two = solver.mkInteger(2)
+  t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3))
+  assert t1[0] == two
+  assert t1.getNumChildren() == 2
+  tnull = Term(solver)
+  with pytest.raises(RuntimeError):
+      tnull.getNumChildren()
+
+  # apply term f(2)
+  intSort = solver.getIntegerSort()
+  fsort = solver.mkFunctionSort(intSort, intSort)
+  f = solver.mkConst(fsort, "f")
+  t2 = solver.mkTerm(kinds.ApplyUf, f, two)
+  # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf
+  assert t2.getNumChildren() == 2
+  assert t2[0] == f
+  assert t2[1] == two
+  with pytest.raises(RuntimeError):
+      tnull[0]
+
+def test_is_integer(solver):
+  int1 = solver.mkInteger("-18446744073709551616")
+  int2 = solver.mkInteger("-18446744073709551615")
+  int3 = solver.mkInteger("-4294967296")
+  int4 = solver.mkInteger("-4294967295")
+  int5 = solver.mkInteger("-10")
+  int6 = solver.mkInteger("0")
+  int7 = solver.mkInteger("10")
+  int8 = solver.mkInteger("4294967295")
+  int9 = solver.mkInteger("4294967296")
+  int10 = solver.mkInteger("18446744073709551615")
+  int11 = solver.mkInteger("18446744073709551616")
+  int12 = solver.mkInteger("-0")
+
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("-")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("-1-")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("0.0")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("-0.1")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("012")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("0000")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("-01")
+  with pytest.raises(RuntimeError):
+      solver.mkInteger("-00")
+
+  assert int1.isInteger()
+  assert int2.isInteger()
+  assert int3.isInteger()
+  assert int4.isInteger()
+  assert int5.isInteger()
+  assert int6.isInteger()
+  assert int7.isInteger()
+  assert int8.isInteger()
+  assert int9.isInteger()
+  assert int10.isInteger()
+  assert int11.isInteger()
+
+
 def test_const_array(solver):
     intsort = solver.getIntegerSort()
     arrsort = solver.mkArraySort(intsort, intsort)