random: Add __gnu_cxx::beta_distribution<> class.
[gcc.git] / libstdc++-v3 / include / ext / random.tcc
1 // Random number extensions -*- C++ -*-
2
3 // Copyright (C) 2012 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library. This library is free
6 // software; you can redistribute it and/or modify it under the
7 // terms of the GNU General Public License as published by the
8 // Free Software Foundation; either version 3, or (at your option)
9 // any later version.
10
11 // This library is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
15
16 // Under Section 7 of GPL version 3, you are granted additional
17 // permissions described in the GCC Runtime Library Exception, version
18 // 3.1, as published by the Free Software Foundation.
19
20 // You should have received a copy of the GNU General Public License and
21 // a copy of the GCC Runtime Library Exception along with this program;
22 // see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
23 // <http://www.gnu.org/licenses/>.
24
25 /** @file ext/random.tcc
26 * This is an internal header file, included by other library headers.
27 * Do not attempt to use it directly. @headername{ext/random}
28 */
29
30 #ifndef _EXT_RANDOM_TCC
31 #define _EXT_RANDOM_TCC 1
32
33 #pragma GCC system_header
34
35
36 namespace __gnu_cxx _GLIBCXX_VISIBILITY(default)
37 {
38 _GLIBCXX_BEGIN_NAMESPACE_VERSION
39
40
41 template<typename _UIntType, size_t __m,
42 size_t __pos1, size_t __sl1, size_t __sl2,
43 size_t __sr1, size_t __sr2,
44 uint32_t __msk1, uint32_t __msk2,
45 uint32_t __msk3, uint32_t __msk4,
46 uint32_t __parity1, uint32_t __parity2,
47 uint32_t __parity3, uint32_t __parity4>
48 void simd_fast_mersenne_twister_engine<_UIntType, __m,
49 __pos1, __sl1, __sl2, __sr1, __sr2,
50 __msk1, __msk2, __msk3, __msk4,
51 __parity1, __parity2, __parity3,
52 __parity4>::
53 seed(_UIntType __seed)
54 {
55 _M_state32[0] = static_cast<uint32_t>(__seed);
56 for (size_t __i = 1; __i < _M_nstate32; ++__i)
57 _M_state32[__i] = (1812433253UL
58 * (_M_state32[__i - 1] ^ (_M_state32[__i - 1] >> 30))
59 + __i);
60 _M_pos = state_size;
61 _M_period_certification();
62 }
63
64
65 namespace {
66
67 inline uint32_t _Func1(uint32_t __x)
68 {
69 return (__x ^ (__x >> 27)) * UINT32_C(1664525);
70 }
71
72 inline uint32_t _Func2(uint32_t __x)
73 {
74 return (__x ^ (__x >> 27)) * UINT32_C(1566083941);
75 }
76
77 }
78
79
80 template<typename _UIntType, size_t __m,
81 size_t __pos1, size_t __sl1, size_t __sl2,
82 size_t __sr1, size_t __sr2,
83 uint32_t __msk1, uint32_t __msk2,
84 uint32_t __msk3, uint32_t __msk4,
85 uint32_t __parity1, uint32_t __parity2,
86 uint32_t __parity3, uint32_t __parity4>
87 template<typename _Sseq>
88 typename std::enable_if<std::is_class<_Sseq>::value>::type
89 simd_fast_mersenne_twister_engine<_UIntType, __m,
90 __pos1, __sl1, __sl2, __sr1, __sr2,
91 __msk1, __msk2, __msk3, __msk4,
92 __parity1, __parity2, __parity3,
93 __parity4>::
94 seed(_Sseq& __q)
95 {
96 size_t __lag;
97
98 if (_M_nstate32 >= 623)
99 __lag = 11;
100 else if (_M_nstate32 >= 68)
101 __lag = 7;
102 else if (_M_nstate32 >= 39)
103 __lag = 5;
104 else
105 __lag = 3;
106 const size_t __mid = (_M_nstate32 - __lag) / 2;
107
108 std::fill(_M_state32, _M_state32 + _M_nstate32, UINT32_C(0x8b8b8b8b));
109 uint32_t __arr[_M_nstate32];
110 __q.generate(__arr + 0, __arr + _M_nstate32);
111
112 uint32_t __r = _Func1(_M_state32[0] ^ _M_state32[__mid]
113 ^ _M_state32[_M_nstate32 - 1]);
114 _M_state32[__mid] += __r;
115 __r += _M_nstate32;
116 _M_state32[__mid + __lag] += __r;
117 _M_state32[0] = __r;
118
119 for (size_t __i = 1, __j = 0; __j < _M_nstate32; ++__j)
120 {
121 __r = _Func1(_M_state32[__i]
122 ^ _M_state32[(__i + __mid) % _M_nstate32]
123 ^ _M_state32[(__i + _M_nstate32 - 1) % _M_nstate32]);
124 _M_state32[(__i + __mid) % _M_nstate32] += __r;
125 __r += __arr[__j] + __i;
126 _M_state32[(__i + __mid + __lag) % _M_nstate32] += __r;
127 _M_state32[__i] = __r;
128 __i = (__i + 1) % _M_nstate32;
129 }
130 for (size_t __j = 0; __j < _M_nstate32; ++__j)
131 {
132 const size_t __i = (__j + 1) % _M_nstate32;
133 __r = _Func2(_M_state32[__i]
134 + _M_state32[(__i + __mid) % _M_nstate32]
135 + _M_state32[(__i + _M_nstate32 - 1) % _M_nstate32]);
136 _M_state32[(__i + __mid) % _M_nstate32] ^= __r;
137 __r -= __i;
138 _M_state32[(__i + __mid + __lag) % _M_nstate32] ^= __r;
139 _M_state32[__i] = __r;
140 }
141
142 _M_pos = state_size;
143 _M_period_certification();
144 }
145
146
147 template<typename _UIntType, size_t __m,
148 size_t __pos1, size_t __sl1, size_t __sl2,
149 size_t __sr1, size_t __sr2,
150 uint32_t __msk1, uint32_t __msk2,
151 uint32_t __msk3, uint32_t __msk4,
152 uint32_t __parity1, uint32_t __parity2,
153 uint32_t __parity3, uint32_t __parity4>
154 void simd_fast_mersenne_twister_engine<_UIntType, __m,
155 __pos1, __sl1, __sl2, __sr1, __sr2,
156 __msk1, __msk2, __msk3, __msk4,
157 __parity1, __parity2, __parity3,
158 __parity4>::
159 _M_period_certification(void)
160 {
161 static const uint32_t __parity[4] = { __parity1, __parity2,
162 __parity3, __parity4 };
163 uint32_t __inner = 0;
164 for (size_t __i = 0; __i < 4; ++__i)
165 if (__parity[__i] != 0)
166 __inner ^= _M_state32[__i] & __parity[__i];
167
168 if (__builtin_parity(__inner) & 1)
169 return;
170 for (size_t __i = 0; __i < 4; ++__i)
171 if (__parity[__i] != 0)
172 {
173 _M_state32[__i] ^= 1 << (__builtin_ffs(__parity[__i]) - 1);
174 return;
175 }
176 __builtin_unreachable();
177 }
178
179
180 template<typename _UIntType, size_t __m,
181 size_t __pos1, size_t __sl1, size_t __sl2,
182 size_t __sr1, size_t __sr2,
183 uint32_t __msk1, uint32_t __msk2,
184 uint32_t __msk3, uint32_t __msk4,
185 uint32_t __parity1, uint32_t __parity2,
186 uint32_t __parity3, uint32_t __parity4>
187 void simd_fast_mersenne_twister_engine<_UIntType, __m,
188 __pos1, __sl1, __sl2, __sr1, __sr2,
189 __msk1, __msk2, __msk3, __msk4,
190 __parity1, __parity2, __parity3,
191 __parity4>::
192 discard(unsigned long long __z)
193 {
194 while (__z > state_size - _M_pos)
195 {
196 __z -= state_size - _M_pos;
197
198 _M_gen_rand();
199 }
200
201 _M_pos += __z;
202 }
203
204
205 #ifdef __SSE2__
206
207 namespace {
208
209 template<size_t __sl1, size_t __sl2, size_t __sr1, size_t __sr2,
210 uint32_t __msk1, uint32_t __msk2, uint32_t __msk3, uint32_t __msk4>
211 inline __m128i __sse2_recursion(__m128i __a, __m128i __b,
212 __m128i __c, __m128i __d)
213 {
214 __m128i __y = _mm_srli_epi32(__b, __sr1);
215 __m128i __z = _mm_srli_si128(__c, __sr2);
216 __m128i __v = _mm_slli_epi32(__d, __sl1);
217 __z = _mm_xor_si128(__z, __a);
218 __z = _mm_xor_si128(__z, __v);
219 __m128i __x = _mm_slli_si128(__a, __sl2);
220 __y = _mm_and_si128(__y, _mm_set_epi32(__msk4, __msk3, __msk2, __msk1));
221 __z = _mm_xor_si128(__z, __x);
222 return _mm_xor_si128(__z, __y);
223 }
224
225 }
226
227
228 template<typename _UIntType, size_t __m,
229 size_t __pos1, size_t __sl1, size_t __sl2,
230 size_t __sr1, size_t __sr2,
231 uint32_t __msk1, uint32_t __msk2,
232 uint32_t __msk3, uint32_t __msk4,
233 uint32_t __parity1, uint32_t __parity2,
234 uint32_t __parity3, uint32_t __parity4>
235 void simd_fast_mersenne_twister_engine<_UIntType, __m,
236 __pos1, __sl1, __sl2, __sr1, __sr2,
237 __msk1, __msk2, __msk3, __msk4,
238 __parity1, __parity2, __parity3,
239 __parity4>::
240 _M_gen_rand(void)
241 {
242 __m128i __r1 = _mm_load_si128(&_M_state[_M_nstate - 2]);
243 __m128i __r2 = _mm_load_si128(&_M_state[_M_nstate - 1]);
244
245 size_t __i;
246 for (__i = 0; __i < _M_nstate - __pos1; ++__i)
247 {
248 __m128i __r = __sse2_recursion<__sl1, __sl2, __sr1, __sr2,
249 __msk1, __msk2, __msk3, __msk4>
250 (_M_state[__i], _M_state[__i + __pos1], __r1, __r2);
251 _mm_store_si128(&_M_state[__i], __r);
252 __r1 = __r2;
253 __r2 = __r;
254 }
255 for (; __i < _M_nstate; ++__i)
256 {
257 __m128i __r = __sse2_recursion<__sl1, __sl2, __sr1, __sr2,
258 __msk1, __msk2, __msk3, __msk4>
259 (_M_state[__i], _M_state[__i + __pos1 - _M_nstate], __r1, __r2);
260 _mm_store_si128(&_M_state[__i], __r);
261 __r1 = __r2;
262 __r2 = __r;
263 }
264
265 _M_pos = 0;
266 }
267
268
269 #else
270
271 namespace {
272
273 template<size_t __shift>
274 inline void __rshift(uint32_t *__out, const uint32_t *__in)
275 {
276 uint64_t __th = ((static_cast<uint64_t>(__in[3]) << 32)
277 | static_cast<uint64_t>(__in[2]));
278 uint64_t __tl = ((static_cast<uint64_t>(__in[1]) << 32)
279 | static_cast<uint64_t>(__in[0]));
280
281 uint64_t __oh = __th >> (__shift * 8);
282 uint64_t __ol = __tl >> (__shift * 8);
283 __ol |= __th << (64 - __shift * 8);
284 __out[1] = static_cast<uint32_t>(__ol >> 32);
285 __out[0] = static_cast<uint32_t>(__ol);
286 __out[3] = static_cast<uint32_t>(__oh >> 32);
287 __out[2] = static_cast<uint32_t>(__oh);
288 }
289
290
291 template<size_t __shift>
292 inline void __lshift(uint32_t *__out, const uint32_t *__in)
293 {
294 uint64_t __th = ((static_cast<uint64_t>(__in[3]) << 32)
295 | static_cast<uint64_t>(__in[2]));
296 uint64_t __tl = ((static_cast<uint64_t>(__in[1]) << 32)
297 | static_cast<uint64_t>(__in[0]));
298
299 uint64_t __oh = __th << (__shift * 8);
300 uint64_t __ol = __tl << (__shift * 8);
301 __oh |= __tl >> (64 - __shift * 8);
302 __out[1] = static_cast<uint32_t>(__ol >> 32);
303 __out[0] = static_cast<uint32_t>(__ol);
304 __out[3] = static_cast<uint32_t>(__oh >> 32);
305 __out[2] = static_cast<uint32_t>(__oh);
306 }
307
308
309 template<size_t __sl1, size_t __sl2, size_t __sr1, size_t __sr2,
310 uint32_t __msk1, uint32_t __msk2, uint32_t __msk3, uint32_t __msk4>
311 inline void __recursion(uint32_t *__r,
312 const uint32_t *__a, const uint32_t *__b,
313 const uint32_t *__c, const uint32_t *__d)
314 {
315 uint32_t __x[4];
316 uint32_t __y[4];
317
318 __lshift<__sl2>(__x, __a);
319 __rshift<__sr2>(__y, __c);
320 __r[0] = (__a[0] ^ __x[0] ^ ((__b[0] >> __sr1) & __msk1)
321 ^ __y[0] ^ (__d[0] << __sl1));
322 __r[1] = (__a[1] ^ __x[1] ^ ((__b[1] >> __sr1) & __msk2)
323 ^ __y[1] ^ (__d[1] << __sl1));
324 __r[2] = (__a[2] ^ __x[2] ^ ((__b[2] >> __sr1) & __msk3)
325 ^ __y[2] ^ (__d[2] << __sl1));
326 __r[3] = (__a[3] ^ __x[3] ^ ((__b[3] >> __sr1) & __msk4)
327 ^ __y[3] ^ (__d[3] << __sl1));
328 }
329
330 }
331
332
333 template<typename _UIntType, size_t __m,
334 size_t __pos1, size_t __sl1, size_t __sl2,
335 size_t __sr1, size_t __sr2,
336 uint32_t __msk1, uint32_t __msk2,
337 uint32_t __msk3, uint32_t __msk4,
338 uint32_t __parity1, uint32_t __parity2,
339 uint32_t __parity3, uint32_t __parity4>
340 void simd_fast_mersenne_twister_engine<_UIntType, __m,
341 __pos1, __sl1, __sl2, __sr1, __sr2,
342 __msk1, __msk2, __msk3, __msk4,
343 __parity1, __parity2, __parity3,
344 __parity4>::
345 _M_gen_rand(void)
346 {
347 const uint32_t *__r1 = &_M_state32[_M_nstate32 - 8];
348 const uint32_t *__r2 = &_M_state32[_M_nstate32 - 4];
349 static constexpr size_t __pos1_32 = __pos1 * 4;
350
351 size_t __i;
352 for (__i = 0; __i < _M_nstate32 - __pos1_32; __i += 4)
353 {
354 __recursion<__sl1, __sl2, __sr1, __sr2,
355 __msk1, __msk2, __msk3, __msk4>
356 (&_M_state32[__i], &_M_state32[__i],
357 &_M_state32[__i + __pos1_32], __r1, __r2);
358 __r1 = __r2;
359 __r2 = &_M_state32[__i];
360 }
361
362 for (; __i < _M_nstate32; __i += 4)
363 {
364 __recursion<__sl1, __sl2, __sr1, __sr2,
365 __msk1, __msk2, __msk3, __msk4>
366 (&_M_state32[__i], &_M_state32[__i],
367 &_M_state32[__i + __pos1_32 - _M_nstate32], __r1, __r2);
368 __r1 = __r2;
369 __r2 = &_M_state32[__i];
370 }
371
372 _M_pos = 0;
373 }
374
375 #endif
376
377
378 template<typename _UIntType, size_t __m,
379 size_t __pos1, size_t __sl1, size_t __sl2,
380 size_t __sr1, size_t __sr2,
381 uint32_t __msk1, uint32_t __msk2,
382 uint32_t __msk3, uint32_t __msk4,
383 uint32_t __parity1, uint32_t __parity2,
384 uint32_t __parity3, uint32_t __parity4,
385 typename _CharT, typename _Traits>
386 std::basic_ostream<_CharT, _Traits>&
387 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
388 const __gnu_cxx::simd_fast_mersenne_twister_engine<_UIntType,
389 __m, __pos1, __sl1, __sl2, __sr1, __sr2,
390 __msk1, __msk2, __msk3, __msk4,
391 __parity1, __parity2, __parity3, __parity4>& __x)
392 {
393 typedef std::basic_ostream<_CharT, _Traits> __ostream_type;
394 typedef typename __ostream_type::ios_base __ios_base;
395
396 const typename __ios_base::fmtflags __flags = __os.flags();
397 const _CharT __fill = __os.fill();
398 const _CharT __space = __os.widen(' ');
399 __os.flags(__ios_base::dec | __ios_base::fixed | __ios_base::left);
400 __os.fill(__space);
401
402 for (size_t __i = 0; __i < __x._M_nstate32; ++__i)
403 __os << __x._M_state32[__i] << __space;
404 __os << __x._M_pos;
405
406 __os.flags(__flags);
407 __os.fill(__fill);
408 return __os;
409 }
410
411
412 template<typename _UIntType, size_t __m,
413 size_t __pos1, size_t __sl1, size_t __sl2,
414 size_t __sr1, size_t __sr2,
415 uint32_t __msk1, uint32_t __msk2,
416 uint32_t __msk3, uint32_t __msk4,
417 uint32_t __parity1, uint32_t __parity2,
418 uint32_t __parity3, uint32_t __parity4,
419 typename _CharT, typename _Traits>
420 std::basic_istream<_CharT, _Traits>&
421 operator>>(std::basic_istream<_CharT, _Traits>& __is,
422 __gnu_cxx::simd_fast_mersenne_twister_engine<_UIntType,
423 __m, __pos1, __sl1, __sl2, __sr1, __sr2,
424 __msk1, __msk2, __msk3, __msk4,
425 __parity1, __parity2, __parity3, __parity4>& __x)
426 {
427 typedef std::basic_istream<_CharT, _Traits> __istream_type;
428 typedef typename __istream_type::ios_base __ios_base;
429
430 const typename __ios_base::fmtflags __flags = __is.flags();
431 __is.flags(__ios_base::dec | __ios_base::skipws);
432
433 for (size_t __i = 0; __i < __x._M_nstate32; ++__i)
434 __is >> __x._M_state32[__i];
435 __is >> __x._M_pos;
436
437 __is.flags(__flags);
438 return __is;
439 }
440
441
442 /**
443 * Iteration method due to M.D. J<o:>hnk.
444 *
445 * M.D. J<o:>hnk, Erzeugung von betaverteilten und gammaverteilten
446 * Zufallszahlen, Metrika, Volume 8, 1964
447 */
448 template<typename _RealType>
449 template<typename _UniformRandomNumberGenerator>
450 typename beta_distribution<_RealType>::result_type
451 beta_distribution<_RealType>::
452 operator()(_UniformRandomNumberGenerator& __urng,
453 const param_type& __param)
454 {
455 std::__detail::_Adaptor<_UniformRandomNumberGenerator, result_type>
456 __aurng(__urng);
457
458 result_type __x, __y;
459 do
460 {
461 __x = std::exp(std::log(__aurng()) / __param.alpha());
462 __y = std::exp(std::log(__aurng()) / __param.beta());
463 }
464 while (__x + __y > result_type(1));
465
466 return __x / (__x + __y);
467 }
468
469 template<typename _RealType>
470 template<typename _OutputIterator,
471 typename _UniformRandomNumberGenerator>
472 void
473 beta_distribution<_RealType>::
474 __generate_impl(_OutputIterator __f, _OutputIterator __t,
475 _UniformRandomNumberGenerator& __urng,
476 const param_type& __param)
477 {
478 __glibcxx_function_requires(_OutputIteratorConcept<_OutputIterator>)
479
480 std::__detail::_Adaptor<_UniformRandomNumberGenerator, result_type>
481 __aurng(__urng);
482
483 while (__f != __t)
484 {
485 result_type __x, __y;
486 do
487 {
488 __x = std::exp(std::log(__aurng()) / __param.alpha());
489 __y = std::exp(std::log(__aurng()) / __param.beta());
490 }
491 while (__x + __y > result_type(1));
492
493 *__f++ = __x / (__x + __y);
494 }
495 }
496
497 template<typename _RealType, typename _CharT, typename _Traits>
498 std::basic_ostream<_CharT, _Traits>&
499 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
500 const __gnu_cxx::beta_distribution<_RealType>& __x)
501 {
502 typedef std::basic_ostream<_CharT, _Traits> __ostream_type;
503 typedef typename __ostream_type::ios_base __ios_base;
504
505 const typename __ios_base::fmtflags __flags = __os.flags();
506 const _CharT __fill = __os.fill();
507 const std::streamsize __precision = __os.precision();
508 const _CharT __space = __os.widen(' ');
509 __os.flags(__ios_base::scientific | __ios_base::left);
510 __os.fill(__space);
511 __os.precision(std::numeric_limits<_RealType>::max_digits10);
512
513 __os << __x.alpha() << __space << __x.beta();
514
515 __os.flags(__flags);
516 __os.fill(__fill);
517 __os.precision(__precision);
518 return __os;
519 }
520
521 template<typename _RealType, typename _CharT, typename _Traits>
522 std::basic_istream<_CharT, _Traits>&
523 operator>>(std::basic_istream<_CharT, _Traits>& __is,
524 __gnu_cxx::beta_distribution<_RealType>& __x)
525 {
526 typedef std::basic_istream<_CharT, _Traits> __istream_type;
527 typedef typename __istream_type::ios_base __ios_base;
528
529 const typename __ios_base::fmtflags __flags = __is.flags();
530 __is.flags(__ios_base::dec | __ios_base::skipws);
531
532 _RealType __alpha_val, __beta_val;
533 __is >> __alpha_val >> __beta_val;
534 __x.param(typename __gnu_cxx::beta_distribution<_RealType>::
535 param_type(__alpha_val, __beta_val));
536
537 __is.flags(__flags);
538 return __is;
539 }
540
541 _GLIBCXX_END_NAMESPACE_VERSION
542 } // namespace
543
544
545 #endif // _EXT_RANDOM_TCC