nir: Rename convert_to_ssa lower_regs_to_ssa
[mesa.git] / src / compiler / nir / nir_opt_loop_unroll.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 * DEALINGS IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_loop_analyze.h"
28
29 /* Prepare this loop for unrolling by first converting to lcssa and then
30 * converting the phis from the loops first block and the block that follows
31 * the loop into regs. Partially converting out of SSA allows us to unroll
32 * the loop without having to keep track of and update phis along the way
33 * which gets tricky and doesn't add much value over conveting to regs.
34 *
35 * The loop may have a continue instruction at the end of the loop which does
36 * nothing. Once we're out of SSA, we can safely delete it so we don't have
37 * to deal with it later.
38 */
39 static void
40 loop_prepare_for_unroll(nir_loop *loop)
41 {
42 nir_convert_loop_to_lcssa(loop);
43
44 nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
45
46 nir_block *block_after_loop =
47 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
48
49 nir_lower_phis_to_regs_block(block_after_loop);
50
51 nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
52 if (last_instr && last_instr->type == nir_instr_type_jump) {
53 assert(nir_instr_as_jump(last_instr)->type == nir_jump_continue);
54 nir_instr_remove(last_instr);
55 }
56 }
57
58 static void
59 get_first_blocks_in_terminator(nir_loop_terminator *term,
60 nir_block **first_break_block,
61 nir_block **first_continue_block)
62 {
63 if (term->continue_from_then) {
64 *first_continue_block = nir_if_first_then_block(term->nif);
65 *first_break_block = nir_if_first_else_block(term->nif);
66 } else {
67 *first_continue_block = nir_if_first_else_block(term->nif);
68 *first_break_block = nir_if_first_then_block(term->nif);
69 }
70 }
71
72 /**
73 * Unroll a loop where we know exactly how many iterations there are and there
74 * is only a single exit point. Note here we can unroll loops with multiple
75 * theoretical exits that only have a single terminating exit that we always
76 * know is the "real" exit.
77 *
78 * loop {
79 * ...instrs...
80 * }
81 *
82 * And the iteration count is 3, the output will be:
83 *
84 * ...instrs... ...instrs... ...instrs...
85 */
86 static void
87 simple_unroll(nir_loop *loop)
88 {
89 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
90 assert(nir_is_trivial_loop_if(limiting_term->nif,
91 limiting_term->break_block));
92
93 loop_prepare_for_unroll(loop);
94
95 /* Skip over loop terminator and get the loop body. */
96 list_for_each_entry(nir_loop_terminator, terminator,
97 &loop->info->loop_terminator_list,
98 loop_terminator_link) {
99
100 /* Remove all but the limiting terminator as we know the other exit
101 * conditions can never be met. Note we need to extract any instructions
102 * in the continue from branch and insert then into the loop body before
103 * removing it.
104 */
105 if (terminator->nif != limiting_term->nif) {
106 nir_block *first_break_block;
107 nir_block *first_continue_block;
108 get_first_blocks_in_terminator(terminator, &first_break_block,
109 &first_continue_block);
110
111 assert(nir_is_trivial_loop_if(terminator->nif,
112 terminator->break_block));
113
114 nir_cf_list continue_from_lst;
115 nir_cf_extract(&continue_from_lst,
116 nir_before_block(first_continue_block),
117 nir_after_block(terminator->continue_from_block));
118 nir_cf_reinsert(&continue_from_lst,
119 nir_after_cf_node(&terminator->nif->cf_node));
120
121 nir_cf_node_remove(&terminator->nif->cf_node);
122 }
123 }
124
125 nir_block *first_break_block;
126 nir_block *first_continue_block;
127 get_first_blocks_in_terminator(limiting_term, &first_break_block,
128 &first_continue_block);
129
130 /* Pluck out the loop header */
131 nir_block *header_blk = nir_loop_first_block(loop);
132 nir_cf_list lp_header;
133 nir_cf_extract(&lp_header, nir_before_block(header_blk),
134 nir_before_cf_node(&limiting_term->nif->cf_node));
135
136 /* Add the continue from block of the limiting terminator to the loop body
137 */
138 nir_cf_list continue_from_lst;
139 nir_cf_extract(&continue_from_lst, nir_before_block(first_continue_block),
140 nir_after_block(limiting_term->continue_from_block));
141 nir_cf_reinsert(&continue_from_lst,
142 nir_after_cf_node(&limiting_term->nif->cf_node));
143
144 /* Pluck out the loop body */
145 nir_cf_list loop_body;
146 nir_cf_extract(&loop_body, nir_after_cf_node(&limiting_term->nif->cf_node),
147 nir_after_block(nir_loop_last_block(loop)));
148
149 struct hash_table *remap_table =
150 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
151 _mesa_key_pointer_equal);
152
153 /* Clone the loop header */
154 nir_cf_list cloned_header;
155 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
156 remap_table);
157
158 /* Insert cloned loop header before the loop */
159 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
160
161 /* Temp list to store the cloned loop body as we unroll */
162 nir_cf_list unrolled_lp_body;
163
164 /* Clone loop header and append to the loop body */
165 for (unsigned i = 0; i < loop->info->trip_count; i++) {
166 /* Clone loop body */
167 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
168 remap_table);
169
170 /* Insert unrolled loop body before the loop */
171 nir_cf_reinsert(&unrolled_lp_body, nir_before_cf_node(&loop->cf_node));
172
173 /* Clone loop header */
174 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
175 remap_table);
176
177 /* Insert loop header after loop body */
178 nir_cf_reinsert(&cloned_header, nir_before_cf_node(&loop->cf_node));
179 }
180
181 /* Remove the break from the loop terminator and add instructions from
182 * the break block after the unrolled loop.
183 */
184 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
185 nir_instr_remove(break_instr);
186 nir_cf_list break_list;
187 nir_cf_extract(&break_list, nir_before_block(first_break_block),
188 nir_after_block(limiting_term->break_block));
189
190 /* Clone so things get properly remapped */
191 nir_cf_list cloned_break_list;
192 nir_cf_list_clone(&cloned_break_list, &break_list, loop->cf_node.parent,
193 remap_table);
194
195 nir_cf_reinsert(&cloned_break_list, nir_before_cf_node(&loop->cf_node));
196
197 /* Remove the loop */
198 nir_cf_node_remove(&loop->cf_node);
199
200 /* Delete the original loop body, break block & header */
201 nir_cf_delete(&lp_header);
202 nir_cf_delete(&loop_body);
203 nir_cf_delete(&break_list);
204
205 _mesa_hash_table_destroy(remap_table, NULL);
206 }
207
208 static void
209 move_cf_list_into_loop_term(nir_cf_list *lst, nir_loop_terminator *term)
210 {
211 /* Move the rest of the loop inside the continue-from-block */
212 nir_cf_reinsert(lst, nir_after_block(term->continue_from_block));
213
214 /* Remove the break */
215 nir_instr_remove(nir_block_last_instr(term->break_block));
216 }
217
218 static nir_cursor
219 get_complex_unroll_insert_location(nir_cf_node *node, bool continue_from_then)
220 {
221 if (node->type == nir_cf_node_loop) {
222 return nir_before_cf_node(node);
223 } else {
224 nir_if *if_stmt = nir_cf_node_as_if(node);
225 if (continue_from_then) {
226 return nir_after_block(nir_if_last_then_block(if_stmt));
227 } else {
228 return nir_after_block(nir_if_last_else_block(if_stmt));
229 }
230 }
231 }
232
233 /**
234 * Unroll a loop with two exists when the trip count of one of the exits is
235 * unknown. If continue_from_then is true, the loop is repeated only when the
236 * "then" branch of the if is taken; otherwise it is repeated only
237 * when the "else" branch of the if is taken.
238 *
239 * For example, if the input is:
240 *
241 * loop {
242 * ...phis/condition...
243 * if condition {
244 * ...then instructions...
245 * } else {
246 * ...continue instructions...
247 * break
248 * }
249 * ...body...
250 * }
251 *
252 * And the iteration count is 3, and unlimit_term->continue_from_then is true,
253 * then the output will be:
254 *
255 * ...condition...
256 * if condition {
257 * ...then instructions...
258 * ...body...
259 * if condition {
260 * ...then instructions...
261 * ...body...
262 * if condition {
263 * ...then instructions...
264 * ...body...
265 * } else {
266 * ...continue instructions...
267 * }
268 * } else {
269 * ...continue instructions...
270 * }
271 * } else {
272 * ...continue instructions...
273 * }
274 */
275 static void
276 complex_unroll(nir_loop *loop, nir_loop_terminator *unlimit_term,
277 bool limiting_term_second)
278 {
279 assert(nir_is_trivial_loop_if(unlimit_term->nif,
280 unlimit_term->break_block));
281
282 nir_loop_terminator *limiting_term = loop->info->limiting_terminator;
283 assert(nir_is_trivial_loop_if(limiting_term->nif,
284 limiting_term->break_block));
285
286 loop_prepare_for_unroll(loop);
287
288 nir_block *header_blk = nir_loop_first_block(loop);
289
290 nir_cf_list lp_header;
291 nir_cf_list limit_break_list;
292 unsigned num_times_to_clone;
293 if (limiting_term_second) {
294 /* Pluck out the loop header */
295 nir_cf_extract(&lp_header, nir_before_block(header_blk),
296 nir_before_cf_node(&unlimit_term->nif->cf_node));
297
298 /* We need some special handling when its the second terminator causing
299 * us to exit the loop for example:
300 *
301 * for (int i = 0; i < uniform_lp_count; i++) {
302 * colour = vec4(0.0, 1.0, 0.0, 1.0);
303 *
304 * if (i == 1) {
305 * break;
306 * }
307 * ... any further code is unreachable after i == 1 ...
308 * }
309 */
310 nir_cf_list after_lt;
311 nir_if *limit_if = limiting_term->nif;
312 nir_cf_extract(&after_lt, nir_after_cf_node(&limit_if->cf_node),
313 nir_after_block(nir_loop_last_block(loop)));
314 move_cf_list_into_loop_term(&after_lt, limiting_term);
315
316 /* Because the trip count is the number of times we pass over the entire
317 * loop before hitting a break when the second terminator is the
318 * limiting terminator we can actually execute code inside the loop when
319 * trip count == 0 e.g. the code above the break. So we need to bump
320 * the trip_count in order for the code below to clone anything. When
321 * trip count == 1 we execute the code above the break twice and the
322 * code below it once so we need clone things twice and so on.
323 */
324 num_times_to_clone = loop->info->trip_count + 1;
325 } else {
326 /* Pluck out the loop header */
327 nir_cf_extract(&lp_header, nir_before_block(header_blk),
328 nir_before_cf_node(&limiting_term->nif->cf_node));
329
330 nir_block *first_break_block;
331 nir_block *first_continue_block;
332 get_first_blocks_in_terminator(limiting_term, &first_break_block,
333 &first_continue_block);
334
335 /* Remove the break then extract instructions from the break block so we
336 * can insert them in the innermost else of the unrolled loop.
337 */
338 nir_instr *break_instr = nir_block_last_instr(limiting_term->break_block);
339 nir_instr_remove(break_instr);
340 nir_cf_extract(&limit_break_list, nir_before_block(first_break_block),
341 nir_after_block(limiting_term->break_block));
342
343 nir_cf_list continue_list;
344 nir_cf_extract(&continue_list, nir_before_block(first_continue_block),
345 nir_after_block(limiting_term->continue_from_block));
346
347 nir_cf_reinsert(&continue_list,
348 nir_after_cf_node(&limiting_term->nif->cf_node));
349
350 nir_cf_node_remove(&limiting_term->nif->cf_node);
351
352 num_times_to_clone = loop->info->trip_count;
353 }
354
355 /* In the terminator that we have no trip count for move everything after
356 * the terminator into the continue from branch.
357 */
358 nir_cf_list loop_end;
359 nir_cf_extract(&loop_end, nir_after_cf_node(&unlimit_term->nif->cf_node),
360 nir_after_block(nir_loop_last_block(loop)));
361 move_cf_list_into_loop_term(&loop_end, unlimit_term);
362
363 /* Pluck out the loop body. */
364 nir_cf_list loop_body;
365 nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
366 nir_after_block(nir_loop_last_block(loop)));
367
368 struct hash_table *remap_table =
369 _mesa_hash_table_create(NULL, _mesa_hash_pointer,
370 _mesa_key_pointer_equal);
371
372 /* Set unroll_loc to the loop as we will insert the unrolled loop before it
373 */
374 nir_cf_node *unroll_loc = &loop->cf_node;
375
376 /* Temp lists to store the cloned loop as we unroll */
377 nir_cf_list unrolled_lp_body;
378 nir_cf_list cloned_header;
379
380 for (unsigned i = 0; i < num_times_to_clone; i++) {
381 /* Clone loop header */
382 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
383 remap_table);
384
385 nir_cursor cursor =
386 get_complex_unroll_insert_location(unroll_loc,
387 unlimit_term->continue_from_then);
388
389 /* Insert cloned loop header */
390 nir_cf_reinsert(&cloned_header, cursor);
391
392 cursor =
393 get_complex_unroll_insert_location(unroll_loc,
394 unlimit_term->continue_from_then);
395
396 /* Clone loop body */
397 nir_cf_list_clone(&unrolled_lp_body, &loop_body, loop->cf_node.parent,
398 remap_table);
399
400 unroll_loc = exec_node_data(nir_cf_node,
401 exec_list_get_tail(&unrolled_lp_body.list),
402 node);
403 assert(unroll_loc->type == nir_cf_node_block &&
404 exec_list_is_empty(&nir_cf_node_as_block(unroll_loc)->instr_list));
405
406 /* Get the unrolled if node */
407 unroll_loc = nir_cf_node_prev(unroll_loc);
408
409 /* Insert unrolled loop body */
410 nir_cf_reinsert(&unrolled_lp_body, cursor);
411 }
412
413 if (!limiting_term_second) {
414 assert(unroll_loc->type == nir_cf_node_if);
415
416 nir_cf_list_clone(&cloned_header, &lp_header, loop->cf_node.parent,
417 remap_table);
418
419 nir_cursor cursor =
420 get_complex_unroll_insert_location(unroll_loc,
421 unlimit_term->continue_from_then);
422
423 /* Insert cloned loop header */
424 nir_cf_reinsert(&cloned_header, cursor);
425
426 /* Clone so things get properly remapped, and insert break block from
427 * the limiting terminator.
428 */
429 nir_cf_list cloned_break_blk;
430 nir_cf_list_clone(&cloned_break_blk, &limit_break_list,
431 loop->cf_node.parent, remap_table);
432
433 cursor =
434 get_complex_unroll_insert_location(unroll_loc,
435 unlimit_term->continue_from_then);
436
437 nir_cf_reinsert(&cloned_break_blk, cursor);
438 nir_cf_delete(&limit_break_list);
439 }
440
441 /* The loop has been unrolled so remove it. */
442 nir_cf_node_remove(&loop->cf_node);
443
444 /* Delete the original loop header and body */
445 nir_cf_delete(&lp_header);
446 nir_cf_delete(&loop_body);
447
448 _mesa_hash_table_destroy(remap_table, NULL);
449 }
450
451 static bool
452 is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
453 {
454 unsigned max_iter = shader->options->max_unroll_iterations;
455
456 if (li->trip_count > max_iter)
457 return false;
458
459 if (li->force_unroll)
460 return true;
461
462 bool loop_not_too_large =
463 li->num_instructions * li->trip_count <= max_iter * 25;
464
465 return loop_not_too_large;
466 }
467
468 static bool
469 process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *innermost_loop)
470 {
471 bool progress = false;
472 nir_loop *loop;
473
474 switch (cf_node->type) {
475 case nir_cf_node_block:
476 return progress;
477 case nir_cf_node_if: {
478 nir_if *if_stmt = nir_cf_node_as_if(cf_node);
479 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
480 progress |= process_loops(sh, nested_node, innermost_loop);
481 foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
482 progress |= process_loops(sh, nested_node, innermost_loop);
483 return progress;
484 }
485 case nir_cf_node_loop: {
486 loop = nir_cf_node_as_loop(cf_node);
487 foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
488 progress |= process_loops(sh, nested_node, innermost_loop);
489 break;
490 }
491 default:
492 unreachable("unknown cf node type");
493 }
494
495 if (*innermost_loop) {
496 /* Don't attempt to unroll outer loops or a second inner loop in
497 * this pass wait until the next pass as we have altered the cf.
498 */
499 *innermost_loop = false;
500
501 if (loop->info->limiting_terminator == NULL)
502 return progress;
503
504 if (!is_loop_small_enough_to_unroll(sh, loop->info))
505 return progress;
506
507 if (loop->info->is_trip_count_known) {
508 simple_unroll(loop);
509 progress = true;
510 } else {
511 /* Attempt to unroll loops with two terminators. */
512 unsigned num_lt = list_length(&loop->info->loop_terminator_list);
513 if (num_lt == 2) {
514 bool limiting_term_second = true;
515 nir_loop_terminator *terminator =
516 list_last_entry(&loop->info->loop_terminator_list,
517 nir_loop_terminator, loop_terminator_link);
518
519
520 if (terminator->nif == loop->info->limiting_terminator->nif) {
521 limiting_term_second = false;
522 terminator =
523 list_first_entry(&loop->info->loop_terminator_list,
524 nir_loop_terminator, loop_terminator_link);
525 }
526
527 /* If the first terminator has a trip count of zero and is the
528 * limiting terminator just do a simple unroll as the second
529 * terminator can never be reached.
530 */
531 if (loop->info->trip_count == 0 && !limiting_term_second) {
532 simple_unroll(loop);
533 } else {
534 complex_unroll(loop, terminator, limiting_term_second);
535 }
536 progress = true;
537 }
538 }
539 }
540
541 return progress;
542 }
543
544 static bool
545 nir_opt_loop_unroll_impl(nir_function_impl *impl,
546 nir_variable_mode indirect_mask)
547 {
548 bool progress = false;
549 nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
550 nir_metadata_require(impl, nir_metadata_block_index);
551
552 foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
553 bool innermost_loop = true;
554 progress |= process_loops(impl->function->shader, node,
555 &innermost_loop);
556 }
557
558 if (progress)
559 nir_lower_regs_to_ssa_impl(impl);
560
561 return progress;
562 }
563
564 bool
565 nir_opt_loop_unroll(nir_shader *shader, nir_variable_mode indirect_mask)
566 {
567 bool progress = false;
568
569 nir_foreach_function(function, shader) {
570 if (function->impl) {
571 progress |= nir_opt_loop_unroll_impl(function->impl, indirect_mask);
572 }
573 }
574 return progress;
575 }