Add `getStatistics()` to python API (#8343)
authorGereon Kremer <gkremer@cs.stanford.edu>
Sun, 20 Mar 2022 16:39:41 +0000 (17:39 +0100)
committerGitHub <noreply@github.com>
Sun, 20 Mar 2022 16:39:41 +0000 (09:39 -0700)
This PR adds Solver.getStatistics() to the python API. To make the usage a bit more pythonic, we do not expose the iterator interface of api::Statistics but instead offer .get() which returns the whole range as a dictionary. The ::get() method to obtain a single statistic value is instead implemented via __getitem__().
The PR also includes the corresponding unit tests.

src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/api/python/cvc5.pxd
src/api/python/cvc5.pxi
test/unit/api/python/test_solver.py

index 085eefa8b0078d5e0a24dd4183eafd047e74b65d..1b48b600d6073c0adb24a9bdce8de2c08dbf8394 100644 (file)
@@ -4751,6 +4751,7 @@ struct Stat::StatData
   StatData() : data() {}
 };
 
+Stat::Stat() {}
 Stat::~Stat() {}
 Stat::Stat(const Stat& s)
     : d_internal(s.d_internal),
@@ -4771,33 +4772,39 @@ bool Stat::isDefault() const { return d_default; }
 
 bool Stat::isInt() const
 {
+  if (!d_data) return false;
   return std::holds_alternative<int64_t>(d_data->data);
 }
 int64_t Stat::getInt() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_RECOVERABLE_CHECK(static_cast<bool>(d_data)) << "Stat holds no value";
   CVC5_API_RECOVERABLE_CHECK(isInt()) << "Expected Stat of type int64_t.";
   return std::get<int64_t>(d_data->data);
   CVC5_API_TRY_CATCH_END;
 }
 bool Stat::isDouble() const
 {
+  if (!d_data) return false;
   return std::holds_alternative<double>(d_data->data);
 }
 double Stat::getDouble() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_RECOVERABLE_CHECK(static_cast<bool>(d_data)) << "Stat holds no value";
   CVC5_API_RECOVERABLE_CHECK(isDouble()) << "Expected Stat of type double.";
   return std::get<double>(d_data->data);
   CVC5_API_TRY_CATCH_END;
 }
 bool Stat::isString() const
 {
+  if (!d_data) return false;
   return std::holds_alternative<std::string>(d_data->data);
 }
 const std::string& Stat::getString() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_RECOVERABLE_CHECK(static_cast<bool>(d_data)) << "Stat holds no value";
   CVC5_API_RECOVERABLE_CHECK(isString())
       << "Expected Stat of type std::string.";
   return std::get<std::string>(d_data->data);
