From 304064c6bb3bf7ea7a7d54b66e2ad152e8fc4525 Mon Sep 17 00:00:00 2001 From: yoni206 Date: Wed, 12 May 2021 22:32:55 -0700 Subject: [PATCH] Adding functions to the python API and testing them -- part 2 (#6517) This PR adds some functions that are missing in the python API, along with unit tests for them. The unit tests are translated from https://github.com/cvc5/cvc5/blob/master/test/unit/api/term_black.cpp --- src/api/python/cvc5.pxd | 6 +++ src/api/python/cvc5.pxi | 17 +++++++ test/python/unit/api/test_term.py | 77 +++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+) diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd index b91a9e9c5..a044c79f5 100644 --- a/src/api/python/cvc5.pxd +++ b/src/api/python/cvc5.pxd @@ -341,6 +341,11 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": Term() bint operator==(const Term&) except + bint operator!=(const Term&) except + + bint operator<(const Term&) except + + bint operator>(const Term&) except + + bint operator<=(const Term&) except + + bint operator>=(const Term&) except + + size_t getNumChildren() except + Term operator[](size_t idx) except + uint64_t getId() except + Kind getKind() except + @@ -367,6 +372,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": Term operator*() except + const_iterator begin() except + const_iterator end() except + + bint isInteger() except + cdef cppclass TermHashFunction: TermHashFunction() except + diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi index fc14c1fca..a512a17a8 100644 --- a/src/api/python/cvc5.pxi +++ b/src/api/python/cvc5.pxi @@ -1460,6 +1460,18 @@ cdef class Term: def __ne__(self, Term other): return self.cterm != other.cterm + def __lt__(self, Term other): + return self.cterm < other.cterm + + def __gt__(self, Term other): + return self.cterm > other.cterm + + def __le__(self, Term other): + return self.cterm <= other.cterm + + def __ge__(self, Term other): + return self.cterm >= other.cterm + def __getitem__(self, int index): cdef Term term = Term(self.solver) if index >= 0: @@ -1483,6 +1495,8 @@ cdef class Term: def __hash__(self): return ctermhash(self.cterm) + def getNumChildren(self): + return self.cterm.getNumChildren() def getId(self): return self.cterm.getId() @@ -1584,6 +1598,9 @@ cdef class Term: term.cterm = self.cterm.iteTerm(then_t.cterm, else_t.cterm) return term + def isInteger(self): + return self.cterm.isInteger() + def toPythonObj(self): ''' Converts a constant value Term to a Python object. diff --git a/test/python/unit/api/test_term.py b/test/python/unit/api/test_term.py index 91424c905..702634807 100644 --- a/test/python/unit/api/test_term.py +++ b/test/python/unit/api/test_term.py @@ -897,6 +897,83 @@ def test_substitute(solver): xpx.substitute(es, rs) +def test_term_compare(solver): + t1 = solver.mkInteger(1) + t2 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2)) + t3 = solver.mkTerm(kinds.Plus, solver.mkInteger(2), solver.mkInteger(2)) + assert t2 >= t3 + assert t2 <= t3 + assert (t1 > t2) != (t1 < t2) + assert (t1 > t2 or t1 == t2) == (t1 >= t2) + +def test_term_children(solver): + # simple term 2+3 + two = solver.mkInteger(2) + t1 = solver.mkTerm(kinds.Plus, two, solver.mkInteger(3)) + assert t1[0] == two + assert t1.getNumChildren() == 2 + tnull = Term(solver) + with pytest.raises(RuntimeError): + tnull.getNumChildren() + + # apply term f(2) + intSort = solver.getIntegerSort() + fsort = solver.mkFunctionSort(intSort, intSort) + f = solver.mkConst(fsort, "f") + t2 = solver.mkTerm(kinds.ApplyUf, f, two) + # due to our higher-order view of terms, we treat f as a child of kinds.ApplyUf + assert t2.getNumChildren() == 2 + assert t2[0] == f + assert t2[1] == two + with pytest.raises(RuntimeError): + tnull[0] + +def test_is_integer(solver): + int1 = solver.mkInteger("-18446744073709551616") + int2 = solver.mkInteger("-18446744073709551615") + int3 = solver.mkInteger("-4294967296") + int4 = solver.mkInteger("-4294967295") + int5 = solver.mkInteger("-10") + int6 = solver.mkInteger("0") + int7 = solver.mkInteger("10") + int8 = solver.mkInteger("4294967295") + int9 = solver.mkInteger("4294967296") + int10 = solver.mkInteger("18446744073709551615") + int11 = solver.mkInteger("18446744073709551616") + int12 = solver.mkInteger("-0") + + with pytest.raises(RuntimeError): + solver.mkInteger("") + with pytest.raises(RuntimeError): + solver.mkInteger("-") + with pytest.raises(RuntimeError): + solver.mkInteger("-1-") + with pytest.raises(RuntimeError): + solver.mkInteger("0.0") + with pytest.raises(RuntimeError): + solver.mkInteger("-0.1") + with pytest.raises(RuntimeError): + solver.mkInteger("012") + with pytest.raises(RuntimeError): + solver.mkInteger("0000") + with pytest.raises(RuntimeError): + solver.mkInteger("-01") + with pytest.raises(RuntimeError): + solver.mkInteger("-00") + + assert int1.isInteger() + assert int2.isInteger() + assert int3.isInteger() + assert int4.isInteger() + assert int5.isInteger() + assert int6.isInteger() + assert int7.isInteger() + assert int8.isInteger() + assert int9.isInteger() + assert int10.isInteger() + assert int11.isInteger() + + def test_const_array(solver): intsort = solver.getIntegerSort() arrsort = solver.mkArraySort(intsort, intsort) -- 2.30.2