Update `toPythonObj` to use new getters -- part 1 (#6623)
authoryoni206 <yoni206@users.noreply.github.com>
Mon, 31 May 2021 00:56:14 +0000 (17:56 -0700)
committerGitHub <noreply@github.com>
Mon, 31 May 2021 00:56:14 +0000 (00:56 +0000)
Following #6496 , this PR adds new getters to the python API, as well as tests for them. This makes toPythonObj simpler.

A future PR will add more getters to the python API.

Co-authored-by: Gereon Kremer nafur42@gmail.com
src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/api/python/test_to_python_obj.py
test/python/unit/api/test_term.py

index 205b82918367a3387a826aabd5963fcb9927b317..ef65c9070909fc3e3dd07fb1cc503d1d124a2268 100644 (file)
@@ -22,7 +22,10 @@ cdef extern from "<functional>" namespace "std" nogil:
 
 cdef extern from "<string>" namespace "std":
     cdef cppclass wstring:
+        wstring() except +
         wstring(const wchar_t*, size_t) except +
+        const wchar_t* data() except +
+        size_t size() except +
 
 cdef extern from "api/cpp/cvc5.h" namespace "cvc5":
     cdef cppclass Options:
@@ -374,7 +377,16 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
             Term operator*() except +
         const_iterator begin() except +
         const_iterator end() except +
+        bint isBooleanValue() except +
+        bint getBooleanValue() except +
+        bint isStringValue() except +
+        wstring getStringValue() except +
         bint isIntegerValue() except +
+        string getIntegerValue() except +
+        bint isRealValue() except +
+        string getRealValue() except +
+        bint isBitVectorValue() except +
+        string getBitVectorValue(uint32_t base) except +
         vector[Term] getSequenceValue() except +
 
     cdef cppclass TermHashFunction:
index b2942e0b35976794aece5077615b81e429566c2f..7731f4e71ad269b860cff9abe6e029077269a7ae 100644 (file)
@@ -33,6 +33,7 @@ from cvc5kinds cimport Kind as c_Kind
 
 cdef extern from "Python.h":
     wchar_t* PyUnicode_AsWideCharString(object, Py_ssize_t *)
+    object PyUnicode_FromWideChar(const wchar_t*, Py_ssize_t)
     void PyMem_Free(void*)
 
 ################################## DECORATORS #################################
@@ -759,6 +760,47 @@ cdef class Solver:
         term.cterm = self.csolver.mkUniverseSet(sort.csort)
         return term
 
+    @expand_list_arg(num_req_args=0)
+    def mkBitVector(self, *args):
+        '''
+            Supports the following arguments:
+            Term mkBitVector(int size, int val=0)
+            Term mkBitVector(string val, int base = 2)
+            Term mkBitVector(int size, string val, int base)
+         '''
+        cdef Term term = Term(self)
+        if len(args) == 1:
+            size_or_val = args[0]
+            if isinstance(args[0], int):
+                size = args[0]
+                term.cterm = self.csolver.mkBitVector(<uint32_t> size)
+            else:
+                assert isinstance(args[0], str)
+                val = args[0]
+                term.cterm = self.csolver.mkBitVector(<const string&> str(val).encode())
+        elif len(args) == 2:
+            if isinstance(args[0], int):
+                size = args[0]
+                assert isinstance(args[1], int)
+                val = args[1]
+                term.cterm = self.csolver.mkBitVector(<uint32_t> size, <uint32_t> val)
+            else:
+                assert isinstance(args[0], str)
+                assert isinstance(args[1], int)
+                val = args[0]
+                base = args[1]
+                term.cterm = self.csolver.mkBitVector(<const string&> str(val).encode(), <uint32_t> base)
+        elif len(args) == 3:
+                assert isinstance(args[0], int)
+                assert isinstance(args[1], str)
+                assert isinstance(args[2], int)
+                size = args[0]
+                val = args[1]
+                base = args[2]
+                term.cterm = self.csolver.mkBitVector(<uint32_t> size, <const string&> str(val).encode(), <uint32_t> base)
+        return term
+
+
     def mkBitVector(self, size_or_str, val = None):
         cdef Term term = Term(self)
         if isinstance(size_or_str, int):
@@ -1603,9 +1645,38 @@ cdef class Term:
         term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm)
         return term
 
-    def isInteger(self):
+    def isBooleanValue(self):
+        return self.cterm.isBooleanValue()
+
+    def getBooleanValue(self):
+        return self.cterm.getBooleanValue()
+
+    def isStringValue(self):
+        return self.cterm.isStringValue()
+
+    def getStringValue(self):
+        cdef Py_ssize_t size
+        cdef c_wstring s = self.cterm.getStringValue()
+        return PyUnicode_FromWideChar(s.data(), s.size())
+
+    def isIntegerValue(self):
         return self.cterm.isIntegerValue()
