nir: Copy propagation between blocks
[mesa.git] / src / compiler / nir / nir_opt_if.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir/nir_builder.h"
26 #include "nir_constant_expressions.h"
27 #include "nir_control_flow.h"
28 #include "nir_loop_analyze.h"
29
30 /**
31 * Gets the single block that jumps back to the loop header. Already assumes
32 * there is exactly one such block.
33 */
34 static nir_block*
35 find_continue_block(nir_loop *loop)
36 {
37 nir_block *header_block = nir_loop_first_block(loop);
38 nir_block *prev_block =
39 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
40
41 assert(header_block->predecessors->entries == 2);
42
43 struct set_entry *pred_entry;
44 set_foreach(header_block->predecessors, pred_entry) {
45 if (pred_entry->key != prev_block)
46 return (nir_block*)pred_entry->key;
47 }
48
49 unreachable("Continue block not found!");
50 }
51
52 /**
53 * This optimization detects if statements at the tops of loops where the
54 * condition is a phi node of two constants and moves half of the if to above
55 * the loop and the other half of the if to the end of the loop. A simple for
56 * loop "for (int i = 0; i < 4; i++)", when run through the SPIR-V front-end,
57 * ends up looking something like this:
58 *
59 * vec1 32 ssa_0 = load_const (0x00000000)
60 * vec1 32 ssa_1 = load_const (0xffffffff)
61 * loop {
62 * block block_1:
63 * vec1 32 ssa_2 = phi block_0: ssa_0, block_7: ssa_5
64 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
65 * if ssa_2 {
66 * block block_2:
67 * vec1 32 ssa_4 = load_const (0x00000001)
68 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
69 * } else {
70 * block block_3:
71 * }
72 * block block_4:
73 * vec1 32 ssa_6 = load_const (0x00000004)
74 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
75 * if ssa_7 {
76 * block block_5:
77 * } else {
78 * block block_6:
79 * break
80 * }
81 * block block_7:
82 * }
83 *
84 * This turns it into something like this:
85 *
86 * // Stuff from block 1
87 * // Stuff from block 3
88 * loop {
89 * block block_1:
90 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
91 * vec1 32 ssa_6 = load_const (0x00000004)
92 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
93 * if ssa_7 {
94 * block block_5:
95 * } else {
96 * block block_6:
97 * break
98 * }
99 * block block_7:
100 * // Stuff from block 1
101 * // Stuff from block 2
102 * vec1 32 ssa_4 = load_const (0x00000001)
103 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
104 * }
105 */
106 static bool
107 opt_peel_loop_initial_if(nir_loop *loop)
108 {
109 nir_block *header_block = nir_loop_first_block(loop);
110 MAYBE_UNUSED nir_block *prev_block =
111 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
112
113 /* It would be insane if this were not true */
114 assert(_mesa_set_search(header_block->predecessors, prev_block));
115
116 /* The loop must have exactly one continue block which could be a block
117 * ending in a continue instruction or the "natural" continue from the
118 * last block in the loop back to the top.
119 */
120 if (header_block->predecessors->entries != 2)
121 return false;
122
123 nir_block *continue_block = find_continue_block(loop);
124
125 nir_cf_node *if_node = nir_cf_node_next(&header_block->cf_node);
126 if (!if_node || if_node->type != nir_cf_node_if)
127 return false;
128
129 nir_if *nif = nir_cf_node_as_if(if_node);
130 assert(nif->condition.is_ssa);
131
132 nir_ssa_def *cond = nif->condition.ssa;
133 if (cond->parent_instr->type != nir_instr_type_phi)
134 return false;
135
136 nir_phi_instr *cond_phi = nir_instr_as_phi(cond->parent_instr);
137 if (cond->parent_instr->block != header_block)
138 return false;
139
140 /* We already know we have exactly one continue */
141 assert(exec_list_length(&cond_phi->srcs) == 2);
142
143 uint32_t entry_val = 0, continue_val = 0;
144 nir_foreach_phi_src(src, cond_phi) {
145 assert(src->src.is_ssa);
146 nir_const_value *const_src = nir_src_as_const_value(src->src);
147 if (!const_src)
148 return false;
149
150 if (src->pred == continue_block) {
151 continue_val = const_src->u32[0];
152 } else {
153 assert(src->pred == prev_block);
154 entry_val = const_src->u32[0];
155 }
156 }
157
158 /* If they both execute or both don't execute, this is a job for
159 * nir_dead_cf, not this pass.
160 */
161 if ((entry_val && continue_val) || (!entry_val && !continue_val))
162 return false;
163
164 struct exec_list *continue_list, *entry_list;
165 if (continue_val) {
166 continue_list = &nif->then_list;
167 entry_list = &nif->else_list;
168 } else {
169 continue_list = &nif->else_list;
170 entry_list = &nif->then_list;
171 }
172
173 /* We want to be moving the contents of entry_list to above the loop so it
174 * can't contain any break or continue instructions.
175 */
176 foreach_list_typed(nir_cf_node, cf_node, node, entry_list) {
177 nir_foreach_block_in_cf_node(block, cf_node) {
178 nir_instr *last_instr = nir_block_last_instr(block);
179 if (last_instr && last_instr->type == nir_instr_type_jump)
180 return false;
181 }
182 }
183
184 /* We're about to re-arrange a bunch of blocks so make sure that we don't
185 * have deref uses which cross block boundaries. We don't want a deref
186 * accidentally ending up in a phi.
187 */
188 nir_rematerialize_derefs_in_use_blocks_impl(
189 nir_cf_node_get_function(&loop->cf_node));
190
191 /* Before we do anything, convert the loop to LCSSA. We're about to
192 * replace a bunch of SSA defs with registers and this will prevent any of
193 * it from leaking outside the loop.
194 */
195 nir_convert_loop_to_lcssa(loop);
196
197 nir_block *after_if_block =
198 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
199
200 /* Get rid of phis in the header block since we will be duplicating it */
201 nir_lower_phis_to_regs_block(header_block);
202 /* Get rid of phis after the if since dominance will change */
203 nir_lower_phis_to_regs_block(after_if_block);
204
205 /* Get rid of SSA defs in the pieces we're about to move around */
206 nir_lower_ssa_defs_to_regs_block(header_block);
207 nir_foreach_block_in_cf_node(block, &nif->cf_node)
208 nir_lower_ssa_defs_to_regs_block(block);
209
210 nir_cf_list header, tmp;
211 nir_cf_extract(&header, nir_before_block(header_block),
212 nir_after_block(header_block));
213
214 nir_cf_list_clone(&tmp, &header, &loop->cf_node, NULL);
215 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
216 nir_cf_extract(&tmp, nir_before_cf_list(entry_list),
217 nir_after_cf_list(entry_list));
218 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
219
220 nir_cf_reinsert(&header, nir_after_block_before_jump(continue_block));
221
222 /* Get continue block again as the previous reinsert might have removed the block. */
223 continue_block = find_continue_block(loop);
224
225 nir_cf_extract(&tmp, nir_before_cf_list(continue_list),
226 nir_after_cf_list(continue_list));
227 nir_cf_reinsert(&tmp, nir_after_block_before_jump(continue_block));
228
229 nir_cf_node_remove(&nif->cf_node);
230
231 return true;
232 }
233
234 static bool
235 is_block_empty(nir_block *block)
236 {
237 return nir_cf_node_is_last(&block->cf_node) &&
238 exec_list_is_empty(&block->instr_list);
239 }
240
241 /**
242 * This optimization turns:
243 *
244 * if (cond) {
245 * } else {
246 * do_work();
247 * }
248 *
249 * into:
250 *
251 * if (!cond) {
252 * do_work();
253 * } else {
254 * }
255 */
256 static bool
257 opt_if_simplification(nir_builder *b, nir_if *nif)
258 {
259 /* Only simplify if the then block is empty and the else block is not. */
260 if (!is_block_empty(nir_if_first_then_block(nif)) ||
261 is_block_empty(nir_if_first_else_block(nif)))
262 return false;
263
264 /* Make sure the condition is a comparison operation. */
265 nir_instr *src_instr = nif->condition.ssa->parent_instr;
266 if (src_instr->type != nir_instr_type_alu)
267 return false;
268
269 nir_alu_instr *alu_instr = nir_instr_as_alu(src_instr);
270 if (!nir_alu_instr_is_comparison(alu_instr))
271 return false;
272
273 /* Insert the inverted instruction and rewrite the condition. */
274 b->cursor = nir_after_instr(&alu_instr->instr);
275
276 nir_ssa_def *new_condition =
277 nir_inot(b, &alu_instr->dest.dest.ssa);
278
279 nir_if_rewrite_condition(nif, nir_src_for_ssa(new_condition));
280
281 /* Grab pointers to the last then/else blocks for fixing up the phis. */
282 nir_block *then_block = nir_if_last_then_block(nif);
283 nir_block *else_block = nir_if_last_else_block(nif);
284
285 /* Walk all the phis in the block immediately following the if statement and
286 * swap the blocks.
287 */
288 nir_block *after_if_block =
289 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
290
291 nir_foreach_instr(instr, after_if_block) {
292 if (instr->type != nir_instr_type_phi)
293 continue;
294
295 nir_phi_instr *phi = nir_instr_as_phi(instr);
296
297 foreach_list_typed(nir_phi_src, src, node, &phi->srcs) {
298 if (src->pred == else_block) {
299 src->pred = then_block;
300 } else if (src->pred == then_block) {
301 src->pred = else_block;
302 }
303 }
304 }
305
306 /* Finally, move the else block to the then block. */
307 nir_cf_list tmp;
308 nir_cf_extract(&tmp, nir_before_cf_list(&nif->else_list),
309 nir_after_cf_list(&nif->else_list));
310 nir_cf_reinsert(&tmp, nir_before_cf_list(&nif->then_list));
311
312 return true;
313 }
314
315 /**
316 * This optimization simplifies potential loop terminators which then allows
317 * other passes such as opt_if_simplification() and loop unrolling to progress
318 * further:
319 *
320 * if (cond) {
321 * ... then block instructions ...
322 * } else {
323 * ...
324 * break;
325 * }
326 *
327 * into:
328 *
329 * if (cond) {
330 * } else {
331 * ...
332 * break;
333 * }
334 * ... then block instructions ...
335 */
336 static bool
337 opt_if_loop_terminator(nir_if *nif)
338 {
339 nir_block *break_blk = NULL;
340 nir_block *continue_from_blk = NULL;
341 bool continue_from_then = true;
342
343 nir_block *last_then = nir_if_last_then_block(nif);
344 nir_block *last_else = nir_if_last_else_block(nif);
345
346 if (nir_block_ends_in_break(last_then)) {
347 break_blk = last_then;
348 continue_from_blk = last_else;
349 continue_from_then = false;
350 } else if (nir_block_ends_in_break(last_else)) {
351 break_blk = last_else;
352 continue_from_blk = last_then;
353 }
354
355 /* Continue if the if-statement contained no jumps at all */
356 if (!break_blk)
357 return false;
358
359 /* If the continue from block is empty then return as there is nothing to
360 * move.
361 */
362 nir_block *first_continue_from_blk = continue_from_then ?
363 nir_if_first_then_block(nif) :
364 nir_if_first_else_block(nif);
365 if (is_block_empty(first_continue_from_blk))
366 return false;
367
368 if (!nir_is_trivial_loop_if(nif, break_blk))
369 return false;
370
371 /* Finally, move the continue from branch after the if-statement. */
372 nir_cf_list tmp;
373 nir_cf_extract(&tmp, nir_before_block(first_continue_from_blk),
374 nir_after_block(continue_from_blk));
375 nir_cf_reinsert(&tmp, nir_after_cf_node(&nif->cf_node));
376
377 return true;
378 }
379
380 static void
381 replace_if_condition_use_with_const(nir_builder *b, nir_src *use,
382 nir_const_value nir_boolean,
383 bool if_condition)
384 {
385 /* Create const */
386 nir_ssa_def *const_def = nir_build_imm(b, 1, 32, nir_boolean);
387
388 /* Rewrite use to use const */
389 nir_src new_src = nir_src_for_ssa(const_def);
390 if (if_condition)
391 nir_if_rewrite_condition(use->parent_if, new_src);
392 else
393 nir_instr_rewrite_src(use->parent_instr, use, new_src);
394 }
395
396 static bool
397 evaluate_if_condition(nir_if *nif, nir_cursor cursor, uint32_t *value)
398 {
399 nir_block *use_block = nir_cursor_current_block(cursor);
400 if (nir_block_dominates(nir_if_first_then_block(nif), use_block)) {
401 *value = NIR_TRUE;
402 return true;
403 } else if (nir_block_dominates(nir_if_first_else_block(nif), use_block)) {
404 *value = NIR_FALSE;
405 return true;
406 } else {
407 return false;
408 }
409 }
410
411 /*
412 * This propagates if condition evaluation down the chain of some alu
413 * instructions. For example by checking the use of some of the following alu
414 * instruction we can eventually replace ssa_107 with NIR_TRUE.
415 *
416 * loop {
417 * block block_1:
418 * vec1 32 ssa_85 = load_const (0x00000002)
419 * vec1 32 ssa_86 = ieq ssa_48, ssa_85
420 * vec1 32 ssa_87 = load_const (0x00000001)
421 * vec1 32 ssa_88 = ieq ssa_48, ssa_87
422 * vec1 32 ssa_89 = ior ssa_86, ssa_88
423 * vec1 32 ssa_90 = ieq ssa_48, ssa_0
424 * vec1 32 ssa_91 = ior ssa_89, ssa_90
425 * if ssa_86 {
426 * block block_2:
427 * ...
428 * break
429 * } else {
430 * block block_3:
431 * }
432 * block block_4:
433 * if ssa_88 {
434 * block block_5:
435 * ...
436 * break
437 * } else {
438 * block block_6:
439 * }
440 * block block_7:
441 * if ssa_90 {
442 * block block_8:
443 * ...
444 * break
445 * } else {
446 * block block_9:
447 * }
448 * block block_10:
449 * vec1 32 ssa_107 = inot ssa_91
450 * if ssa_107 {
451 * block block_11:
452 * break
453 * } else {
454 * block block_12:
455 * }
456 * }
457 */
458 static bool
459 propagate_condition_eval(nir_builder *b, nir_if *nif, nir_src *use_src,
460 nir_src *alu_use, nir_alu_instr *alu,
461 bool is_if_condition)
462 {
463 bool progress = false;
464
465 nir_const_value bool_value;
466 b->cursor = nir_before_src(alu_use, is_if_condition);
467 if (nir_op_infos[alu->op].num_inputs == 1) {
468 assert(alu->op == nir_op_inot || alu->op == nir_op_b2i);
469
470 if (evaluate_if_condition(nif, b->cursor, &bool_value.u32[0])) {
471 assert(nir_src_bit_size(alu->src[0].src) == 32);
472
473 nir_const_value result =
474 nir_eval_const_opcode(alu->op, 1, 32, &bool_value);
475
476 replace_if_condition_use_with_const(b, alu_use, result,
477 is_if_condition);
478 progress = true;
479 }
480 } else {
481 assert(alu->op == nir_op_ior || alu->op == nir_op_iand);
482
483 if (evaluate_if_condition(nif, b->cursor, &bool_value.u32[0])) {
484 nir_ssa_def *def[2];
485 for (unsigned i = 0; i < 2; i++) {
486 if (alu->src[i].src.ssa == use_src->ssa) {
487 def[i] = nir_build_imm(b, 1, 32, bool_value);
488 } else {
489 def[i] = alu->src[i].src.ssa;
490 }
491 }
492
493 nir_ssa_def *nalu =
494 nir_build_alu(b, alu->op, def[0], def[1], NULL, NULL);
495
496 /* Rewrite use to use new alu instruction */
497 nir_src new_src = nir_src_for_ssa(nalu);
498
499 if (is_if_condition)
500 nir_if_rewrite_condition(alu_use->parent_if, new_src);
501 else
502 nir_instr_rewrite_src(alu_use->parent_instr, alu_use, new_src);
503
504 progress = true;
505 }
506 }
507
508 return progress;
509 }
510
511 static bool
512 can_propagate_through_alu(nir_src *src)
513 {
514 if (src->parent_instr->type == nir_instr_type_alu &&
515 (nir_instr_as_alu(src->parent_instr)->op == nir_op_ior ||
516 nir_instr_as_alu(src->parent_instr)->op == nir_op_iand ||
517 nir_instr_as_alu(src->parent_instr)->op == nir_op_inot ||
518 nir_instr_as_alu(src->parent_instr)->op == nir_op_b2i))
519 return true;
520
521 return false;
522 }
523
524 static bool
525 evaluate_condition_use(nir_builder *b, nir_if *nif, nir_src *use_src,
526 bool is_if_condition)
527 {
528 bool progress = false;
529
530 nir_const_value value;
531 b->cursor = nir_before_src(use_src, is_if_condition);
532
533 if (evaluate_if_condition(nif, b->cursor, &value.u32[0])) {
534 replace_if_condition_use_with_const(b, use_src, value, is_if_condition);
535 progress = true;
536 }
537
538 if (!is_if_condition && can_propagate_through_alu(use_src)) {
539 nir_alu_instr *alu = nir_instr_as_alu(use_src->parent_instr);
540
541 nir_foreach_use_safe(alu_use, &alu->dest.dest.ssa) {
542 progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu,
543 false);
544 }
545
546 nir_foreach_if_use_safe(alu_use, &alu->dest.dest.ssa) {
547 progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu,
548 true);
549 }
550 }
551
552 return progress;
553 }
554
555 static bool
556 opt_if_evaluate_condition_use(nir_builder *b, nir_if *nif)
557 {
558 bool progress = false;
559
560 /* Evaluate any uses of the if condition inside the if branches */
561 assert(nif->condition.is_ssa);
562 nir_foreach_use_safe(use_src, nif->condition.ssa) {
563 progress |= evaluate_condition_use(b, nif, use_src, false);
564 }
565
566 nir_foreach_if_use_safe(use_src, nif->condition.ssa) {
567 if (use_src->parent_if != nif)
568 progress |= evaluate_condition_use(b, nif, use_src, true);
569 }
570
571 return progress;
572 }
573
574 static bool
575 opt_if_cf_list(nir_builder *b, struct exec_list *cf_list)
576 {
577 bool progress = false;
578 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
579 switch (cf_node->type) {
580 case nir_cf_node_block:
581 break;
582
583 case nir_cf_node_if: {
584 nir_if *nif = nir_cf_node_as_if(cf_node);
585 progress |= opt_if_cf_list(b, &nif->then_list);
586 progress |= opt_if_cf_list(b, &nif->else_list);
587 progress |= opt_if_loop_terminator(nif);
588 progress |= opt_if_simplification(b, nif);
589 break;
590 }
591
592 case nir_cf_node_loop: {
593 nir_loop *loop = nir_cf_node_as_loop(cf_node);
594 progress |= opt_if_cf_list(b, &loop->body);
595 progress |= opt_peel_loop_initial_if(loop);
596 break;
597 }
598
599 case nir_cf_node_function:
600 unreachable("Invalid cf type");
601 }
602 }
603
604 return progress;
605 }
606
607 /**
608 * These optimisations depend on nir_metadata_block_index and therefore must
609 * not do anything to cause the metadata to become invalid.
610 */
611 static bool
612 opt_if_safe_cf_list(nir_builder *b, struct exec_list *cf_list)
613 {
614 bool progress = false;
615 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
616 switch (cf_node->type) {
617 case nir_cf_node_block:
618 break;
619
620 case nir_cf_node_if: {
621 nir_if *nif = nir_cf_node_as_if(cf_node);
622 progress |= opt_if_safe_cf_list(b, &nif->then_list);
623 progress |= opt_if_safe_cf_list(b, &nif->else_list);
624 progress |= opt_if_evaluate_condition_use(b, nif);
625 break;
626 }
627
628 case nir_cf_node_loop: {
629 nir_loop *loop = nir_cf_node_as_loop(cf_node);
630 progress |= opt_if_safe_cf_list(b, &loop->body);
631 break;
632 }
633
634 case nir_cf_node_function:
635 unreachable("Invalid cf type");
636 }
637 }
638
639 return progress;
640 }
641
642 bool
643 nir_opt_if(nir_shader *shader)
644 {
645 bool progress = false;
646
647 nir_foreach_function(function, shader) {
648 if (function->impl == NULL)
649 continue;
650
651 nir_builder b;
652 nir_builder_init(&b, function->impl);
653
654 nir_metadata_require(function->impl, nir_metadata_block_index |
655 nir_metadata_dominance);
656 progress = opt_if_safe_cf_list(&b, &function->impl->body);
657 nir_metadata_preserve(function->impl, nir_metadata_block_index |
658 nir_metadata_dominance);
659
660 if (opt_if_cf_list(&b, &function->impl->body)) {
661 nir_metadata_preserve(function->impl, nir_metadata_none);
662
663 /* If that made progress, we're no longer really in SSA form. We
664 * need to convert registers back into SSA defs and clean up SSA defs
665 * that don't dominate their uses.
666 */
667 nir_lower_regs_to_ssa_impl(function->impl);
668
669 progress = true;
670 }
671 }
672
673 return progress;
674 }