dae5bfc90203621d7d818bb822a0500a15af8da7
[mesa.git] / src / compiler / nir / nir_opt_loop_unroll.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 * DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_loop_analyze.h"
28
29
30 /* This limit is chosen fairly arbitrarily. GLSL IR max iteration is 32
31 * instructions. (Multiply counting nodes and magic number 5.) But there is
32 * no 1:1 mapping between GLSL IR and NIR so 25 was picked because it seemed
33 * to give about the same results. Around 5 instructions per node. But some
34 * loops that would unroll with GLSL IR fail to unroll if we set this to 25 so
35 * we set it to 26.
36 * This was bumped to 96 because it unrolled more loops with a positive
37 * effect (vulkan ssao demo).
38 */
39 #define LOOP_UNROLL_LIMIT 96
40
41 /* Prepare this loop for unrolling by first converting to lcssa and then
42 * converting the phis from the loops first block and the block that follows
43 * the loop into regs. Partially converting out of SSA allows us to unroll
44 * the loop without having to keep track of and update phis along the way
45 * which gets tricky and doesn't add much value over conveting to regs.
46 *
47 * The loop may have a continue instruction at the end of the loop which does
48 * nothing. Once we're out of SSA, we can safely delete it so we don't have
49 * to deal with it later.
50 */
51 static void
52 loop_prepare_for_unroll(nir_loop *loop)
53 {
54 nir_convert_loop_to_lcssa(loop);
55
56 nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
57
58 nir_block *block_after_loop =
59 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
60
61 nir_lower_phis_to_regs_block(block_after_loop);
62
63 nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
64 if (last_instr && last_instr->type == nir_instr_type_jump) {
65 assert(nir_instr_as_jump(last_instr)->type == nir_jump_continue);
66 nir_instr_remove(last_instr);
67 }
68 }
69
70 static void
71 get_first_blocks_in_terminator(nir_loop_terminator *term,
72 nir_block **first_break_block,
73 nir_block **first_continue_block)
74 {
75 if (term->continue_from_then) {
76 *first_continue_block = nir_if_first_then_block(term->nif);
77 *first_break_block = nir_if_first_else_block(term->nif);
78 } else {
79 *first_continue_block = nir_if_first_else_block(term->nif);
80 *first_break_block = nir_if_first_then_block(term->nif);
81 }
82 }
83
84 /**
85 * Unroll a loop where we know exactly how many iterations there are and there
86 * is only a single exit point. Note here we can unroll loops with multiple
87 * theoretical exits that only have a single terminating exit that we always
88 * know is the "real" exit.
89 *
90 * loop {
91 * ...instrs...
92 * }
93 *
94 * And the iteration count is 3, the output will be:
95 *
96 * ...instrs... ...instrs... ...instrs...
97 */
98 static void
99 simple_unroll(nir_loop *loop)
100 {
101 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
102 assert(nir_is_trivial_loop_if(limiting_term->nif,
103 limiting_term->break_block));
104
105 loop_prepare_for_unroll(loop);
106
107 /* Skip over loop terminator and get the loop body. */
108 list_for_each_entry(nir_loop_terminator, terminator,
109 &loop->info->loop_terminator_list,
110 loop_terminator_link) {
111
112 /* Remove all but the limiting terminator as we know the other exit
113 * conditions can never be met. Note we need to extract any instructions
114 * in the continue from branch and insert then into the loop body before
115 * removing it.
116 */
117 if (terminator->nif != limiting_term->nif) {
118 nir_block *first_break_block;
119 nir_block *first_continue_block;
120 get_first_blocks_in_terminator(terminator, &first_break_block,
121 &first_continue_block);
122
123 assert(nir_is_trivial_loop_if(terminator->nif,
124 terminator->break_block));
125
126 nir_cf_list continue_from_lst;
127 nir_cf_extract(&continue_from_lst,
128 nir_before_block(first_continue_block),
129 nir_after_block(terminator->continue_from_block));
130 nir_cf_reinsert(&continue_from_lst,
131 nir_after_cf_node(&terminator->nif->cf_node));
132
133 nir_cf_node_remove(&terminator->nif->cf_node);
134 }
135 }
136
137 nir_block *first_break_block;
138 nir_block *first_continue_block;
139 get_first_blocks_in_terminator(limiting_term, &first_break_block,
140 &first_continue_block);
141
142 /* Pluck out the loop header */
143 nir_block *header_blk = nir_loop_first_block(loop);
144 nir_cf_list lp_header;
145 nir_cf_extract(&lp_header, nir_before_block(header_blk),
146 nir_before_cf_node(&limiting_term->nif->cf_node));
147
148 /* Add the continue from block of the limiting terminator to the loop body
149 */
150 nir_cf_list continue_from_lst;
151 nir_cf_extract(&continue_from_lst, nir_before_block(first_continue_block),
152 nir_after_block(limiting_term->continue_from_block));
153 nir_cf_reinsert(&continue_from_lst,
154 nir_after_cf_node(&limiting_term->nif->cf_node));
155
156 /* Pluck out the loop body */
157 nir_cf_list loop_body;
158 nir_cf_extract(&loop_body, nir_after_cf_node(&limiting_term->nif->cf_node),
159 nir_after_block(nir_loop_last_block(loop)));
160
161 struct hash_table *remap_table =
162 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
163 _mesa_key_pointer_equal);
164
165 /* Clone the loop header */
166 nir_cf_list cloned_header;
167 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
168 remap_table);
169
170 /* Insert cloned loop header before the loop */
171 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
172
173 /* Temp list to store the cloned loop body as we unroll */
174 nir_cf_list unrolled_lp_body;
175
176 /* Clone loop header and append to the loop body */
177 for (unsigned i = 0; i < loop->info->trip_count; i++) {
178 /* Clone loop body */
179 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
180 remap_table);
181
182 /* Insert unrolled loop body before the loop */
183 nir_cf_reinsert(&unrolled_lp_body, nir_before_cf_node(&loop->cf_node));
184
185 /* Clone loop header */
186 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
187 remap_table);
188
189 /* Insert loop header after loop body */
190 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
191 }
192
193 /* Remove the break from the loop terminator and add instructions from
194 * the break block after the unrolled loop.
195 */
196 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
197 nir_instr_remove(break_instr);
198 nir_cf_list break_list;
199 nir_cf_extract(&break_list, nir_before_block(first_break_block),
200 nir_after_block(limiting_term->break_block));
201
202 /* Clone so things get properly remapped */
203 nir_cf_list cloned_break_list;
204 nir_cf_list_clone(&cloned_break_list, &break_list, loop->cf_node.parent,
205 remap_table);
206
207 nir_cf_reinsert(&cloned_break_list, nir_before_cf_node(&loop->cf_node));
208
209 /* Remove the loop */
210 nir_cf_node_remove(&loop->cf_node);
211
212 /* Delete the original loop body, break block & header */
213 nir_cf_delete(&lp_header);
214 nir_cf_delete(&loop_body);
215 nir_cf_delete(&break_list);
216
217 _mesa_hash_table_destroy(remap_table, NULL);
218 }
219
220 static void
221 move_cf_list_into_loop_term(nir_cf_list *lst, nir_loop_terminator *term)
222 {
223 /* Move the rest of the loop inside the continue-from-block */
224 nir_cf_reinsert(lst, nir_after_block(term->continue_from_block));
225
226 /* Remove the break */
227 nir_instr_remove(nir_block_last_instr(term->break_block));
228 }
229
230 static nir_cursor
231 get_complex_unroll_insert_location(nir_cf_node *node, bool continue_from_then)
232 {
233 if (node->type == nir_cf_node_loop) {
234 return nir_before_cf_node(node);
235 } else {
236 nir_if *if_stmt = nir_cf_node_as_if(node);
237 if (continue_from_then) {
238 return nir_after_block(nir_if_last_then_block(if_stmt));
239 } else {
240 return nir_after_block(nir_if_last_else_block(if_stmt));
241 }
242 }
243 }
244
245 /**
246 * Unroll a loop with two exists when the trip count of one of the exits is
247 * unknown. If continue_from_then is true, the loop is repeated only when the
248 * "then" branch of the if is taken; otherwise it is repeated only
249 * when the "else" branch of the if is taken.
250 *
251 * For example, if the input is:
252 *
253 * loop {
254 * ...phis/condition...
255 * if condition {
256 * ...then instructions...
257 * } else {
258 * ...continue instructions...
259 * break
260 * }
261 * ...body...
262 * }
263 *
264 * And the iteration count is 3, and unlimit_term->continue_from_then is true,
265 * then the output will be:
266 *
267 * ...condition...
268 * if condition {
269 * ...then instructions...
270 * ...body...
271 * if condition {
272 * ...then instructions...
273 * ...body...
274 * if condition {
275 * ...then instructions...
276 * ...body...
277 * } else {
278 * ...continue instructions...
279 * }
280 * } else {
281 * ...continue instructions...
282 * }
283 * } else {
284 * ...continue instructions...
285 * }
286 */
287 static void
288 complex_unroll(nir_loop *loop, nir_loop_terminator *unlimit_term,
289 bool limiting_term_second)
290 {
291 assert(nir_is_trivial_loop_if(unlimit_term->nif,
292 unlimit_term->break_block));
293
294 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
295 assert(nir_is_trivial_loop_if(limiting_term->nif,
296 limiting_term->break_block));
297
298 loop_prepare_for_unroll(loop);
299
300 nir_block *header_blk = nir_loop_first_block(loop);
301
302 nir_cf_list lp_header;
303 nir_cf_list limit_break_list;
304 unsigned num_times_to_clone;
305 if (limiting_term_second) {
306 /* Pluck out the loop header */
307 nir_cf_extract(&lp_header, nir_before_block(header_blk),
308 nir_before_cf_node(&unlimit_term->nif->cf_node));
309
310 /* We need some special handling when its the second terminator causing
311 * us to exit the loop for example:
312 *
313 * for (int i = 0; i < uniform_lp_count; i++) {
314 * colour = vec4(0.0, 1.0, 0.0, 1.0);
315 *
316 * if (i == 1) {
317 * break;
318 * }
319 * ... any further code is unreachable after i == 1 ...
320 * }
321 */
322 nir_cf_list after_lt;
323 nir_if *limit_if = limiting_term->nif;
324 nir_cf_extract(&after_lt, nir_after_cf_node(&limit_if->cf_node),
325 nir_after_block(nir_loop_last_block(loop)));
326 move_cf_list_into_loop_term(&after_lt, limiting_term);
327
328 /* Because the trip count is the number of times we pass over the entire
329 * loop before hitting a break when the second terminator is the
330 * limiting terminator we can actually execute code inside the loop when
331 * trip count == 0 e.g. the code above the break. So we need to bump
332 * the trip_count in order for the code below to clone anything. When
333 * trip count == 1 we execute the code above the break twice and the
334 * code below it once so we need clone things twice and so on.
335 */
336 num_times_to_clone = loop->info->trip_count + 1;
337 } else {
338 /* Pluck out the loop header */
339 nir_cf_extract(&lp_header, nir_before_block(header_blk),
340 nir_before_cf_node(&limiting_term->nif->cf_node));
341
342 nir_block *first_break_block;
343 nir_block *first_continue_block;
344 get_first_blocks_in_terminator(limiting_term, &first_break_block,
345 &first_continue_block);
346
347 /* Remove the break then extract instructions from the break block so we
348 * can insert them in the innermost else of the unrolled loop.
349 */
350 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
351 nir_instr_remove(break_instr);
352 nir_cf_extract(&limit_break_list, nir_before_block(first_break_block),
353 nir_after_block(limiting_term->break_block));
354
355 nir_cf_list continue_list;
356 nir_cf_extract(&continue_list, nir_before_block(first_continue_block),
357 nir_after_block(limiting_term->continue_from_block));
358
359 nir_cf_reinsert(&continue_list,
360 nir_after_cf_node(&limiting_term->nif->cf_node));
361
362 nir_cf_node_remove(&limiting_term->nif->cf_node);
363
364 num_times_to_clone = loop->info->trip_count;
365 }
366
367 /* In the terminator that we have no trip count for move everything after
368 * the terminator into the continue from branch.
369 */
370 nir_cf_list loop_end;
371 nir_cf_extract(&loop_end, nir_after_cf_node(&unlimit_term->nif->cf_node),
372 nir_after_block(nir_loop_last_block(loop)));
373 move_cf_list_into_loop_term(&loop_end, unlimit_term);
374
375 /* Pluck out the loop body. */
376 nir_cf_list loop_body;
377 nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
378 nir_after_block(nir_loop_last_block(loop)));
379
380 struct hash_table *remap_table =
381 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
382 _mesa_key_pointer_equal);
383
384 /* Set unroll_loc to the loop as we will insert the unrolled loop before it
385 */
386 nir_cf_node *unroll_loc = &loop->cf_node;
387
388 /* Temp lists to store the cloned loop as we unroll */
389 nir_cf_list unrolled_lp_body;
390 nir_cf_list cloned_header;
391
392 for (unsigned i = 0; i < num_times_to_clone; i++) {
393 /* Clone loop header */
394 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
395 remap_table);
396
397 nir_cursor cursor =
398 get_complex_unroll_insert_location(unroll_loc,
399 unlimit_term->continue_from_then);
400
401 /* Insert cloned loop header */
402 nir_cf_reinsert(&cloned_header, cursor);
403
404 cursor =
405 get_complex_unroll_insert_location(unroll_loc,
406 unlimit_term->continue_from_then);
407
408 /* Clone loop body */
409 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
410 remap_table);
411
412 unroll_loc = exec_node_data(nir_cf_node,
413 exec_list_get_tail(&unrolled_lp_body.list),
414 node);
415 assert(unroll_loc->type == nir_cf_node_block &&
416 exec_list_is_empty(&nir_cf_node_as_block(unroll_loc)->instr_list));
417
418 /* Get the unrolled if node */
419 unroll_loc = nir_cf_node_prev(unroll_loc);
420
421 /* Insert unrolled loop body */
422 nir_cf_reinsert(&unrolled_lp_body, cursor);
423 }
424
425 if (!limiting_term_second) {
426 assert(unroll_loc->type == nir_cf_node_if);
427
428 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
429 remap_table);
430
431 nir_cursor cursor =
432 get_complex_unroll_insert_location(unroll_loc,
433 unlimit_term->continue_from_then);
434
435 /* Insert cloned loop header */
436 nir_cf_reinsert(&cloned_header, cursor);
437
438 /* Clone so things get properly remapped, and insert break block from
439 * the limiting terminator.
440 */
441 nir_cf_list cloned_break_blk;
442 nir_cf_list_clone(&cloned_break_blk, &limit_break_list,
443 loop->cf_node.parent, remap_table);
444
445 cursor =
446 get_complex_unroll_insert_location(unroll_loc,
447 unlimit_term->continue_from_then);
448
449 nir_cf_reinsert(&cloned_break_blk, cursor);
450 nir_cf_delete(&limit_break_list);
451 }
452
453 /* The loop has been unrolled so remove it. */
454 nir_cf_node_remove(&loop->cf_node);
455
456 /* Delete the original loop header and body */
457 nir_cf_delete(&lp_header);
458 nir_cf_delete(&loop_body);
459
460 _mesa_hash_table_destroy(remap_table, NULL);
461 }
462
463 static bool
464 is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
465 {
466 unsigned max_iter = shader->options->max_unroll_iterations;
467
468 if (li->trip_count > max_iter)
469 return false;
470
471 if (li->force_unroll)
472 return true;
473
474 bool loop_not_too_large =
475 li->num_instructions * li->trip_count <= max_iter * LOOP_UNROLL_LIMIT;
476
477 return loop_not_too_large;
478 }
479
480 static bool
481 process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *innermost_loop)
482 {
483 bool progress = false;
484 nir_loop *loop;
485
486 switch (cf_node->type) {
487 case nir_cf_node_block:
488 return progress;
489 case nir_cf_node_if: {
490 nir_if *if_stmt = nir_cf_node_as_if(cf_node);
491 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
492 progress |= process_loops(sh, nested_node, innermost_loop);
493 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
494 progress |= process_loops(sh, nested_node, innermost_loop);
495 return progress;
496 }
497 case nir_cf_node_loop: {
498 loop = nir_cf_node_as_loop(cf_node);
499 foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
500 progress |= process_loops(sh, nested_node, innermost_loop);
501 break;
502 }
503 default:
504 unreachable("unknown cf node type");
505 }
506
507 if (*innermost_loop) {
508 /* Don't attempt to unroll outer loops or a second inner loop in
509 * this pass wait until the next pass as we have altered the cf.
510 */
511 *innermost_loop = false;
512
513 if (loop->info->limiting_terminator == NULL)
514 return progress;
515
516 if (!is_loop_small_enough_to_unroll(sh, loop->info))
517 return progress;
518
519 if (loop->info->is_trip_count_known) {
520 simple_unroll(loop);
521 progress = true;
522 } else {
523 /* Attempt to unroll loops with two terminators. */
524 unsigned num_lt = list_length(&loop->info->loop_terminator_list);
525 if (num_lt == 2) {
526 bool limiting_term_second = true;
527 nir_loop_terminator *terminator =
528 list_last_entry(&loop->info->loop_terminator_list,
529 nir_loop_terminator, loop_terminator_link);
530
531
532 if (terminator->nif == loop->info->limiting_terminator->nif) {
533 limiting_term_second = false;
534 terminator =
535 list_first_entry(&loop->info->loop_terminator_list,
536 nir_loop_terminator, loop_terminator_link);
537 }
538
539 /* If the first terminator has a trip count of zero and is the
540 * limiting terminator just do a simple unroll as the second
541 * terminator can never be reached.
542 */
543 if (loop->info->trip_count == 0 && !limiting_term_second) {
544 simple_unroll(loop);
545 } else {
546 complex_unroll(loop, terminator, limiting_term_second);
547 }
548 progress = true;
549 }
550 }
551 }
552
553 return progress;
554 }
555
556 static bool
557 nir_opt_loop_unroll_impl(nir_function_impl *impl,
558 nir_variable_mode indirect_mask)
559 {
560 bool progress = false;
561 nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
562 nir_metadata_require(impl, nir_metadata_block_index);
563
564 foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
565 bool innermost_loop = true;
566 progress |= process_loops(impl->function->shader, node,
567 &innermost_loop);
568 }
569
570 if (progress)
571 nir_lower_regs_to_ssa_impl(impl);
572
573 return progress;
574 }
575
576 bool
577 nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask)
578 {
579 bool progress = false;
580
581 nir_foreach_function(function, shader) {
582 if (function->impl) {
583 progress |= nir_opt_loop_unroll_impl(function->impl, indirect_mask);
584 }
585 }
586 return progress;
587 }