nir: implement the GLSL equivalent of if simplication in nir_opt_if
[mesa.git] / src / compiler / nir / nir_opt_if.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir/nir_builder.h"
26 #include "nir_control_flow.h"
27
28 /**
29 * This optimization detects if statements at the tops of loops where the
30 * condition is a phi node of two constants and moves half of the if to above
31 * the loop and the other half of the if to the end of the loop. A simple for
32 * loop "for (int i = 0; i < 4; i++)", when run through the SPIR-V front-end,
33 * ends up looking something like this:
34 *
35 * vec1 32 ssa_0 = load_const (0x00000000)
36 * vec1 32 ssa_1 = load_const (0xffffffff)
37 * loop {
38 * block block_1:
39 * vec1 32 ssa_2 = phi block_0: ssa_0, block_7: ssa_5
40 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
41 * if ssa_2 {
42 * block block_2:
43 * vec1 32 ssa_4 = load_const (0x00000001)
44 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
45 * } else {
46 * block block_3:
47 * }
48 * block block_4:
49 * vec1 32 ssa_6 = load_const (0x00000004)
50 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
51 * if ssa_7 {
52 * block block_5:
53 * } else {
54 * block block_6:
55 * break
56 * }
57 * block block_7:
58 * }
59 *
60 * This turns it into something like this:
61 *
62 * // Stuff from block 1
63 * // Stuff from block 3
64 * loop {
65 * block block_1:
66 * vec1 32 ssa_3 = phi block_0: ssa_0, block_7: ssa_1
67 * vec1 32 ssa_6 = load_const (0x00000004)
68 * vec1 32 ssa_7 = ilt ssa_5, ssa_6
69 * if ssa_7 {
70 * block block_5:
71 * } else {
72 * block block_6:
73 * break
74 * }
75 * block block_7:
76 * // Stuff from block 1
77 * // Stuff from block 2
78 * vec1 32 ssa_4 = load_const (0x00000001)
79 * vec1 32 ssa_5 = iadd ssa_2, ssa_4
80 * }
81 */
82 static bool
83 opt_peel_loop_initial_if(nir_loop *loop)
84 {
85 nir_block *header_block = nir_loop_first_block(loop);
86 nir_block *prev_block =
87 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
88
89 /* It would be insane if this were not true */
90 assert(_mesa_set_search(header_block->predecessors, prev_block));
91
92 /* The loop must have exactly one continue block which could be a block
93 * ending in a continue instruction or the "natural" continue from the
94 * last block in the loop back to the top.
95 */
96 if (header_block->predecessors->entries != 2)
97 return false;
98
99 nir_block *continue_block = NULL;
100 struct set_entry *pred_entry;
101 set_foreach(header_block->predecessors, pred_entry) {
102 if (pred_entry->key != prev_block)
103 continue_block = (void *)pred_entry->key;
104 }
105
106 nir_cf_node *if_node = nir_cf_node_next(&header_block->cf_node);
107 if (!if_node || if_node->type != nir_cf_node_if)
108 return false;
109
110 nir_if *nif = nir_cf_node_as_if(if_node);
111 assert(nif->condition.is_ssa);
112
113 nir_ssa_def *cond = nif->condition.ssa;
114 if (cond->parent_instr->type != nir_instr_type_phi)
115 return false;
116
117 nir_phi_instr *cond_phi = nir_instr_as_phi(cond->parent_instr);
118 if (cond->parent_instr->block != header_block)
119 return false;
120
121 /* We already know we have exactly one continue */
122 assert(exec_list_length(&cond_phi->srcs) == 2);
123
124 uint32_t entry_val = 0, continue_val = 0;
125 nir_foreach_phi_src(src, cond_phi) {
126 assert(src->src.is_ssa);
127 nir_const_value *const_src = nir_src_as_const_value(src->src);
128 if (!const_src)
129 return false;
130
131 if (src->pred == continue_block) {
132 continue_val = const_src->u32[0];
133 } else {
134 assert(src->pred == prev_block);
135 entry_val = const_src->u32[0];
136 }
137 }
138
139 /* If they both execute or both don't execute, this is a job for
140 * nir_dead_cf, not this pass.
141 */
142 if ((entry_val && continue_val) || (!entry_val && !continue_val))
143 return false;
144
145 struct exec_list *continue_list, *entry_list;
146 if (continue_val) {
147 continue_list = &nif->then_list;
148 entry_list = &nif->else_list;
149 } else {
150 continue_list = &nif->else_list;
151 entry_list = &nif->then_list;
152 }
153
154 /* We want to be moving the contents of entry_list to above the loop so it
155 * can't contain any break or continue instructions.
156 */
157 foreach_list_typed(nir_cf_node, cf_node, node, entry_list) {
158 nir_foreach_block_in_cf_node(block, cf_node) {
159 nir_instr *last_instr = nir_block_last_instr(block);
160 if (last_instr && last_instr->type == nir_instr_type_jump)
161 return false;
162 }
163 }
164
165 /* Before we do anything, convert the loop to LCSSA. We're about to
166 * replace a bunch of SSA defs with registers and this will prevent any of
167 * it from leaking outside the loop.
168 */
169 nir_convert_loop_to_lcssa(loop);
170
171 nir_block *after_if_block =
172 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
173
174 /* Get rid of phis in the header block since we will be duplicating it */
175 nir_lower_phis_to_regs_block(header_block);
176 /* Get rid of phis after the if since dominance will change */
177 nir_lower_phis_to_regs_block(after_if_block);
178
179 /* Get rid of SSA defs in the pieces we're about to move around */
180 nir_lower_ssa_defs_to_regs_block(header_block);
181 nir_foreach_block_in_cf_node(block, &nif->cf_node)
182 nir_lower_ssa_defs_to_regs_block(block);
183
184 nir_cf_list header, tmp;
185 nir_cf_extract(&header, nir_before_block(header_block),
186 nir_after_block(header_block));
187
188 nir_cf_list_clone(&tmp, &header, &loop->cf_node, NULL);
189 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
190 nir_cf_extract(&tmp, nir_before_cf_list(entry_list),
191 nir_after_cf_list(entry_list));
192 nir_cf_reinsert(&tmp, nir_before_cf_node(&loop->cf_node));
193
194 nir_cf_reinsert(&header, nir_after_block_before_jump(continue_block));
195 nir_cf_extract(&tmp, nir_before_cf_list(continue_list),
196 nir_after_cf_list(continue_list));
197 nir_cf_reinsert(&tmp, nir_after_block_before_jump(continue_block));
198
199 nir_cf_node_remove(&nif->cf_node);
200
201 return true;
202 }
203
204 static bool
205 is_block_empty(nir_block *block)
206 {
207 return nir_cf_node_is_last(&block->cf_node) &&
208 exec_list_is_empty(&block->instr_list);
209 }
210
211 /**
212 * This optimization turns:
213 *
214 * if (cond) {
215 * } else {
216 * do_work();
217 * }
218 *
219 * into:
220 *
221 * if (!cond) {
222 * do_work();
223 * } else {
224 * }
225 */
226 static bool
227 opt_if_simplification(nir_builder *b, nir_if *nif)
228 {
229 /* Only simplify if the then block is empty and the else block is not. */
230 if (!is_block_empty(nir_if_first_then_block(nif)) ||
231 is_block_empty(nir_if_first_else_block(nif)))
232 return false;
233
234 /* Make sure the condition is a comparison operation. */
235 nir_instr *src_instr = nif->condition.ssa->parent_instr;
236 if (src_instr->type != nir_instr_type_alu)
237 return false;
238
239 nir_alu_instr *alu_instr = nir_instr_as_alu(src_instr);
240 if (!nir_alu_instr_is_comparison(alu_instr))
241 return false;
242
243 /* Insert the inverted instruction and rewrite the condition. */
244 b->cursor = nir_after_instr(&alu_instr->instr);
245
246 nir_ssa_def *new_condition =
247 nir_inot(b, &alu_instr->dest.dest.ssa);
248
249 nir_if_rewrite_condition(nif, nir_src_for_ssa(new_condition));
250
251 /* Grab pointers to the last then/else blocks for fixing up the phis. */
252 nir_block *then_block = nir_if_last_then_block(nif);
253 nir_block *else_block = nir_if_last_else_block(nif);
254
255 /* Walk all the phis in the block immediately following the if statement and
256 * swap the blocks.
257 */
258 nir_block *after_if_block =
259 nir_cf_node_as_block(nir_cf_node_next(&nif->cf_node));
260
261 nir_foreach_instr(instr, after_if_block) {
262 if (instr->type != nir_instr_type_phi)
263 continue;
264
265 nir_phi_instr *phi = nir_instr_as_phi(instr);
266
267 foreach_list_typed(nir_phi_src, src, node, &phi->srcs) {
268 if (src->pred == else_block) {
269 src->pred = then_block;
270 } else if (src->pred == then_block) {
271 src->pred = else_block;
272 }
273 }
274 }
275
276 /* Finally, move the else block to the then block. */
277 nir_cf_list tmp;
278 nir_cf_extract(&tmp, nir_before_cf_list(&nif->else_list),
279 nir_after_cf_list(&nif->else_list));
280 nir_cf_reinsert(&tmp, nir_before_cf_list(&nif->then_list));
281 nir_cf_delete(&tmp);
282
283 return true;
284 }
285
286 static bool
287 opt_if_cf_list(nir_builder *b, struct exec_list *cf_list)
288 {
289 bool progress = false;
290 foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
291 switch (cf_node->type) {
292 case nir_cf_node_block:
293 break;
294
295 case nir_cf_node_if: {
296 nir_if *nif = nir_cf_node_as_if(cf_node);
297 progress |= opt_if_cf_list(b, &nif->then_list);
298 progress |= opt_if_cf_list(b, &nif->else_list);
299 progress |= opt_if_simplification(b, nif);
300 break;
301 }
302
303 case nir_cf_node_loop: {
304 nir_loop *loop = nir_cf_node_as_loop(cf_node);
305 progress |= opt_if_cf_list(b, &loop->body);
306 progress |= opt_peel_loop_initial_if(loop);
307 break;
308 }
309
310 case nir_cf_node_function:
311 unreachable("Invalid cf type");
312 }
313 }
314
315 return progress;
316 }
317
318 bool
319 nir_opt_if(nir_shader *shader)
320 {
321 bool progress = false;
322
323 nir_foreach_function(function, shader) {
324 if (function->impl == NULL)
325 continue;
326
327 nir_builder b;
328 nir_builder_init(&b, function->impl);
329
330 if (opt_if_cf_list(&b, &function->impl->body)) {
331 nir_metadata_preserve(function->impl, nir_metadata_none);
332
333 /* If that made progress, we're no longer really in SSA form. We
334 * need to convert registers back into SSA defs and clean up SSA defs
335 * that don't dominate their uses.
336 */
337 nir_lower_regs_to_ssa_impl(function->impl);
338 progress = true;
339 }
340 }
341
342 return progress;
343 }