FP value support in python API (#6644)
authoryoni206 <yoni206@users.noreply.github.com>
Tue, 1 Jun 2021 09:40:54 +0000 (02:40 -0700)
committerGitHub <noreply@github.com>
Tue, 1 Jun 2021 09:40:54 +0000 (11:40 +0200)
This PR adds new is* functions from the cpp API to the python API.
In particular, it adds getFloatingPointValue() function from the cpp API.
A test (translated from term_black.cpp) is added.

getFloatingPointValue() returns a tuple, and so this requires importing an instance of tuples into cython.

src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/python/unit/api/test_term.py

index ef65c9070909fc3e3dd07fb1cc503d1d124a2268..87a646666464391eebd71798498242d9172ef602 100644 (file)
@@ -27,6 +27,15 @@ cdef extern from "<string>" namespace "std":
         const wchar_t* data() except +
         size_t size() except +
 
+cdef extern from "<tuple>" namespace "std" nogil:
+    cdef cppclass tuple[T, U, S]:
+        pass
+
+cdef extern from "<tuple>" namespace "std":
+    uint32_t get0 "std::get<0>"(tuple[uint32_t,uint32_t,Term]) except +
+    uint32_t get1 "std::get<1>"(tuple[uint32_t,uint32_t,Term]) except +
+    Term get2 "std::get<2>"(tuple[uint32_t,uint32_t,Term]) except +
+
 cdef extern from "api/cpp/cvc5.h" namespace "cvc5":
     cdef cppclass Options:
         pass
@@ -387,6 +396,14 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         string getRealValue() except +
         bint isBitVectorValue() except +
         string getBitVectorValue(uint32_t base) except +
+        bint isFloatingPointPosZero() except +
+        bint isFloatingPointNegZero() except +
+        bint isFloatingPointPosInf() except +
+        bint isFloatingPointNegInf() except +
+        bint isFloatingPointNaN() except +
+        bint isFloatingPointValue() except +
+
+        tuple[uint32_t, uint32_t, Term] getFloatingPointValue() except +
         vector[Term] getSequenceValue() except +
 
     cdef cppclass TermHashFunction:
index 7731f4e71ad269b860cff9abe6e029077269a7ae..cd9e91e51c720ac1346380db67b4ff61139b9f00 100644 (file)
@@ -28,7 +28,8 @@ from cvc5 cimport ROUND_NEAREST_TIES_TO_AWAY
 from cvc5 cimport Term as c_Term
 from cvc5 cimport hash as c_hash
 from cvc5 cimport wstring as c_wstring
-
+from cvc5 cimport tuple as c_tuple
+from cvc5 cimport get0, get1, get2
 from cvc5kinds cimport Kind as c_Kind
 
 cdef extern from "Python.h":
@@ -1661,6 +1662,30 @@ cdef class Term:
 
     def isIntegerValue(self):
         return self.cterm.isIntegerValue()
+    
+    def isFloatingPointPosZero(self):
+        return self.cterm.isFloatingPointPosZero()
+    
+    def isFloatingPointNegZero(self):
+        return self.cterm.isFloatingPointNegZero()
+    
+    def isFloatingPointPosInf(self):
+        return self.cterm.isFloatingPointPosInf()
+    
+    def isFloatingPointNegInf(self):
+        return self.cterm.isFloatingPointNegInf()
+    
+    def isFloatingPointNaN(self):
+        return self.cterm.isFloatingPointNaN()
+    
+    def isFloatingPointValue(self):
+        return self.cterm.isFloatingPointValue()
+
+    def getFloatingPointValue(self):
+        cdef c_tuple[uint32_t, uint32_t, c_Term] t = self.cterm.getFloatingPointValue()
+        cdef Term term = Term(self.solver)
+        term.cterm = get2(t)
+        return (get0(t), get1(t), term)
 
     def getIntegerValue(self):
         return int(self.cterm.getIntegerValue().decode())
index 2b6fd8fd6e44323bc8ec1501c19eb46cc1d59eb5..5603655c61f388168e0636021b8565d1f57c63ee 100644 (file)
@@ -930,6 +930,25 @@ def test_term_children(solver):
         tnull[0]
 
 
+def test_get_floating_point(solver):
+    bvval = solver.mkBitVector("0000110000000011")
+    fp = solver.mkFloatingPoint(5, 11, bvval)
+
+    assert fp.isFloatingPointValue()
+    assert not fp.isFloatingPointPosZero()
+    assert not fp.isFloatingPointNegZero()
+    assert not fp.isFloatingPointPosInf()
+    assert not fp.isFloatingPointNegInf()
+    assert not fp.isFloatingPointNaN()
+    assert (5, 11, bvval) == fp.getFloatingPointValue()
+
+    assert solver.mkPosZero(5, 11).isFloatingPointPosZero()
+    assert solver.mkNegZero(5, 11).isFloatingPointNegZero()
+    assert solver.mkPosInf(5, 11).isFloatingPointPosInf()
+    assert solver.mkNegInf(5, 11).isFloatingPointNegInf()
+    assert solver.mkNaN(5, 11).isFloatingPointNaN()
+
+
 def test_is_integer(solver):
     int1 = solver.mkInteger("-18446744073709551616")
     int2 = solver.mkInteger("-18446744073709551615")
@@ -1018,11 +1037,11 @@ def test_get_real(solver):
     assert 0 == real1.getRealValue()
     assert 0 == real2.getRealValue()
     assert -17 == real3.getRealValue()
-    assert -3/5 == real4.getRealValue()
-    assert 127/10 == real5.getRealValue()
-    assert 1/4294967297 == real6.getRealValue()
+    assert -3 / 5 == real4.getRealValue()
+    assert 127 / 10 == real5.getRealValue()
+    assert 1 / 4294967297 == real6.getRealValue()
     assert 4294967297 == real7.getRealValue()
-    assert 1/18446744073709551617 == real8.getRealValue()
+    assert 1 / 18446744073709551617 == real8.getRealValue()
     assert float(18446744073709551617) == real9.getRealValue()