@@ -4805,11 +4812,13 @@ const std::string& Stat::getString() const
 }
 bool Stat::isHistogram() const
 {
+  if (!d_data) return false;
   return std::holds_alternative<HistogramData>(d_data->data);
 }
 const Stat::HistogramData& Stat::getHistogram() const
 {
   CVC5_API_TRY_CATCH_BEGIN;
+  CVC5_API_RECOVERABLE_CHECK(static_cast<bool>(d_data)) << "Stat holds no value";
   CVC5_API_RECOVERABLE_CHECK(isHistogram())
       << "Expected Stat of type histogram.";
   return std::get<HistogramData>(d_data->data);
index eedc10b809ea4fba410c78e2c1ce7fb50ac24994..7ea90c7749621609a38d96459977c35e9bbc3b44 100644 (file)
@@ -2840,8 +2840,12 @@ class CVC5_EXPORT Stat
   friend std::ostream& operator<<(std::ostream& os, const Stat& sv);
   /** Representation of a histogram: maps names to frequencies. */
   using HistogramData = std::map<std::string, uint64_t>;
-  /** Can only be obtained from a `Statistics` object. */
-  Stat() = delete;
+  /**
+   * Create an empty statistics object. On such an object all ``isX()`` return
+   * false and all ``getX()`` throw an API exception. It solely exists because
+   * it makes implementing bindings for other languages much easier.
+   */
+  Stat();
   /** Copy constructor */
   Stat(const Stat& s);
   /** Destructor */
@@ -2937,6 +2941,7 @@ class CVC5_EXPORT Statistics
   {
    public:
     friend class Statistics;
+    iterator() = default;
     BaseType::const_reference operator*() const;
     BaseType::const_pointer operator->() const;
     iterator& operator++();
@@ -2958,6 +2963,9 @@ class CVC5_EXPORT Statistics
     bool d_showDefault = false;
   };
 
+  /** Creates an empty statistics object. */
+  Statistics() = default;
+
   /**
    * Retrieve the statistic with the given name.
    * Asserts that a statistic with the given name actually exists and throws
@@ -2978,7 +2986,6 @@ class CVC5_EXPORT Statistics
   iterator end() const;
 
  private:
-  Statistics() = default;
   Statistics(const StatisticsRegistry& reg);
   /** Internal data */
   BaseType d_stats;
index 07250f143cb93ee7f56c5ee1a8c6282b6df3d31e..71ec13081dc779766ad4ace73e59d74e477a91c2 100644 (file)
@@ -2,6 +2,7 @@
 from cython.operator cimport dereference as deref, preincrement as inc
 from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t
 from libc.stddef cimport wchar_t
+from libcpp.map cimport map as c_map
 from libcpp.set cimport set
 from libcpp.string cimport string
 from libcpp.vector cimport vector
@@ -302,6 +303,7 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         void blockModel() except +
         void blockModelValues(const vector[Term]& terms) except +
         string getInstantiations() except +
+        Statistics getStatistics() except +
 
     cdef cppclass Grammar:
         Grammar() except +
@@ -380,6 +382,29 @@ cdef extern from "api/cpp/cvc5.h" namespace "cvc5::api":
         SortHashFunction() except +
         size_t operator()(const Sort & s) except +
 
+    cdef cppclass Stat:
+        bint isInternal() except +
+        bint isDefault() except +
+        bint isInt() except +
+        int64_t getInt() except +
+        bint isDouble() except +
+        double getDouble() except +
+        bint isString() except +
+        string getString() except +
+        bint isHistogram() except +
+        c_map[string,uint64_t] getHistogram() except +
+
+    cdef cppclass Statistics:
+        Statistics() except +
+        cppclass iterator:
+            iterator() except +
+            bint operator!=(const iterator& it) except +
+            iterator& operator++() except +
+            pair[string, Stat]& operator*() except +;
+        iterator begin(bint internal, bint defaulted) except +
+        iterator end() except +
+        Stat get(string name) except +
+
     cdef cppclass Term:
         Term()
         bint operator==(const Term&) except +
index f682af8aac104c40352829f62faa6afd55decc43..c8d552ea3bd9d626e6dce889d496fbdd12c272ff 100644 (file)
@@ -3,6 +3,8 @@ from fractions import Fraction
 from functools import wraps
 import sys
 
+from cython.operator cimport dereference, preincrement
+
 from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t
 from libc.stddef cimport wchar_t
 
@@ -22,6 +24,8 @@ from cvc5 cimport RoundingMode as c_RoundingMode
 from cvc5 cimport UnknownExplanation as c_UnknownExplanation
 from cvc5 cimport Op as c_Op
 from cvc5 cimport Solver as c_Solver
+from cvc5 cimport Statistics as c_Statistics
+from cvc5 cimport Stat as c_Stat
 from cvc5 cimport Grammar as c_Grammar
 from cvc5 cimport Sort as c_Sort
 from cvc5 cimport ROUND_NEAREST_TIES_TO_EVEN, ROUND_TOWARD_POSITIVE
@@ -2397,6 +2401,16 @@ cdef class Solver:
         """
         return self.csolver.getInstantiations()
 
+    def getStatistics(self):
+        """
+        Returns a snapshot of the current state of the statistic values of this
+        solver. The returned object is completely decoupled from the solver and
+        will not change when the solver is used again.
+        """
+        res = Statistics()
+        res.cstats = self.csolver.getStatistics()
+        return res
+
 
 cdef class Sort:
     """
@@ -2919,6 +2933,47 @@ cdef class Sort:
         return tuple_sorts
 
 
+cdef class Statistics:
+    """
+    The cvc5 Statistics.
+    Wrapper class for :cpp:class:`cvc5::api::Statistics`.
+    Obtain a single statistic value using ``stats["name"]`` and a dictionary
+    with all (visible) statistics using ``stats.get(internal=False, defaulted=False)``.
+    """
+    cdef c_Statistics cstats
+
+    cdef __stat_to_dict(self, const c_Stat& s):
+        res = None
+        if s.isInt():
+            res = s.getInt()
+        elif s.isDouble():
+            res = s.getDouble()
+        elif s.isString():
+            res = s.getString().decode()
+        elif s.isHistogram():
+            res = { h.first.decode(): h.second for h in s.getHistogram() }
+        return {
+            'defaulted': s.isDefault(),
+            'internal': s.isInternal(),
+            'value': res
+        }
+
+    def __getitem__(self, str name):
+        """Get the statistics information for the statistic called ``name``."""
+        return self.__stat_to_dict(self.cstats.get(name.encode()))
+
+    def get(self, bint internal = False, bint defaulted = False):
+        """Get all statistics. See :cpp:class:`cvc5::api::Statistics::begin()` for more information."""
+        cdef c_Statistics.iterator it = self.cstats.begin(internal, defaulted)
+        cdef pair[string,c_Stat]* s
+        res = {}
+        while it != self.cstats.end():
+            s = &dereference(it)
+            res[s.first.decode()] = self.__stat_to_dict(s.second)
+            preincrement(it)
+        return res
+
+
 cdef class Term:
     """
     A cvc5 Term.
index aca628d280628f444c6626267e1fd1c30752eb5b..a05421ab21dbd81ade9e235cc317cbed60b7ac5d 100644 (file)
@@ -1572,6 +1572,21 @@ def test_block_model_values5(solver):
     solver.checkSat()
     solver.blockModelValues([x])
 
+def test_get_statistics(solver):
+    intSort = solver.getIntegerSort()
+    x = solver.mkConst(intSort, "x")
+    y = solver.mkConst(intSort, "y")
+    zero = solver.mkInteger(0)
+    ten = solver.mkInteger(10)
+    f0 = solver.mkTerm(Kind.Geq, x, ten)
+    f1 = solver.mkTerm(Kind.Or, solver.mkTerm(Kind.Geq, zero, x), solver.mkTerm(Kind.Geq, y, zero))
+    solver.assertFormula(f0)
+    solver.assertFormula(f1)
+    solver.checkSat()
+    s = solver.getStatistics()
+    assert s['api::TERM'] == {'defaulted': False, 'internal': False, 'value': {'GEQ': 3, 'OR': 1}}
+    assert s.get(True, False) != {}
+
 def test_set_info(solver):
     with pytest.raises(RuntimeError):
         solver.setInfo("cvc5-lagic", "QF_BV")