5780ae3794bf061aeddaf4187561a08e0f0c3e17
[mesa.git] / src / compiler / nir / nir_opt_if.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir/nir_builder.h"
26 #include "nir_constant_expressions.h"
27 #include "nir_control_flow.h"
28 #include "nir_loop_analyze.h"
29
30 /**
31 * Gets the single block that jumps back to the loop header. Already assumes
32 * there is exactly one such block.
33 */
34 static nir_block*
35 find_continue_block(nir_loop *loop)
36 {
37 nir_block *header_block = nir_loop_first_block(loop);
38 nir_block *prev_block =
39 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
40
41 assert(header_block->predecessors->entries == 2);
42
43 struct set_entry *pred_entry;
44 set_foreach(header_block->predecessors, pred_entry) {
45 if (pred_entry->key != prev_block)
46 return (nir_block*)pred_entry->key;
47 }
48
49 unreachable("Continue block not found!");
50 }
51
52 /**
53 * This optimization detects if statements at the tops of loops where the
54 * condition is a phi node of two constants and moves half of the if to above
55 * the loop and the other half of the if to the end of the loop. A simple for
56 * loop "for (int i = 0; i < 4; i++)", when run through the SPIR-V front-end,
57 * ends up looking something like this:
58 *
59 * vec1 32 ssa_0 = load_const (0x00000000)
60 * vec1 32 ssa_1 = load_const (0xffffffff)
61 * loop {
62 * block block_1:
63 * vec1 32 ssa_2 = phi block_0: ssa_0, block_7: ssa_5
64 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
65 * if ssa_2 {
66 * block block_2:
67 * vec1 32 ssa_4 = load_const (0x00000001)
68 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
69 * } else {
70 * block block_3:
71 * }
72 * block block_4:
73 * vec1 32 ssa_6 = load_const (0x00000004)
74 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
75 * if ssa_7 {
76 * block block_5:
77 * } else {
78 * block block_6:
79 * break
80 * }
81 * block block_7:
82 * }
83 *
84 * This turns it into something like this:
85 *
86 * // Stuff from block 1
87 * // Stuff from block 3
88 * loop {
89 * block block_1:
90 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
91 * vec1 32 ssa_6 = load_const (0x00000004)
92 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
93 * if ssa_7 {
94 * block block_5:
95 * } else {
96 * block block_6:
97 * break
98 * }
99 * block block_7:
100 * // Stuff from block 1
101 * // Stuff from block 2
102 * vec1 32 ssa_4 = load_const (0x00000001)
103 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
104 * }
105 */
106 static bool
107 opt_peel_loop_initial_if(nir_loop *loop)
108 {
109 nir_block *header_block = nir_loop_first_block(loop);
110 MAYBE_UNUSED nir_block *prev_block =
111 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
112
113 /* It would be insane if this were not true */
114 assert(_mesa_set_search(header_block->predecessors, prev_block));
115
116 /* The loop must have exactly one continue block which could be a block
117 * ending in a continue instruction or the "natural" continue from the
118 * last block in the loop back to the top.
119 */
120 if (header_block->predecessors->entries != 2)
121 return false;
122
123 nir_block *continue_block = find_continue_block(loop);
124
125 nir_cf_node *if_node = nir_cf_node_next(&header_block->cf_node);
126 if (!if_node || if_node->type != nir_cf_node_if)
127 return false;
128
129 nir_if *nif = nir_cf_node_as_if(if_node);
130 assert(nif->condition.is_ssa);
131
132 nir_ssa_def *cond = nif->condition.ssa;
133 if (cond->parent_instr->type != nir_instr_type_phi)
134 return false;
135
136 nir_phi_instr *cond_phi = nir_instr_as_phi(cond->parent_instr);
137 if (cond->parent_instr->block != header_block)
138 return false;
139
140 /* We already know we have exactly one continue */
141 assert(exec_list_length(&cond_phi->srcs) == 2);
142
143 uint32_t entry_val = 0, continue_val = 0;
144 nir_foreach_phi_src(src, cond_phi) {
145 assert(src->src.is_ssa);
146 nir_const_value *const_src = nir_src_as_const_value(src->src);
147 if (!const_src)
148 return false;
149
150 if (src->pred == continue_block) {
151 continue_val = const_src->u32[0];
152 } else {
153 assert(src->pred == prev_block);
154 entry_val = const_src->u32[0];
155 }
156 }
157
158 /* If they both execute or both don't execute, this is a job for
159 * nir_dead_cf, not this pass.
160 */
161 if ((entry_val && continue_val) || (!entry_val && !continue_val))
162 return false;
163
164 struct exec_list *continue_list, *entry_list;
165 if (continue_val) {
166 continue_list = &nif->then_list;
167 entry_list = &nif->else_list;
168 } else {
169 continue_list = &nif->else_list;
170 entry_list = &nif->then_list;
171 }
172
173 /* We want to be moving the contents of entry_list to above the loop so it
174 * can't contain any break or continue instructions.
175 */
176 foreach_list_typed(nir_cf_node, cf_node, node, entry_list) {
177 nir_foreach_block_in_cf_node(block, cf_node) {
178 nir_instr *last_instr = nir_block_last_instr(block);
179 if (last_instr && last_instr->type == nir_instr_type_jump)
180 return false;
181 }
182 }
183
184 /* Before we do anything, convert the loop to LCSSA. We're about to
185 * replace a bunch of SSA defs with registers and this will prevent any of
186 * it from leaking outside the loop.
187 */
188 nir_convert_loop_to_lcssa(loop);
189
190 nir_block *after_if_block =
191 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
192
193 /* Get rid of phis in the header block since we will be duplicating it */
194 nir_lower_phis_to_regs_block(header_block);
195 /* Get rid of phis after the if since dominance will change */
196 nir_lower_phis_to_regs_block(after_if_block);
197
198 /* Get rid of SSA defs in the pieces we're about to move around */
199 nir_lower_ssa_defs_to_regs_block(header_block);
200 nir_foreach_block_in_cf_node(block, &nif->cf_node)
201 nir_lower_ssa_defs_to_regs_block(block);
202
203 nir_cf_list header, tmp;
204 nir_cf_extract(&header, nir_before_block(header_block),
205 nir_after_block(header_block));
206
207 nir_cf_list_clone(&tmp, &header, &loop->cf_node, NULL);
208 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
209 nir_cf_extract(&tmp, nir_before_cf_list(entry_list),
210 nir_after_cf_list(entry_list));
211 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
212
213 nir_cf_reinsert(&header, nir_after_block_before_jump(continue_block));
214
215 /* Get continue block again as the previous reinsert might have removed the block. */
216 continue_block = find_continue_block(loop);
217
218 nir_cf_extract(&tmp, nir_before_cf_list(continue_list),
219 nir_after_cf_list(continue_list));
220 nir_cf_reinsert(&tmp, nir_after_block_before_jump(continue_block));
221
222 nir_cf_node_remove(&nif->cf_node);
223
224 return true;
225 }
226
227 static bool
228 is_block_empty(nir_block *block)
229 {
230 return nir_cf_node_is_last(&block->cf_node) &&
231 exec_list_is_empty(&block->instr_list);
232 }
233
234 /**
235 * This optimization turns:
236 *
237 * if (cond) {
238 * } else {
239 * do_work();
240 * }
241 *
242 * into:
243 *
244 * if (!cond) {
245 * do_work();
246 * } else {
247 * }
248 */
249 static bool
250 opt_if_simplification(nir_builder *b, nir_if *nif)
251 {
252 /* Only simplify if the then block is empty and the else block is not. */
253 if (!is_block_empty(nir_if_first_then_block(nif)) ||
254 is_block_empty(nir_if_first_else_block(nif)))
255 return false;
256
257 /* Make sure the condition is a comparison operation. */
258 nir_instr *src_instr = nif->condition.ssa->parent_instr;
259 if (src_instr->type != nir_instr_type_alu)
260 return false;
261
262 nir_alu_instr *alu_instr = nir_instr_as_alu(src_instr);
263 if (!nir_alu_instr_is_comparison(alu_instr))
264 return false;
265
266 /* Insert the inverted instruction and rewrite the condition. */
267 b->cursor = nir_after_instr(&alu_instr->instr);
268
269 nir_ssa_def *new_condition =
270 nir_inot(b, &alu_instr->dest.dest.ssa);
271
272 nir_if_rewrite_condition(nif, nir_src_for_ssa(new_condition));
273
274 /* Grab pointers to the last then/else blocks for fixing up the phis. */
275 nir_block *then_block = nir_if_last_then_block(nif);
276 nir_block *else_block = nir_if_last_else_block(nif);
277
278 /* Walk all the phis in the block immediately following the if statement and
279 * swap the blocks.
280 */
281 nir_block *after_if_block =
282 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
283
284 nir_foreach_instr(instr, after_if_block) {
285 if (instr->type != nir_instr_type_phi)
286 continue;
287
288 nir_phi_instr *phi = nir_instr_as_phi(instr);
289
290 foreach_list_typed(nir_phi_src, src, node, &phi->srcs) {
291 if (src->pred == else_block) {
292 src->pred = then_block;
293 } else if (src->pred == then_block) {
294 src->pred = else_block;
295 }
296 }
297 }
298
299 /* Finally, move the else block to the then block. */
300 nir_cf_list tmp;
301 nir_cf_extract(&tmp, nir_before_cf_list(&nif->else_list),
302 nir_after_cf_list(&nif->else_list));
303 nir_cf_reinsert(&tmp, nir_before_cf_list(&nif->then_list));
304
305 return true;
306 }
307
308 /**
309 * This optimization simplifies potential loop terminators which then allows
310 * other passes such as opt_if_simplification() and loop unrolling to progress
311 * further:
312 *
313 * if (cond) {
314 * ... then block instructions ...
315 * } else {
316 * ...
317 * break;
318 * }
319 *
320 * into:
321 *
322 * if (cond) {
323 * } else {
324 * ...
325 * break;
326 * }
327 * ... then block instructions ...
328 */
329 static bool
330 opt_if_loop_terminator(nir_if *nif)
331 {
332 nir_block *break_blk = NULL;
333 nir_block *continue_from_blk = NULL;
334 bool continue_from_then = true;
335
336 nir_block *last_then = nir_if_last_then_block(nif);
337 nir_block *last_else = nir_if_last_else_block(nif);
338
339 if (nir_block_ends_in_break(last_then)) {
340 break_blk = last_then;
341 continue_from_blk = last_else;
342 continue_from_then = false;
343 } else if (nir_block_ends_in_break(last_else)) {
344 break_blk = last_else;
345 continue_from_blk = last_then;
346 }
347
348 /* Continue if the if-statement contained no jumps at all */
349 if (!break_blk)
350 return false;
351
352 /* If the continue from block is empty then return as there is nothing to
353 * move.
354 */
355 nir_block *first_continue_from_blk = continue_from_then ?
356 nir_if_first_then_block(nif) :
357 nir_if_first_else_block(nif);
358 if (is_block_empty(first_continue_from_blk))
359 return false;
360
361 if (!nir_is_trivial_loop_if(nif, break_blk))
362 return false;
363
364 /* Finally, move the continue from branch after the if-statement. */
365 nir_cf_list tmp;
366 nir_cf_extract(&tmp, nir_before_block(first_continue_from_blk),
367 nir_after_block(continue_from_blk));
368 nir_cf_reinsert(&tmp, nir_after_cf_node(&nif->cf_node));
369
370 return true;
371 }
372
373 static void
374 replace_if_condition_use_with_const(nir_builder *b, nir_src *use,
375 nir_const_value nir_boolean,
376 bool if_condition)
377 {
378 /* Create const */
379 nir_ssa_def *const_def = nir_build_imm(b, 1, 32, nir_boolean);
380
381 /* Rewrite use to use const */
382 nir_src new_src = nir_src_for_ssa(const_def);
383 if (if_condition)
384 nir_if_rewrite_condition(use->parent_if, new_src);
385 else
386 nir_instr_rewrite_src(use->parent_instr, use, new_src);
387 }
388
389 static bool
390 evaluate_if_condition(nir_if *nif, nir_cursor cursor, uint32_t *value)
391 {
392 nir_block *use_block = nir_cursor_current_block(cursor);
393 if (nir_block_dominates(nir_if_first_then_block(nif), use_block)) {
394 *value = NIR_TRUE;
395 return true;
396 } else if (nir_block_dominates(nir_if_first_else_block(nif), use_block)) {
397 *value = NIR_FALSE;
398 return true;
399 } else {
400 return false;
401 }
402 }
403
404 /*
405 * This propagates if condition evaluation down the chain of some alu
406 * instructions. For example by checking the use of some of the following alu
407 * instruction we can eventually replace ssa_107 with NIR_TRUE.
408 *
409 * loop {
410 * block block_1:
411 * vec1 32 ssa_85 = load_const (0x00000002)
412 * vec1 32 ssa_86 = ieq ssa_48, ssa_85
413 * vec1 32 ssa_87 = load_const (0x00000001)
414 * vec1 32 ssa_88 = ieq ssa_48, ssa_87
415 * vec1 32 ssa_89 = ior ssa_86, ssa_88
416 * vec1 32 ssa_90 = ieq ssa_48, ssa_0
417 * vec1 32 ssa_91 = ior ssa_89, ssa_90
418 * if ssa_86 {
419 * block block_2:
420 * ...
421 * break
422 * } else {
423 * block block_3:
424 * }
425 * block block_4:
426 * if ssa_88 {
427 * block block_5:
428 * ...
429 * break
430 * } else {
431 * block block_6:
432 * }
433 * block block_7:
434 * if ssa_90 {
435 * block block_8:
436 * ...
437 * break
438 * } else {
439 * block block_9:
440 * }
441 * block block_10:
442 * vec1 32 ssa_107 = inot ssa_91
443 * if ssa_107 {
444 * block block_11:
445 * break
446 * } else {
447 * block block_12:
448 * }
449 * }
450 */
451 static bool
452 propagate_condition_eval(nir_builder *b, nir_if *nif, nir_src *use_src,
453 nir_src *alu_use, nir_alu_instr *alu,
454 bool is_if_condition)
455 {
456 bool progress = false;
457
458 nir_const_value bool_value;
459 b->cursor = nir_before_src(alu_use, is_if_condition);
460 if (nir_op_infos[alu->op].num_inputs == 1) {
461 assert(alu->op == nir_op_inot || alu->op == nir_op_b2i);
462
463 if (evaluate_if_condition(nif, b->cursor, &bool_value.u32[0])) {
464 assert(nir_src_bit_size(alu->src[0].src) == 32);
465
466 nir_const_value result =
467 nir_eval_const_opcode(alu->op, 1, 32, &bool_value);
468
469 replace_if_condition_use_with_const(b, alu_use, result,
470 is_if_condition);
471 progress = true;
472 }
473 } else {
474 assert(alu->op == nir_op_ior || alu->op == nir_op_iand);
475
476 if (evaluate_if_condition(nif, b->cursor, &bool_value.u32[0])) {
477 nir_ssa_def *def[2];
478 for (unsigned i = 0; i < 2; i++) {
479 if (alu->src[i].src.ssa == use_src->ssa) {
480 def[i] = nir_build_imm(b, 1, 32, bool_value);
481 } else {
482 def[i] = alu->src[i].src.ssa;
483 }
484 }
485
486 nir_ssa_def *nalu =
487 nir_build_alu(b, alu->op, def[0], def[1], NULL, NULL);
488
489 /* Rewrite use to use new alu instruction */
490 nir_src new_src = nir_src_for_ssa(nalu);
491
492 if (is_if_condition)
493 nir_if_rewrite_condition(alu_use->parent_if, new_src);
494 else
495 nir_instr_rewrite_src(alu_use->parent_instr, alu_use, new_src);
496
497 progress = true;
498 }
499 }
500
501 return progress;
502 }
503
504 static bool
505 can_propagate_through_alu(nir_src *src)
506 {
507 if (src->parent_instr->type == nir_instr_type_alu &&
508 (nir_instr_as_alu(src->parent_instr)->op == nir_op_ior ||
509 nir_instr_as_alu(src->parent_instr)->op == nir_op_iand ||
510 nir_instr_as_alu(src->parent_instr)->op == nir_op_inot ||
511 nir_instr_as_alu(src->parent_instr)->op == nir_op_b2i))
512 return true;
513
514 return false;
515 }
516
517 static bool
518 evaluate_condition_use(nir_builder *b, nir_if *nif, nir_src *use_src,
519 bool is_if_condition)
520 {
521 bool progress = false;
522
523 nir_const_value value;
524 b->cursor = nir_before_src(use_src, is_if_condition);
525
526 if (evaluate_if_condition(nif, b->cursor, &value.u32[0])) {
527 replace_if_condition_use_with_const(b, use_src, value, is_if_condition);
528 progress = true;
529 }
530
531 if (!is_if_condition && can_propagate_through_alu(use_src)) {
532 nir_alu_instr *alu = nir_instr_as_alu(use_src->parent_instr);
533
534 nir_foreach_use_safe(alu_use, &alu->dest.dest.ssa) {
535 progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu,
536 false);
537 }
538
539 nir_foreach_if_use_safe(alu_use, &alu->dest.dest.ssa) {
540 progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu,
541 true);
542 }
543 }
544
545 return progress;
546 }
547
548 static bool
549 opt_if_evaluate_condition_use(nir_builder *b, nir_if *nif)
550 {
551 bool progress = false;
552
553 /* Evaluate any uses of the if condition inside the if branches */
554 assert(nif->condition.is_ssa);
555 nir_foreach_use_safe(use_src, nif->condition.ssa) {
556 progress |= evaluate_condition_use(b, nif, use_src, false);
557 }
558
559 nir_foreach_if_use_safe(use_src, nif->condition.ssa) {
560 if (use_src->parent_if != nif)
561 progress |= evaluate_condition_use(b, nif, use_src, true);
562 }
563
564 return progress;
565 }
566
567 static bool
568 opt_if_cf_list(nir_builder *b, struct exec_list *cf_list)
569 {
570 bool progress = false;
571 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
572 switch (cf_node->type) {
573 case nir_cf_node_block:
574 break;
575
576 case nir_cf_node_if: {
577 nir_if *nif = nir_cf_node_as_if(cf_node);
578 progress |= opt_if_cf_list(b, &nif->then_list);
579 progress |= opt_if_cf_list(b, &nif->else_list);
580 progress |= opt_if_loop_terminator(nif);
581 progress |= opt_if_simplification(b, nif);
582 break;
583 }
584
585 case nir_cf_node_loop: {
586 nir_loop *loop = nir_cf_node_as_loop(cf_node);
587 progress |= opt_if_cf_list(b, &loop->body);
588 progress |= opt_peel_loop_initial_if(loop);
589 break;
590 }
591
592 case nir_cf_node_function:
593 unreachable("Invalid cf type");
594 }
595 }
596
597 return progress;
598 }
599
600 /**
601 * These optimisations depend on nir_metadata_block_index and therefore must
602 * not do anything to cause the metadata to become invalid.
603 */
604 static bool
605 opt_if_safe_cf_list(nir_builder *b, struct exec_list *cf_list)
606 {
607 bool progress = false;
608 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
609 switch (cf_node->type) {
610 case nir_cf_node_block:
611 break;
612
613 case nir_cf_node_if: {
614 nir_if *nif = nir_cf_node_as_if(cf_node);
615 progress |= opt_if_safe_cf_list(b, &nif->then_list);
616 progress |= opt_if_safe_cf_list(b, &nif->else_list);
617 progress |= opt_if_evaluate_condition_use(b, nif);
618 break;
619 }
620
621 case nir_cf_node_loop: {
622 nir_loop *loop = nir_cf_node_as_loop(cf_node);
623 progress |= opt_if_safe_cf_list(b, &loop->body);
624 break;
625 }
626
627 case nir_cf_node_function:
628 unreachable("Invalid cf type");
629 }
630 }
631
632 return progress;
633 }
634
635 bool
636 nir_opt_if(nir_shader *shader)
637 {
638 bool progress = false;
639
640 nir_foreach_function(function, shader) {
641 if (function->impl == NULL)
642 continue;
643
644 nir_builder b;
645 nir_builder_init(&b, function->impl);
646
647 nir_metadata_require(function->impl, nir_metadata_block_index |
648 nir_metadata_dominance);
649 progress = opt_if_safe_cf_list(&b, &function->impl->body);
650 nir_metadata_preserve(function->impl, nir_metadata_block_index |
651 nir_metadata_dominance);
652
653 if (opt_if_cf_list(&b, &function->impl->body)) {
654 nir_metadata_preserve(function->impl, nir_metadata_none);
655
656 /* If that made progress, we're no longer really in SSA form. We
657 * need to convert registers back into SSA defs and clean up SSA defs
658 * that don't dominate their uses.
659 */
660 nir_lower_regs_to_ssa_impl(function->impl);
661
662 /* Calling nir_convert_loop_to_lcssa() in opt_peel_loop_initial_if()
663 * adds extra phi nodes which may not be valid if they're used for
664 * something such as a deref. Remove any unneeded phis.
665 */
666 nir_opt_remove_phis_impl(function->impl);
667
668 progress = true;
669 }
670 }
671
672 return progress;
673 }