Revert "nir: bump loop unroll limit to 96."
[mesa.git] / src / compiler / nir / nir_opt_loop_unroll.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 * DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_loop_analyze.h"
28
29
30 /* This limit is chosen fairly arbitrarily. GLSL IR max iteration is 32
31 * instructions. (Multiply counting nodes and magic number 5.) But there is
32 * no 1:1 mapping between GLSL IR and NIR so 25 was picked because it seemed
33 * to give about the same results. Around 5 instructions per node. But some
34 * loops that would unroll with GLSL IR fail to unroll if we set this to 25 so
35 * we set it to 26.
36 */
37 #define LOOP_UNROLL_LIMIT 26
38
39 /* Prepare this loop for unrolling by first converting to lcssa and then
40 * converting the phis from the loops first block and the block that follows
41 * the loop into regs. Partially converting out of SSA allows us to unroll
42 * the loop without having to keep track of and update phis along the way
43 * which gets tricky and doesn't add much value over conveting to regs.
44 *
45 * The loop may have a continue instruction at the end of the loop which does
46 * nothing. Once we're out of SSA, we can safely delete it so we don't have
47 * to deal with it later.
48 */
49 static void
50 loop_prepare_for_unroll(nir_loop *loop)
51 {
52 nir_convert_loop_to_lcssa(loop);
53
54 nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
55
56 nir_block *block_after_loop =
57 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
58
59 nir_lower_phis_to_regs_block(block_after_loop);
60
61 nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
62 if (last_instr && last_instr->type == nir_instr_type_jump) {
63 assert(nir_instr_as_jump(last_instr)->type == nir_jump_continue);
64 nir_instr_remove(last_instr);
65 }
66 }
67
68 static void
69 get_first_blocks_in_terminator(nir_loop_terminator *term,
70 nir_block **first_break_block,
71 nir_block **first_continue_block)
72 {
73 if (term->continue_from_then) {
74 *first_continue_block = nir_if_first_then_block(term->nif);
75 *first_break_block = nir_if_first_else_block(term->nif);
76 } else {
77 *first_continue_block = nir_if_first_else_block(term->nif);
78 *first_break_block = nir_if_first_then_block(term->nif);
79 }
80 }
81
82 /**
83 * Unroll a loop where we know exactly how many iterations there are and there
84 * is only a single exit point. Note here we can unroll loops with multiple
85 * theoretical exits that only have a single terminating exit that we always
86 * know is the "real" exit.
87 *
88 * loop {
89 * ...instrs...
90 * }
91 *
92 * And the iteration count is 3, the output will be:
93 *
94 * ...instrs... ...instrs... ...instrs...
95 */
96 static void
97 simple_unroll(nir_loop *loop)
98 {
99 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
100 assert(nir_is_trivial_loop_if(limiting_term->nif,
101 limiting_term->break_block));
102
103 loop_prepare_for_unroll(loop);
104
105 /* Skip over loop terminator and get the loop body. */
106 list_for_each_entry(nir_loop_terminator, terminator,
107 &loop->info->loop_terminator_list,
108 loop_terminator_link) {
109
110 /* Remove all but the limiting terminator as we know the other exit
111 * conditions can never be met. Note we need to extract any instructions
112 * in the continue from branch and insert then into the loop body before
113 * removing it.
114 */
115 if (terminator->nif != limiting_term->nif) {
116 nir_block *first_break_block;
117 nir_block *first_continue_block;
118 get_first_blocks_in_terminator(terminator, &first_break_block,
119 &first_continue_block);
120
121 assert(nir_is_trivial_loop_if(terminator->nif,
122 terminator->break_block));
123
124 nir_cf_list continue_from_lst;
125 nir_cf_extract(&continue_from_lst,
126 nir_before_block(first_continue_block),
127 nir_after_block(terminator->continue_from_block));
128 nir_cf_reinsert(&continue_from_lst,
129 nir_after_cf_node(&terminator->nif->cf_node));
130
131 nir_cf_node_remove(&terminator->nif->cf_node);
132 }
133 }
134
135 nir_block *first_break_block;
136 nir_block *first_continue_block;
137 get_first_blocks_in_terminator(limiting_term, &first_break_block,
138 &first_continue_block);
139
140 /* Pluck out the loop header */
141 nir_block *header_blk = nir_loop_first_block(loop);
142 nir_cf_list lp_header;
143 nir_cf_extract(&lp_header, nir_before_block(header_blk),
144 nir_before_cf_node(&limiting_term->nif->cf_node));
145
146 /* Add the continue from block of the limiting terminator to the loop body
147 */
148 nir_cf_list continue_from_lst;
149 nir_cf_extract(&continue_from_lst, nir_before_block(first_continue_block),
150 nir_after_block(limiting_term->continue_from_block));
151 nir_cf_reinsert(&continue_from_lst,
152 nir_after_cf_node(&limiting_term->nif->cf_node));
153
154 /* Pluck out the loop body */
155 nir_cf_list loop_body;
156 nir_cf_extract(&loop_body, nir_after_cf_node(&limiting_term->nif->cf_node),
157 nir_after_block(nir_loop_last_block(loop)));
158
159 struct hash_table *remap_table =
160 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
161 _mesa_key_pointer_equal);
162
163 /* Clone the loop header */
164 nir_cf_list cloned_header;
165 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
166 remap_table);
167
168 /* Insert cloned loop header before the loop */
169 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
170
171 /* Temp list to store the cloned loop body as we unroll */
172 nir_cf_list unrolled_lp_body;
173
174 /* Clone loop header and append to the loop body */
175 for (unsigned i = 0; i < loop->info->trip_count; i++) {
176 /* Clone loop body */
177 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
178 remap_table);
179
180 /* Insert unrolled loop body before the loop */
181 nir_cf_reinsert(&unrolled_lp_body, nir_before_cf_node(&loop->cf_node));
182
183 /* Clone loop header */
184 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
185 remap_table);
186
187 /* Insert loop header after loop body */
188 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
189 }
190
191 /* Remove the break from the loop terminator and add instructions from
192 * the break block after the unrolled loop.
193 */
194 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
195 nir_instr_remove(break_instr);
196 nir_cf_list break_list;
197 nir_cf_extract(&break_list, nir_before_block(first_break_block),
198 nir_after_block(limiting_term->break_block));
199
200 /* Clone so things get properly remapped */
201 nir_cf_list cloned_break_list;
202 nir_cf_list_clone(&cloned_break_list, &break_list, loop->cf_node.parent,
203 remap_table);
204
205 nir_cf_reinsert(&cloned_break_list, nir_before_cf_node(&loop->cf_node));
206
207 /* Remove the loop */
208 nir_cf_node_remove(&loop->cf_node);
209
210 /* Delete the original loop body, break block & header */
211 nir_cf_delete(&lp_header);
212 nir_cf_delete(&loop_body);
213 nir_cf_delete(&break_list);
214
215 _mesa_hash_table_destroy(remap_table, NULL);
216 }
217
218 static void
219 move_cf_list_into_loop_term(nir_cf_list *lst, nir_loop_terminator *term)
220 {
221 /* Move the rest of the loop inside the continue-from-block */
222 nir_cf_reinsert(lst, nir_after_block(term->continue_from_block));
223
224 /* Remove the break */
225 nir_instr_remove(nir_block_last_instr(term->break_block));
226 }
227
228 static nir_cursor
229 get_complex_unroll_insert_location(nir_cf_node *node, bool continue_from_then)
230 {
231 if (node->type == nir_cf_node_loop) {
232 return nir_before_cf_node(node);
233 } else {
234 nir_if *if_stmt = nir_cf_node_as_if(node);
235 if (continue_from_then) {
236 return nir_after_block(nir_if_last_then_block(if_stmt));
237 } else {
238 return nir_after_block(nir_if_last_else_block(if_stmt));
239 }
240 }
241 }
242
243 /**
244 * Unroll a loop with two exists when the trip count of one of the exits is
245 * unknown. If continue_from_then is true, the loop is repeated only when the
246 * "then" branch of the if is taken; otherwise it is repeated only
247 * when the "else" branch of the if is taken.
248 *
249 * For example, if the input is:
250 *
251 * loop {
252 * ...phis/condition...
253 * if condition {
254 * ...then instructions...
255 * } else {
256 * ...continue instructions...
257 * break
258 * }
259 * ...body...
260 * }
261 *
262 * And the iteration count is 3, and unlimit_term->continue_from_then is true,
263 * then the output will be:
264 *
265 * ...condition...
266 * if condition {
267 * ...then instructions...
268 * ...body...
269 * if condition {
270 * ...then instructions...
271 * ...body...
272 * if condition {
273 * ...then instructions...
274 * ...body...
275 * } else {
276 * ...continue instructions...
277 * }
278 * } else {
279 * ...continue instructions...
280 * }
281 * } else {
282 * ...continue instructions...
283 * }
284 */
285 static void
286 complex_unroll(nir_loop *loop, nir_loop_terminator *unlimit_term,
287 bool limiting_term_second)
288 {
289 assert(nir_is_trivial_loop_if(unlimit_term->nif,
290 unlimit_term->break_block));
291
292 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
293 assert(nir_is_trivial_loop_if(limiting_term->nif,
294 limiting_term->break_block));
295
296 loop_prepare_for_unroll(loop);
297
298 nir_block *header_blk = nir_loop_first_block(loop);
299
300 nir_cf_list lp_header;
301 nir_cf_list limit_break_list;
302 unsigned num_times_to_clone;
303 if (limiting_term_second) {
304 /* Pluck out the loop header */
305 nir_cf_extract(&lp_header, nir_before_block(header_blk),
306 nir_before_cf_node(&unlimit_term->nif->cf_node));
307
308 /* We need some special handling when its the second terminator causing
309 * us to exit the loop for example:
310 *
311 * for (int i = 0; i < uniform_lp_count; i++) {
312 * colour = vec4(0.0, 1.0, 0.0, 1.0);
313 *
314 * if (i == 1) {
315 * break;
316 * }
317 * ... any further code is unreachable after i == 1 ...
318 * }
319 */
320 nir_cf_list after_lt;
321 nir_if *limit_if = limiting_term->nif;
322 nir_cf_extract(&after_lt, nir_after_cf_node(&limit_if->cf_node),
323 nir_after_block(nir_loop_last_block(loop)));
324 move_cf_list_into_loop_term(&after_lt, limiting_term);
325
326 /* Because the trip count is the number of times we pass over the entire
327 * loop before hitting a break when the second terminator is the
328 * limiting terminator we can actually execute code inside the loop when
329 * trip count == 0 e.g. the code above the break. So we need to bump
330 * the trip_count in order for the code below to clone anything. When
331 * trip count == 1 we execute the code above the break twice and the
332 * code below it once so we need clone things twice and so on.
333 */
334 num_times_to_clone = loop->info->trip_count + 1;
335 } else {
336 /* Pluck out the loop header */
337 nir_cf_extract(&lp_header, nir_before_block(header_blk),
338 nir_before_cf_node(&limiting_term->nif->cf_node));
339
340 nir_block *first_break_block;
341 nir_block *first_continue_block;
342 get_first_blocks_in_terminator(limiting_term, &first_break_block,
343 &first_continue_block);
344
345 /* Remove the break then extract instructions from the break block so we
346 * can insert them in the innermost else of the unrolled loop.
347 */
348 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
349 nir_instr_remove(break_instr);
350 nir_cf_extract(&limit_break_list, nir_before_block(first_break_block),
351 nir_after_block(limiting_term->break_block));
352
353 nir_cf_list continue_list;
354 nir_cf_extract(&continue_list, nir_before_block(first_continue_block),
355 nir_after_block(limiting_term->continue_from_block));
356
357 nir_cf_reinsert(&continue_list,
358 nir_after_cf_node(&limiting_term->nif->cf_node));
359
360 nir_cf_node_remove(&limiting_term->nif->cf_node);
361
362 num_times_to_clone = loop->info->trip_count;
363 }
364
365 /* In the terminator that we have no trip count for move everything after
366 * the terminator into the continue from branch.
367 */
368 nir_cf_list loop_end;
369 nir_cf_extract(&loop_end, nir_after_cf_node(&unlimit_term->nif->cf_node),
370 nir_after_block(nir_loop_last_block(loop)));
371 move_cf_list_into_loop_term(&loop_end, unlimit_term);
372
373 /* Pluck out the loop body. */
374 nir_cf_list loop_body;
375 nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
376 nir_after_block(nir_loop_last_block(loop)));
377
378 struct hash_table *remap_table =
379 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
380 _mesa_key_pointer_equal);
381
382 /* Set unroll_loc to the loop as we will insert the unrolled loop before it
383 */
384 nir_cf_node *unroll_loc = &loop->cf_node;
385
386 /* Temp lists to store the cloned loop as we unroll */
387 nir_cf_list unrolled_lp_body;
388 nir_cf_list cloned_header;
389
390 for (unsigned i = 0; i < num_times_to_clone; i++) {
391 /* Clone loop header */
392 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
393 remap_table);
394
395 nir_cursor cursor =
396 get_complex_unroll_insert_location(unroll_loc,
397 unlimit_term->continue_from_then);
398
399 /* Insert cloned loop header */
400 nir_cf_reinsert(&cloned_header, cursor);
401
402 cursor =
403 get_complex_unroll_insert_location(unroll_loc,
404 unlimit_term->continue_from_then);
405
406 /* Clone loop body */
407 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
408 remap_table);
409
410 unroll_loc = exec_node_data(nir_cf_node,
411 exec_list_get_tail(&unrolled_lp_body.list),
412 node);
413 assert(unroll_loc->type == nir_cf_node_block &&
414 exec_list_is_empty(&nir_cf_node_as_block(unroll_loc)->instr_list));
415
416 /* Get the unrolled if node */
417 unroll_loc = nir_cf_node_prev(unroll_loc);
418
419 /* Insert unrolled loop body */
420 nir_cf_reinsert(&unrolled_lp_body, cursor);
421 }
422
423 if (!limiting_term_second) {
424 assert(unroll_loc->type == nir_cf_node_if);
425
426 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
427 remap_table);
428
429 nir_cursor cursor =
430 get_complex_unroll_insert_location(unroll_loc,
431 unlimit_term->continue_from_then);
432
433 /* Insert cloned loop header */
434 nir_cf_reinsert(&cloned_header, cursor);
435
436 /* Clone so things get properly remapped, and insert break block from
437 * the limiting terminator.
438 */
439 nir_cf_list cloned_break_blk;
440 nir_cf_list_clone(&cloned_break_blk, &limit_break_list,
441 loop->cf_node.parent, remap_table);
442
443 cursor =
444 get_complex_unroll_insert_location(unroll_loc,
445 unlimit_term->continue_from_then);
446
447 nir_cf_reinsert(&cloned_break_blk, cursor);
448 nir_cf_delete(&limit_break_list);
449 }
450
451 /* The loop has been unrolled so remove it. */
452 nir_cf_node_remove(&loop->cf_node);
453
454 /* Delete the original loop header and body */
455 nir_cf_delete(&lp_header);
456 nir_cf_delete(&loop_body);
457
458 _mesa_hash_table_destroy(remap_table, NULL);
459 }
460
461 static bool
462 is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
463 {
464 unsigned max_iter = shader->options->max_unroll_iterations;
465
466 if (li->trip_count > max_iter)
467 return false;
468
469 if (li->force_unroll)
470 return true;
471
472 bool loop_not_too_large =
473 li->num_instructions * li->trip_count <= max_iter * LOOP_UNROLL_LIMIT;
474
475 return loop_not_too_large;
476 }
477
478 static bool
479 process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *innermost_loop)
480 {
481 bool progress = false;
482 nir_loop *loop;
483
484 switch (cf_node->type) {
485 case nir_cf_node_block:
486 return progress;
487 case nir_cf_node_if: {
488 nir_if *if_stmt = nir_cf_node_as_if(cf_node);
489 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
490 progress |= process_loops(sh, nested_node, innermost_loop);
491 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
492 progress |= process_loops(sh, nested_node, innermost_loop);
493 return progress;
494 }
495 case nir_cf_node_loop: {
496 loop = nir_cf_node_as_loop(cf_node);
497 foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
498 progress |= process_loops(sh, nested_node, innermost_loop);
499 break;
500 }
501 default:
502 unreachable("unknown cf node type");
503 }
504
505 if (*innermost_loop) {
506 /* Don't attempt to unroll outer loops or a second inner loop in
507 * this pass wait until the next pass as we have altered the cf.
508 */
509 *innermost_loop = false;
510
511 if (loop->info->limiting_terminator == NULL)
512 return progress;
513
514 if (!is_loop_small_enough_to_unroll(sh, loop->info))
515 return progress;
516
517 if (loop->info->is_trip_count_known) {
518 simple_unroll(loop);
519 progress = true;
520 } else {
521 /* Attempt to unroll loops with two terminators. */
522 unsigned num_lt = list_length(&loop->info->loop_terminator_list);
523 if (num_lt == 2) {
524 bool limiting_term_second = true;
525 nir_loop_terminator *terminator =
526 list_last_entry(&loop->info->loop_terminator_list,
527 nir_loop_terminator, loop_terminator_link);
528
529
530 if (terminator->nif == loop->info->limiting_terminator->nif) {
531 limiting_term_second = false;
532 terminator =
533 list_first_entry(&loop->info->loop_terminator_list,
534 nir_loop_terminator, loop_terminator_link);
535 }
536
537 /* If the first terminator has a trip count of zero and is the
538 * limiting terminator just do a simple unroll as the second
539 * terminator can never be reached.
540 */
541 if (loop->info->trip_count == 0 && !limiting_term_second) {
542 simple_unroll(loop);
543 } else {
544 complex_unroll(loop, terminator, limiting_term_second);
545 }
546 progress = true;
547 }
548 }
549 }
550
551 return progress;
552 }
553
554 static bool
555 nir_opt_loop_unroll_impl(nir_function_impl *impl,
556 nir_variable_mode indirect_mask)
557 {
558 bool progress = false;
559 nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
560 nir_metadata_require(impl, nir_metadata_block_index);
561
562 foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
563 bool innermost_loop = true;
564 progress |= process_loops(impl->function->shader, node,
565 &innermost_loop);
566 }
567
568 if (progress)
569 nir_lower_regs_to_ssa_impl(impl);
570
571 return progress;
572 }
573
574 bool
575 nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask)
576 {
577 bool progress = false;
578
579 nir_foreach_function(function, shader) {
580 if (function->impl) {
581 progress |= nir_opt_loop_unroll_impl(function->impl, indirect_mask);
582 }
583 }
584 return progress;
585 }