-    
+
+    def getIntegerValue(self):
+        return int(self.cterm.getIntegerValue().decode())
+
+    def isRealValue(self):
+        return self.cterm.isRealValue()
+
+    def getRealValue(self):
+        return float(Fraction(self.cterm.getRealValue().decode()))
+
+    def isBitVectorValue(self):
+        return self.cterm.isBitVectorValue()
+
+    def getBitVectorValue(self, base = 2):
+        return self.cterm.getBitVectorValue(base).decode()
+
     def toPythonObj(self):
         '''
         Converts a constant value Term to a Python object.
@@ -1615,61 +1686,23 @@ cdef class Term:
           Int     -- returns a Python int
           Real    -- returns a Python Fraction
           BV      -- returns a Python int (treats BV as unsigned)
+          String  -- returns a Python Unicode string
           Array   -- returns a Python dict mapping indices to values
                   -- the constant base is returned as the default value
-          String  -- returns a Python Unicode string
         '''
 
-        string_repr = self.cterm.toString().decode()
-        assert string_repr
-        sort = self.getSort()
-        res = None
-        if sort.isBoolean():
-            if string_repr == "true":
-                res = True
-            else:
-                assert string_repr == "false"
-                res = False
-
-        elif sort.isInteger():
-            updated_string_repr = string_repr.strip('()').replace(' ', '')
-            try:
-                res = int(updated_string_repr)
-            except:
-                raise ValueError("Failed to convert"
-                                 " {} to an int".format(string_repr))
-
-        elif sort.isReal():
-            updated_string_repr = string_repr
-            try:
-                # rational format (/ a b) most likely
-                # note: a or b could be negated: (- a)
-                splits = [s.strip('()/')
-                          for s in updated_string_repr.strip('()/') \
-                          .replace('(- ', '(-').split()]
-                assert len(splits) == 2
-                num = int(splits[0])
-                den = int(splits[1])
-                res = Fraction(num, den)
-            except:
-                try:
-                    # could be exact: e.g., 1.0
-                    res = Fraction(updated_string_repr)
-                except:
-                    raise ValueError("Failed to convert "
-                                     "{} to a Fraction".format(string_repr))
-
-        elif sort.isBitVector():
-            # expecting format #b<bits>
-            assert string_repr[:2] == "#b"
-            python_bin_repr = "0" + string_repr[1:]
-            try:
-                res = int(python_bin_repr, 2)
-            except:
-                raise ValueError("Failed to convert bitvector "
-                                 "{} to an int".format(string_repr))
-
-        elif sort.isArray():
+        if self.isBooleanValue():
+            return self.getBooleanValue()
+        elif self.isIntegerValue():
+            return self.getIntegerValue()
+        elif self.isRealValue():
+            return self.getRealValue()
+        elif self.isBitVectorValue():
+            return int(self.getBitVectorValue(), 2)
+        elif self.isStringValue():
+            return self.getStringValue()
+        elif self.getSort().isArray():
+            res = None
             keys = []
             values = []
             base_value = None
@@ -1696,33 +1729,7 @@ cdef class Term:
             for k, v in zip(keys, values):
                 res[k] = v
 
-        elif sort.isString():
-            # Strip leading and trailing double quotes and replace double
-            # double quotes by single quotes
-            string_repr = string_repr[1:-1].replace('""', '"')
-
-            # Convert escape sequences
-            res = ''
-            escape_prefix = '\\u{'
-            i = 0
-            while True:
-              prev_i = i
-              i = string_repr.find(escape_prefix, i)
-              if i == -1:
-                res += string_repr[prev_i:]
-                break
-
-              res += string_repr[prev_i:i]
-              val = string_repr[i + len(escape_prefix):string_repr.find('}', i)]
-              res += chr(int(val, 16))
-              i += len(escape_prefix) + len(val) + 1
-        else:
-            raise ValueError("Cannot convert term {}"
-                             " of sort {} to Python object".format(string_repr,
-                                                                   sort))
-
-        assert res is not None
-        return res
+            return res
 
 
 # Generate rounding modes
index 572453670f860c72a57f414bcbe62d1eb1e7a8fe..2ba685d5072e6faaffb200e6d3bb67e11391c42d 100644 (file)
@@ -115,4 +115,4 @@ def testGetValueReal():
     xval = solver.getValue(x)
     yval = solver.getValue(y)
     assert xval.toPythonObj() == Fraction("6")
