nir: fix crash in loop unroll corner case
[mesa.git] / src / compiler / nir / nir_opt_loop_unroll.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 * DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_loop_analyze.h"
28
29
30 /* This limit is chosen fairly arbitrarily. GLSL IR max iteration is 32
31 * instructions. (Multiply counting nodes and magic number 5.) But there is
32 * no 1:1 mapping between GLSL IR and NIR so 25 was picked because it seemed
33 * to give about the same results. Around 5 instructions per node. But some
34 * loops that would unroll with GLSL IR fail to unroll if we set this to 25 so
35 * we set it to 26.
36 */
37 #define LOOP_UNROLL_LIMIT 26
38
39 /* Prepare this loop for unrolling by first converting to lcssa and then
40 * converting the phis from the top level of the loop body to regs.
41 * Partially converting out of SSA allows us to unroll the loop without having
42 * to keep track of and update phis along the way which gets tricky and
43 * doesn't add much value over converting to regs.
44 *
45 * The loop may have a continue instruction at the end of the loop which does
46 * nothing. Once we're out of SSA, we can safely delete it so we don't have
47 * to deal with it later.
48 */
49 static void
50 loop_prepare_for_unroll(nir_loop *loop)
51 {
52 nir_convert_loop_to_lcssa(loop);
53
54 /* Lower phis at the top level of the loop body */
55 foreach_list_typed_safe(nir_cf_node, node, node, &loop->body) {
56 if (nir_cf_node_block == node->type) {
57 nir_lower_phis_to_regs_block(nir_cf_node_as_block(node));
58 }
59 }
60
61 /* Lower phis after the loop */
62 nir_block *block_after_loop =
63 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
64
65 nir_lower_phis_to_regs_block(block_after_loop);
66
67 /* Remove continue if its the last instruction in the loop */
68 nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
69 if (last_instr && last_instr->type == nir_instr_type_jump) {
70 assert(nir_instr_as_jump(last_instr)->type == nir_jump_continue);
71 nir_instr_remove(last_instr);
72 }
73 }
74
75 static void
76 get_first_blocks_in_terminator(nir_loop_terminator *term,
77 nir_block **first_break_block,
78 nir_block **first_continue_block)
79 {
80 if (term->continue_from_then) {
81 *first_continue_block = nir_if_first_then_block(term->nif);
82 *first_break_block = nir_if_first_else_block(term->nif);
83 } else {
84 *first_continue_block = nir_if_first_else_block(term->nif);
85 *first_break_block = nir_if_first_then_block(term->nif);
86 }
87 }
88
89 /**
90 * Unroll a loop where we know exactly how many iterations there are and there
91 * is only a single exit point. Note here we can unroll loops with multiple
92 * theoretical exits that only have a single terminating exit that we always
93 * know is the "real" exit.
94 *
95 * loop {
96 * ...instrs...
97 * }
98 *
99 * And the iteration count is 3, the output will be:
100 *
101 * ...instrs... ...instrs... ...instrs...
102 */
103 static void
104 simple_unroll(nir_loop *loop)
105 {
106 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
107 assert(nir_is_trivial_loop_if(limiting_term->nif,
108 limiting_term->break_block));
109
110 loop_prepare_for_unroll(loop);
111
112 /* Skip over loop terminator and get the loop body. */
113 list_for_each_entry(nir_loop_terminator, terminator,
114 &loop->info->loop_terminator_list,
115 loop_terminator_link) {
116
117 /* Remove all but the limiting terminator as we know the other exit
118 * conditions can never be met. Note we need to extract any instructions
119 * in the continue from branch and insert then into the loop body before
120 * removing it.
121 */
122 if (terminator->nif != limiting_term->nif) {
123 nir_block *first_break_block;
124 nir_block *first_continue_block;
125 get_first_blocks_in_terminator(terminator, &first_break_block,
126 &first_continue_block);
127
128 assert(nir_is_trivial_loop_if(terminator->nif,
129 terminator->break_block));
130
131 nir_cf_list continue_from_lst;
132 nir_cf_extract(&continue_from_lst,
133 nir_before_block(first_continue_block),
134 nir_after_block(terminator->continue_from_block));
135 nir_cf_reinsert(&continue_from_lst,
136 nir_after_cf_node(&terminator->nif->cf_node));
137
138 nir_cf_node_remove(&terminator->nif->cf_node);
139 }
140 }
141
142 nir_block *first_break_block;
143 nir_block *first_continue_block;
144 get_first_blocks_in_terminator(limiting_term, &first_break_block,
145 &first_continue_block);
146
147 /* Pluck out the loop header */
148 nir_block *header_blk = nir_loop_first_block(loop);
149 nir_cf_list lp_header;
150 nir_cf_extract(&lp_header, nir_before_block(header_blk),
151 nir_before_cf_node(&limiting_term->nif->cf_node));
152
153 /* Add the continue from block of the limiting terminator to the loop body
154 */
155 nir_cf_list continue_from_lst;
156 nir_cf_extract(&continue_from_lst, nir_before_block(first_continue_block),
157 nir_after_block(limiting_term->continue_from_block));
158 nir_cf_reinsert(&continue_from_lst,
159 nir_after_cf_node(&limiting_term->nif->cf_node));
160
161 /* Pluck out the loop body */
162 nir_cf_list loop_body;
163 nir_cf_extract(&loop_body, nir_after_cf_node(&limiting_term->nif->cf_node),
164 nir_after_block(nir_loop_last_block(loop)));
165
166 struct hash_table *remap_table =
167 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
168 _mesa_key_pointer_equal);
169
170 /* Clone the loop header */
171 nir_cf_list cloned_header;
172 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
173 remap_table);
174
175 /* Insert cloned loop header before the loop */
176 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
177
178 /* Temp list to store the cloned loop body as we unroll */
179 nir_cf_list unrolled_lp_body;
180
181 /* Clone loop header and append to the loop body */
182 for (unsigned i = 0; i < loop->info->trip_count; i++) {
183 /* Clone loop body */
184 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
185 remap_table);
186
187 /* Insert unrolled loop body before the loop */
188 nir_cf_reinsert(&unrolled_lp_body, nir_before_cf_node(&loop->cf_node));
189
190 /* Clone loop header */
191 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
192 remap_table);
193
194 /* Insert loop header after loop body */
195 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
196 }
197
198 /* Remove the break from the loop terminator and add instructions from
199 * the break block after the unrolled loop.
200 */
201 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
202 nir_instr_remove(break_instr);
203 nir_cf_list break_list;
204 nir_cf_extract(&break_list, nir_before_block(first_break_block),
205 nir_after_block(limiting_term->break_block));
206
207 /* Clone so things get properly remapped */
208 nir_cf_list cloned_break_list;
209 nir_cf_list_clone(&cloned_break_list, &break_list, loop->cf_node.parent,
210 remap_table);
211
212 nir_cf_reinsert(&cloned_break_list, nir_before_cf_node(&loop->cf_node));
213
214 /* Remove the loop */
215 nir_cf_node_remove(&loop->cf_node);
216
217 /* Delete the original loop body, break block & header */
218 nir_cf_delete(&lp_header);
219 nir_cf_delete(&loop_body);
220 nir_cf_delete(&break_list);
221
222 _mesa_hash_table_destroy(remap_table, NULL);
223 }
224
225 static void
226 move_cf_list_into_loop_term(nir_cf_list *lst, nir_loop_terminator *term)
227 {
228 /* Move the rest of the loop inside the continue-from-block */
229 nir_cf_reinsert(lst, nir_after_block(term->continue_from_block));
230
231 /* Remove the break */
232 nir_instr_remove(nir_block_last_instr(term->break_block));
233 }
234
235 static nir_cursor
236 get_complex_unroll_insert_location(nir_cf_node *node, bool continue_from_then)
237 {
238 if (node->type == nir_cf_node_loop) {
239 return nir_before_cf_node(node);
240 } else {
241 nir_if *if_stmt = nir_cf_node_as_if(node);
242 if (continue_from_then) {
243 return nir_after_block(nir_if_last_then_block(if_stmt));
244 } else {
245 return nir_after_block(nir_if_last_else_block(if_stmt));
246 }
247 }
248 }
249
250 /**
251 * Unroll a loop with two exists when the trip count of one of the exits is
252 * unknown. If continue_from_then is true, the loop is repeated only when the
253 * "then" branch of the if is taken; otherwise it is repeated only
254 * when the "else" branch of the if is taken.
255 *
256 * For example, if the input is:
257 *
258 * loop {
259 * ...phis/condition...
260 * if condition {
261 * ...then instructions...
262 * } else {
263 * ...continue instructions...
264 * break
265 * }
266 * ...body...
267 * }
268 *
269 * And the iteration count is 3, and unlimit_term->continue_from_then is true,
270 * then the output will be:
271 *
272 * ...condition...
273 * if condition {
274 * ...then instructions...
275 * ...body...
276 * if condition {
277 * ...then instructions...
278 * ...body...
279 * if condition {
280 * ...then instructions...
281 * ...body...
282 * } else {
283 * ...continue instructions...
284 * }
285 * } else {
286 * ...continue instructions...
287 * }
288 * } else {
289 * ...continue instructions...
290 * }
291 */
292 static void
293 complex_unroll(nir_loop *loop, nir_loop_terminator *unlimit_term,
294 bool limiting_term_second)
295 {
296 assert(nir_is_trivial_loop_if(unlimit_term->nif,
297 unlimit_term->break_block));
298
299 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
300 assert(nir_is_trivial_loop_if(limiting_term->nif,
301 limiting_term->break_block));
302
303 loop_prepare_for_unroll(loop);
304
305 nir_block *header_blk = nir_loop_first_block(loop);
306
307 nir_cf_list lp_header;
308 nir_cf_list limit_break_list;
309 unsigned num_times_to_clone;
310 if (limiting_term_second) {
311 /* Pluck out the loop header */
312 nir_cf_extract(&lp_header, nir_before_block(header_blk),
313 nir_before_cf_node(&unlimit_term->nif->cf_node));
314
315 /* We need some special handling when its the second terminator causing
316 * us to exit the loop for example:
317 *
318 * for (int i = 0; i < uniform_lp_count; i++) {
319 * colour = vec4(0.0, 1.0, 0.0, 1.0);
320 *
321 * if (i == 1) {
322 * break;
323 * }
324 * ... any further code is unreachable after i == 1 ...
325 * }
326 */
327 nir_cf_list after_lt;
328 nir_if *limit_if = limiting_term->nif;
329 nir_cf_extract(&after_lt, nir_after_cf_node(&limit_if->cf_node),
330 nir_after_block(nir_loop_last_block(loop)));
331 move_cf_list_into_loop_term(&after_lt, limiting_term);
332
333 /* Because the trip count is the number of times we pass over the entire
334 * loop before hitting a break when the second terminator is the
335 * limiting terminator we can actually execute code inside the loop when
336 * trip count == 0 e.g. the code above the break. So we need to bump
337 * the trip_count in order for the code below to clone anything. When
338 * trip count == 1 we execute the code above the break twice and the
339 * code below it once so we need clone things twice and so on.
340 */
341 num_times_to_clone = loop->info->trip_count + 1;
342 } else {
343 /* Pluck out the loop header */
344 nir_cf_extract(&lp_header, nir_before_block(header_blk),
345 nir_before_cf_node(&limiting_term->nif->cf_node));
346
347 nir_block *first_break_block;
348 nir_block *first_continue_block;
349 get_first_blocks_in_terminator(limiting_term, &first_break_block,
350 &first_continue_block);
351
352 /* Remove the break then extract instructions from the break block so we
353 * can insert them in the innermost else of the unrolled loop.
354 */
355 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
356 nir_instr_remove(break_instr);
357 nir_cf_extract(&limit_break_list, nir_before_block(first_break_block),
358 nir_after_block(limiting_term->break_block));
359
360 nir_cf_list continue_list;
361 nir_cf_extract(&continue_list, nir_before_block(first_continue_block),
362 nir_after_block(limiting_term->continue_from_block));
363
364 nir_cf_reinsert(&continue_list,
365 nir_after_cf_node(&limiting_term->nif->cf_node));
366
367 nir_cf_node_remove(&limiting_term->nif->cf_node);
368
369 num_times_to_clone = loop->info->trip_count;
370 }
371
372 /* In the terminator that we have no trip count for move everything after
373 * the terminator into the continue from branch.
374 */
375 nir_cf_list loop_end;
376 nir_cf_extract(&loop_end, nir_after_cf_node(&unlimit_term->nif->cf_node),
377 nir_after_block(nir_loop_last_block(loop)));
378 move_cf_list_into_loop_term(&loop_end, unlimit_term);
379
380 /* Pluck out the loop body. */
381 nir_cf_list loop_body;
382 nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
383 nir_after_block(nir_loop_last_block(loop)));
384
385 struct hash_table *remap_table =
386 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
387 _mesa_key_pointer_equal);
388
389 /* Set unroll_loc to the loop as we will insert the unrolled loop before it
390 */
391 nir_cf_node *unroll_loc = &loop->cf_node;
392
393 /* Temp lists to store the cloned loop as we unroll */
394 nir_cf_list unrolled_lp_body;
395 nir_cf_list cloned_header;
396
397 for (unsigned i = 0; i < num_times_to_clone; i++) {
398 /* Clone loop header */
399 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
400 remap_table);
401
402 nir_cursor cursor =
403 get_complex_unroll_insert_location(unroll_loc,
404 unlimit_term->continue_from_then);
405
406 /* Insert cloned loop header */
407 nir_cf_reinsert(&cloned_header, cursor);
408
409 cursor =
410 get_complex_unroll_insert_location(unroll_loc,
411 unlimit_term->continue_from_then);
412
413 /* Clone loop body */
414 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
415 remap_table);
416
417 unroll_loc = exec_node_data(nir_cf_node,
418 exec_list_get_tail(&unrolled_lp_body.list),
419 node);
420 assert(unroll_loc->type == nir_cf_node_block &&
421 exec_list_is_empty(&nir_cf_node_as_block(unroll_loc)->instr_list));
422
423 /* Get the unrolled if node */
424 unroll_loc = nir_cf_node_prev(unroll_loc);
425
426 /* Insert unrolled loop body */
427 nir_cf_reinsert(&unrolled_lp_body, cursor);
428 }
429
430 if (!limiting_term_second) {
431 assert(unroll_loc->type == nir_cf_node_if);
432
433 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
434 remap_table);
435
436 nir_cursor cursor =
437 get_complex_unroll_insert_location(unroll_loc,
438 unlimit_term->continue_from_then);
439
440 /* Insert cloned loop header */
441 nir_cf_reinsert(&cloned_header, cursor);
442
443 /* Clone so things get properly remapped, and insert break block from
444 * the limiting terminator.
445 */
446 nir_cf_list cloned_break_blk;
447 nir_cf_list_clone(&cloned_break_blk, &limit_break_list,
448 loop->cf_node.parent, remap_table);
449
450 cursor =
451 get_complex_unroll_insert_location(unroll_loc,
452 unlimit_term->continue_from_then);
453
454 nir_cf_reinsert(&cloned_break_blk, cursor);
455 nir_cf_delete(&limit_break_list);
456 }
457
458 /* The loop has been unrolled so remove it. */
459 nir_cf_node_remove(&loop->cf_node);
460
461 /* Delete the original loop header and body */
462 nir_cf_delete(&lp_header);
463 nir_cf_delete(&loop_body);
464
465 _mesa_hash_table_destroy(remap_table, NULL);
466 }
467
468 static bool
469 is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
470 {
471 unsigned max_iter = shader->options->max_unroll_iterations;
472
473 if (li->trip_count > max_iter)
474 return false;
475
476 if (li->force_unroll)
477 return true;
478
479 bool loop_not_too_large =
480 li->num_instructions * li->trip_count <= max_iter * LOOP_UNROLL_LIMIT;
481
482 return loop_not_too_large;
483 }
484
485 static bool
486 process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *innermost_loop)
487 {
488 bool progress = false;
489 nir_loop *loop;
490
491 switch (cf_node->type) {
492 case nir_cf_node_block:
493 return progress;
494 case nir_cf_node_if: {
495 nir_if *if_stmt = nir_cf_node_as_if(cf_node);
496 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
497 progress |= process_loops(sh, nested_node, innermost_loop);
498 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
499 progress |= process_loops(sh, nested_node, innermost_loop);
500 return progress;
501 }
502 case nir_cf_node_loop: {
503 loop = nir_cf_node_as_loop(cf_node);
504 foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
505 progress |= process_loops(sh, nested_node, innermost_loop);
506 break;
507 }
508 default:
509 unreachable("unknown cf node type");
510 }
511
512 if (*innermost_loop) {
513 /* Don't attempt to unroll outer loops or a second inner loop in
514 * this pass wait until the next pass as we have altered the cf.
515 */
516 *innermost_loop = false;
517
518 if (loop->info->limiting_terminator == NULL)
519 return progress;
520
521 if (!is_loop_small_enough_to_unroll(sh, loop->info))
522 return progress;
523
524 if (loop->info->is_trip_count_known) {
525 simple_unroll(loop);
526 progress = true;
527 } else {
528 /* Attempt to unroll loops with two terminators. */
529 unsigned num_lt = list_length(&loop->info->loop_terminator_list);
530 if (num_lt == 2) {
531 bool limiting_term_second = true;
532 nir_loop_terminator *terminator =
533 list_last_entry(&loop->info->loop_terminator_list,
534 nir_loop_terminator, loop_terminator_link);
535
536
537 if (terminator->nif == loop->info->limiting_terminator->nif) {
538 limiting_term_second = false;
539 terminator =
540 list_first_entry(&loop->info->loop_terminator_list,
541 nir_loop_terminator, loop_terminator_link);
542 }
543
544 /* If the first terminator has a trip count of zero and is the
545 * limiting terminator just do a simple unroll as the second
546 * terminator can never be reached.
547 */
548 if (loop->info->trip_count == 0 && !limiting_term_second) {
549 simple_unroll(loop);
550 } else {
551 complex_unroll(loop, terminator, limiting_term_second);
552 }
553 progress = true;
554 }
555 }
556 }
557
558 return progress;
559 }
560
561 static bool
562 nir_opt_loop_unroll_impl(nir_function_impl *impl,
563 nir_variable_mode indirect_mask)
564 {
565 bool progress = false;
566 nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
567 nir_metadata_require(impl, nir_metadata_block_index);
568
569 foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
570 bool innermost_loop = true;
571 progress |= process_loops(impl->function->shader, node,
572 &innermost_loop);
573 }
574
575 if (progress)
576 nir_lower_regs_to_ssa_impl(impl);
577
578 return progress;
579 }
580
581 bool
582 nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask)
583 {
584 bool progress = false;
585
586 nir_foreach_function(function, shader) {
587 if (function->impl) {
588 progress |= nir_opt_loop_unroll_impl(function->impl, indirect_mask);
589 }
590 }
591 return progress;
592 }