nir: evaluate if condition uses inside the if branches
[mesa.git] / src / compiler / nir / nir_opt_if.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir/nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_loop_analyze.h"
28
29 /**
30 * Gets the single block that jumps back to the loop header. Already assumes
31 * there is exactly one such block.
32 */
33 static nir_block*
34 find_continue_block(nir_loop *loop)
35 {
36 nir_block *header_block = nir_loop_first_block(loop);
37 nir_block *prev_block =
38 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
39
40 assert(header_block->predecessors->entries == 2);
41
42 struct set_entry *pred_entry;
43 set_foreach(header_block->predecessors, pred_entry) {
44 if (pred_entry->key != prev_block)
45 return (nir_block*)pred_entry->key;
46 }
47
48 unreachable("Continue block not found!");
49 }
50
51 /**
52 * This optimization detects if statements at the tops of loops where the
53 * condition is a phi node of two constants and moves half of the if to above
54 * the loop and the other half of the if to the end of the loop. A simple for
55 * loop "for (int i = 0; i < 4; i++)", when run through the SPIR-V front-end,
56 * ends up looking something like this:
57 *
58 * vec1 32 ssa_0 = load_const (0x00000000)
59 * vec1 32 ssa_1 = load_const (0xffffffff)
60 * loop {
61 * block block_1:
62 * vec1 32 ssa_2 = phi block_0: ssa_0, block_7: ssa_5
63 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
64 * if ssa_2 {
65 * block block_2:
66 * vec1 32 ssa_4 = load_const (0x00000001)
67 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
68 * } else {
69 * block block_3:
70 * }
71 * block block_4:
72 * vec1 32 ssa_6 = load_const (0x00000004)
73 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
74 * if ssa_7 {
75 * block block_5:
76 * } else {
77 * block block_6:
78 * break
79 * }
80 * block block_7:
81 * }
82 *
83 * This turns it into something like this:
84 *
85 * // Stuff from block 1
86 * // Stuff from block 3
87 * loop {
88 * block block_1:
89 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
90 * vec1 32 ssa_6 = load_const (0x00000004)
91 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
92 * if ssa_7 {
93 * block block_5:
94 * } else {
95 * block block_6:
96 * break
97 * }
98 * block block_7:
99 * // Stuff from block 1
100 * // Stuff from block 2
101 * vec1 32 ssa_4 = load_const (0x00000001)
102 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
103 * }
104 */
105 static bool
106 opt_peel_loop_initial_if(nir_loop *loop)
107 {
108 nir_block *header_block = nir_loop_first_block(loop);
109 MAYBE_UNUSED nir_block *prev_block =
110 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
111
112 /* It would be insane if this were not true */
113 assert(_mesa_set_search(header_block->predecessors, prev_block));
114
115 /* The loop must have exactly one continue block which could be a block
116 * ending in a continue instruction or the "natural" continue from the
117 * last block in the loop back to the top.
118 */
119 if (header_block->predecessors->entries != 2)
120 return false;
121
122 nir_block *continue_block = find_continue_block(loop);
123
124 nir_cf_node *if_node = nir_cf_node_next(&header_block->cf_node);
125 if (!if_node || if_node->type != nir_cf_node_if)
126 return false;
127
128 nir_if *nif = nir_cf_node_as_if(if_node);
129 assert(nif->condition.is_ssa);
130
131 nir_ssa_def *cond = nif->condition.ssa;
132 if (cond->parent_instr->type != nir_instr_type_phi)
133 return false;
134
135 nir_phi_instr *cond_phi = nir_instr_as_phi(cond->parent_instr);
136 if (cond->parent_instr->block != header_block)
137 return false;
138
139 /* We already know we have exactly one continue */
140 assert(exec_list_length(&cond_phi->srcs) == 2);
141
142 uint32_t entry_val = 0, continue_val = 0;
143 nir_foreach_phi_src(src, cond_phi) {
144 assert(src->src.is_ssa);
145 nir_const_value *const_src = nir_src_as_const_value(src->src);
146 if (!const_src)
147 return false;
148
149 if (src->pred == continue_block) {
150 continue_val = const_src->u32[0];
151 } else {
152 assert(src->pred == prev_block);
153 entry_val = const_src->u32[0];
154 }
155 }
156
157 /* If they both execute or both don't execute, this is a job for
158 * nir_dead_cf, not this pass.
159 */
160 if ((entry_val && continue_val) || (!entry_val && !continue_val))
161 return false;
162
163 struct exec_list *continue_list, *entry_list;
164 if (continue_val) {
165 continue_list = &nif->then_list;
166 entry_list = &nif->else_list;
167 } else {
168 continue_list = &nif->else_list;
169 entry_list = &nif->then_list;
170 }
171
172 /* We want to be moving the contents of entry_list to above the loop so it
173 * can't contain any break or continue instructions.
174 */
175 foreach_list_typed(nir_cf_node, cf_node, node, entry_list) {
176 nir_foreach_block_in_cf_node(block, cf_node) {
177 nir_instr *last_instr = nir_block_last_instr(block);
178 if (last_instr && last_instr->type == nir_instr_type_jump)
179 return false;
180 }
181 }
182
183 /* Before we do anything, convert the loop to LCSSA. We're about to
184 * replace a bunch of SSA defs with registers and this will prevent any of
185 * it from leaking outside the loop.
186 */
187 nir_convert_loop_to_lcssa(loop);
188
189 nir_block *after_if_block =
190 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
191
192 /* Get rid of phis in the header block since we will be duplicating it */
193 nir_lower_phis_to_regs_block(header_block);
194 /* Get rid of phis after the if since dominance will change */
195 nir_lower_phis_to_regs_block(after_if_block);
196
197 /* Get rid of SSA defs in the pieces we're about to move around */
198 nir_lower_ssa_defs_to_regs_block(header_block);
199 nir_foreach_block_in_cf_node(block, &nif->cf_node)
200 nir_lower_ssa_defs_to_regs_block(block);
201
202 nir_cf_list header, tmp;
203 nir_cf_extract(&header, nir_before_block(header_block),
204 nir_after_block(header_block));
205
206 nir_cf_list_clone(&tmp, &header, &loop->cf_node, NULL);
207 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
208 nir_cf_extract(&tmp, nir_before_cf_list(entry_list),
209 nir_after_cf_list(entry_list));
210 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
211
212 nir_cf_reinsert(&header, nir_after_block_before_jump(continue_block));
213
214 /* Get continue block again as the previous reinsert might have removed the block. */
215 continue_block = find_continue_block(loop);
216
217 nir_cf_extract(&tmp, nir_before_cf_list(continue_list),
218 nir_after_cf_list(continue_list));
219 nir_cf_reinsert(&tmp, nir_after_block_before_jump(continue_block));
220
221 nir_cf_node_remove(&nif->cf_node);
222
223 return true;
224 }
225
226 static bool
227 is_block_empty(nir_block *block)
228 {
229 return nir_cf_node_is_last(&block->cf_node) &&
230 exec_list_is_empty(&block->instr_list);
231 }
232
233 /**
234 * This optimization turns:
235 *
236 * if (cond) {
237 * } else {
238 * do_work();
239 * }
240 *
241 * into:
242 *
243 * if (!cond) {
244 * do_work();
245 * } else {
246 * }
247 */
248 static bool
249 opt_if_simplification(nir_builder *b, nir_if *nif)
250 {
251 /* Only simplify if the then block is empty and the else block is not. */
252 if (!is_block_empty(nir_if_first_then_block(nif)) ||
253 is_block_empty(nir_if_first_else_block(nif)))
254 return false;
255
256 /* Make sure the condition is a comparison operation. */
257 nir_instr *src_instr = nif->condition.ssa->parent_instr;
258 if (src_instr->type != nir_instr_type_alu)
259 return false;
260
261 nir_alu_instr *alu_instr = nir_instr_as_alu(src_instr);
262 if (!nir_alu_instr_is_comparison(alu_instr))
263 return false;
264
265 /* Insert the inverted instruction and rewrite the condition. */
266 b->cursor = nir_after_instr(&alu_instr->instr);
267
268 nir_ssa_def *new_condition =
269 nir_inot(b, &alu_instr->dest.dest.ssa);
270
271 nir_if_rewrite_condition(nif, nir_src_for_ssa(new_condition));
272
273 /* Grab pointers to the last then/else blocks for fixing up the phis. */
274 nir_block *then_block = nir_if_last_then_block(nif);
275 nir_block *else_block = nir_if_last_else_block(nif);
276
277 /* Walk all the phis in the block immediately following the if statement and
278 * swap the blocks.
279 */
280 nir_block *after_if_block =
281 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
282
283 nir_foreach_instr(instr, after_if_block) {
284 if (instr->type != nir_instr_type_phi)
285 continue;
286
287 nir_phi_instr *phi = nir_instr_as_phi(instr);
288
289 foreach_list_typed(nir_phi_src, src, node, &phi->srcs) {
290 if (src->pred == else_block) {
291 src->pred = then_block;
292 } else if (src->pred == then_block) {
293 src->pred = else_block;
294 }
295 }
296 }
297
298 /* Finally, move the else block to the then block. */
299 nir_cf_list tmp;
300 nir_cf_extract(&tmp, nir_before_cf_list(&nif->else_list),
301 nir_after_cf_list(&nif->else_list));
302 nir_cf_reinsert(&tmp, nir_before_cf_list(&nif->then_list));
303
304 return true;
305 }
306
307 /**
308 * This optimization simplifies potential loop terminators which then allows
309 * other passes such as opt_if_simplification() and loop unrolling to progress
310 * further:
311 *
312 * if (cond) {
313 * ... then block instructions ...
314 * } else {
315 * ...
316 * break;
317 * }
318 *
319 * into:
320 *
321 * if (cond) {
322 * } else {
323 * ...
324 * break;
325 * }
326 * ... then block instructions ...
327 */
328 static bool
329 opt_if_loop_terminator(nir_if *nif)
330 {
331 nir_block *break_blk = NULL;
332 nir_block *continue_from_blk = NULL;
333 bool continue_from_then = true;
334
335 nir_block *last_then = nir_if_last_then_block(nif);
336 nir_block *last_else = nir_if_last_else_block(nif);
337
338 if (nir_block_ends_in_break(last_then)) {
339 break_blk = last_then;
340 continue_from_blk = last_else;
341 continue_from_then = false;
342 } else if (nir_block_ends_in_break(last_else)) {
343 break_blk = last_else;
344 continue_from_blk = last_then;
345 }
346
347 /* Continue if the if-statement contained no jumps at all */
348 if (!break_blk)
349 return false;
350
351 /* If the continue from block is empty then return as there is nothing to
352 * move.
353 */
354 nir_block *first_continue_from_blk = continue_from_then ?
355 nir_if_first_then_block(nif) :
356 nir_if_first_else_block(nif);
357 if (is_block_empty(first_continue_from_blk))
358 return false;
359
360 if (!nir_is_trivial_loop_if(nif, break_blk))
361 return false;
362
363 /* Finally, move the continue from branch after the if-statement. */
364 nir_cf_list tmp;
365 nir_cf_extract(&tmp, nir_before_block(first_continue_from_blk),
366 nir_after_block(continue_from_blk));
367 nir_cf_reinsert(&tmp, nir_after_cf_node(&nif->cf_node));
368
369 return true;
370 }
371
372 static void
373 replace_if_condition_use_with_const(nir_builder *b, nir_src *use,
374 nir_const_value nir_boolean,
375 bool if_condition)
376 {
377 /* Create const */
378 nir_ssa_def *const_def = nir_build_imm(b, 1, 32, nir_boolean);
379
380 /* Rewrite use to use const */
381 nir_src new_src = nir_src_for_ssa(const_def);
382 if (if_condition)
383 nir_if_rewrite_condition(use->parent_if, new_src);
384 else
385 nir_instr_rewrite_src(use->parent_instr, use, new_src);
386 }
387
388 static bool
389 evaluate_if_condition(nir_if *nif, nir_cursor cursor, uint32_t *value)
390 {
391 nir_block *use_block = nir_cursor_current_block(cursor);
392 if (nir_block_dominates(nir_if_first_then_block(nif), use_block)) {
393 *value = NIR_TRUE;
394 return true;
395 } else if (nir_block_dominates(nir_if_first_else_block(nif), use_block)) {
396 *value = NIR_FALSE;
397 return true;
398 } else {
399 return false;
400 }
401 }
402
403 static bool
404 evaluate_condition_use(nir_builder *b, nir_if *nif, nir_src *use_src,
405 bool is_if_condition)
406 {
407 bool progress = false;
408
409 nir_const_value value;
410 b->cursor = nir_before_src(use_src, is_if_condition);
411
412 if (evaluate_if_condition(nif, b->cursor, &value.u32[0])) {
413 replace_if_condition_use_with_const(b, use_src, value, is_if_condition);
414 progress = true;
415 }
416
417 return progress;
418 }
419
420 static bool
421 opt_if_evaluate_condition_use(nir_builder *b, nir_if *nif)
422 {
423 bool progress = false;
424
425 /* Evaluate any uses of the if condition inside the if branches */
426 assert(nif->condition.is_ssa);
427 nir_foreach_use_safe(use_src, nif->condition.ssa) {
428 progress |= evaluate_condition_use(b, nif, use_src, false);
429 }
430
431 nir_foreach_if_use_safe(use_src, nif->condition.ssa) {
432 if (use_src->parent_if != nif)
433 progress |= evaluate_condition_use(b, nif, use_src, true);
434 }
435
436 return progress;
437 }
438
439 static bool
440 opt_if_cf_list(nir_builder *b, struct exec_list *cf_list)
441 {
442 bool progress = false;
443 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
444 switch (cf_node->type) {
445 case nir_cf_node_block:
446 break;
447
448 case nir_cf_node_if: {
449 nir_if *nif = nir_cf_node_as_if(cf_node);
450 progress |= opt_if_cf_list(b, &nif->then_list);
451 progress |= opt_if_cf_list(b, &nif->else_list);
452 progress |= opt_if_loop_terminator(nif);
453 progress |= opt_if_simplification(b, nif);
454 break;
455 }
456
457 case nir_cf_node_loop: {
458 nir_loop *loop = nir_cf_node_as_loop(cf_node);
459 progress |= opt_if_cf_list(b, &loop->body);
460 progress |= opt_peel_loop_initial_if(loop);
461 break;
462 }
463
464 case nir_cf_node_function:
465 unreachable("Invalid cf type");
466 }
467 }
468
469 return progress;
470 }
471
472 /**
473 * These optimisations depend on nir_metadata_block_index and therefore must
474 * not do anything to cause the metadata to become invalid.
475 */
476 static bool
477 opt_if_safe_cf_list(nir_builder *b, struct exec_list *cf_list)
478 {
479 bool progress = false;
480 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
481 switch (cf_node->type) {
482 case nir_cf_node_block:
483 break;
484
485 case nir_cf_node_if: {
486 nir_if *nif = nir_cf_node_as_if(cf_node);
487 progress |= opt_if_safe_cf_list(b, &nif->then_list);
488 progress |= opt_if_safe_cf_list(b, &nif->else_list);
489 progress |= opt_if_evaluate_condition_use(b, nif);
490 break;
491 }
492
493 case nir_cf_node_loop: {
494 nir_loop *loop = nir_cf_node_as_loop(cf_node);
495 progress |= opt_if_safe_cf_list(b, &loop->body);
496 break;
497 }
498
499 case nir_cf_node_function:
500 unreachable("Invalid cf type");
501 }
502 }
503
504 return progress;
505 }
506
507 bool
508 nir_opt_if(nir_shader *shader)
509 {
510 bool progress = false;
511
512 nir_foreach_function(function, shader) {
513 if (function->impl == NULL)
514 continue;
515
516 nir_builder b;
517 nir_builder_init(&b, function->impl);
518
519 nir_metadata_require(function->impl, nir_metadata_block_index |
520 nir_metadata_dominance);
521 progress = opt_if_safe_cf_list(&b, &function->impl->body);
522 nir_metadata_preserve(function->impl, nir_metadata_block_index |
523 nir_metadata_dominance);
524
525 if (opt_if_cf_list(&b, &function->impl->body)) {
526 nir_metadata_preserve(function->impl, nir_metadata_none);
527
528 /* If that made progress, we're no longer really in SSA form. We
529 * need to convert registers back into SSA defs and clean up SSA defs
530 * that don't dominate their uses.
531 */
532 nir_lower_regs_to_ssa_impl(function->impl);
533
534 /* Calling nir_convert_loop_to_lcssa() in opt_peel_loop_initial_if()
535 * adds extra phi nodes which may not be valid if they're used for
536 * something such as a deref. Remove any unneeded phis.
537 */
538 nir_opt_remove_phis_impl(function->impl);
539
540 progress = true;
541 }
542 }
543
544 return progress;
545 }