-    assert yval.toPythonObj() == Fraction("8.33")
+    assert yval.toPythonObj() == float(Fraction("8.33"))
index 936ff3e1cccbc3c1ad5c1c40894f5d08eada073e..2b6fd8fd6e44323bc8ec1501c19eb46cc1d59eb5 100644 (file)
@@ -898,80 +898,182 @@ def test_substitute(solver):
 
 
 def test_term_compare(solver):
-  t1 = solver.mkInteger(1)
-  t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
-  t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
-  assert t2 >= t3
-  assert t2 <= t3
-  assert (t1 > t2) != (t1 < t2)
-  assert (t1 > t2 or t1 == t2) == (t1 >= t2)
+    t1 = solver.mkInteger(1)
+    t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
+    t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2))
+    assert t2 >= t3
+    assert t2 <= t3
+    assert (t1 > t2) != (t1 < t2)
+    assert (t1 > t2 or t1 == t2) == (t1 >= t2)
+
 
 def test_term_children(solver):
-  # simple term 2+3
-  two = solver.mkInteger(2)
-  t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3))
-  assert t1[0] == two
-  assert t1.getNumChildren() == 2
-  tnull = Term(solver)
-  with pytest.raises(RuntimeError):
-      tnull.getNumChildren()
-
-  # apply term f(2)
-  intSort = solver.getIntegerSort()
-  fsort = solver.mkFunctionSort(intSort, intSort)
-  f = solver.mkConst(fsort, "f")
-  t2 = solver.mkTerm(kinds.ApplyUf, f, two)
-  # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf
-  assert t2.getNumChildren() == 2
-  assert t2[0] == f
-  assert t2[1] == two
-  with pytest.raises(RuntimeError):
-      tnull[0]
+    # simple term 2+3
+    two = solver.mkInteger(2)
+    t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3))
+    assert t1[0] == two
+    assert t1.getNumChildren() == 2
+    tnull = Term(solver)
+    with pytest.raises(RuntimeError):
+        tnull.getNumChildren()
+
+    # apply term f(2)
+    intSort = solver.getIntegerSort()
+    fsort = solver.mkFunctionSort(intSort, intSort)
+    f = solver.mkConst(fsort, "f")
+    t2 = solver.mkTerm(kinds.ApplyUf, f, two)
+    # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf
+    assert t2.getNumChildren() == 2
+    assert t2[0] == f
+    assert t2[1] == two
+    with pytest.raises(RuntimeError):
+        tnull[0]
+
 
 def test_is_integer(solver):
