From: yoni206 Date: Wed, 2 Jun 2021 19:30:46 +0000 (-0700) Subject: Adding getters to the python API and testing them (#6652) X-Git-Tag: cvc5-1.0.0~1653 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=85a300898d7815973c064fe2c7b5b33473a71a5c;p=cvc5.git Adding getters to the python API and testing them (#6652) This PR adds missing API functions from the cpp Term API to the python API. Corresponding tests are translated from term_black.cpp. --- diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd index fdcbfa997..2ad8cef5c 100644 --- a/src/api/python/cvc5.pxd +++ b/src/api/python/cvc5.pxd @@ -176,6 +176,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": Sort mkTupleSort(const vector[Sort]& sorts) except + Term mkTerm(Op op) except + Term mkTerm(Op op, const vector[Term]& children) except + + Term mkTuple(const vector[Sort]& sorts, const vector[Term]& terms) except + Op mkOp(Kind kind) except + Op mkOp(Kind kind, Kind k) except + Op mkOp(Kind kind, const string& arg) except + @@ -388,6 +389,8 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": Term operator*() except + const_iterator begin() except + const_iterator end() except + + + bint isConstArray() except + bint isBooleanValue() except + bint getBooleanValue() except + bint isStringValue() except + @@ -398,6 +401,8 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": string getRealValue() except + bint isBitVectorValue() except + string getBitVectorValue(uint32_t base) except + + bint isAbstractValue() except + + string getAbstractValue() except + bint isFloatingPointPosZero() except + bint isFloatingPointNegZero() except + bint isFloatingPointPosInf() except + @@ -406,7 +411,15 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": bint isFloatingPointValue() except + tuple[uint32_t, uint32_t, Term] getFloatingPointValue() except + + bint isSetValue() except + + set[Term] getSetValue() except + + bint isSequenceValue() except + vector[Term] getSequenceValue() except + + bint isUninterpretedValue() except + + pair[Sort, int32_t] getUninterpretedValue() except + + bint isTupleValue() except + + vector[Term] getTupleValue() except + + cdef cppclass TermHashFunction: TermHashFunction() except + diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi index 25ded76bb..8599a1cd1 100644 --- a/src/api/python/cvc5.pxi +++ b/src/api/python/cvc5.pxi @@ -658,6 +658,19 @@ cdef class Solver: term.cterm = self.csolver.mkTerm(( op).cop, v) return term + def mkTuple(self, sorts, terms): + cdef vector[c_Sort] csorts + cdef vector[c_Term] cterms + + for s in sorts: + csorts.push_back(( s).csort) + for s in terms: + cterms.push_back(( s).cterm) + cdef Term result = Term(self) + result.cterm = self.csolver.mkTuple(csorts, cterms) + return result + + def mkOp(self, kind k, arg0=None, arg1 = None): ''' Supports the following uses: @@ -1609,19 +1622,6 @@ cdef class Term: def isNull(self): return self.cterm.isNull() - def getConstArrayBase(self): - cdef Term term = Term(self.solver) - term.cterm = self.cterm.getConstArrayBase() - return term - - def getSequenceValue(self): - elems = [] - for e in self.cterm.getSequenceValue(): - term = Term(self.solver) - term.cterm = e - elems.append(term) - return elems - def notTerm(self): cdef Term term = Term(self.solver) term.cterm = self.cterm.notTerm() @@ -1657,6 +1657,14 @@ cdef class Term: term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm) return term + def isConstArray(self): + return self.cterm.isConstArray() + + def getConstArrayBase(self): + cdef Term term = Term(self.solver) + term.cterm = self.cterm.getConstArrayBase() + return term + def isBooleanValue(self): return self.cterm.isBooleanValue() @@ -1673,7 +1681,12 @@ cdef class Term: def isIntegerValue(self): return self.cterm.isIntegerValue() - + def isAbstractValue(self): + return self.cterm.isAbstractValue() + + def getAbstractValue(self): + return self.cterm.getAbstractValue().decode() + def isFloatingPointPosZero(self): return self.cterm.isFloatingPointPosZero() @@ -1698,6 +1711,49 @@ cdef class Term: term.cterm = get2(t) return (get0(t), get1(t), term) + def isSetValue(self): + return self.cterm.isSetValue() + + def getSetValue(self): + elems = set() + for e in self.cterm.getSetValue(): + term = Term(self.solver) + term.cterm = e + elems.add(term) + return elems + + def isSequenceValue(self): + return self.cterm.isSequenceValue() + + def getSequenceValue(self): + elems = [] + for e in self.cterm.getSequenceValue(): + term = Term(self.solver) + term.cterm = e + elems.append(term) + return elems + + def isUninterpretedValue(self): + return self.cterm.isUninterpretedValue() + + def getUninterpretedValue(self): + cdef pair[c_Sort, int32_t] p = self.cterm.getUninterpretedValue() + cdef Sort sort = Sort(self.solver) + sort.csort = p.first + i = p.second + return (sort, i) + + def isTupleValue(self): + return self.cterm.isTupleValue() + + def getTupleValue(self): + elems = [] + for e in self.cterm.getTupleValue(): + term = Term(self.solver) + term.cterm = e + elems.append(term) + return elems + def getIntegerValue(self): return int(self.cterm.getIntegerValue().decode()) diff --git a/test/python/unit/api/test_term.py b/test/python/unit/api/test_term.py index 5603655c6..32813e17f 100644 --- a/test/python/unit/api/test_term.py +++ b/test/python/unit/api/test_term.py @@ -930,6 +930,114 @@ def test_term_children(solver): tnull[0] +def test_get_const_array_base(solver): + intsort = solver.getIntegerSort() + arrsort = solver.mkArraySort(intsort, intsort) + one = solver.mkInteger(1) + constarr = solver.mkConstArray(arrsort, one) + + assert constarr.isConstArray() + assert one == constarr.getConstArrayBase() + + +def test_get_abstract_value(solver): + v1 = solver.mkAbstractValue(1) + v2 = solver.mkAbstractValue("15") + v3 = solver.mkAbstractValue("18446744073709551617") + + assert v1.isAbstractValue() + assert v2.isAbstractValue() + assert v3.isAbstractValue() + assert "1" == v1.getAbstractValue() + assert "15" == v2.getAbstractValue() + assert "18446744073709551617" == v3.getAbstractValue() + + +def test_get_tuple(solver): + s1 = solver.getIntegerSort() + s2 = solver.getRealSort() + s3 = solver.getStringSort() + + t1 = solver.mkInteger(15) + t2 = solver.mkReal(17, 25) + t3 = solver.mkString("abc") + + tup = solver.mkTuple([s1, s2, s3], [t1, t2, t3]) + + assert tup.isTupleValue() + assert [t1, t2, t3] == tup.getTupleValue() + + +def test_get_set(solver): + s = solver.mkSetSort(solver.getIntegerSort()) + + i1 = solver.mkInteger(5) + i2 = solver.mkInteger(7) + + s1 = solver.mkEmptySet(s) + s2 = solver.mkTerm(kinds.Singleton, i1) + s3 = solver.mkTerm(kinds.Singleton, i1) + s4 = solver.mkTerm(kinds.Singleton, i2) + s5 = solver.mkTerm(kinds.Union, s2, solver.mkTerm(kinds.Union, s3, s4)) + + assert s1.isSetValue() + assert s2.isSetValue() + assert s3.isSetValue() + assert s4.isSetValue() + assert not s5.isSetValue() + s5 = solver.simplify(s5) + assert s5.isSetValue() + + assert set([]) == s1.getSetValue() + assert set([i1]) == s2.getSetValue() + assert set([i1]) == s3.getSetValue() + assert set([i2]) == s4.getSetValue() + assert set([i1, i2]) == s5.getSetValue() + + +def test_get_sequence(solver): + s = solver.mkSequenceSort(solver.getIntegerSort()) + + i1 = solver.mkInteger(5) + i2 = solver.mkInteger(7) + + s1 = solver.mkEmptySequence(s) + s2 = solver.mkTerm(kinds.SeqUnit, i1) + s3 = solver.mkTerm(kinds.SeqUnit, i1) + s4 = solver.mkTerm(kinds.SeqUnit, i2) + s5 = solver.mkTerm(kinds.SeqConcat, s2, + solver.mkTerm(kinds.SeqConcat, s3, s4)) + + assert s1.isSequenceValue() + assert not s2.isSequenceValue() + assert not s3.isSequenceValue() + assert not s4.isSequenceValue() + assert not s5.isSequenceValue() + + s2 = solver.simplify(s2) + s3 = solver.simplify(s3) + s4 = solver.simplify(s4) + s5 = solver.simplify(s5) + + assert [] == s1.getSequenceValue() + assert [i1] == s2.getSequenceValue() + assert [i1] == s3.getSequenceValue() + assert [i2] == s4.getSequenceValue() + assert [i1, i1, i2] == s5.getSequenceValue() + + +def test_get_uninterpreted_const(solver): + s = solver.mkUninterpretedSort("test") + t1 = solver.mkUninterpretedConst(s, 3) + t2 = solver.mkUninterpretedConst(s, 5) + + assert t1.isUninterpretedValue() + assert t2.isUninterpretedValue() + + assert (s, 3) == t1.getUninterpretedValue() + assert (s, 5) == t2.getUninterpretedValue() + + def test_get_floating_point(solver): bvval = solver.mkBitVector("0000110000000011") fp = solver.mkFloatingPoint(5, 11, bvval)