nir/opt_if: Remove unneeded phis if we make progress
[mesa.git] / src / compiler / nir / nir_opt_if.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir/nir_builder.h"
26 #include "nir_control_flow.h"
27 #include "nir_loop_analyze.h"
28
29 /**
30 * This optimization detects if statements at the tops of loops where the
31 * condition is a phi node of two constants and moves half of the if to above
32 * the loop and the other half of the if to the end of the loop. A simple for
33 * loop "for (int i = 0; i < 4; i++)", when run through the SPIR-V front-end,
34 * ends up looking something like this:
35 *
36 * vec1 32 ssa_0 = load_const (0x00000000)
37 * vec1 32 ssa_1 = load_const (0xffffffff)
38 * loop {
39 * block block_1:
40 * vec1 32 ssa_2 = phi block_0: ssa_0, block_7: ssa_5
41 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
42 * if ssa_2 {
43 * block block_2:
44 * vec1 32 ssa_4 = load_const (0x00000001)
45 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
46 * } else {
47 * block block_3:
48 * }
49 * block block_4:
50 * vec1 32 ssa_6 = load_const (0x00000004)
51 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
52 * if ssa_7 {
53 * block block_5:
54 * } else {
55 * block block_6:
56 * break
57 * }
58 * block block_7:
59 * }
60 *
61 * This turns it into something like this:
62 *
63 * // Stuff from block 1
64 * // Stuff from block 3
65 * loop {
66 * block block_1:
67 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
68 * vec1 32 ssa_6 = load_const (0x00000004)
69 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
70 * if ssa_7 {
71 * block block_5:
72 * } else {
73 * block block_6:
74 * break
75 * }
76 * block block_7:
77 * // Stuff from block 1
78 * // Stuff from block 2
79 * vec1 32 ssa_4 = load_const (0x00000001)
80 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
81 * }
82 */
83 static bool
84 opt_peel_loop_initial_if(nir_loop *loop)
85 {
86 nir_block *header_block = nir_loop_first_block(loop);
87 nir_block *prev_block =
88 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
89
90 /* It would be insane if this were not true */
91 assert(_mesa_set_search(header_block->predecessors, prev_block));
92
93 /* The loop must have exactly one continue block which could be a block
94 * ending in a continue instruction or the "natural" continue from the
95 * last block in the loop back to the top.
96 */
97 if (header_block->predecessors->entries != 2)
98 return false;
99
100 nir_block *continue_block = NULL;
101 struct set_entry *pred_entry;
102 set_foreach(header_block->predecessors, pred_entry) {
103 if (pred_entry->key != prev_block)
104 continue_block = (void *)pred_entry->key;
105 }
106
107 nir_cf_node *if_node = nir_cf_node_next(&header_block->cf_node);
108 if (!if_node || if_node->type != nir_cf_node_if)
109 return false;
110
111 nir_if *nif = nir_cf_node_as_if(if_node);
112 assert(nif->condition.is_ssa);
113
114 nir_ssa_def *cond = nif->condition.ssa;
115 if (cond->parent_instr->type != nir_instr_type_phi)
116 return false;
117
118 nir_phi_instr *cond_phi = nir_instr_as_phi(cond->parent_instr);
119 if (cond->parent_instr->block != header_block)
120 return false;
121
122 /* We already know we have exactly one continue */
123 assert(exec_list_length(&cond_phi->srcs) == 2);
124
125 uint32_t entry_val = 0, continue_val = 0;
126 nir_foreach_phi_src(src, cond_phi) {
127 assert(src->src.is_ssa);
128 nir_const_value *const_src = nir_src_as_const_value(src->src);
129 if (!const_src)
130 return false;
131
132 if (src->pred == continue_block) {
133 continue_val = const_src->u32[0];
134 } else {
135 assert(src->pred == prev_block);
136 entry_val = const_src->u32[0];
137 }
138 }
139
140 /* If they both execute or both don't execute, this is a job for
141 * nir_dead_cf, not this pass.
142 */
143 if ((entry_val && continue_val) || (!entry_val && !continue_val))
144 return false;
145
146 struct exec_list *continue_list, *entry_list;
147 if (continue_val) {
148 continue_list = &nif->then_list;
149 entry_list = &nif->else_list;
150 } else {
151 continue_list = &nif->else_list;
152 entry_list = &nif->then_list;
153 }
154
155 /* We want to be moving the contents of entry_list to above the loop so it
156 * can't contain any break or continue instructions.
157 */
158 foreach_list_typed(nir_cf_node, cf_node, node, entry_list) {
159 nir_foreach_block_in_cf_node(block, cf_node) {
160 nir_instr *last_instr = nir_block_last_instr(block);
161 if (last_instr && last_instr->type == nir_instr_type_jump)
162 return false;
163 }
164 }
165
166 /* Before we do anything, convert the loop to LCSSA. We're about to
167 * replace a bunch of SSA defs with registers and this will prevent any of
168 * it from leaking outside the loop.
169 */
170 nir_convert_loop_to_lcssa(loop);
171
172 nir_block *after_if_block =
173 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
174
175 /* Get rid of phis in the header block since we will be duplicating it */
176 nir_lower_phis_to_regs_block(header_block);
177 /* Get rid of phis after the if since dominance will change */
178 nir_lower_phis_to_regs_block(after_if_block);
179
180 /* Get rid of SSA defs in the pieces we're about to move around */
181 nir_lower_ssa_defs_to_regs_block(header_block);
182 nir_foreach_block_in_cf_node(block, &nif->cf_node)
183 nir_lower_ssa_defs_to_regs_block(block);
184
185 nir_cf_list header, tmp;
186 nir_cf_extract(&header, nir_before_block(header_block),
187 nir_after_block(header_block));
188
189 nir_cf_list_clone(&tmp, &header, &loop->cf_node, NULL);
190 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
191 nir_cf_extract(&tmp, nir_before_cf_list(entry_list),
192 nir_after_cf_list(entry_list));
193 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
194
195 nir_cf_reinsert(&header, nir_after_block_before_jump(continue_block));
196 nir_cf_extract(&tmp, nir_before_cf_list(continue_list),
197 nir_after_cf_list(continue_list));
198 nir_cf_reinsert(&tmp, nir_after_block_before_jump(continue_block));
199
200 nir_cf_node_remove(&nif->cf_node);
201
202 return true;
203 }
204
205 static bool
206 is_block_empty(nir_block *block)
207 {
208 return nir_cf_node_is_last(&block->cf_node) &&
209 exec_list_is_empty(&block->instr_list);
210 }
211
212 /**
213 * This optimization turns:
214 *
215 * if (cond) {
216 * } else {
217 * do_work();
218 * }
219 *
220 * into:
221 *
222 * if (!cond) {
223 * do_work();
224 * } else {
225 * }
226 */
227 static bool
228 opt_if_simplification(nir_builder *b, nir_if *nif)
229 {
230 /* Only simplify if the then block is empty and the else block is not. */
231 if (!is_block_empty(nir_if_first_then_block(nif)) ||
232 is_block_empty(nir_if_first_else_block(nif)))
233 return false;
234
235 /* Make sure the condition is a comparison operation. */
236 nir_instr *src_instr = nif->condition.ssa->parent_instr;
237 if (src_instr->type != nir_instr_type_alu)
238 return false;
239
240 nir_alu_instr *alu_instr = nir_instr_as_alu(src_instr);
241 if (!nir_alu_instr_is_comparison(alu_instr))
242 return false;
243
244 /* Insert the inverted instruction and rewrite the condition. */
245 b->cursor = nir_after_instr(&alu_instr->instr);
246
247 nir_ssa_def *new_condition =
248 nir_inot(b, &alu_instr->dest.dest.ssa);
249
250 nir_if_rewrite_condition(nif, nir_src_for_ssa(new_condition));
251
252 /* Grab pointers to the last then/else blocks for fixing up the phis. */
253 nir_block *then_block = nir_if_last_then_block(nif);
254 nir_block *else_block = nir_if_last_else_block(nif);
255
256 /* Walk all the phis in the block immediately following the if statement and
257 * swap the blocks.
258 */
259 nir_block *after_if_block =
260 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
261
262 nir_foreach_instr(instr, after_if_block) {
263 if (instr->type != nir_instr_type_phi)
264 continue;
265
266 nir_phi_instr *phi = nir_instr_as_phi(instr);
267
268 foreach_list_typed(nir_phi_src, src, node, &phi->srcs) {
269 if (src->pred == else_block) {
270 src->pred = then_block;
271 } else if (src->pred == then_block) {
272 src->pred = else_block;
273 }
274 }
275 }
276
277 /* Finally, move the else block to the then block. */
278 nir_cf_list tmp;
279 nir_cf_extract(&tmp, nir_before_cf_list(&nif->else_list),
280 nir_after_cf_list(&nif->else_list));
281 nir_cf_reinsert(&tmp, nir_before_cf_list(&nif->then_list));
282 nir_cf_delete(&tmp);
283
284 return true;
285 }
286
287 /**
288 * This optimization simplifies potential loop terminators which then allows
289 * other passes such as opt_if_simplification() and loop unrolling to progress
290 * further:
291 *
292 * if (cond) {
293 * ... then block instructions ...
294 * } else {
295 * ...
296 * break;
297 * }
298 *
299 * into:
300 *
301 * if (cond) {
302 * } else {
303 * ...
304 * break;
305 * }
306 * ... then block instructions ...
307 */
308 static bool
309 opt_if_loop_terminator(nir_if *nif)
310 {
311 nir_block *break_blk = NULL;
312 nir_block *continue_from_blk = NULL;
313 bool continue_from_then = true;
314
315 nir_block *last_then = nir_if_last_then_block(nif);
316 nir_block *last_else = nir_if_last_else_block(nif);
317
318 if (nir_block_ends_in_break(last_then)) {
319 break_blk = last_then;
320 continue_from_blk = last_else;
321 continue_from_then = false;
322 } else if (nir_block_ends_in_break(last_else)) {
323 break_blk = last_else;
324 continue_from_blk = last_then;
325 }
326
327 /* Continue if the if-statement contained no jumps at all */
328 if (!break_blk)
329 return false;
330
331 /* If the continue from block is empty then return as there is nothing to
332 * move.
333 */
334 nir_block *first_continue_from_blk = continue_from_then ?
335 nir_if_first_then_block(nif) :
336 nir_if_first_else_block(nif);
337 if (is_block_empty(first_continue_from_blk))
338 return false;
339
340 if (!nir_is_trivial_loop_if(nif, break_blk))
341 return false;
342
343 /* Finally, move the continue from branch after the if-statement. */
344 nir_cf_list tmp;
345 nir_cf_extract(&tmp, nir_before_block(first_continue_from_blk),
346 nir_after_block(continue_from_blk));
347 nir_cf_reinsert(&tmp, nir_after_cf_node(&nif->cf_node));
348 nir_cf_delete(&tmp);
349
350 return true;
351 }
352
353 static bool
354 opt_if_cf_list(nir_builder *b, struct exec_list *cf_list)
355 {
356 bool progress = false;
357 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
358 switch (cf_node->type) {
359 case nir_cf_node_block:
360 break;
361
362 case nir_cf_node_if: {
363 nir_if *nif = nir_cf_node_as_if(cf_node);
364 progress |= opt_if_cf_list(b, &nif->then_list);
365 progress |= opt_if_cf_list(b, &nif->else_list);
366 progress |= opt_if_loop_terminator(nif);
367 progress |= opt_if_simplification(b, nif);
368 break;
369 }
370
371 case nir_cf_node_loop: {
372 nir_loop *loop = nir_cf_node_as_loop(cf_node);
373 progress |= opt_if_cf_list(b, &loop->body);
374 progress |= opt_peel_loop_initial_if(loop);
375 break;
376 }
377
378 case nir_cf_node_function:
379 unreachable("Invalid cf type");
380 }
381 }
382
383 return progress;
384 }
385
386 bool
387 nir_opt_if(nir_shader *shader)
388 {
389 bool progress = false;
390
391 nir_foreach_function(function, shader) {
392 if (function->impl == NULL)
393 continue;
394
395 nir_builder b;
396 nir_builder_init(&b, function->impl);
397
398 if (opt_if_cf_list(&b, &function->impl->body)) {
399 nir_metadata_preserve(function->impl, nir_metadata_none);
400
401 /* If that made progress, we're no longer really in SSA form. We
402 * need to convert registers back into SSA defs and clean up SSA defs
403 * that don't dominate their uses.
404 */
405 nir_lower_regs_to_ssa_impl(function->impl);
406
407 /* Calling nir_convert_loop_to_lcssa() in opt_peel_loop_initial_if()
408 * adds extra phi nodes which may not be valid if they're used for
409 * something such as a deref. Remove any unneeded phis.
410 */
411 nir_opt_remove_phis_impl(function->impl);
412
413 progress = true;
414 }
415 }
416
417 return progress;
418 }