-  int1 = solver.mkInteger("-18446744073709551616")
-  int2 = solver.mkInteger("-18446744073709551615")
-  int3 = solver.mkInteger("-4294967296")
-  int4 = solver.mkInteger("-4294967295")
-  int5 = solver.mkInteger("-10")
-  int6 = solver.mkInteger("0")
-  int7 = solver.mkInteger("10")
-  int8 = solver.mkInteger("4294967295")
-  int9 = solver.mkInteger("4294967296")
-  int10 = solver.mkInteger("18446744073709551615")
-  int11 = solver.mkInteger("18446744073709551616")
-  int12 = solver.mkInteger("-0")
-
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("-")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("-1-")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("0.0")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("-0.1")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("012")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("0000")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("-01")
-  with pytest.raises(RuntimeError):
-      solver.mkInteger("-00")
-
-  assert int1.isInteger()
-  assert int2.isInteger()
-  assert int3.isInteger()
-  assert int4.isInteger()
-  assert int5.isInteger()
-  assert int6.isInteger()
-  assert int7.isInteger()
-  assert int8.isInteger()
-  assert int9.isInteger()
-  assert int10.isInteger()
-  assert int11.isInteger()
+    int1 = solver.mkInteger("-18446744073709551616")
+    int2 = solver.mkInteger("-18446744073709551615")
+    int3 = solver.mkInteger("-4294967296")
+    int4 = solver.mkInteger("-4294967295")
+    int5 = solver.mkInteger("-10")
+    int6 = solver.mkInteger("0")
+    int7 = solver.mkInteger("10")
+    int8 = solver.mkInteger("4294967295")
+    int9 = solver.mkInteger("4294967296")
+    int10 = solver.mkInteger("18446744073709551615")
+    int11 = solver.mkInteger("18446744073709551616")
+    int12 = solver.mkInteger("-0")
+
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("-")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("-1-")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("0.0")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("-0.1")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("012")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("0000")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("-01")
+    with pytest.raises(RuntimeError):
+        solver.mkInteger("-00")
+
+    assert int1.isIntegerValue()
+    assert int2.isIntegerValue()
+    assert int3.isIntegerValue()
+    assert int4.isIntegerValue()
+    assert int5.isIntegerValue()
+    assert int6.isIntegerValue()
+    assert int7.isIntegerValue()
+    assert int8.isIntegerValue()
+    assert int9.isIntegerValue()
+    assert int10.isIntegerValue()
+    assert int11.isIntegerValue()
+
+    assert int1.getIntegerValue() == -18446744073709551616
+    assert int2.getIntegerValue() == -18446744073709551615
+    assert int3.getIntegerValue() == -4294967296
+    assert int4.getIntegerValue() == -4294967295
+    assert int5.getIntegerValue() == -10
+    assert int6.getIntegerValue() == 0
+    assert int7.getIntegerValue() == 10
+    assert int8.getIntegerValue() == 4294967295
+    assert int9.getIntegerValue() == 4294967296
+    assert int10.getIntegerValue() == 18446744073709551615
+    assert int11.getIntegerValue() == 18446744073709551616
+
+
+def test_get_string(solver):
+    s1 = solver.mkString("abcde")
+    assert s1.isStringValue()
+    assert s1.getStringValue() == str("abcde")
+
+
+def test_get_real(solver):
+    real1 = solver.mkReal("0")
+    real2 = solver.mkReal(".0")
+    real3 = solver.mkReal("-17")
+    real4 = solver.mkReal("-3/5")
+    real5 = solver.mkReal("12.7")
+    real6 = solver.mkReal("1/4294967297")
+    real7 = solver.mkReal("4294967297")
+    real8 = solver.mkReal("1/18446744073709551617")
+    real9 = solver.mkReal("18446744073709551617")
+
+    assert real1.isRealValue()
+    assert real2.isRealValue()
+    assert real3.isRealValue()
+    assert real4.isRealValue()
+    assert real5.isRealValue()
+    assert real6.isRealValue()
+    assert real7.isRealValue()
+    assert real8.isRealValue()
+    assert real9.isRealValue()
+
+    assert 0 == real1.getRealValue()
+    assert 0 == real2.getRealValue()
+    assert -17 == real3.getRealValue()
+    assert -3/5 == real4.getRealValue()
+    assert 127/10 == real5.getRealValue()
+    assert 1/4294967297 == real6.getRealValue()
+    assert 4294967297 == real7.getRealValue()
+    assert 1/18446744073709551617 == real8.getRealValue()
+    assert float(18446744073709551617) == real9.getRealValue()
+
+
+def test_get_boolean(solver):
+    b1 = solver.mkBoolean(True)
+    b2 = solver.mkBoolean(False)
+
+    assert b1.isBooleanValue()
+    assert b2.isBooleanValue()
+    assert b1.getBooleanValue()
+    assert not b2.getBooleanValue()
+
+
+def test_get_bit_vector(solver):
+    b1 = solver.mkBitVector(8, 15)
+    b2 = solver.mkBitVector("00001111", 2)
+    b3 = solver.mkBitVector("15", 10)
+    b4 = solver.mkBitVector("0f", 16)
+    b5 = solver.mkBitVector(8, "00001111", 2)
+    b6 = solver.mkBitVector(8, "15", 10)
+    b7 = solver.mkBitVector(8, "0f", 16)
+
+    assert b1.isBitVectorValue()
+    assert b2.isBitVectorValue()
+    assert b3.isBitVectorValue()
+    assert b4.isBitVectorValue()
+    assert b5.isBitVectorValue()
+    assert b6.isBitVectorValue()
+    assert b7.isBitVectorValue()
+
+    assert "00001111" == b1.getBitVectorValue(2)
+    assert "15" == b1.getBitVectorValue(10)
+    assert "f" == b1.getBitVectorValue(16)
+    assert "00001111" == b2.getBitVectorValue(2)
+    assert "15" == b2.getBitVectorValue(10)
+    assert "f" == b2.getBitVectorValue(16)
+    assert "1111" == b3.getBitVectorValue(2)
+    assert "15" == b3.getBitVectorValue(10)
+    assert "f" == b3.getBitVectorValue(16)
+    assert "00001111" == b4.getBitVectorValue(2)
+    assert "15" == b4.getBitVectorValue(10)
+    assert "f" == b4.getBitVectorValue(16)
+    assert "00001111" == b5.getBitVectorValue(2)
+    assert "15" == b5.getBitVectorValue(10)
+    assert "f" == b5.getBitVectorValue(16)
+    assert "00001111" == b6.getBitVectorValue(2)
+    assert "15" == b6.getBitVectorValue(10)
+    assert "f" == b6.getBitVectorValue(16)
+    assert "00001111" == b7.getBitVectorValue(2)
+    assert "15" == b7.getBitVectorValue(10)
+    assert "f" == b7.getBitVectorValue(16)
 
 
 def test_const_array(solver):