From 1cfc902c459ca2e4f8f2946564de6f684164ad97 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Sun, 20 Mar 2022 17:39:41 +0100 Subject: [PATCH] Add `getStatistics()` to python API (#8343) This PR adds Solver.getStatistics() to the python API. To make the usage a bit more pythonic, we do not expose the iterator interface of api::Statistics but instead offer .get() which returns the whole range as a dictionary. The ::get() method to obtain a single statistic value is instead implemented via __getitem__(). The PR also includes the corresponding unit tests. --- src/api/cpp/cvc5.cpp | 9 +++++ src/api/cpp/cvc5.h | 13 +++++-- src/api/python/cvc5.pxd | 25 +++++++++++++ src/api/python/cvc5.pxi | 55 +++++++++++++++++++++++++++++ test/unit/api/python/test_solver.py | 15 ++++++++ 5 files changed, 114 insertions(+), 3 deletions(-) diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 085eefa8b..1b48b600d 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -4751,6 +4751,7 @@ struct Stat::StatData StatData() : data() {} }; +Stat::Stat() {} Stat::~Stat() {} Stat::Stat(const Stat& s) : d_internal(s.d_internal), @@ -4771,33 +4772,39 @@ bool Stat::isDefault() const { return d_default; } bool Stat::isInt() const { + if (!d_data) return false; return std::holds_alternative(d_data->data); } int64_t Stat::getInt() const { CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_RECOVERABLE_CHECK(static_cast(d_data)) << "Stat holds no value"; CVC5_API_RECOVERABLE_CHECK(isInt()) << "Expected Stat of type int64_t."; return std::get(d_data->data); CVC5_API_TRY_CATCH_END; } bool Stat::isDouble() const { + if (!d_data) return false; return std::holds_alternative(d_data->data); } double Stat::getDouble() const { CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_RECOVERABLE_CHECK(static_cast(d_data)) << "Stat holds no value"; CVC5_API_RECOVERABLE_CHECK(isDouble()) << "Expected Stat of type double."; return std::get(d_data->data); CVC5_API_TRY_CATCH_END; } bool Stat::isString() const { + if (!d_data) return false; return std::holds_alternative(d_data->data); } const std::string& Stat::getString() const { CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_RECOVERABLE_CHECK(static_cast(d_data)) << "Stat holds no value"; CVC5_API_RECOVERABLE_CHECK(isString()) << "Expected Stat of type std::string."; return std::get(d_data->data); @@ -4805,11 +4812,13 @@ const std::string& Stat::getString() const } bool Stat::isHistogram() const { + if (!d_data) return false; return std::holds_alternative(d_data->data); } const Stat::HistogramData& Stat::getHistogram() const { CVC5_API_TRY_CATCH_BEGIN; + CVC5_API_RECOVERABLE_CHECK(static_cast(d_data)) << "Stat holds no value"; CVC5_API_RECOVERABLE_CHECK(isHistogram()) << "Expected Stat of type histogram."; return std::get(d_data->data); diff --git a/src/api/cpp/cvc5.h b/src/api/cpp/cvc5.h index eedc10b80..7ea90c774 100644 --- a/src/api/cpp/cvc5.h +++ b/src/api/cpp/cvc5.h @@ -2840,8 +2840,12 @@ class CVC5_EXPORT Stat friend std::ostream& operator<<(std::ostream& os, const Stat& sv); /** Representation of a histogram: maps names to frequencies. */ using HistogramData = std::map; - /** Can only be obtained from a `Statistics` object. */ - Stat() = delete; + /** + * Create an empty statistics object. On such an object all ``isX()`` return + * false and all ``getX()`` throw an API exception. It solely exists because + * it makes implementing bindings for other languages much easier. + */ + Stat(); /** Copy constructor */ Stat(const Stat& s); /** Destructor */ @@ -2937,6 +2941,7 @@ class CVC5_EXPORT Statistics { public: friend class Statistics; + iterator() = default; BaseType::const_reference operator*() const; BaseType::const_pointer operator->() const; iterator& operator++(); @@ -2958,6 +2963,9 @@ class CVC5_EXPORT Statistics bool d_showDefault = false; }; + /** Creates an empty statistics object. */ + Statistics() = default; + /** * Retrieve the statistic with the given name. * Asserts that a statistic with the given name actually exists and throws @@ -2978,7 +2986,6 @@ class CVC5_EXPORT Statistics iterator end() const; private: - Statistics() = default; Statistics(const StatisticsRegistry& reg); /** Internal data */ BaseType d_stats; diff --git a/src/api/python/cvc5.pxd b/src/api/python/cvc5.pxd index 07250f143..71ec13081 100644 --- a/src/api/python/cvc5.pxd +++ b/src/api/python/cvc5.pxd @@ -2,6 +2,7 @@ from cython.operator cimport dereference as deref, preincrement as inc from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t from libc.stddef cimport wchar_t +from libcpp.map cimport map as c_map from libcpp.set cimport set from libcpp.string cimport string from libcpp.vector cimport vector @@ -302,6 +303,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": void blockModel() except + void blockModelValues(const vector[Term]& terms) except + string getInstantiations() except + + Statistics getStatistics() except + cdef cppclass Grammar: Grammar() except + @@ -380,6 +382,29 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api": SortHashFunction() except + size_t operator()(const Sort & s) except + + cdef cppclass Stat: + bint isInternal() except + + bint isDefault() except + + bint isInt() except + + int64_t getInt() except + + bint isDouble() except + + double getDouble() except + + bint isString() except + + string getString() except + + bint isHistogram() except + + c_map[string,uint64_t] getHistogram() except + + + cdef cppclass Statistics: + Statistics() except + + cppclass iterator: + iterator() except + + bint operator!=(const iterator& it) except + + iterator& operator++() except + + pair[string, Stat]& operator*() except +; + iterator begin(bint internal, bint defaulted) except + + iterator end() except + + Stat get(string name) except + + cdef cppclass Term: Term() bint operator==(const Term&) except + diff --git a/src/api/python/cvc5.pxi b/src/api/python/cvc5.pxi index f682af8aa..c8d552ea3 100644 --- a/src/api/python/cvc5.pxi +++ b/src/api/python/cvc5.pxi @@ -3,6 +3,8 @@ from fractions import Fraction from functools import wraps import sys +from cython.operator cimport dereference, preincrement + from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t from libc.stddef cimport wchar_t @@ -22,6 +24,8 @@ from cvc5 cimport RoundingMode as c_RoundingMode from cvc5 cimport UnknownExplanation as c_UnknownExplanation from cvc5 cimport Op as c_Op from cvc5 cimport Solver as c_Solver +from cvc5 cimport Statistics as c_Statistics +from cvc5 cimport Stat as c_Stat from cvc5 cimport Grammar as c_Grammar from cvc5 cimport Sort as c_Sort from cvc5 cimport ROUND_NEAREST_TIES_TO_EVEN, ROUND_TOWARD_POSITIVE @@ -2397,6 +2401,16 @@ cdef class Solver: """ return self.csolver.getInstantiations() + def getStatistics(self): + """ + Returns a snapshot of the current state of the statistic values of this + solver. The returned object is completely decoupled from the solver and + will not change when the solver is used again. + """ + res = Statistics() + res.cstats = self.csolver.getStatistics() + return res + cdef class Sort: """ @@ -2919,6 +2933,47 @@ cdef class Sort: return tuple_sorts +cdef class Statistics: + """ + The cvc5 Statistics. + Wrapper class for :cpp:class:`cvc5::api::Statistics`. + Obtain a single statistic value using ``stats["name"]`` and a dictionary + with all (visible) statistics using ``stats.get(internal=False, defaulted=False)``. + """ + cdef c_Statistics cstats + + cdef __stat_to_dict(self, const c_Stat& s): + res = None + if s.isInt(): + res = s.getInt() + elif s.isDouble(): + res = s.getDouble() + elif s.isString(): + res = s.getString().decode() + elif s.isHistogram(): + res = { h.first.decode(): h.second for h in s.getHistogram() } + return { + 'defaulted': s.isDefault(), + 'internal': s.isInternal(), + 'value': res + } + + def __getitem__(self, str name): + """Get the statistics information for the statistic called ``name``.""" + return self.__stat_to_dict(self.cstats.get(name.encode())) + + def get(self, bint internal = False, bint defaulted = False): + """Get all statistics. See :cpp:class:`cvc5::api::Statistics::begin()` for more information.""" + cdef c_Statistics.iterator it = self.cstats.begin(internal, defaulted) + cdef pair[string,c_Stat]* s + res = {} + while it != self.cstats.end(): + s = &dereference(it) + res[s.first.decode()] = self.__stat_to_dict(s.second) + preincrement(it) + return res + + cdef class Term: """ A cvc5 Term. diff --git a/test/unit/api/python/test_solver.py b/test/unit/api/python/test_solver.py index aca628d28..a05421ab2 100644 --- a/test/unit/api/python/test_solver.py +++ b/test/unit/api/python/test_solver.py @@ -1572,6 +1572,21 @@ def test_block_model_values5(solver): solver.checkSat() solver.blockModelValues([x]) +def test_get_statistics(solver): + intSort = solver.getIntegerSort() + x = solver.mkConst(intSort, "x") + y = solver.mkConst(intSort, "y") + zero = solver.mkInteger(0) + ten = solver.mkInteger(10) + f0 = solver.mkTerm(Kind.Geq, x, ten) + f1 = solver.mkTerm(Kind.Or, solver.mkTerm(Kind.Geq, zero, x), solver.mkTerm(Kind.Geq, y, zero)) + solver.assertFormula(f0) + solver.assertFormula(f1) + solver.checkSat() + s = solver.getStatistics() + assert s['api::TERM'] == {'defaulted': False, 'internal': False, 'value': {'GEQ': 3, 'OR': 1}} + assert s.get(True, False) != {} + def test_set_info(solver): with pytest.raises(RuntimeError): solver.setInfo("cvc5-lagic", "QF_BV") -- 2.30.2