find_selectors.h: Correct name for include guard #ifndef.
[gcc.git] / libstdc++-v3 / include / parallel / multiway_mergesort.h
1 // -*- C++ -*-
2
3 // Copyright (C) 2007, 2008 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library. This library is free
6 // software; you can redistribute it and/or modify it under the terms
7 // of the GNU General Public License as published by the Free Software
8 // Foundation; either version 2, or (at your option) any later
9 // version.
10
11 // This library is distributed in the hope that it will be useful, but
12 // WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 // General Public License for more details.
15
16 // You should have received a copy of the GNU General Public License
17 // along with this library; see the file COPYING. If not, write to
18 // the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
19 // MA 02111-1307, USA.
20
21 // As a special exception, you may use this file as part of a free
22 // software library without restriction. Specifically, if other files
23 // instantiate templates or use macros or inline functions from this
24 // file, or you compile this file and link it with other files to
25 // produce an executable, this file does not by itself cause the
26 // resulting executable to be covered by the GNU General Public
27 // License. This exception does not however invalidate any other
28 // reasons why the executable file might be covered by the GNU General
29 // Public License.
30
31 /** @file parallel/multiway_mergesort.h
32 * @brief Parallel multiway merge sort.
33 * This file is a GNU parallel extension to the Standard C++ Library.
34 */
35
36 // Written by Johannes Singler.
37
38 #ifndef _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H
39 #define _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H 1
40
41 #include <vector>
42
43 #include <parallel/basic_iterator.h>
44 #include <bits/stl_algo.h>
45 #include <parallel/parallel.h>
46 #include <parallel/multiway_merge.h>
47
48 namespace __gnu_parallel
49 {
50
51 /** @brief Subsequence description. */
52 template<typename _DifferenceTp>
53 struct Piece
54 {
55 typedef _DifferenceTp difference_type;
56
57 /** @brief Begin of subsequence. */
58 difference_type begin;
59
60 /** @brief End of subsequence. */
61 difference_type end;
62 };
63
64 /** @brief Data accessed by all threads.
65 *
66 * PMWMS = parallel multiway mergesort */
67 template<typename RandomAccessIterator>
68 struct PMWMSSortingData
69 {
70 typedef std::iterator_traits<RandomAccessIterator> traits_type;
71 typedef typename traits_type::value_type value_type;
72 typedef typename traits_type::difference_type difference_type;
73
74 /** @brief Number of threads involved. */
75 thread_index_t num_threads;
76
77 /** @brief Input begin. */
78 RandomAccessIterator source;
79
80 /** @brief Start indices, per thread. */
81 difference_type* starts;
82
83 /** @brief Storage in which to sort. */
84 value_type** temporary;
85
86 /** @brief Samples. */
87 value_type* samples;
88
89 /** @brief Offsets to add to the found positions. */
90 difference_type* offsets;
91
92 /** @brief Pieces of data to merge @c [thread][sequence] */
93 std::vector<Piece<difference_type> >* pieces;
94 };
95
96 /**
97 * @brief Select samples from a sequence.
98 * @param sd Pointer to algorithm data. Result will be placed in
99 * @c sd->samples.
100 * @param num_samples Number of samples to select.
101 */
102 template<typename RandomAccessIterator, typename _DifferenceTp>
103 void
104 determine_samples(PMWMSSortingData<RandomAccessIterator>* sd,
105 _DifferenceTp num_samples)
106 {
107 typedef std::iterator_traits<RandomAccessIterator> traits_type;
108 typedef typename traits_type::value_type value_type;
109 typedef _DifferenceTp difference_type;
110
111 thread_index_t iam = omp_get_thread_num();
112
113 difference_type* es = new difference_type[num_samples + 2];
114
115 equally_split(sd->starts[iam + 1] - sd->starts[iam],
116 num_samples + 1, es);
117
118 for (difference_type i = 0; i < num_samples; ++i)
119 ::new(&(sd->samples[iam * num_samples + i]))
120 value_type(sd->source[sd->starts[iam] + es[i + 1]]);
121
122 delete[] es;
123 }
124
125 /** @brief Split consistently. */
126 template<bool exact, typename RandomAccessIterator,
127 typename Comparator, typename SortingPlacesIterator>
128 struct split_consistently
129 {
130 };
131
132 /** @brief Split by exact splitting. */
133 template<typename RandomAccessIterator, typename Comparator,
134 typename SortingPlacesIterator>
135 struct split_consistently
136 <true, RandomAccessIterator, Comparator, SortingPlacesIterator>
137 {
138 void operator()(
139 const thread_index_t iam,
140 PMWMSSortingData<RandomAccessIterator>* sd,
141 Comparator& comp,
142 const typename
143 std::iterator_traits<RandomAccessIterator>::difference_type
144 num_samples)
145 const
146 {
147 # pragma omp barrier
148
149 std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
150 seqs(sd->num_threads);
151 for (thread_index_t s = 0; s < sd->num_threads; s++)
152 seqs[s] = std::make_pair(sd->temporary[s],
153 sd->temporary[s]
154 + (sd->starts[s + 1] - sd->starts[s]));
155
156 std::vector<SortingPlacesIterator> offsets(sd->num_threads);
157
158 // if not last thread
159 if (iam < sd->num_threads - 1)
160 multiseq_partition(seqs.begin(), seqs.end(),
161 sd->starts[iam + 1], offsets.begin(), comp);
162
163 for (int seq = 0; seq < sd->num_threads; seq++)
164 {
165 // for each sequence
166 if (iam < (sd->num_threads - 1))
167 sd->pieces[iam][seq].end = offsets[seq] - seqs[seq].first;
168 else
169 // very end of this sequence
170 sd->pieces[iam][seq].end =
171 sd->starts[seq + 1] - sd->starts[seq];
172 }
173
174 # pragma omp barrier
175
176 for (thread_index_t seq = 0; seq < sd->num_threads; seq++)
177 {
178 // For each sequence.
179 if (iam > 0)
180 sd->pieces[iam][seq].begin = sd->pieces[iam - 1][seq].end;
181 else
182 // Absolute beginning.
183 sd->pieces[iam][seq].begin = 0;
184 }
185 }
186 };
187
188 /** @brief Split by sampling. */
189 template<typename RandomAccessIterator, typename Comparator,
190 typename SortingPlacesIterator>
191 struct split_consistently<false, RandomAccessIterator, Comparator,
192 SortingPlacesIterator>
193 {
194 void operator()(
195 const thread_index_t iam,
196 PMWMSSortingData<RandomAccessIterator>* sd,
197 Comparator& comp,
198 const typename
199 std::iterator_traits<RandomAccessIterator>::difference_type
200 num_samples)
201 const
202 {
203 typedef std::iterator_traits<RandomAccessIterator> traits_type;
204 typedef typename traits_type::value_type value_type;
205 typedef typename traits_type::difference_type difference_type;
206
207 determine_samples(sd, num_samples);
208
209 # pragma omp barrier
210
211 # pragma omp single
212 __gnu_sequential::sort(sd->samples,
213 sd->samples + (num_samples * sd->num_threads),
214 comp);
215
216 # pragma omp barrier
217
218 for (thread_index_t s = 0; s < sd->num_threads; ++s)
219 {
220 // For each sequence.
221 if (num_samples * iam > 0)
222 sd->pieces[iam][s].begin =
223 std::lower_bound(sd->temporary[s],
224 sd->temporary[s]
225 + (sd->starts[s + 1] - sd->starts[s]),
226 sd->samples[num_samples * iam],
227 comp)
228 - sd->temporary[s];
229 else
230 // Absolute beginning.
231 sd->pieces[iam][s].begin = 0;
232
233 if ((num_samples * (iam + 1)) < (num_samples * sd->num_threads))
234 sd->pieces[iam][s].end =
235 std::lower_bound(sd->temporary[s],
236 sd->temporary[s]
237 + (sd->starts[s + 1] - sd->starts[s]),
238 sd->samples[num_samples * (iam + 1)],
239 comp)
240 - sd->temporary[s];
241 else
242 // Absolute end.
243 sd->pieces[iam][s].end = sd->starts[s + 1] - sd->starts[s];
244 }
245 }
246 };
247
248 template<bool stable, typename RandomAccessIterator, typename Comparator>
249 struct possibly_stable_sort
250 {
251 };
252
253 template<typename RandomAccessIterator, typename Comparator>
254 struct possibly_stable_sort<true, RandomAccessIterator, Comparator>
255 {
256 void operator()(const RandomAccessIterator& begin,
257 const RandomAccessIterator& end, Comparator& comp) const
258 {
259 __gnu_sequential::stable_sort(begin, end, comp);
260 }
261 };
262
263 template<typename RandomAccessIterator, typename Comparator>
264 struct possibly_stable_sort<false, RandomAccessIterator, Comparator>
265 {
266 void operator()(const RandomAccessIterator begin,
267 const RandomAccessIterator end, Comparator& comp) const
268 {
269 __gnu_sequential::sort(begin, end, comp);
270 }
271 };
272
273 template<bool stable, typename SeqRandomAccessIterator,
274 typename RandomAccessIterator, typename Comparator,
275 typename DiffType>
276 struct possibly_stable_multiway_merge
277 {
278 };
279
280 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
281 typename Comparator, typename DiffType>
282 struct possibly_stable_multiway_merge
283 <true, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
284 DiffType>
285 {
286 void operator()(const SeqRandomAccessIterator& seqs_begin,
287 const SeqRandomAccessIterator& seqs_end,
288 const RandomAccessIterator& target,
289 Comparator& comp,
290 DiffType length_am) const
291 {
292 stable_multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
293 sequential_tag());
294 }
295 };
296
297 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
298 typename Comparator, typename DiffType>
299 struct possibly_stable_multiway_merge
300 <false, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
301 DiffType>
302 {
303 void operator()(const SeqRandomAccessIterator& seqs_begin,
304 const SeqRandomAccessIterator& seqs_end,
305 const RandomAccessIterator& target,
306 Comparator& comp,
307 DiffType length_am) const
308 {
309 multiway_merge(seqs_begin, seqs_end, target, length_am, comp,
310 sequential_tag());
311 }
312 };
313
314 /** @brief PMWMS code executed by each thread.
315 * @param sd Pointer to algorithm data.
316 * @param comp Comparator.
317 */
318 template<bool stable, bool exact, typename RandomAccessIterator,
319 typename Comparator>
320 void
321 parallel_sort_mwms_pu(PMWMSSortingData<RandomAccessIterator>* sd,
322 Comparator& comp)
323 {
324 typedef std::iterator_traits<RandomAccessIterator> traits_type;
325 typedef typename traits_type::value_type value_type;
326 typedef typename traits_type::difference_type difference_type;
327
328 thread_index_t iam = omp_get_thread_num();
329
330 // Length of this thread's chunk, before merging.
331 difference_type length_local = sd->starts[iam + 1] - sd->starts[iam];
332
333 // Sort in temporary storage, leave space for sentinel.
334
335 typedef value_type* SortingPlacesIterator;
336
337 sd->temporary[iam] =
338 static_cast<value_type*>(
339 ::operator new(sizeof(value_type) * (length_local + 1)));
340
341 // Copy there.
342 std::uninitialized_copy(sd->source + sd->starts[iam],
343 sd->source + sd->starts[iam] + length_local,
344 sd->temporary[iam]);
345
346 possibly_stable_sort<stable, SortingPlacesIterator, Comparator>()
347 (sd->temporary[iam], sd->temporary[iam] + length_local, comp);
348
349 // Invariant: locally sorted subsequence in sd->temporary[iam],
350 // sd->temporary[iam] + length_local.
351
352 // No barrier here: Synchronization is done by the splitting routine.
353
354 difference_type num_samples =
355 _Settings::get().sort_mwms_oversampling * sd->num_threads - 1;
356 split_consistently
357 <exact, RandomAccessIterator, Comparator, SortingPlacesIterator>()
358 (iam, sd, comp, num_samples);
359
360 // Offset from target begin, length after merging.
361 difference_type offset = 0, length_am = 0;
362 for (thread_index_t s = 0; s < sd->num_threads; s++)
363 {
364 length_am += sd->pieces[iam][s].end - sd->pieces[iam][s].begin;
365 offset += sd->pieces[iam][s].begin;
366 }
367
368 typedef std::vector<
369 std::pair<SortingPlacesIterator, SortingPlacesIterator> >
370 seq_vector_type;
371 seq_vector_type seqs(sd->num_threads);
372
373 for (int s = 0; s < sd->num_threads; ++s)
374 {
375 seqs[s] =
376 std::make_pair(sd->temporary[s] + sd->pieces[iam][s].begin,
377 sd->temporary[s] + sd->pieces[iam][s].end);
378 }
379
380 possibly_stable_multiway_merge<
381 stable,
382 typename seq_vector_type::iterator,
383 RandomAccessIterator,
384 Comparator, difference_type>()
385 (seqs.begin(), seqs.end(),
386 sd->source + offset, comp,
387 length_am);
388
389 # pragma omp barrier
390
391 ::operator delete(sd->temporary[iam]);
392 }
393
394 /** @brief PMWMS main call.
395 * @param begin Begin iterator of sequence.
396 * @param end End iterator of sequence.
397 * @param comp Comparator.
398 * @param n Length of sequence.
399 * @param num_threads Number of threads to use.
400 */
401 template<bool stable, bool exact, typename RandomAccessIterator,
402 typename Comparator>
403 void
404 parallel_sort_mwms(RandomAccessIterator begin, RandomAccessIterator end,
405 Comparator comp,
406 thread_index_t num_threads)
407 {
408 _GLIBCXX_CALL(end - begin)
409
410 typedef std::iterator_traits<RandomAccessIterator> traits_type;
411 typedef typename traits_type::value_type value_type;
412 typedef typename traits_type::difference_type difference_type;
413
414 difference_type n = end - begin;
415
416 if (n <= 1)
417 return;
418
419 // at least one element per thread
420 if (num_threads > n)
421 num_threads = static_cast<thread_index_t>(n);
422
423 // shared variables
424 PMWMSSortingData<RandomAccessIterator> sd;
425 difference_type* starts;
426
427 # pragma omp parallel num_threads(num_threads)
428 {
429 num_threads = omp_get_num_threads(); //no more threads than requested
430
431 # pragma omp single
432 {
433 sd.num_threads = num_threads;
434 sd.source = begin;
435
436 sd.temporary = new value_type*[num_threads];
437
438 if (!exact)
439 {
440 difference_type size =
441 (_Settings::get().sort_mwms_oversampling * num_threads - 1)
442 * num_threads;
443 sd.samples = static_cast<value_type*>(
444 ::operator new(size * sizeof(value_type)));
445 }
446 else
447 sd.samples = NULL;
448
449 sd.offsets = new difference_type[num_threads - 1];
450 sd.pieces = new std::vector<Piece<difference_type> >[num_threads];
451 for (int s = 0; s < num_threads; ++s)
452 sd.pieces[s].resize(num_threads);
453 starts = sd.starts = new difference_type[num_threads + 1];
454
455 difference_type chunk_length = n / num_threads;
456 difference_type split = n % num_threads;
457 difference_type pos = 0;
458 for (int i = 0; i < num_threads; ++i)
459 {
460 starts[i] = pos;
461 pos += (i < split) ? (chunk_length + 1) : chunk_length;
462 }
463 starts[num_threads] = pos;
464 } //single
465
466 // Now sort in parallel.
467 parallel_sort_mwms_pu<stable, exact>(&sd, comp);
468 } //parallel
469
470 delete[] starts;
471 delete[] sd.temporary;
472
473 if (!exact)
474 ::operator delete(sd.samples);
475
476 delete[] sd.offsets;
477 delete[] sd.pieces;
478 }
479 } //namespace __gnu_parallel
480
481 #endif /* _GLIBCXX_PARALLEL_MULTIWAY_MERGESORT_H */