327de85d36d7e45fa2726974de3d8b09576f5267
[mesa.git] / src / compiler / nir / nir_to_lcssa.c
1 /*
2 * Copyright © 2015 Thomas Helland
3 * Copyright © 2019 Valve Corporation
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24
25 /*
26 * This pass converts the ssa-graph into "Loop Closed SSA form". This is
27 * done by placing phi nodes at the exits of the loop for all values
28 * that are used outside the loop. The result is it transforms:
29 *
30 * loop { -> loop {
31 * ssa2 = .... -> ssa2 = ...
32 * if (cond) -> if (cond)
33 * break; -> break;
34 * ssa3 = ssa2 * ssa4 -> ssa3 = ssa2 * ssa4
35 * } -> }
36 * ssa6 = ssa2 + 4 -> ssa5 = phi(ssa2)
37 * ssa6 = ssa5 + 4
38 */
39
40 #include "nir.h"
41
42 typedef struct {
43 /* The nir_shader we are transforming */
44 nir_shader *shader;
45
46 /* The loop we store information for */
47 nir_loop *loop;
48
49 /* Whether to skip loop invariant variables */
50 bool skip_invariants;
51 bool skip_bool_invariants;
52
53 bool progress;
54 } lcssa_state;
55
56 static bool
57 is_if_use_inside_loop(nir_src *use, nir_loop *loop)
58 {
59 nir_block *block_before_loop =
60 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
61 nir_block *block_after_loop =
62 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
63
64 nir_block *prev_block =
65 nir_cf_node_as_block(nir_cf_node_prev(&use->parent_if->cf_node));
66 if (prev_block->index <= block_before_loop->index ||
67 prev_block->index >= block_after_loop->index) {
68 return false;
69 }
70
71 return true;
72 }
73
74 static bool
75 is_use_inside_loop(nir_src *use, nir_loop *loop)
76 {
77 nir_block *block_before_loop =
78 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
79 nir_block *block_after_loop =
80 nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
81
82 if (use->parent_instr->block->index <= block_before_loop->index ||
83 use->parent_instr->block->index >= block_after_loop->index) {
84 return false;
85 }
86
87 return true;
88 }
89
90 static bool
91 is_defined_before_loop(nir_ssa_def *def, nir_loop *loop)
92 {
93 nir_instr *instr = def->parent_instr;
94 nir_block *block_before_loop =
95 nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
96
97 return instr->block->index <= block_before_loop->index;
98 }
99
100 typedef enum instr_invariance {
101 undefined = 0,
102 invariant,
103 not_invariant,
104 } instr_invariance;
105
106 static instr_invariance
107 instr_is_invariant(nir_instr *instr, nir_loop *loop);
108
109 static bool
110 def_is_invariant(nir_ssa_def *def, nir_loop *loop)
111 {
112 if (is_defined_before_loop(def, loop))
113 return invariant;
114
115 if (def->parent_instr->pass_flags == undefined)
116 def->parent_instr->pass_flags = instr_is_invariant(def->parent_instr, loop);
117
118 return def->parent_instr->pass_flags == invariant;
119 }
120
121 static bool
122 src_is_invariant(nir_src *src, void *state)
123 {
124 assert(src->is_ssa);
125 return def_is_invariant(src->ssa, (nir_loop *)state);
126 }
127
128 static instr_invariance
129 phi_is_invariant(nir_phi_instr *instr, nir_loop *loop)
130 {
131 /* Base case: it's a phi at the loop header
132 * Loop-header phis are updated in each loop iteration with
133 * the loop-carried value, and thus control-flow dependent
134 * on the loop itself.
135 */
136 if (instr->instr.block == nir_loop_first_block(loop))
137 return not_invariant;
138
139 nir_foreach_phi_src(src, instr) {
140 if (!src_is_invariant(&src->src, loop))
141 return not_invariant;
142 }
143
144 /* All loop header- and LCSSA-phis should be handled by this point. */
145 nir_cf_node *prev = nir_cf_node_prev(&instr->instr.block->cf_node);
146 assert(prev && prev->type == nir_cf_node_if);
147
148 /* Invariance of phis after if-nodes also depends on the invariance
149 * of the branch condition.
150 */
151 nir_if *if_node = nir_cf_node_as_if(prev);
152 if (!def_is_invariant(if_node->condition.ssa, loop))
153 return not_invariant;
154
155 return invariant;
156 }
157
158
159 /* An instruction is said to be loop-invariant if it
160 * - has no sideeffects and
161 * - solely depends on variables defined outside of the loop or
162 * by other invariant instructions
163 */
164 static instr_invariance
165 instr_is_invariant(nir_instr *instr, nir_loop *loop)
166 {
167 assert(instr->pass_flags == undefined);
168
169 switch (instr->type) {
170 case nir_instr_type_load_const:
171 case nir_instr_type_ssa_undef:
172 return invariant;
173 case nir_instr_type_call:
174 return not_invariant;
175 case nir_instr_type_phi:
176 return phi_is_invariant(nir_instr_as_phi(instr), loop);
177 case nir_instr_type_intrinsic: {
178 nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
179 if (!(nir_intrinsic_infos[intrinsic->intrinsic].flags & NIR_INTRINSIC_CAN_REORDER))
180 return not_invariant;
181 }
182 /* fallthrough */
183 default:
184 return nir_foreach_src(instr, src_is_invariant, loop) ? invariant : not_invariant;
185 }
186
187 return invariant;
188 }
189
190 static bool
191 convert_loop_exit_for_ssa(nir_ssa_def *def, void *void_state)
192 {
193 lcssa_state *state = void_state;
194 bool all_uses_inside_loop = true;
195
196 /* Don't create LCSSA-Phis for loop-invariant variables */
197 if (state->skip_invariants &&
198 (def->bit_size != 1 || state->skip_bool_invariants)) {
199 assert(def->parent_instr->pass_flags != undefined);
200 if (def->parent_instr->pass_flags == invariant)
201 return true;
202 }
203
204 nir_block *block_after_loop =
205 nir_cf_node_as_block(nir_cf_node_next(&state->loop->cf_node));
206
207 nir_foreach_use(use, def) {
208 if (use->parent_instr->type == nir_instr_type_phi &&
209 use->parent_instr->block == block_after_loop) {
210 continue;
211 }
212
213 if (!is_use_inside_loop(use, state->loop)) {
214 all_uses_inside_loop = false;
215 }
216 }
217
218 nir_foreach_if_use(use, def) {
219 if (!is_if_use_inside_loop(use, state->loop)) {
220 all_uses_inside_loop = false;
221 }
222 }
223
224 /* There where no sources that had defs outside the loop */
225 if (all_uses_inside_loop)
226 return true;
227
228 /* Initialize a phi-instruction */
229 nir_phi_instr *phi = nir_phi_instr_create(state->shader);
230 nir_ssa_dest_init(&phi->instr, &phi->dest,
231 def->num_components, def->bit_size, "LCSSA-phi");
232
233 /* Create a phi node with as many sources pointing to the same ssa_def as
234 * the block has predecessors.
235 */
236 set_foreach(block_after_loop->predecessors, entry) {
237 nir_phi_src *phi_src = ralloc(phi, nir_phi_src);
238 phi_src->src = nir_src_for_ssa(def);
239 phi_src->pred = (nir_block *) entry->key;
240
241 exec_list_push_tail(&phi->srcs, &phi_src->node);
242 }
243
244 nir_instr_insert_before_block(block_after_loop, &phi->instr);
245 nir_ssa_def *dest = &phi->dest.ssa;
246
247 /* deref instructions need a cast after the phi */
248 if (def->parent_instr->type == nir_instr_type_deref) {
249 nir_deref_instr *cast =
250 nir_deref_instr_create(state->shader, nir_deref_type_cast);
251
252 nir_deref_instr *instr = nir_instr_as_deref(def->parent_instr);
253 cast->mode = instr->mode;
254 cast->type = instr->type;
255 cast->parent = nir_src_for_ssa(&phi->dest.ssa);
256 cast->cast.ptr_stride = nir_deref_instr_ptr_as_array_stride(instr);
257
258 nir_ssa_dest_init(&cast->instr, &cast->dest,
259 phi->dest.ssa.num_components,
260 phi->dest.ssa.bit_size, NULL);
261 nir_instr_insert(nir_after_phis(block_after_loop), &cast->instr);
262 dest = &cast->dest.ssa;
263 }
264
265 /* Run through all uses and rewrite those outside the loop to point to
266 * the phi instead of pointing to the ssa-def.
267 */
268 nir_foreach_use_safe(use, def) {
269 if (use->parent_instr->type == nir_instr_type_phi &&
270 block_after_loop == use->parent_instr->block) {
271 continue;
272 }
273
274 if (!is_use_inside_loop(use, state->loop)) {
275 nir_instr_rewrite_src(use->parent_instr, use, nir_src_for_ssa(dest));
276 }
277 }
278
279 nir_foreach_if_use_safe(use, def) {
280 if (!is_if_use_inside_loop(use, state->loop)) {
281 nir_if_rewrite_condition(use->parent_if, nir_src_for_ssa(dest));
282 }
283 }
284
285 state->progress = true;
286 return true;
287 }
288
289 static void
290 convert_to_lcssa(nir_cf_node *cf_node, lcssa_state *state)
291 {
292 switch (cf_node->type) {
293 case nir_cf_node_block:
294 return;
295 case nir_cf_node_if: {
296 nir_if *if_stmt = nir_cf_node_as_if(cf_node);
297 foreach_list_typed(nir_cf_node, nested_node, node, &if_stmt->then_list)
298 convert_to_lcssa(nested_node, state);
299 foreach_list_typed(nir_cf_node, nested_node, node, &if_stmt->else_list)
300 convert_to_lcssa(nested_node, state);
301 return;
302 }
303 case nir_cf_node_loop: {
304 if (state->skip_invariants) {
305 nir_foreach_block_in_cf_node(block, cf_node) {
306 nir_foreach_instr(instr, block)
307 instr->pass_flags = undefined;
308 }
309 }
310
311 /* first, convert inner loops */
312 nir_loop *loop = nir_cf_node_as_loop(cf_node);
313 foreach_list_typed(nir_cf_node, nested_node, node, &loop->body)
314 convert_to_lcssa(nested_node, state);
315
316 /* mark loop-invariant instructions */
317 if (state->skip_invariants) {
318 nir_foreach_block_in_cf_node(block, cf_node) {
319 nir_foreach_instr(instr, block) {
320 if (instr->pass_flags == undefined)
321 instr->pass_flags = instr_is_invariant(instr, nir_cf_node_as_loop(cf_node));
322 }
323 }
324 }
325
326 state->loop = loop;
327 nir_foreach_block_in_cf_node(block, cf_node) {
328 nir_foreach_instr(instr, block) {
329 nir_foreach_ssa_def(instr, convert_loop_exit_for_ssa, state);
330
331 /* for outer loops, invariant instructions can be variant */
332 if (state->skip_invariants && instr->pass_flags == invariant)
333 instr->pass_flags = undefined;
334 }
335 }
336
337 /* For outer loops, the LCSSA-phi should be considered not invariant */
338 if (state->skip_invariants) {
339 nir_block *block_after_loop =
340 nir_cf_node_as_block(nir_cf_node_next(&state->loop->cf_node));
341 nir_foreach_instr(instr, block_after_loop) {
342 if (instr->type == nir_instr_type_phi)
343 instr->pass_flags = not_invariant;
344 else
345 break;
346 }
347 }
348 return;
349 }
350 default:
351 unreachable("unknown cf node type");
352 }
353 }
354
355 void
356 nir_convert_loop_to_lcssa(nir_loop *loop)
357 {
358 nir_function_impl *impl = nir_cf_node_get_function(&loop->cf_node);
359
360 nir_metadata_require(impl, nir_metadata_block_index);
361
362 lcssa_state *state = rzalloc(NULL, lcssa_state);
363 state->loop = loop;
364 state->shader = impl->function->shader;
365 state->skip_invariants = false;
366 state->skip_bool_invariants = false;
367
368 nir_foreach_block_in_cf_node (block, &loop->cf_node) {
369 nir_foreach_instr(instr, block)
370 nir_foreach_ssa_def(instr, convert_loop_exit_for_ssa, state);
371 }
372
373 ralloc_free(state);
374 }
375
376 bool
377 nir_convert_to_lcssa(nir_shader *shader, bool skip_invariants, bool skip_bool_invariants)
378 {
379 bool progress = false;
380 lcssa_state *state = rzalloc(NULL, lcssa_state);
381 state->shader = shader;
382 state->skip_invariants = skip_invariants;
383 state->skip_bool_invariants = skip_bool_invariants;
384
385 nir_foreach_function(function, shader) {
386 if (function->impl == NULL)
387 continue;
388
389 state->progress = false;
390 nir_metadata_require(function->impl, nir_metadata_block_index);
391
392 foreach_list_typed(nir_cf_node, node, node, &function->impl->body)
393 convert_to_lcssa(node, state);
394
395 if (state->progress) {
396 progress = true;
397 nir_metadata_preserve(function->impl, nir_metadata_block_index |
398 nir_metadata_dominance);
399 } else {
400 nir_metadata_preserve(function->impl, nir_metadata_all);
401 }
402 }
403
404 ralloc_free(state);
405 return progress;
406 }
407