nir: Rename convert_to_ssa lower_regs_to_ssa
[mesa.git] / src / compiler / nir / nir_lower_regs_to_ssa.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Connor Abbott (cwabbott0@gmail.com)
25 *
26 */
27
28 #include "nir.h"
29 #include <stdlib.h>
30
31 /*
32 * Implements the classic to-SSA algorithm described by Cytron et. al. in
33 * "Efficiently Computing Static Single Assignment Form and the Control
34 * Dependence Graph."
35 */
36
37 /* inserts a phi node of the form reg = phi(reg, reg, reg, ...) */
38
39 static void
40 insert_trivial_phi(nir_register *reg, nir_block *block, void *mem_ctx)
41 {
42 nir_phi_instr *instr = nir_phi_instr_create(mem_ctx);
43
44 instr->dest.reg.reg = reg;
45 struct set_entry *entry;
46 set_foreach(block->predecessors, entry) {
47 nir_block *pred = (nir_block *) entry->key;
48
49 nir_phi_src *src = ralloc(instr, nir_phi_src);
50 src->pred = pred;
51 src->src.is_ssa = false;
52 src->src.reg.base_offset = 0;
53 src->src.reg.indirect = NULL;
54 src->src.reg.reg = reg;
55 exec_list_push_tail(&instr->srcs, &src->node);
56 }
57
58 nir_instr_insert_before_block(block, &instr->instr);
59 }
60
61 static void
62 insert_phi_nodes(nir_function_impl *impl)
63 {
64 void *mem_ctx = ralloc_parent(impl);
65
66 unsigned *work = calloc(impl->num_blocks, sizeof(unsigned));
67 unsigned *has_already = calloc(impl->num_blocks, sizeof(unsigned));
68
69 /*
70 * Since the work flags already prevent us from inserting a node that has
71 * ever been inserted into W, we don't need to use a set to represent W.
72 * Also, since no block can ever be inserted into W more than once, we know
73 * that the maximum size of W is the number of basic blocks in the
74 * function. So all we need to handle W is an array and a pointer to the
75 * next element to be inserted and the next element to be removed.
76 */
77 nir_block **W = malloc(impl->num_blocks * sizeof(nir_block *));
78 unsigned w_start, w_end;
79
80 unsigned iter_count = 0;
81
82 nir_index_blocks(impl);
83
84 foreach_list_typed(nir_register, reg, node, &impl->registers) {
85 if (reg->num_array_elems != 0)
86 continue;
87
88 w_start = w_end = 0;
89 iter_count++;
90
91 nir_foreach_def(dest, reg) {
92 nir_instr *def = dest->reg.parent_instr;
93 if (work[def->block->index] < iter_count)
94 W[w_end++] = def->block;
95 work[def->block->index] = iter_count;
96 }
97
98 while (w_start != w_end) {
99 nir_block *cur = W[w_start++];
100 struct set_entry *entry;
101 set_foreach(cur->dom_frontier, entry) {
102 nir_block *next = (nir_block *) entry->key;
103
104 /*
105 * If there's more than one return statement, then the end block
106 * can be a join point for some definitions. However, there are
107 * no instructions in the end block, so nothing would use those
108 * phi nodes. Of course, we couldn't place those phi nodes
109 * anyways due to the restriction of having no instructions in the
110 * end block...
111 */
112 if (next == impl->end_block)
113 continue;
114
115 if (has_already[next->index] < iter_count) {
116 insert_trivial_phi(reg, next, mem_ctx);
117 has_already[next->index] = iter_count;
118 if (work[next->index] < iter_count) {
119 work[next->index] = iter_count;
120 W[w_end++] = next;
121 }
122 }
123 }
124 }
125 }
126
127 free(work);
128 free(has_already);
129 free(W);
130 }
131
132 typedef struct {
133 nir_ssa_def **stack;
134 int index;
135 unsigned num_defs; /** < used to add indices to debug names */
136 #ifndef NDEBUG
137 unsigned stack_size;
138 #endif
139 } reg_state;
140
141 typedef struct {
142 reg_state *states;
143 void *mem_ctx;
144 nir_instr *parent_instr;
145 nir_if *parent_if;
146 nir_function_impl *impl;
147
148 /* map from SSA value -> original register */
149 struct hash_table *ssa_map;
150 } rewrite_state;
151
152 static nir_ssa_def *get_ssa_src(nir_register *reg, rewrite_state *state)
153 {
154 unsigned index = reg->index;
155
156 if (state->states[index].index == -1) {
157 /*
158 * We're using an undefined register, create a new undefined SSA value
159 * to preserve the information that this source is undefined
160 */
161 nir_ssa_undef_instr *instr =
162 nir_ssa_undef_instr_create(state->mem_ctx, reg->num_components,
163 reg->bit_size);
164
165 /*
166 * We could just insert the undefined instruction before the instruction
167 * we're rewriting, but we could be rewriting a phi source in which case
168 * we can't do that, so do the next easiest thing - insert it at the
169 * beginning of the program. In the end, it doesn't really matter where
170 * the undefined instructions are because they're going to be ignored
171 * in the backend.
172 */
173 nir_instr_insert_before_cf_list(&state->impl->body, &instr->instr);
174 return &instr->def;
175 }
176
177 return state->states[index].stack[state->states[index].index];
178 }
179
180 static bool
181 rewrite_use(nir_src *src, void *_state)
182 {
183 rewrite_state *state = (rewrite_state *) _state;
184
185 if (src->is_ssa)
186 return true;
187
188 unsigned index = src->reg.reg->index;
189
190 if (state->states[index].stack == NULL)
191 return true;
192
193 nir_ssa_def *def = get_ssa_src(src->reg.reg, state);
194 if (state->parent_instr)
195 nir_instr_rewrite_src(state->parent_instr, src, nir_src_for_ssa(def));
196 else
197 nir_if_rewrite_condition(state->parent_if, nir_src_for_ssa(def));
198
199 return true;
200 }
201
202 static bool
203 rewrite_def_forwards(nir_dest *dest, void *_state)
204 {
205 rewrite_state *state = (rewrite_state *) _state;
206
207 if (dest->is_ssa)
208 return true;
209
210 nir_register *reg = dest->reg.reg;
211 unsigned index = reg->index;
212
213 if (state->states[index].stack == NULL)
214 return true;
215
216 char *name = NULL;
217 if (dest->reg.reg->name)
218 name = ralloc_asprintf(state->mem_ctx, "%s_%u", dest->reg.reg->name,
219 state->states[index].num_defs);
220
221 list_del(&dest->reg.def_link);
222 nir_ssa_dest_init(state->parent_instr, dest, reg->num_components,
223 reg->bit_size, name);
224 ralloc_free(name);
225
226 /* push our SSA destination on the stack */
227 state->states[index].index++;
228 assert(state->states[index].index < state->states[index].stack_size);
229 state->states[index].stack[state->states[index].index] = &dest->ssa;
230 state->states[index].num_defs++;
231
232 _mesa_hash_table_insert(state->ssa_map, &dest->ssa, reg);
233
234 return true;
235 }
236
237 static void
238 rewrite_alu_instr_forward(nir_alu_instr *instr, rewrite_state *state)
239 {
240 state->parent_instr = &instr->instr;
241
242 nir_foreach_src(&instr->instr, rewrite_use, state);
243
244 if (instr->dest.dest.is_ssa)
245 return;
246
247 nir_register *reg = instr->dest.dest.reg.reg;
248 unsigned index = reg->index;
249
250 if (state->states[index].stack == NULL)
251 return;
252
253 unsigned write_mask = instr->dest.write_mask;
254 if (write_mask != (1 << instr->dest.dest.reg.reg->num_components) - 1) {
255 /*
256 * Calculate the number of components the final instruction, which for
257 * per-component things is the number of output components of the
258 * instruction and non-per-component things is the number of enabled
259 * channels in the write mask.
260 */
261 unsigned num_components;
262 if (nir_op_infos[instr->op].output_size == 0) {
263 unsigned temp = (write_mask & 0x5) + ((write_mask >> 1) & 0x5);
264 num_components = (temp & 0x3) + ((temp >> 2) & 0x3);
265 } else {
266 num_components = nir_op_infos[instr->op].output_size;
267 }
268
269 char *name = NULL;
270 if (instr->dest.dest.reg.reg->name)
271 name = ralloc_asprintf(state->mem_ctx, "%s_%u",
272 reg->name, state->states[index].num_defs);
273
274 instr->dest.write_mask = (1 << num_components) - 1;
275 list_del(&instr->dest.dest.reg.def_link);
276 nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components,
277 reg->bit_size, name);
278 ralloc_free(name);
279
280 if (nir_op_infos[instr->op].output_size == 0) {
281 /*
282 * When we change the output writemask, we need to change the
283 * swizzles for per-component inputs too
284 */
285 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
286 if (nir_op_infos[instr->op].input_sizes[i] != 0)
287 continue;
288
289 unsigned new_swizzle[4] = {0, 0, 0, 0};
290
291 /*
292 * We keep two indices:
293 * 1. The index of the original (non-SSA) component
294 * 2. The index of the post-SSA, compacted, component
295 *
296 * We need to map the swizzle component at index 1 to the swizzle
297 * component at index 2.
298 */
299
300 unsigned ssa_index = 0;
301 for (unsigned index = 0; index < 4; index++) {
302 if (!((write_mask >> index) & 1))
303 continue;
304
305 new_swizzle[ssa_index] = instr->src[i].swizzle[index];
306 ssa_index++;
307 }
308
309 for (unsigned j = 0; j < 4; j++)
310 instr->src[i].swizzle[j] = new_swizzle[j];
311 }
312 }
313
314 nir_op op;
315 switch (reg->num_components) {
316 case 2: op = nir_op_vec2; break;
317 case 3: op = nir_op_vec3; break;
318 case 4: op = nir_op_vec4; break;
319 default: unreachable("not reached");
320 }
321
322 nir_alu_instr *vec = nir_alu_instr_create(state->mem_ctx, op);
323
324 vec->dest.dest.reg.reg = reg;
325 vec->dest.write_mask = (1 << reg->num_components) - 1;
326
327 nir_ssa_def *old_src = get_ssa_src(reg, state);
328 nir_ssa_def *new_src = &instr->dest.dest.ssa;
329
330 unsigned ssa_index = 0;
331 for (unsigned i = 0; i < reg->num_components; i++) {
332 vec->src[i].src.is_ssa = true;
333 if ((write_mask >> i) & 1) {
334 vec->src[i].src.ssa = new_src;
335 if (nir_op_infos[instr->op].output_size == 0)
336 vec->src[i].swizzle[0] = ssa_index;
337 else
338 vec->src[i].swizzle[0] = i;
339 ssa_index++;
340 } else {
341 vec->src[i].src.ssa = old_src;
342 vec->src[i].swizzle[0] = i;
343 }
344 }
345
346 nir_instr_insert_after(&instr->instr, &vec->instr);
347
348 state->parent_instr = &vec->instr;
349 rewrite_def_forwards(&vec->dest.dest, state);
350 } else {
351 rewrite_def_forwards(&instr->dest.dest, state);
352 }
353 }
354
355 static void
356 rewrite_phi_instr(nir_phi_instr *instr, rewrite_state *state)
357 {
358 state->parent_instr = &instr->instr;
359 rewrite_def_forwards(&instr->dest, state);
360 }
361
362 static void
363 rewrite_instr_forward(nir_instr *instr, rewrite_state *state)
364 {
365 if (instr->type == nir_instr_type_alu) {
366 rewrite_alu_instr_forward(nir_instr_as_alu(instr), state);
367 return;
368 }
369
370 if (instr->type == nir_instr_type_phi) {
371 rewrite_phi_instr(nir_instr_as_phi(instr), state);
372 return;
373 }
374
375 state->parent_instr = instr;
376
377 nir_foreach_src(instr, rewrite_use, state);
378 nir_foreach_dest(instr, rewrite_def_forwards, state);
379 }
380
381 static void
382 rewrite_phi_sources(nir_block *block, nir_block *pred, rewrite_state *state)
383 {
384 nir_foreach_instr(instr, block) {
385 if (instr->type != nir_instr_type_phi)
386 break;
387
388 nir_phi_instr *phi_instr = nir_instr_as_phi(instr);
389
390 state->parent_instr = instr;
391
392 nir_foreach_phi_src(src, phi_instr) {
393 if (src->pred == pred) {
394 rewrite_use(&src->src, state);
395 break;
396 }
397 }
398 }
399 }
400
401 static bool
402 rewrite_def_backwards(nir_dest *dest, void *_state)
403 {
404 rewrite_state *state = (rewrite_state *) _state;
405
406 if (!dest->is_ssa)
407 return true;
408
409 struct hash_entry *entry =
410 _mesa_hash_table_search(state->ssa_map, &dest->ssa);
411
412 if (!entry)
413 return true;
414
415 nir_register *reg = (nir_register *) entry->data;
416 unsigned index = reg->index;
417
418 state->states[index].index--;
419 assert(state->states[index].index >= -1);
420
421 return true;
422 }
423
424 static void
425 rewrite_instr_backwards(nir_instr *instr, rewrite_state *state)
426 {
427 nir_foreach_dest(instr, rewrite_def_backwards, state);
428 }
429
430 static void
431 rewrite_block(nir_block *block, rewrite_state *state)
432 {
433 /* This will skip over any instructions after the current one, which is
434 * what we want because those instructions (vector gather, conditional
435 * select) will already be in SSA form.
436 */
437 nir_foreach_instr_safe(instr, block) {
438 rewrite_instr_forward(instr, state);
439 }
440
441 if (block != state->impl->end_block &&
442 !nir_cf_node_is_last(&block->cf_node) &&
443 nir_cf_node_next(&block->cf_node)->type == nir_cf_node_if) {
444 nir_if *if_stmt = nir_cf_node_as_if(nir_cf_node_next(&block->cf_node));
445 state->parent_instr = NULL;
446 state->parent_if = if_stmt;
447 rewrite_use(&if_stmt->condition, state);
448 }
449
450 if (block->successors[0])
451 rewrite_phi_sources(block->successors[0], block, state);
452 if (block->successors[1])
453 rewrite_phi_sources(block->successors[1], block, state);
454
455 for (unsigned i = 0; i < block->num_dom_children; i++)
456 rewrite_block(block->dom_children[i], state);
457
458 nir_foreach_instr_reverse(instr, block) {
459 rewrite_instr_backwards(instr, state);
460 }
461 }
462
463 static void
464 remove_unused_regs(nir_function_impl *impl, rewrite_state *state)
465 {
466 foreach_list_typed_safe(nir_register, reg, node, &impl->registers) {
467 if (state->states[reg->index].stack != NULL)
468 exec_node_remove(&reg->node);
469 }
470 }
471
472 static void
473 init_rewrite_state(nir_function_impl *impl, rewrite_state *state)
474 {
475 state->impl = impl;
476 state->mem_ctx = ralloc_parent(impl);
477 state->ssa_map = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
478 _mesa_key_pointer_equal);
479 state->states = rzalloc_array(NULL, reg_state, impl->reg_alloc);
480
481 foreach_list_typed(nir_register, reg, node, &impl->registers) {
482 assert(reg->index < impl->reg_alloc);
483 if (reg->num_array_elems > 0) {
484 state->states[reg->index].stack = NULL;
485 } else {
486 /*
487 * Calculate a conservative estimate of the stack size based on the
488 * number of definitions there are. Note that this function *must* be
489 * called after phi nodes are inserted so we can count phi node
490 * definitions too.
491 */
492 unsigned stack_size = list_length(&reg->defs);
493
494 state->states[reg->index].stack = ralloc_array(state->states,
495 nir_ssa_def *,
496 stack_size);
497 #ifndef NDEBUG
498 state->states[reg->index].stack_size = stack_size;
499 #endif
500 state->states[reg->index].index = -1;
501 state->states[reg->index].num_defs = 0;
502 }
503 }
504 }
505
506 static void
507 destroy_rewrite_state(rewrite_state *state)
508 {
509 _mesa_hash_table_destroy(state->ssa_map, NULL);
510 ralloc_free(state->states);
511 }
512
513 void
514 nir_lower_regs_to_ssa_impl(nir_function_impl *impl)
515 {
516 nir_metadata_require(impl, nir_metadata_dominance);
517
518 insert_phi_nodes(impl);
519
520 rewrite_state state;
521 init_rewrite_state(impl, &state);
522
523 rewrite_block(nir_start_block(impl), &state);
524
525 remove_unused_regs(impl, &state);
526
527 nir_metadata_preserve(impl, nir_metadata_block_index |
528 nir_metadata_dominance);
529
530 destroy_rewrite_state(&state);
531 }
532
533 void
534 nir_lower_regs_to_ssa(nir_shader *shader)
535 {
536 nir_foreach_function(function, shader) {
537 if (function->impl)
538 nir_lower_regs_to_ssa_impl(function->impl);
539 }
540 }