Adding getters to the python API and testing them (#6652)
authoryoni206 <yoni206@users.noreply.github.com>
Wed, 2 Jun 2021 19:30:46 +0000 (12:30 -0700)
committerGitHub <noreply@github.com>
Wed, 2 Jun 2021 19:30:46 +0000 (14:30 -0500)
This PR adds missing API functions from the cpp Term API to the python API.
Corresponding tests are translated from term_black.cpp.

src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/python/unit/api/test_term.py

index fdcbfa9979750236762cb83dd86ed542f012f4e7..2ad8cef5c5eca68e4dc31d6adbdf09d3944a748e 100644 (file)
@@ -176,6 +176,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         Sort mkTupleSort(const vector[Sort]& sorts) except +
         Term mkTerm(Op op) except +
         Term mkTerm(Op op, const vector[Term]& children) except +
+        Term mkTuple(const vector[Sort]& sorts, const vector[Term]& terms) except +
         Op mkOp(Kind kind) except +
         Op mkOp(Kind kind, Kind k) except +
         Op mkOp(Kind kind, const string& arg) except +
@@ -388,6 +389,8 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
             Term operator*() except +
         const_iterator begin() except +
         const_iterator end() except +
+
+        bint isConstArray() except +
         bint isBooleanValue() except +
         bint getBooleanValue() except +
         bint isStringValue() except +
@@ -398,6 +401,8 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         string getRealValue() except +
         bint isBitVectorValue() except +
         string getBitVectorValue(uint32_t base) except +
+        bint isAbstractValue() except +
+        string getAbstractValue() except +
         bint isFloatingPointPosZero() except +
         bint isFloatingPointNegZero() except +
         bint isFloatingPointPosInf() except +
@@ -406,7 +411,15 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         bint isFloatingPointValue() except +
 
         tuple[uint32_t, uint32_t, Term] getFloatingPointValue() except +
+        bint isSetValue() except +
+        set[Term] getSetValue() except +
+        bint isSequenceValue() except +
         vector[Term] getSequenceValue() except +
+        bint isUninterpretedValue() except +
+        pair[Sort, int32_t] getUninterpretedValue() except +
+        bint isTupleValue() except +
+        vector[Term] getTupleValue() except +
+
 
     cdef cppclass TermHashFunction:
         TermHashFunction() except +
index 25ded76bbdb2b6a1659dc498cf8cae8994fa3d3b..8599a1cd1d47cbab2004ce64a2b7b8fb5b1d2071 100644 (file)
@@ -658,6 +658,19 @@ cdef class Solver:
             term.cterm = self.csolver.mkTerm((<Op?> op).cop, v)
         return term
 
+    def mkTuple(self, sorts, terms):
+        cdef vector[c_Sort] csorts
+        cdef vector[c_Term] cterms
+
+        for s in sorts:
+            csorts.push_back((<Sort?> s).csort)
+        for s in terms:
+            cterms.push_back((<Term?> s).cterm)
+        cdef Term result = Term(self)
+        result.cterm = self.csolver.mkTuple(csorts, cterms)
+        return result
+
+
     def mkOp(self, kind k, arg0=None, arg1 = None):
         '''
         Supports the following uses:
@@ -1609,19 +1622,6 @@ cdef class Term:
     def isNull(self):
         return self.cterm.isNull()
 
-    def getConstArrayBase(self):
-        cdef Term term = Term(self.solver)
-        term.cterm = self.cterm.getConstArrayBase()
-        return term
-
-    def getSequenceValue(self):
-        elems = []
-        for e in self.cterm.getSequenceValue():
-            term = Term(self.solver)
-            term.cterm = e
-            elems.append(term)
-        return elems
-
     def notTerm(self):
         cdef Term term = Term(self.solver)
         term.cterm = self.cterm.notTerm()
@@ -1657,6 +1657,14 @@ cdef class Term:
         term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm)
         return term
 
+    def isConstArray(self):
+        return self.cterm.isConstArray()
+
+    def getConstArrayBase(self):
+        cdef Term term = Term(self.solver)
+        term.cterm = self.cterm.getConstArrayBase()
+        return term
+
     def isBooleanValue(self):
         return self.cterm.isBooleanValue()
 
@@ -1673,7 +1681,12 @@ cdef class Term:
 
     def isIntegerValue(self):
         return self.cterm.isIntegerValue()
-    
+    def isAbstractValue(self):
+        return self.cterm.isAbstractValue()
+
+    def getAbstractValue(self):
+        return self.cterm.getAbstractValue().decode()
+
     def isFloatingPointPosZero(self):
         return self.cterm.isFloatingPointPosZero()
     
@@ -1698,6 +1711,49 @@ cdef class Term:
         term.cterm = get2(t)
         return (get0(t), get1(t), term)
 
+    def isSetValue(self):
+        return self.cterm.isSetValue()
+
+    def getSetValue(self):
+        elems = set()
+        for e in self.cterm.getSetValue():
+            term = Term(self.solver)
+            term.cterm = e
+            elems.add(term)
+        return elems
+
+    def isSequenceValue(self):
+        return self.cterm.isSequenceValue()
+
+    def getSequenceValue(self):
+        elems = []
+        for e in self.cterm.getSequenceValue():
+            term = Term(self.solver)
+            term.cterm = e
+            elems.append(term)
+        return elems
+
+    def isUninterpretedValue(self):
+        return self.cterm.isUninterpretedValue()
+    def getUninterpretedValue(self):
+        cdef pair[c_Sort, int32_t] p = self.cterm.getUninterpretedValue()
+        cdef Sort sort = Sort(self.solver)
+        sort.csort = p.first
+        i = p.second
+        return (sort, i)
+
+    def isTupleValue(self):
+        return self.cterm.isTupleValue()
+
+    def getTupleValue(self):
+        elems = []
+        for e in self.cterm.getTupleValue():
+            term = Term(self.solver)
+            term.cterm = e
+            elems.append(term)
+        return elems
+
     def getIntegerValue(self):
         return int(self.cterm.getIntegerValue().decode())
 
index 5603655c61f388168e0636021b8565d1f57c63ee..32813e17f390a281c167c56d4de059a43f23feb3 100644 (file)
@@ -930,6 +930,114 @@ def test_term_children(solver):
         tnull[0]
 
 
+def test_get_const_array_base(solver):
+    intsort = solver.getIntegerSort()
+    arrsort = solver.mkArraySort(intsort, intsort)
+    one = solver.mkInteger(1)
+    constarr = solver.mkConstArray(arrsort, one)
+
+    assert constarr.isConstArray()
+    assert one == constarr.getConstArrayBase()
+
+
+def test_get_abstract_value(solver):
+    v1 = solver.mkAbstractValue(1)
+    v2 = solver.mkAbstractValue("15")
+    v3 = solver.mkAbstractValue("18446744073709551617")
+
+    assert v1.isAbstractValue()
+    assert v2.isAbstractValue()
+    assert v3.isAbstractValue()
+    assert "1" == v1.getAbstractValue()
+    assert "15" == v2.getAbstractValue()
+    assert "18446744073709551617" == v3.getAbstractValue()
+
+
+def test_get_tuple(solver):
+    s1 = solver.getIntegerSort()
+    s2 = solver.getRealSort()
+    s3 = solver.getStringSort()
+
+    t1 = solver.mkInteger(15)
+    t2 = solver.mkReal(17, 25)
+    t3 = solver.mkString("abc")
+
+    tup = solver.mkTuple([s1, s2, s3], [t1, t2, t3])
+
+    assert tup.isTupleValue()
+    assert [t1, t2, t3] == tup.getTupleValue()
+
+
+def test_get_set(solver):
+    s = solver.mkSetSort(solver.getIntegerSort())
+
+    i1 = solver.mkInteger(5)
+    i2 = solver.mkInteger(7)
+
+    s1 = solver.mkEmptySet(s)
+    s2 = solver.mkTerm(kinds.Singleton, i1)
+    s3 = solver.mkTerm(kinds.Singleton, i1)
+    s4 = solver.mkTerm(kinds.Singleton, i2)
+    s5 = solver.mkTerm(kinds.Union, s2, solver.mkTerm(kinds.Union, s3, s4))
+
+    assert s1.isSetValue()
+    assert s2.isSetValue()
+    assert s3.isSetValue()
+    assert s4.isSetValue()
+    assert not s5.isSetValue()
+    s5 = solver.simplify(s5)
+    assert s5.isSetValue()
+
+    assert set([]) == s1.getSetValue()
+    assert set([i1]) == s2.getSetValue()
+    assert set([i1]) == s3.getSetValue()
+    assert set([i2]) == s4.getSetValue()
+    assert set([i1, i2]) == s5.getSetValue()
+
+
+def test_get_sequence(solver):
+    s = solver.mkSequenceSort(solver.getIntegerSort())
+
+    i1 = solver.mkInteger(5)
+    i2 = solver.mkInteger(7)
+
+    s1 = solver.mkEmptySequence(s)
+    s2 = solver.mkTerm(kinds.SeqUnit, i1)
+    s3 = solver.mkTerm(kinds.SeqUnit, i1)
+    s4 = solver.mkTerm(kinds.SeqUnit, i2)
+    s5 = solver.mkTerm(kinds.SeqConcat, s2,
+                       solver.mkTerm(kinds.SeqConcat, s3, s4))
+
+    assert s1.isSequenceValue()
+    assert not s2.isSequenceValue()
+    assert not s3.isSequenceValue()
+    assert not s4.isSequenceValue()
+    assert not s5.isSequenceValue()
+
+    s2 = solver.simplify(s2)
+    s3 = solver.simplify(s3)
+    s4 = solver.simplify(s4)
+    s5 = solver.simplify(s5)
+
+    assert [] == s1.getSequenceValue()
+    assert [i1] == s2.getSequenceValue()
+    assert [i1] == s3.getSequenceValue()
+    assert [i2] == s4.getSequenceValue()
+    assert [i1, i1, i2] == s5.getSequenceValue()
+
+
+def test_get_uninterpreted_const(solver):
+    s = solver.mkUninterpretedSort("test")
+    t1 = solver.mkUninterpretedConst(s, 3)
+    t2 = solver.mkUninterpretedConst(s, 5)
+
+    assert t1.isUninterpretedValue()
+    assert t2.isUninterpretedValue()
+
+    assert (s, 3) == t1.getUninterpretedValue()
+    assert (s, 5) == t2.getUninterpretedValue()
+
+
 def test_get_floating_point(solver):
     bvval = solver.mkBitVector("0000110000000011")
     fp = solver.mkFloatingPoint(5, 11, bvval)