From 6edc06b3fb6367e8366cab13340228e2bebfca1e Mon Sep 17 00:00:00 2001 From: yoni206 Date: Sun, 30 May 2021 17:56:14 -0700 Subject: [PATCH] Update `toPythonObj` to use new getters -- part 1 (#6623) Following #6496 , this PR adds new getters to the python API, as well as tests for them. This makes toPythonObj simpler. A future PR will add more getters to the python API. Co-authored-by: Gereon Kremer nafur42@gmail.com --- src/api/python/cvc5.pxd | 12 ++ src/api/python/cvc5.pxi | 167 +++++++++--------- test/api/python/test_to_python_obj.py | 2 +- test/python/unit/api/test_term.py | 242 ++++++++++++++++++-------- 4 files changed, 272 insertions(+), 151 deletions(-) diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd index 205b82918..ef65c9070 100644 --- a/src/api/python/cvc5.pxd +++ b/src/api/python/cvc5.pxd @@ -22,7 +22,10 @@ cdef extern from "" namespace "std" nogil: cdef extern from "" namespace "std": cdef cppclass wstring: + wstring() except + wstring(const wchar_t*, size_t) except + + const wchar_t* data() except + + size_t size() except + cdef extern from "api/cpp/cvc5.h" namespace "cvc5": cdef cppclass Options: @@ -374,7 +377,16 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": Term operator*() except + const_iterator begin() except + const_iterator end() except + + bint isBooleanValue() except + + bint getBooleanValue() except + + bint isStringValue() except + + wstring getStringValue() except + bint isIntegerValue() except + + string getIntegerValue() except + + bint isRealValue() except + + string getRealValue() except + + bint isBitVectorValue() except + + string getBitVectorValue(uint32_t base) except + vector[Term] getSequenceValue() except + cdef cppclass TermHashFunction: diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi index b2942e0b3..7731f4e71 100644 --- a/src/api/python/cvc5.pxi +++ b/src/api/python/cvc5.pxi @@ -33,6 +33,7 @@ from cvc5kinds cimport Kind as c_Kind cdef extern from "Python.h": wchar_t* PyUnicode_AsWideCharString(object, Py_ssize_t *) + object PyUnicode_FromWideChar(const wchar_t*, Py_ssize_t) void PyMem_Free(void*) ################################## DECORATORS ################################# @@ -759,6 +760,47 @@ cdef class Solver: term.cterm = self.csolver.mkUniverseSet(sort.csort) return term + @expand_list_arg(num_req_args=0) + def mkBitVector(self, *args): + ''' + Supports the following arguments: + Term mkBitVector(int size, int val=0) + Term mkBitVector(string val, int base = 2) + Term mkBitVector(int size, string val, int base) + ''' + cdef Term term = Term(self) + if len(args) == 1: + size_or_val = args[0] + if isinstance(args[0], int): + size = args[0] + term.cterm = self.csolver.mkBitVector( size) + else: + assert isinstance(args[0], str) + val = args[0] + term.cterm = self.csolver.mkBitVector( str(val).encode()) + elif len(args) == 2: + if isinstance(args[0], int): + size = args[0] + assert isinstance(args[1], int) + val = args[1] + term.cterm = self.csolver.mkBitVector( size, val) + else: + assert isinstance(args[0], str) + assert isinstance(args[1], int) + val = args[0] + base = args[1] + term.cterm = self.csolver.mkBitVector( str(val).encode(), base) + elif len(args) == 3: + assert isinstance(args[0], int) + assert isinstance(args[1], str) + assert isinstance(args[2], int) + size = args[0] + val = args[1] + base = args[2] + term.cterm = self.csolver.mkBitVector( size, str(val).encode(), base) + return term + + def mkBitVector(self, size_or_str, val = None): cdef Term term = Term(self) if isinstance(size_or_str, int): @@ -1603,9 +1645,38 @@ cdef class Term: term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm) return term - def isInteger(self): + def isBooleanValue(self): + return self.cterm.isBooleanValue() + + def getBooleanValue(self): + return self.cterm.getBooleanValue() + + def isStringValue(self): + return self.cterm.isStringValue() + + def getStringValue(self): + cdef Py_ssize_t size + cdef c_wstring s = self.cterm.getStringValue() + return PyUnicode_FromWideChar(s.data(), s.size()) + + def isIntegerValue(self): return self.cterm.isIntegerValue() - + + def getIntegerValue(self): + return int(self.cterm.getIntegerValue().decode()) + + def isRealValue(self): + return self.cterm.isRealValue() + + def getRealValue(self): + return float(Fraction(self.cterm.getRealValue().decode())) + + def isBitVectorValue(self): + return self.cterm.isBitVectorValue() + + def getBitVectorValue(self, base = 2): + return self.cterm.getBitVectorValue(base).decode() + def toPythonObj(self): ''' Converts a constant value Term to a Python object. @@ -1615,61 +1686,23 @@ cdef class Term: Int -- returns a Python int Real -- returns a Python Fraction BV -- returns a Python int (treats BV as unsigned) + String -- returns a Python Unicode string Array -- returns a Python dict mapping indices to values -- the constant base is returned as the default value - String -- returns a Python Unicode string ''' - string_repr = self.cterm.toString().decode() - assert string_repr - sort = self.getSort() - res = None - if sort.isBoolean(): - if string_repr == "true": - res = True - else: - assert string_repr == "false" - res = False - - elif sort.isInteger(): - updated_string_repr = string_repr.strip('()').replace(' ', '') - try: - res = int(updated_string_repr) - except: - raise ValueError("Failed to convert" - " {} to an int".format(string_repr)) - - elif sort.isReal(): - updated_string_repr = string_repr - try: - # rational format (/ a b) most likely - # note: a or b could be negated: (- a) - splits = [s.strip('()/') - for s in updated_string_repr.strip('()/') \ - .replace('(- ', '(-').split()] - assert len(splits) == 2 - num = int(splits[0]) - den = int(splits[1]) - res = Fraction(num, den) - except: - try: - # could be exact: e.g., 1.0 - res = Fraction(updated_string_repr) - except: - raise ValueError("Failed to convert " - "{} to a Fraction".format(string_repr)) - - elif sort.isBitVector(): - # expecting format #b - assert string_repr[:2] == "#b" - python_bin_repr = "0" + string_repr[1:] - try: - res = int(python_bin_repr, 2) - except: - raise ValueError("Failed to convert bitvector " - "{} to an int".format(string_repr)) - - elif sort.isArray(): + if self.isBooleanValue(): + return self.getBooleanValue() + elif self.isIntegerValue(): + return self.getIntegerValue() + elif self.isRealValue(): + return self.getRealValue() + elif self.isBitVectorValue(): + return int(self.getBitVectorValue(), 2) + elif self.isStringValue(): + return self.getStringValue() + elif self.getSort().isArray(): + res = None keys = [] values = [] base_value = None @@ -1696,33 +1729,7 @@ cdef class Term: for k, v in zip(keys, values): res[k] = v - elif sort.isString(): - # Strip leading and trailing double quotes and replace double - # double quotes by single quotes - string_repr = string_repr[1:-1].replace('""', '"') - - # Convert escape sequences - res = '' - escape_prefix = '\\u{' - i = 0 - while True: - prev_i = i - i = string_repr.find(escape_prefix, i) - if i == -1: - res += string_repr[prev_i:] - break - - res += string_repr[prev_i:i] - val = string_repr[i + len(escape_prefix):string_repr.find('}', i)] - res += chr(int(val, 16)) - i += len(escape_prefix) + len(val) + 1 - else: - raise ValueError("Cannot convert term {}" - " of sort {} to Python object".format(string_repr, - sort)) - - assert res is not None - return res + return res # Generate rounding modes diff --git a/test/api/python/test_to_python_obj.py b/test/api/python/test_to_python_obj.py index 572453670..2ba685d50 100644 --- a/test/api/python/test_to_python_obj.py +++ b/test/api/python/test_to_python_obj.py @@ -115,4 +115,4 @@ def testGetValueReal(): xval = solver.getValue(x) yval = solver.getValue(y) assert xval.toPythonObj() == Fraction("6") - assert yval.toPythonObj() == Fraction("8.33") + assert yval.toPythonObj() == float(Fraction("8.33")) diff --git a/test/python/unit/api/test_term.py b/test/python/unit/api/test_term.py index 936ff3e1c..2b6fd8fd6 100644 --- a/test/python/unit/api/test_term.py +++ b/test/python/unit/api/test_term.py @@ -898,80 +898,182 @@ def test_substitute(solver): def test_term_compare(solver): - t1 = solver.mkInteger(1) - t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2)) - t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2)) - assert t2 >= t3 - assert t2 <= t3 - assert (t1 > t2) != (t1 < t2) - assert (t1 > t2 or t1 == t2) == (t1 >= t2) + t1 = solver.mkInteger(1) + t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2)) + t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2)) + assert t2 >= t3 + assert t2 <= t3 + assert (t1 > t2) != (t1 < t2) + assert (t1 > t2 or t1 == t2) == (t1 >= t2) + def test_term_children(solver): - # simple term 2+3 - two = solver.mkInteger(2) - t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3)) - assert t1[0] == two - assert t1.getNumChildren() == 2 - tnull = Term(solver) - with pytest.raises(RuntimeError): - tnull.getNumChildren() - - # apply term f(2) - intSort = solver.getIntegerSort() - fsort = solver.mkFunctionSort(intSort, intSort) - f = solver.mkConst(fsort, "f") - t2 = solver.mkTerm(kinds.ApplyUf, f, two) - # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf - assert t2.getNumChildren() == 2 - assert t2[0] == f - assert t2[1] == two - with pytest.raises(RuntimeError): - tnull[0] + # simple term 2+3 + two = solver.mkInteger(2) + t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3)) + assert t1[0] == two + assert t1.getNumChildren() == 2 + tnull = Term(solver) + with pytest.raises(RuntimeError): + tnull.getNumChildren() + + # apply term f(2) + intSort = solver.getIntegerSort() + fsort = solver.mkFunctionSort(intSort, intSort) + f = solver.mkConst(fsort, "f") + t2 = solver.mkTerm(kinds.ApplyUf, f, two) + # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf + assert t2.getNumChildren() == 2 + assert t2[0] == f + assert t2[1] == two + with pytest.raises(RuntimeError): + tnull[0] + def test_is_integer(solver): - int1 = solver.mkInteger("-18446744073709551616") - int2 = solver.mkInteger("-18446744073709551615") - int3 = solver.mkInteger("-4294967296") - int4 = solver.mkInteger("-4294967295") - int5 = solver.mkInteger("-10") - int6 = solver.mkInteger("0") - int7 = solver.mkInteger("10") - int8 = solver.mkInteger("4294967295") - int9 = solver.mkInteger("4294967296") - int10 = solver.mkInteger("18446744073709551615") - int11 = solver.mkInteger("18446744073709551616") - int12 = solver.mkInteger("-0") - - with pytest.raises(RuntimeError): - solver.mkInteger("") - with pytest.raises(RuntimeError): - solver.mkInteger("-") - with pytest.raises(RuntimeError): - solver.mkInteger("-1-") - with pytest.raises(RuntimeError): - solver.mkInteger("0.0") - with pytest.raises(RuntimeError): - solver.mkInteger("-0.1") - with pytest.raises(RuntimeError): - solver.mkInteger("012") - with pytest.raises(RuntimeError): - solver.mkInteger("0000") - with pytest.raises(RuntimeError): - solver.mkInteger("-01") - with pytest.raises(RuntimeError): - solver.mkInteger("-00") - - assert int1.isInteger() - assert int2.isInteger() - assert int3.isInteger() - assert int4.isInteger() - assert int5.isInteger() - assert int6.isInteger() - assert int7.isInteger() - assert int8.isInteger() - assert int9.isInteger() - assert int10.isInteger() - assert int11.isInteger() + int1 = solver.mkInteger("-18446744073709551616") + int2 = solver.mkInteger("-18446744073709551615") + int3 = solver.mkInteger("-4294967296") + int4 = solver.mkInteger("-4294967295") + int5 = solver.mkInteger("-10") + int6 = solver.mkInteger("0") + int7 = solver.mkInteger("10") + int8 = solver.mkInteger("4294967295") + int9 = solver.mkInteger("4294967296") + int10 = solver.mkInteger("18446744073709551615") + int11 = solver.mkInteger("18446744073709551616") + int12 = solver.mkInteger("-0") + + with pytest.raises(RuntimeError): + solver.mkInteger("") + with pytest.raises(RuntimeError): + solver.mkInteger("-") + with pytest.raises(RuntimeError): + solver.mkInteger("-1-") + with pytest.raises(RuntimeError): + solver.mkInteger("0.0") + with pytest.raises(RuntimeError): + solver.mkInteger("-0.1") + with pytest.raises(RuntimeError): + solver.mkInteger("012") + with pytest.raises(RuntimeError): + solver.mkInteger("0000") + with pytest.raises(RuntimeError): + solver.mkInteger("-01") + with pytest.raises(RuntimeError): + solver.mkInteger("-00") + + assert int1.isIntegerValue() + assert int2.isIntegerValue() + assert int3.isIntegerValue() + assert int4.isIntegerValue() + assert int5.isIntegerValue() + assert int6.isIntegerValue() + assert int7.isIntegerValue() + assert int8.isIntegerValue() + assert int9.isIntegerValue() + assert int10.isIntegerValue() + assert int11.isIntegerValue() + + assert int1.getIntegerValue() == -18446744073709551616 + assert int2.getIntegerValue() == -18446744073709551615 + assert int3.getIntegerValue() == -4294967296 + assert int4.getIntegerValue() == -4294967295 + assert int5.getIntegerValue() == -10 + assert int6.getIntegerValue() == 0 + assert int7.getIntegerValue() == 10 + assert int8.getIntegerValue() == 4294967295 + assert int9.getIntegerValue() == 4294967296 + assert int10.getIntegerValue() == 18446744073709551615 + assert int11.getIntegerValue() == 18446744073709551616 + + +def test_get_string(solver): + s1 = solver.mkString("abcde") + assert s1.isStringValue() + assert s1.getStringValue() == str("abcde") + + +def test_get_real(solver): + real1 = solver.mkReal("0") + real2 = solver.mkReal(".0") + real3 = solver.mkReal("-17") + real4 = solver.mkReal("-3/5") + real5 = solver.mkReal("12.7") + real6 = solver.mkReal("1/4294967297") + real7 = solver.mkReal("4294967297") + real8 = solver.mkReal("1/18446744073709551617") + real9 = solver.mkReal("18446744073709551617") + + assert real1.isRealValue() + assert real2.isRealValue() + assert real3.isRealValue() + assert real4.isRealValue() + assert real5.isRealValue() + assert real6.isRealValue() + assert real7.isRealValue() + assert real8.isRealValue() + assert real9.isRealValue() + + assert 0 == real1.getRealValue() + assert 0 == real2.getRealValue() + assert -17 == real3.getRealValue() + assert -3/5 == real4.getRealValue() + assert 127/10 == real5.getRealValue() + assert 1/4294967297 == real6.getRealValue() + assert 4294967297 == real7.getRealValue() + assert 1/18446744073709551617 == real8.getRealValue() + assert float(18446744073709551617) == real9.getRealValue() + + +def test_get_boolean(solver): + b1 = solver.mkBoolean(True) + b2 = solver.mkBoolean(False) + + assert b1.isBooleanValue() + assert b2.isBooleanValue() + assert b1.getBooleanValue() + assert not b2.getBooleanValue() + + +def test_get_bit_vector(solver): + b1 = solver.mkBitVector(8, 15) + b2 = solver.mkBitVector("00001111", 2) + b3 = solver.mkBitVector("15", 10) + b4 = solver.mkBitVector("0f", 16) + b5 = solver.mkBitVector(8, "00001111", 2) + b6 = solver.mkBitVector(8, "15", 10) + b7 = solver.mkBitVector(8, "0f", 16) + + assert b1.isBitVectorValue() + assert b2.isBitVectorValue() + assert b3.isBitVectorValue() + assert b4.isBitVectorValue() + assert b5.isBitVectorValue() + assert b6.isBitVectorValue() + assert b7.isBitVectorValue() + + assert "00001111" == b1.getBitVectorValue(2) + assert "15" == b1.getBitVectorValue(10) + assert "f" == b1.getBitVectorValue(16) + assert "00001111" == b2.getBitVectorValue(2) + assert "15" == b2.getBitVectorValue(10) + assert "f" == b2.getBitVectorValue(16) + assert "1111" == b3.getBitVectorValue(2) + assert "15" == b3.getBitVectorValue(10) + assert "f" == b3.getBitVectorValue(16) + assert "00001111" == b4.getBitVectorValue(2) + assert "15" == b4.getBitVectorValue(10) + assert "f" == b4.getBitVectorValue(16) + assert "00001111" == b5.getBitVectorValue(2) + assert "15" == b5.getBitVectorValue(10) + assert "f" == b5.getBitVectorValue(16) + assert "00001111" == b6.getBitVectorValue(2) + assert "15" == b6.getBitVectorValue(10) + assert "f" == b6.getBitVectorValue(16) + assert "00001111" == b7.getBitVectorValue(2) + assert "15" == b7.getBitVectorValue(10) + assert "f" == b7.getBitVectorValue(16) def test_const_array(solver): -- 2.30.2