2 tests/test_sequences_and_iterators.cpp -- supporting Pythons' sequence protocol, iterators,
5 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
7 All rights reserved. Use of this source code is governed by a
8 BSD-style license that can be found in the LICENSE file.
11 #include "pybind11_tests.h"
12 #include "constructor_stats.h"
13 #include <pybind11/operators.h>
14 #include <pybind11/stl.h>
17 class NonZeroIterator
{
20 NonZeroIterator(const T
* ptr
) : ptr_(ptr
) {}
21 const T
& operator*() const { return *ptr_
; }
22 NonZeroIterator
& operator++() { ++ptr_
; return *this; }
25 class NonZeroSentinel
{};
27 template<typename A
, typename B
>
28 bool operator==(const NonZeroIterator
<std::pair
<A
, B
>>& it
, const NonZeroSentinel
&) {
29 return !(*it
).first
|| !(*it
).second
;
32 template <typename PythonType
>
33 py::list
test_random_access_iterator(PythonType x
) {
35 throw py::value_error("Please provide at least 5 elements for testing.");
37 auto checks
= py::list();
38 auto assert_equal
= [&checks
](py::handle a
, py::handle b
) {
39 auto result
= PyObject_RichCompareBool(a
.ptr(), b
.ptr(), Py_EQ
);
40 if (result
== -1) { throw py::error_already_set(); }
41 checks
.append(result
!= 0);
45 assert_equal(x
[0], *it
);
46 assert_equal(x
[0], it
[0]);
47 assert_equal(x
[1], it
[1]);
49 assert_equal(x
[1], *(++it
));
50 assert_equal(x
[1], *(it
++));
51 assert_equal(x
[2], *it
);
52 assert_equal(x
[3], *(it
+= 1));
53 assert_equal(x
[2], *(--it
));
54 assert_equal(x
[2], *(it
--));
55 assert_equal(x
[1], *it
);
56 assert_equal(x
[0], *(it
-= 1));
58 assert_equal(it
->attr("real"), x
[0].attr("real"));
59 assert_equal((it
+ 1)->attr("real"), x
[1].attr("real"));
61 assert_equal(x
[1], *(it
+ 1));
62 assert_equal(x
[1], *(1 + it
));
64 assert_equal(x
[1], *(it
- 2));
66 checks
.append(static_cast<std::size_t>(x
.end() - x
.begin()) == x
.size());
67 checks
.append((x
.begin() + static_cast<std::ptrdiff_t>(x
.size())) == x
.end());
68 checks
.append(x
.begin() < x
.end());
73 TEST_SUBMODULE(sequences_and_iterators
, m
) {
78 Sequence(size_t size
) : m_size(size
) {
79 print_created(this, "of size", m_size
);
80 m_data
= new float[size
];
81 memset(m_data
, 0, sizeof(float) * size
);
83 Sequence(const std::vector
<float> &value
) : m_size(value
.size()) {
84 print_created(this, "of size", m_size
, "from std::vector");
85 m_data
= new float[m_size
];
86 memcpy(m_data
, &value
[0], sizeof(float) * m_size
);
88 Sequence(const Sequence
&s
) : m_size(s
.m_size
) {
89 print_copy_created(this);
90 m_data
= new float[m_size
];
91 memcpy(m_data
, s
.m_data
, sizeof(float)*m_size
);
93 Sequence(Sequence
&&s
) : m_size(s
.m_size
), m_data(s
.m_data
) {
94 print_move_created(this);
99 ~Sequence() { print_destroyed(this); delete[] m_data
; }
101 Sequence
&operator=(const Sequence
&s
) {
105 m_data
= new float[m_size
];
106 memcpy(m_data
, s
.m_data
, sizeof(float)*m_size
);
108 print_copy_assigned(this);
112 Sequence
&operator=(Sequence
&&s
) {
120 print_move_assigned(this);
124 bool operator==(const Sequence
&s
) const {
125 if (m_size
!= s
.size()) return false;
126 for (size_t i
= 0; i
< m_size
; ++i
)
127 if (m_data
[i
] != s
[i
])
131 bool operator!=(const Sequence
&s
) const { return !operator==(s
); }
133 float operator[](size_t index
) const { return m_data
[index
]; }
134 float &operator[](size_t index
) { return m_data
[index
]; }
136 bool contains(float v
) const {
137 for (size_t i
= 0; i
< m_size
; ++i
)
143 Sequence
reversed() const {
144 Sequence
result(m_size
);
145 for (size_t i
= 0; i
< m_size
; ++i
)
146 result
[m_size
- i
- 1] = m_data
[i
];
150 size_t size() const { return m_size
; }
152 const float *begin() const { return m_data
; }
153 const float *end() const { return m_data
+m_size
; }
159 py::class_
<Sequence
>(m
, "Sequence")
160 .def(py::init
<size_t>())
161 .def(py::init
<const std::vector
<float>&>())
162 /// Bare bones interface
163 .def("__getitem__", [](const Sequence
&s
, size_t i
) {
164 if (i
>= s
.size()) throw py::index_error();
167 .def("__setitem__", [](Sequence
&s
, size_t i
, float v
) {
168 if (i
>= s
.size()) throw py::index_error();
171 .def("__len__", &Sequence::size
)
172 /// Optional sequence protocol operations
173 .def("__iter__", [](const Sequence
&s
) { return py::make_iterator(s
.begin(), s
.end()); },
174 py::keep_alive
<0, 1>() /* Essential: keep object alive while iterator exists */)
175 .def("__contains__", [](const Sequence
&s
, float v
) { return s
.contains(v
); })
176 .def("__reversed__", [](const Sequence
&s
) -> Sequence
{ return s
.reversed(); })
177 /// Slicing protocol (optional)
178 .def("__getitem__", [](const Sequence
&s
, py::slice slice
) -> Sequence
* {
179 size_t start
, stop
, step
, slicelength
;
180 if (!slice
.compute(s
.size(), &start
, &stop
, &step
, &slicelength
))
181 throw py::error_already_set();
182 Sequence
*seq
= new Sequence(slicelength
);
183 for (size_t i
= 0; i
< slicelength
; ++i
) {
184 (*seq
)[i
] = s
[start
]; start
+= step
;
188 .def("__setitem__", [](Sequence
&s
, py::slice slice
, const Sequence
&value
) {
189 size_t start
, stop
, step
, slicelength
;
190 if (!slice
.compute(s
.size(), &start
, &stop
, &step
, &slicelength
))
191 throw py::error_already_set();
192 if (slicelength
!= value
.size())
193 throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
194 for (size_t i
= 0; i
< slicelength
; ++i
) {
195 s
[start
] = value
[i
]; start
+= step
;
199 .def(py::self
== py::self
)
200 .def(py::self
!= py::self
)
201 // Could also define py::self + py::self for concatenation, etc.
205 // Interface of a map-like object that isn't (directly) an unordered_map, but provides some basic
206 // map-like functionality.
209 StringMap() = default;
210 StringMap(std::unordered_map
<std::string
, std::string
> init
)
211 : map(std::move(init
)) {}
213 void set(std::string key
, std::string val
) { map
[key
] = val
; }
214 std::string
get(std::string key
) const { return map
.at(key
); }
215 size_t size() const { return map
.size(); }
217 std::unordered_map
<std::string
, std::string
> map
;
219 decltype(map
.cbegin()) begin() const { return map
.cbegin(); }
220 decltype(map
.cend()) end() const { return map
.cend(); }
222 py::class_
<StringMap
>(m
, "StringMap")
224 .def(py::init
<std::unordered_map
<std::string
, std::string
>>())
225 .def("__getitem__", [](const StringMap
&map
, std::string key
) {
226 try { return map
.get(key
); }
227 catch (const std::out_of_range
&) {
228 throw py::key_error("key '" + key
+ "' does not exist");
231 .def("__setitem__", &StringMap::set
)
232 .def("__len__", &StringMap::size
)
233 .def("__iter__", [](const StringMap
&map
) { return py::make_key_iterator(map
.begin(), map
.end()); },
234 py::keep_alive
<0, 1>())
235 .def("items", [](const StringMap
&map
) { return py::make_iterator(map
.begin(), map
.end()); },
236 py::keep_alive
<0, 1>())
239 // test_generalized_iterators
242 IntPairs(std::vector
<std::pair
<int, int>> data
) : data_(std::move(data
)) {}
243 const std::pair
<int, int>* begin() const { return data_
.data(); }
245 std::vector
<std::pair
<int, int>> data_
;
247 py::class_
<IntPairs
>(m
, "IntPairs")
248 .def(py::init
<std::vector
<std::pair
<int, int>>>())
249 .def("nonzero", [](const IntPairs
& s
) {
250 return py::make_iterator(NonZeroIterator
<std::pair
<int, int>>(s
.begin()), NonZeroSentinel());
251 }, py::keep_alive
<0, 1>())
252 .def("nonzero_keys", [](const IntPairs
& s
) {
253 return py::make_key_iterator(NonZeroIterator
<std::pair
<int, int>>(s
.begin()), NonZeroSentinel());
254 }, py::keep_alive
<0, 1>())
259 // Obsolete: special data structure for exposing custom iterator types to python
260 // kept here for illustrative purposes because there might be some use cases which
261 // are not covered by the much simpler py::make_iterator
263 struct PySequenceIterator
{
264 PySequenceIterator(const Sequence
&seq
, py::object ref
) : seq(seq
), ref(ref
) { }
267 if (index
== seq
.size())
268 throw py::stop_iteration();
273 py::object ref
; // keep a reference
277 py::class_
<PySequenceIterator
>(seq
, "Iterator")
278 .def("__iter__", [](PySequenceIterator
&it
) -> PySequenceIterator
& { return it
; })
279 .def("__next__", &PySequenceIterator::next
);
281 On the actual Sequence object
, the iterator would be constructed as follows
:
282 .def("__iter__", [](py::object s
) { return PySequenceIterator(s
.cast
<const Sequence
&>(), s
); })
285 // test_python_iterator_in_cpp
286 m
.def("object_to_list", [](py::object o
) {
288 for (auto item
: o
) {
294 m
.def("iterator_to_list", [](py::iterator it
) {
296 while (it
!= py::iterator::sentinel()) {
303 // Make sure that py::iterator works with std algorithms
304 m
.def("count_none", [](py::object o
) {
305 return std::count_if(o
.begin(), o
.end(), [](py::handle h
) { return h
.is_none(); });
308 m
.def("find_none", [](py::object o
) {
309 auto it
= std::find_if(o
.begin(), o
.end(), [](py::handle h
) { return h
.is_none(); });
310 return it
->is_none();
313 m
.def("count_nonzeros", [](py::dict d
) {
314 return std::count_if(d
.begin(), d
.end(), [](std::pair
<py::handle
, py::handle
> p
) {
315 return p
.second
.cast
<int>() != 0;
319 m
.def("tuple_iterator", &test_random_access_iterator
<py::tuple
>);
320 m
.def("list_iterator", &test_random_access_iterator
<py::list
>);
321 m
.def("sequence_iterator", &test_random_access_iterator
<py::sequence
>);
323 // test_iterator_passthrough
324 // #181: iterator passthrough did not compile
325 m
.def("iterator_passthrough", [](py::iterator s
) -> py::iterator
{
326 return py::make_iterator(std::begin(s
), std::end(s
));
330 // #388: Can't make iterators via make_iterator() with different r/v policies
331 static std::vector
<int> list
= { 1, 2, 3 };
332 m
.def("make_iterator_1", []() { return py::make_iterator
<py::return_value_policy::copy
>(list
); });
333 m
.def("make_iterator_2", []() { return py::make_iterator
<py::return_value_policy::automatic
>(list
); });