a455212568a6e79d16917c56c7f85182e9183703
[gem5.git] / ext / pybind11 / tests / test_sequences_and_iterators.cpp
1 /*
2 tests/test_sequences_and_iterators.cpp -- supporting Pythons' sequence protocol, iterators,
3 etc.
4
5 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
6
7 All rights reserved. Use of this source code is governed by a
8 BSD-style license that can be found in the LICENSE file.
9 */
10
11 #include "pybind11_tests.h"
12 #include "constructor_stats.h"
13 #include <pybind11/operators.h>
14 #include <pybind11/stl.h>
15
16 template<typename T>
17 class NonZeroIterator {
18 const T* ptr_;
19 public:
20 NonZeroIterator(const T* ptr) : ptr_(ptr) {}
21 const T& operator*() const { return *ptr_; }
22 NonZeroIterator& operator++() { ++ptr_; return *this; }
23 };
24
25 class NonZeroSentinel {};
26
27 template<typename A, typename B>
28 bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentinel&) {
29 return !(*it).first || !(*it).second;
30 }
31
32 template <typename PythonType>
33 py::list test_random_access_iterator(PythonType x) {
34 if (x.size() < 5)
35 throw py::value_error("Please provide at least 5 elements for testing.");
36
37 auto checks = py::list();
38 auto assert_equal = [&checks](py::handle a, py::handle b) {
39 auto result = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_EQ);
40 if (result == -1) { throw py::error_already_set(); }
41 checks.append(result != 0);
42 };
43
44 auto it = x.begin();
45 assert_equal(x[0], *it);
46 assert_equal(x[0], it[0]);
47 assert_equal(x[1], it[1]);
48
49 assert_equal(x[1], *(++it));
50 assert_equal(x[1], *(it++));
51 assert_equal(x[2], *it);
52 assert_equal(x[3], *(it += 1));
53 assert_equal(x[2], *(--it));
54 assert_equal(x[2], *(it--));
55 assert_equal(x[1], *it);
56 assert_equal(x[0], *(it -= 1));
57
58 assert_equal(it->attr("real"), x[0].attr("real"));
59 assert_equal((it + 1)->attr("real"), x[1].attr("real"));
60
61 assert_equal(x[1], *(it + 1));
62 assert_equal(x[1], *(1 + it));
63 it += 3;
64 assert_equal(x[1], *(it - 2));
65
66 checks.append(static_cast<std::size_t>(x.end() - x.begin()) == x.size());
67 checks.append((x.begin() + static_cast<std::ptrdiff_t>(x.size())) == x.end());
68 checks.append(x.begin() < x.end());
69
70 return checks;
71 }
72
73 TEST_SUBMODULE(sequences_and_iterators, m) {
74
75 // test_sequence
76 class Sequence {
77 public:
78 Sequence(size_t size) : m_size(size) {
79 print_created(this, "of size", m_size);
80 m_data = new float[size];
81 memset(m_data, 0, sizeof(float) * size);
82 }
83 Sequence(const std::vector<float> &value) : m_size(value.size()) {
84 print_created(this, "of size", m_size, "from std::vector");
85 m_data = new float[m_size];
86 memcpy(m_data, &value[0], sizeof(float) * m_size);
87 }
88 Sequence(const Sequence &s) : m_size(s.m_size) {
89 print_copy_created(this);
90 m_data = new float[m_size];
91 memcpy(m_data, s.m_data, sizeof(float)*m_size);
92 }
93 Sequence(Sequence &&s) : m_size(s.m_size), m_data(s.m_data) {
94 print_move_created(this);
95 s.m_size = 0;
96 s.m_data = nullptr;
97 }
98
99 ~Sequence() { print_destroyed(this); delete[] m_data; }
100
101 Sequence &operator=(const Sequence &s) {
102 if (&s != this) {
103 delete[] m_data;
104 m_size = s.m_size;
105 m_data = new float[m_size];
106 memcpy(m_data, s.m_data, sizeof(float)*m_size);
107 }
108 print_copy_assigned(this);
109 return *this;
110 }
111
112 Sequence &operator=(Sequence &&s) {
113 if (&s != this) {
114 delete[] m_data;
115 m_size = s.m_size;
116 m_data = s.m_data;
117 s.m_size = 0;
118 s.m_data = nullptr;
119 }
120 print_move_assigned(this);
121 return *this;
122 }
123
124 bool operator==(const Sequence &s) const {
125 if (m_size != s.size()) return false;
126 for (size_t i = 0; i < m_size; ++i)
127 if (m_data[i] != s[i])
128 return false;
129 return true;
130 }
131 bool operator!=(const Sequence &s) const { return !operator==(s); }
132
133 float operator[](size_t index) const { return m_data[index]; }
134 float &operator[](size_t index) { return m_data[index]; }
135
136 bool contains(float v) const {
137 for (size_t i = 0; i < m_size; ++i)
138 if (v == m_data[i])
139 return true;
140 return false;
141 }
142
143 Sequence reversed() const {
144 Sequence result(m_size);
145 for (size_t i = 0; i < m_size; ++i)
146 result[m_size - i - 1] = m_data[i];
147 return result;
148 }
149
150 size_t size() const { return m_size; }
151
152 const float *begin() const { return m_data; }
153 const float *end() const { return m_data+m_size; }
154
155 private:
156 size_t m_size;
157 float *m_data;
158 };
159 py::class_<Sequence>(m, "Sequence")
160 .def(py::init<size_t>())
161 .def(py::init<const std::vector<float>&>())
162 /// Bare bones interface
163 .def("__getitem__", [](const Sequence &s, size_t i) {
164 if (i >= s.size()) throw py::index_error();
165 return s[i];
166 })
167 .def("__setitem__", [](Sequence &s, size_t i, float v) {
168 if (i >= s.size()) throw py::index_error();
169 s[i] = v;
170 })
171 .def("__len__", &Sequence::size)
172 /// Optional sequence protocol operations
173 .def("__iter__", [](const Sequence &s) { return py::make_iterator(s.begin(), s.end()); },
174 py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
175 .def("__contains__", [](const Sequence &s, float v) { return s.contains(v); })
176 .def("__reversed__", [](const Sequence &s) -> Sequence { return s.reversed(); })
177 /// Slicing protocol (optional)
178 .def("__getitem__", [](const Sequence &s, py::slice slice) -> Sequence* {
179 size_t start, stop, step, slicelength;
180 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength))
181 throw py::error_already_set();
182 Sequence *seq = new Sequence(slicelength);
183 for (size_t i = 0; i < slicelength; ++i) {
184 (*seq)[i] = s[start]; start += step;
185 }
186 return seq;
187 })
188 .def("__setitem__", [](Sequence &s, py::slice slice, const Sequence &value) {
189 size_t start, stop, step, slicelength;
190 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength))
191 throw py::error_already_set();
192 if (slicelength != value.size())
193 throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
194 for (size_t i = 0; i < slicelength; ++i) {
195 s[start] = value[i]; start += step;
196 }
197 })
198 /// Comparisons
199 .def(py::self == py::self)
200 .def(py::self != py::self)
201 // Could also define py::self + py::self for concatenation, etc.
202 ;
203
204 // test_map_iterator
205 // Interface of a map-like object that isn't (directly) an unordered_map, but provides some basic
206 // map-like functionality.
207 class StringMap {
208 public:
209 StringMap() = default;
210 StringMap(std::unordered_map<std::string, std::string> init)
211 : map(std::move(init)) {}
212
213 void set(std::string key, std::string val) { map[key] = val; }
214 std::string get(std::string key) const { return map.at(key); }
215 size_t size() const { return map.size(); }
216 private:
217 std::unordered_map<std::string, std::string> map;
218 public:
219 decltype(map.cbegin()) begin() const { return map.cbegin(); }
220 decltype(map.cend()) end() const { return map.cend(); }
221 };
222 py::class_<StringMap>(m, "StringMap")
223 .def(py::init<>())
224 .def(py::init<std::unordered_map<std::string, std::string>>())
225 .def("__getitem__", [](const StringMap &map, std::string key) {
226 try { return map.get(key); }
227 catch (const std::out_of_range&) {
228 throw py::key_error("key '" + key + "' does not exist");
229 }
230 })
231 .def("__setitem__", &StringMap::set)
232 .def("__len__", &StringMap::size)
233 .def("__iter__", [](const StringMap &map) { return py::make_key_iterator(map.begin(), map.end()); },
234 py::keep_alive<0, 1>())
235 .def("items", [](const StringMap &map) { return py::make_iterator(map.begin(), map.end()); },
236 py::keep_alive<0, 1>())
237 ;
238
239 // test_generalized_iterators
240 class IntPairs {
241 public:
242 IntPairs(std::vector<std::pair<int, int>> data) : data_(std::move(data)) {}
243 const std::pair<int, int>* begin() const { return data_.data(); }
244 private:
245 std::vector<std::pair<int, int>> data_;
246 };
247 py::class_<IntPairs>(m, "IntPairs")
248 .def(py::init<std::vector<std::pair<int, int>>>())
249 .def("nonzero", [](const IntPairs& s) {
250 return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
251 }, py::keep_alive<0, 1>())
252 .def("nonzero_keys", [](const IntPairs& s) {
253 return py::make_key_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel());
254 }, py::keep_alive<0, 1>())
255 ;
256
257
258 #if 0
259 // Obsolete: special data structure for exposing custom iterator types to python
260 // kept here for illustrative purposes because there might be some use cases which
261 // are not covered by the much simpler py::make_iterator
262
263 struct PySequenceIterator {
264 PySequenceIterator(const Sequence &seq, py::object ref) : seq(seq), ref(ref) { }
265
266 float next() {
267 if (index == seq.size())
268 throw py::stop_iteration();
269 return seq[index++];
270 }
271
272 const Sequence &seq;
273 py::object ref; // keep a reference
274 size_t index = 0;
275 };
276
277 py::class_<PySequenceIterator>(seq, "Iterator")
278 .def("__iter__", [](PySequenceIterator &it) -> PySequenceIterator& { return it; })
279 .def("__next__", &PySequenceIterator::next);
280
281 On the actual Sequence object, the iterator would be constructed as follows:
282 .def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); })
283 #endif
284
285 // test_python_iterator_in_cpp
286 m.def("object_to_list", [](py::object o) {
287 auto l = py::list();
288 for (auto item : o) {
289 l.append(item);
290 }
291 return l;
292 });
293
294 m.def("iterator_to_list", [](py::iterator it) {
295 auto l = py::list();
296 while (it != py::iterator::sentinel()) {
297 l.append(*it);
298 ++it;
299 }
300 return l;
301 });
302
303 // Make sure that py::iterator works with std algorithms
304 m.def("count_none", [](py::object o) {
305 return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
306 });
307
308 m.def("find_none", [](py::object o) {
309 auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });
310 return it->is_none();
311 });
312
313 m.def("count_nonzeros", [](py::dict d) {
314 return std::count_if(d.begin(), d.end(), [](std::pair<py::handle, py::handle> p) {
315 return p.second.cast<int>() != 0;
316 });
317 });
318
319 m.def("tuple_iterator", &test_random_access_iterator<py::tuple>);
320 m.def("list_iterator", &test_random_access_iterator<py::list>);
321 m.def("sequence_iterator", &test_random_access_iterator<py::sequence>);
322
323 // test_iterator_passthrough
324 // #181: iterator passthrough did not compile
325 m.def("iterator_passthrough", [](py::iterator s) -> py::iterator {
326 return py::make_iterator(std::begin(s), std::end(s));
327 });
328
329 // test_iterator_rvp
330 // #388: Can't make iterators via make_iterator() with different r/v policies
331 static std::vector<int> list = { 1, 2, 3 };
332 m.def("make_iterator_1", []() { return py::make_iterator<py::return_value_policy::copy>(list); });
333 m.def("make_iterator_2", []() { return py::make_iterator<py::return_value_policy::automatic>(list); });
334 }