nir: add bit_size info to nir_ssa_undef_instr_create()
[mesa.git] / src / compiler / nir / nir_to_ssa.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Connor Abbott (cwabbott0@gmail.com)
25 *
26 */
27
28 #include "nir.h"
29 #include <stdlib.h>
30 #include <unistd.h>
31
32 /*
33 * Implements the classic to-SSA algorithm described by Cytron et. al. in
34 * "Efficiently Computing Static Single Assignment Form and the Control
35 * Dependence Graph."
36 */
37
38 /* inserts a phi node of the form reg = phi(reg, reg, reg, ...) */
39
40 static void
41 insert_trivial_phi(nir_register *reg, nir_block *block, void *mem_ctx)
42 {
43 nir_phi_instr *instr = nir_phi_instr_create(mem_ctx);
44
45 instr->dest.reg.reg = reg;
46 struct set_entry *entry;
47 set_foreach(block->predecessors, entry) {
48 nir_block *pred = (nir_block *) entry->key;
49
50 nir_phi_src *src = ralloc(instr, nir_phi_src);
51 src->pred = pred;
52 src->src.is_ssa = false;
53 src->src.reg.base_offset = 0;
54 src->src.reg.indirect = NULL;
55 src->src.reg.reg = reg;
56 exec_list_push_tail(&instr->srcs, &src->node);
57 }
58
59 nir_instr_insert_before_block(block, &instr->instr);
60 }
61
62 static void
63 insert_phi_nodes(nir_function_impl *impl)
64 {
65 void *mem_ctx = ralloc_parent(impl);
66
67 unsigned *work = calloc(impl->num_blocks, sizeof(unsigned));
68 unsigned *has_already = calloc(impl->num_blocks, sizeof(unsigned));
69
70 /*
71 * Since the work flags already prevent us from inserting a node that has
72 * ever been inserted into W, we don't need to use a set to represent W.
73 * Also, since no block can ever be inserted into W more than once, we know
74 * that the maximum size of W is the number of basic blocks in the
75 * function. So all we need to handle W is an array and a pointer to the
76 * next element to be inserted and the next element to be removed.
77 */
78 nir_block **W = malloc(impl->num_blocks * sizeof(nir_block *));
79 unsigned w_start, w_end;
80
81 unsigned iter_count = 0;
82
83 nir_index_blocks(impl);
84
85 foreach_list_typed(nir_register, reg, node, &impl->registers) {
86 if (reg->num_array_elems != 0)
87 continue;
88
89 w_start = w_end = 0;
90 iter_count++;
91
92 nir_foreach_def(reg, dest) {
93 nir_instr *def = dest->reg.parent_instr;
94 if (work[def->block->index] < iter_count)
95 W[w_end++] = def->block;
96 work[def->block->index] = iter_count;
97 }
98
99 while (w_start != w_end) {
100 nir_block *cur = W[w_start++];
101 struct set_entry *entry;
102 set_foreach(cur->dom_frontier, entry) {
103 nir_block *next = (nir_block *) entry->key;
104
105 /*
106 * If there's more than one return statement, then the end block
107 * can be a join point for some definitions. However, there are
108 * no instructions in the end block, so nothing would use those
109 * phi nodes. Of course, we couldn't place those phi nodes
110 * anyways due to the restriction of having no instructions in the
111 * end block...
112 */
113 if (next == impl->end_block)
114 continue;
115
116 if (has_already[next->index] < iter_count) {
117 insert_trivial_phi(reg, next, mem_ctx);
118 has_already[next->index] = iter_count;
119 if (work[next->index] < iter_count) {
120 work[next->index] = iter_count;
121 W[w_end++] = next;
122 }
123 }
124 }
125 }
126 }
127
128 free(work);
129 free(has_already);
130 free(W);
131 }
132
133 typedef struct {
134 nir_ssa_def **stack;
135 int index;
136 unsigned num_defs; /** < used to add indices to debug names */
137 #ifndef NDEBUG
138 unsigned stack_size;
139 #endif
140 } reg_state;
141
142 typedef struct {
143 reg_state *states;
144 void *mem_ctx;
145 nir_instr *parent_instr;
146 nir_if *parent_if;
147 nir_function_impl *impl;
148
149 /* map from SSA value -> original register */
150 struct hash_table *ssa_map;
151 } rewrite_state;
152
153 static nir_ssa_def *get_ssa_src(nir_register *reg, rewrite_state *state)
154 {
155 unsigned index = reg->index;
156
157 if (state->states[index].index == -1) {
158 /*
159 * We're using an undefined register, create a new undefined SSA value
160 * to preserve the information that this source is undefined
161 */
162 nir_ssa_undef_instr *instr =
163 nir_ssa_undef_instr_create(state->mem_ctx, reg->num_components,
164 reg->bit_size);
165
166 /*
167 * We could just insert the undefined instruction before the instruction
168 * we're rewriting, but we could be rewriting a phi source in which case
169 * we can't do that, so do the next easiest thing - insert it at the
170 * beginning of the program. In the end, it doesn't really matter where
171 * the undefined instructions are because they're going to be ignored
172 * in the backend.
173 */
174 nir_instr_insert_before_cf_list(&state->impl->body, &instr->instr);
175 return &instr->def;
176 }
177
178 return state->states[index].stack[state->states[index].index];
179 }
180
181 static bool
182 rewrite_use(nir_src *src, void *_state)
183 {
184 rewrite_state *state = (rewrite_state *) _state;
185
186 if (src->is_ssa)
187 return true;
188
189 unsigned index = src->reg.reg->index;
190
191 if (state->states[index].stack == NULL)
192 return true;
193
194 nir_ssa_def *def = get_ssa_src(src->reg.reg, state);
195 if (state->parent_instr)
196 nir_instr_rewrite_src(state->parent_instr, src, nir_src_for_ssa(def));
197 else
198 nir_if_rewrite_condition(state->parent_if, nir_src_for_ssa(def));
199
200 return true;
201 }
202
203 static bool
204 rewrite_def_forwards(nir_dest *dest, void *_state)
205 {
206 rewrite_state *state = (rewrite_state *) _state;
207
208 if (dest->is_ssa)
209 return true;
210
211 nir_register *reg = dest->reg.reg;
212 unsigned index = reg->index;
213
214 if (state->states[index].stack == NULL)
215 return true;
216
217 char *name = NULL;
218 if (dest->reg.reg->name)
219 name = ralloc_asprintf(state->mem_ctx, "%s_%u", dest->reg.reg->name,
220 state->states[index].num_defs);
221
222 list_del(&dest->reg.def_link);
223 nir_ssa_dest_init(state->parent_instr, dest, reg->num_components,
224 reg->bit_size, name);
225 ralloc_free(name);
226
227 /* push our SSA destination on the stack */
228 state->states[index].index++;
229 assert(state->states[index].index < state->states[index].stack_size);
230 state->states[index].stack[state->states[index].index] = &dest->ssa;
231 state->states[index].num_defs++;
232
233 _mesa_hash_table_insert(state->ssa_map, &dest->ssa, reg);
234
235 return true;
236 }
237
238 static void
239 rewrite_alu_instr_forward(nir_alu_instr *instr, rewrite_state *state)
240 {
241 state->parent_instr = &instr->instr;
242
243 nir_foreach_src(&instr->instr, rewrite_use, state);
244
245 if (instr->dest.dest.is_ssa)
246 return;
247
248 nir_register *reg = instr->dest.dest.reg.reg;
249 unsigned index = reg->index;
250
251 if (state->states[index].stack == NULL)
252 return;
253
254 unsigned write_mask = instr->dest.write_mask;
255 if (write_mask != (1 << instr->dest.dest.reg.reg->num_components) - 1) {
256 /*
257 * Calculate the number of components the final instruction, which for
258 * per-component things is the number of output components of the
259 * instruction and non-per-component things is the number of enabled
260 * channels in the write mask.
261 */
262 unsigned num_components;
263 if (nir_op_infos[instr->op].output_size == 0) {
264 unsigned temp = (write_mask & 0x5) + ((write_mask >> 1) & 0x5);
265 num_components = (temp & 0x3) + ((temp >> 2) & 0x3);
266 } else {
267 num_components = nir_op_infos[instr->op].output_size;
268 }
269
270 char *name = NULL;
271 if (instr->dest.dest.reg.reg->name)
272 name = ralloc_asprintf(state->mem_ctx, "%s_%u",
273 reg->name, state->states[index].num_defs);
274
275 instr->dest.write_mask = (1 << num_components) - 1;
276 list_del(&instr->dest.dest.reg.def_link);
277 nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components,
278 reg->bit_size, name);
279 ralloc_free(name);
280
281 if (nir_op_infos[instr->op].output_size == 0) {
282 /*
283 * When we change the output writemask, we need to change the
284 * swizzles for per-component inputs too
285 */
286 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
287 if (nir_op_infos[instr->op].input_sizes[i] != 0)
288 continue;
289
290 unsigned new_swizzle[4] = {0, 0, 0, 0};
291
292 /*
293 * We keep two indices:
294 * 1. The index of the original (non-SSA) component
295 * 2. The index of the post-SSA, compacted, component
296 *
297 * We need to map the swizzle component at index 1 to the swizzle
298 * component at index 2.
299 */
300
301 unsigned ssa_index = 0;
302 for (unsigned index = 0; index < 4; index++) {
303 if (!((write_mask >> index) & 1))
304 continue;
305
306 new_swizzle[ssa_index] = instr->src[i].swizzle[index];
307 ssa_index++;
308 }
309
310 for (unsigned j = 0; j < 4; j++)
311 instr->src[i].swizzle[j] = new_swizzle[j];
312 }
313 }
314
315 nir_op op;
316 switch (reg->num_components) {
317 case 2: op = nir_op_vec2; break;
318 case 3: op = nir_op_vec3; break;
319 case 4: op = nir_op_vec4; break;
320 default: unreachable("not reached");
321 }
322
323 nir_alu_instr *vec = nir_alu_instr_create(state->mem_ctx, op);
324
325 vec->dest.dest.reg.reg = reg;
326 vec->dest.write_mask = (1 << reg->num_components) - 1;
327
328 nir_ssa_def *old_src = get_ssa_src(reg, state);
329 nir_ssa_def *new_src = &instr->dest.dest.ssa;
330
331 unsigned ssa_index = 0;
332 for (unsigned i = 0; i < reg->num_components; i++) {
333 vec->src[i].src.is_ssa = true;
334 if ((write_mask >> i) & 1) {
335 vec->src[i].src.ssa = new_src;
336 if (nir_op_infos[instr->op].output_size == 0)
337 vec->src[i].swizzle[0] = ssa_index;
338 else
339 vec->src[i].swizzle[0] = i;
340 ssa_index++;
341 } else {
342 vec->src[i].src.ssa = old_src;
343 vec->src[i].swizzle[0] = i;
344 }
345 }
346
347 nir_instr_insert_after(&instr->instr, &vec->instr);
348
349 state->parent_instr = &vec->instr;
350 rewrite_def_forwards(&vec->dest.dest, state);
351 } else {
352 rewrite_def_forwards(&instr->dest.dest, state);
353 }
354 }
355
356 static void
357 rewrite_phi_instr(nir_phi_instr *instr, rewrite_state *state)
358 {
359 state->parent_instr = &instr->instr;
360 rewrite_def_forwards(&instr->dest, state);
361 }
362
363 static void
364 rewrite_instr_forward(nir_instr *instr, rewrite_state *state)
365 {
366 if (instr->type == nir_instr_type_alu) {
367 rewrite_alu_instr_forward(nir_instr_as_alu(instr), state);
368 return;
369 }
370
371 if (instr->type == nir_instr_type_phi) {
372 rewrite_phi_instr(nir_instr_as_phi(instr), state);
373 return;
374 }
375
376 state->parent_instr = instr;
377
378 nir_foreach_src(instr, rewrite_use, state);
379 nir_foreach_dest(instr, rewrite_def_forwards, state);
380 }
381
382 static void
383 rewrite_phi_sources(nir_block *block, nir_block *pred, rewrite_state *state)
384 {
385 nir_foreach_instr(block, instr) {
386 if (instr->type != nir_instr_type_phi)
387 break;
388
389 nir_phi_instr *phi_instr = nir_instr_as_phi(instr);
390
391 state->parent_instr = instr;
392
393 nir_foreach_phi_src(phi_instr, src) {
394 if (src->pred == pred) {
395 rewrite_use(&src->src, state);
396 break;
397 }
398 }
399 }
400 }
401
402 static bool
403 rewrite_def_backwards(nir_dest *dest, void *_state)
404 {
405 rewrite_state *state = (rewrite_state *) _state;
406
407 if (!dest->is_ssa)
408 return true;
409
410 struct hash_entry *entry =
411 _mesa_hash_table_search(state->ssa_map, &dest->ssa);
412
413 if (!entry)
414 return true;
415
416 nir_register *reg = (nir_register *) entry->data;
417 unsigned index = reg->index;
418
419 state->states[index].index--;
420 assert(state->states[index].index >= -1);
421
422 return true;
423 }
424
425 static void
426 rewrite_instr_backwards(nir_instr *instr, rewrite_state *state)
427 {
428 nir_foreach_dest(instr, rewrite_def_backwards, state);
429 }
430
431 static void
432 rewrite_block(nir_block *block, rewrite_state *state)
433 {
434 /* This will skip over any instructions after the current one, which is
435 * what we want because those instructions (vector gather, conditional
436 * select) will already be in SSA form.
437 */
438 nir_foreach_instr_safe(block, instr) {
439 rewrite_instr_forward(instr, state);
440 }
441
442 if (block != state->impl->end_block &&
443 !nir_cf_node_is_last(&block->cf_node) &&
444 nir_cf_node_next(&block->cf_node)->type == nir_cf_node_if) {
445 nir_if *if_stmt = nir_cf_node_as_if(nir_cf_node_next(&block->cf_node));
446 state->parent_instr = NULL;
447 state->parent_if = if_stmt;
448 rewrite_use(&if_stmt->condition, state);
449 }
450
451 if (block->successors[0])
452 rewrite_phi_sources(block->successors[0], block, state);
453 if (block->successors[1])
454 rewrite_phi_sources(block->successors[1], block, state);
455
456 for (unsigned i = 0; i < block->num_dom_children; i++)
457 rewrite_block(block->dom_children[i], state);
458
459 nir_foreach_instr_reverse(block, instr) {
460 rewrite_instr_backwards(instr, state);
461 }
462 }
463
464 static void
465 remove_unused_regs(nir_function_impl *impl, rewrite_state *state)
466 {
467 foreach_list_typed_safe(nir_register, reg, node, &impl->registers) {
468 if (state->states[reg->index].stack != NULL)
469 exec_node_remove(&reg->node);
470 }
471 }
472
473 static void
474 init_rewrite_state(nir_function_impl *impl, rewrite_state *state)
475 {
476 state->impl = impl;
477 state->mem_ctx = ralloc_parent(impl);
478 state->ssa_map = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
479 _mesa_key_pointer_equal);
480 state->states = ralloc_array(NULL, reg_state, impl->reg_alloc);
481
482 foreach_list_typed(nir_register, reg, node, &impl->registers) {
483 assert(reg->index < impl->reg_alloc);
484 if (reg->num_array_elems > 0) {
485 state->states[reg->index].stack = NULL;
486 } else {
487 /*
488 * Calculate a conservative estimate of the stack size based on the
489 * number of definitions there are. Note that this function *must* be
490 * called after phi nodes are inserted so we can count phi node
491 * definitions too.
492 */
493 unsigned stack_size = list_length(&reg->defs);
494
495 state->states[reg->index].stack = ralloc_array(state->states,
496 nir_ssa_def *,
497 stack_size);
498 #ifndef NDEBUG
499 state->states[reg->index].stack_size = stack_size;
500 #endif
501 state->states[reg->index].index = -1;
502 state->states[reg->index].num_defs = 0;
503 }
504 }
505 }
506
507 static void
508 destroy_rewrite_state(rewrite_state *state)
509 {
510 _mesa_hash_table_destroy(state->ssa_map, NULL);
511 ralloc_free(state->states);
512 }
513
514 void
515 nir_convert_to_ssa_impl(nir_function_impl *impl)
516 {
517 nir_metadata_require(impl, nir_metadata_dominance);
518
519 insert_phi_nodes(impl);
520
521 rewrite_state state;
522 init_rewrite_state(impl, &state);
523
524 rewrite_block(nir_start_block(impl), &state);
525
526 remove_unused_regs(impl, &state);
527
528 nir_metadata_preserve(impl, nir_metadata_block_index |
529 nir_metadata_dominance);
530
531 destroy_rewrite_state(&state);
532 }
533
534 void
535 nir_convert_to_ssa(nir_shader *shader)
536 {
537 nir_foreach_function(shader, function) {
538 if (function->impl)
539 nir_convert_to_ssa_impl(function->impl);
540 }
541 }