Add method to get Python object from constant value term in Python API (#5083)
authormakaimann <makaim@stanford.edu>
Tue, 22 Sep 2020 18:58:03 +0000 (11:58 -0700)
committerGitHub <noreply@github.com>
Tue, 22 Sep 2020 18:58:03 +0000 (11:58 -0700)
This PR addresses issue https://github.com/CVC4/CVC4/issues/5014. It simply interprets the SMT-LIB string representation and produces a Python object. It currently supports booleans, ints, reals, bit-vectors, and arrays. The method (`toPythonObj`) is only valid to call if `isConst` returns true.

src/api/python/cvc4.pxi
test/unit/api/python/test_to_python_obj.py [new file with mode: 0644]

index 8c4bfe5e51429cafeb4206801fe4850b900d5a0e..9e4102b909802f553aeab9f62caff174ac9b065b 100644 (file)
@@ -1,3 +1,5 @@
+from collections import defaultdict
+from fractions import Fraction
 import sys
 
 from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t
@@ -1503,6 +1505,99 @@ cdef class Term:
         term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm)
         return term
 
+    def toPythonObj(self):
+        '''
+        Converts a constant value Term to a Python object.
+        Requires isConst to hold.
+
+        Currently supports:
+          Boolean -- returns a Python bool
+          Int     -- returns a Python int
+          Real    -- returns a Python Fraction
+          BV      -- returns a Python int (treats BV as unsigned)
+          Array   -- returns a Python dict mapping indices to values
+                  -- the constant base is returned as the default value
+        '''
+
+        if not self.isConst():
+            raise RuntimeError("Cannot call toPythonObj on a non-const Term")
+
+        string_repr = self.cterm.toString().decode()
+        assert string_repr
+        sort = self.getSort()
+        res = None
+        if sort.isBoolean():
+            if string_repr == "true":
+                res = True
+            else:
+                assert string_repr == "false"
+                res = False
+        elif sort.isInteger():
+            updated_string_repr = string_repr.strip('()').replace(' ', '')
+            try:
+                res = int(updated_string_repr)
+            except:
+                raise ValueError("Failed to convert"
+                                 " {} to an int".format(string_repr))
+        elif sort.isReal():
+            updated_string_repr = string_repr
+            try:
+                # expecting format (/ a b)
+                # note: a or b could be negated: (- a)
+                splits = [s.strip('()/')
+                          for s in updated_string_repr.strip('()/') \
+                          .replace('(- ', '(-').split()]
+                assert len(splits) == 2
+                num = int(splits[0])
+                den = int(splits[1])
+                res = Fraction(num, den)
+            except:
+                raise ValueError("Failed to convert "
+                                 "{} to a Fraction".format(string_repr))
+
+        elif sort.isBitVector():
+            # expecting format #b<bits>
+            assert string_repr[:2] == "#b"
+            python_bin_repr = "0" + string_repr[1:]
+            try:
+                res = int(python_bin_repr, 2)
+            except:
+                raise ValueError("Failed to convert bitvector "
+                                 "{} to an int".format(string_repr))
+        elif sort.isArray():
+            keys = []
+            values = []
+            base_value = None
+            to_visit = [self]
+            # Array models are represented as a series of store operations
+            # on a constant array
+            while to_visit:
+                t = to_visit.pop()
+                if t.getKind() == kinds.Store:
+                    # save the mappings
+                    keys.append(t[1].toPythonObj())
+                    values.append(t[2].toPythonObj())
+                    to_visit.append(t[0])
+                else:
+                    assert t.getKind() == kinds.ConstArray
+                    base_value = t.getConstArrayBase().toPythonObj()
+
+            assert len(keys) == len(values)
+            assert base_value is not None
+
+            # put everything in a dictionary with the constant
+            # base as the result for any index not included in the stores
+            res = defaultdict(lambda : base_value)
+            for k, v in zip(keys, values):
+                res[k] = v
+        else:
+            raise ValueError("Cannot convert term {}"
+                             " of sort {} to Python object".format(string_repr,
+                                                                   sort))
+
+        assert res is not None
+        return res
+
 
 # Generate rounding modes
 cdef __rounding_modes = {
diff --git a/test/unit/api/python/test_to_python_obj.py b/test/unit/api/python/test_to_python_obj.py
new file mode 100644 (file)
index 0000000..69b7b92
--- /dev/null
@@ -0,0 +1,66 @@
+from fractions import Fraction
+import pytest
+
+import pycvc4
+from pycvc4 import kinds
+
+
+def testGetBool():
+    solver = pycvc4.Solver()
+    t = solver.mkTrue()
+    f = solver.mkFalse()
+    assert t.toPythonObj() == True
+    assert f.toPythonObj() == False
+
+
+def testGetInt():
+    solver = pycvc4.Solver()
+    two = solver.mkReal(2)
+    assert two.toPythonObj() == 2
+
+
+def testGetReal():
+    solver = pycvc4.Solver()
+    half = solver.mkReal("1/2")
+    assert half.toPythonObj() == Fraction(1, 2)
+
+    neg34 = solver.mkReal("-3/4")
+    assert neg34.toPythonObj() == Fraction(-3, 4)
+
+    neg1 = solver.mkReal("-1")
+    assert neg1.toPythonObj() == -1
+
+
+def testGetBV():
+    solver = pycvc4.Solver()
+    three = solver.mkBitVector(8, 3)
+    assert three.toPythonObj() == 3
+
+
+def testGetArray():
+    solver = pycvc4.Solver()
+    arrsort = solver.mkArraySort(solver.getRealSort(), solver.getRealSort())
+    zero_array = solver.mkConstArray(arrsort, solver.mkReal(0))
+    stores = solver.mkTerm(kinds.Store, zero_array, solver.mkReal(1), solver.mkReal(2))
+    stores = solver.mkTerm(kinds.Store, stores, solver.mkReal(2), solver.mkReal(3))
+    stores = solver.mkTerm(kinds.Store, stores, solver.mkReal(4), solver.mkReal(5))
+
+    assert stores.isConst()
+
+    array_dict = stores.toPythonObj()
+
+    print(array_dict)
+
+    assert array_dict[1] == 2
+    assert array_dict[2] == 3
+    assert array_dict[4] == 5
+    # an index that wasn't stored at should give zero
+    assert array_dict[8] == 0
+
+
+def testGetSymbol():
+    solver = pycvc4.Solver()
+    x = solver.mkConst(solver.getBooleanSort(), "x")
+
+    with pytest.raises(RuntimeError):
+        x.toPythonObj()