nir: Move propagation of cast derefs to a new nir_opt_deref pass
[mesa.git] / src / compiler / nir / nir_deref.c
1 /*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27 #include "util/hash_table.h"
28
29 void
30 nir_deref_path_init(nir_deref_path *path,
31 nir_deref_instr *deref, void *mem_ctx)
32 {
33 assert(deref != NULL);
34
35 /* The length of the short path is at most ARRAY_SIZE - 1 because we need
36 * room for the NULL terminator.
37 */
38 static const int max_short_path_len = ARRAY_SIZE(path->_short_path) - 1;
39
40 int count = 0;
41
42 nir_deref_instr **tail = &path->_short_path[max_short_path_len];
43 nir_deref_instr **head = tail;
44
45 *tail = NULL;
46 for (nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d)) {
47 count++;
48 if (count <= max_short_path_len)
49 *(--head) = d;
50 }
51
52 if (count <= max_short_path_len) {
53 /* If we're under max_short_path_len, just use the short path. */
54 path->path = head;
55 goto done;
56 }
57
58 #ifndef NDEBUG
59 /* Just in case someone uses short_path by accident */
60 for (unsigned i = 0; i < ARRAY_SIZE(path->_short_path); i++)
61 path->_short_path[i] = (void *)0xdeadbeef;
62 #endif
63
64 path->path = ralloc_array(mem_ctx, nir_deref_instr *, count + 1);
65 head = tail = path->path + count;
66 *tail = NULL;
67 for (nir_deref_instr *d = deref; d; d = nir_deref_instr_parent(d))
68 *(--head) = d;
69
70 done:
71 assert(head == path->path);
72 assert(tail == head + count);
73 assert((*head)->deref_type == nir_deref_type_var);
74 assert(*tail == NULL);
75 }
76
77 void
78 nir_deref_path_finish(nir_deref_path *path)
79 {
80 if (path->path < &path->_short_path[0] ||
81 path->path > &path->_short_path[ARRAY_SIZE(path->_short_path) - 1])
82 ralloc_free(path->path);
83 }
84
85 /**
86 * Recursively removes unused deref instructions
87 */
88 bool
89 nir_deref_instr_remove_if_unused(nir_deref_instr *instr)
90 {
91 bool progress = false;
92
93 for (nir_deref_instr *d = instr; d; d = nir_deref_instr_parent(d)) {
94 /* If anyone is using this deref, leave it alone */
95 assert(d->dest.is_ssa);
96 if (!list_empty(&d->dest.ssa.uses))
97 break;
98
99 nir_instr_remove(&d->instr);
100 progress = true;
101 }
102
103 return progress;
104 }
105
106 bool
107 nir_deref_instr_has_indirect(nir_deref_instr *instr)
108 {
109 while (instr->deref_type != nir_deref_type_var) {
110 /* Consider casts to be indirects */
111 if (instr->deref_type == nir_deref_type_cast)
112 return true;
113
114 if (instr->deref_type == nir_deref_type_array &&
115 !nir_src_is_const(instr->arr.index))
116 return true;
117
118 instr = nir_deref_instr_parent(instr);
119 }
120
121 return false;
122 }
123
124 static unsigned
125 type_get_array_stride(const struct glsl_type *elem_type,
126 glsl_type_size_align_func size_align)
127 {
128 unsigned elem_size, elem_align;
129 glsl_get_natural_size_align_bytes(elem_type, &elem_size, &elem_align);
130 return ALIGN_POT(elem_size, elem_align);
131 }
132
133 static unsigned
134 struct_type_get_field_offset(const struct glsl_type *struct_type,
135 glsl_type_size_align_func size_align,
136 unsigned field_idx)
137 {
138 assert(glsl_type_is_struct(struct_type));
139 unsigned offset = 0;
140 for (unsigned i = 0; i <= field_idx; i++) {
141 unsigned elem_size, elem_align;
142 glsl_get_natural_size_align_bytes(glsl_get_struct_field(struct_type, i),
143 &elem_size, &elem_align);
144 offset = ALIGN_POT(offset, elem_align);
145 if (i < field_idx)
146 offset += elem_size;
147 }
148 return offset;
149 }
150
151 unsigned
152 nir_deref_instr_get_const_offset(nir_deref_instr *deref,
153 glsl_type_size_align_func size_align)
154 {
155 nir_deref_path path;
156 nir_deref_path_init(&path, deref, NULL);
157
158 assert(path.path[0]->deref_type == nir_deref_type_var);
159
160 unsigned offset = 0;
161 for (nir_deref_instr **p = &path.path[1]; *p; p++) {
162 if ((*p)->deref_type == nir_deref_type_array) {
163 offset += nir_src_as_uint((*p)->arr.index) *
164 type_get_array_stride((*p)->type, size_align);
165 } else if ((*p)->deref_type == nir_deref_type_struct) {
166 /* p starts at path[1], so this is safe */
167 nir_deref_instr *parent = *(p - 1);
168 offset += struct_type_get_field_offset(parent->type, size_align,
169 (*p)->strct.index);
170 } else {
171 unreachable("Unsupported deref type");
172 }
173 }
174
175 nir_deref_path_finish(&path);
176
177 return offset;
178 }
179
180 nir_ssa_def *
181 nir_build_deref_offset(nir_builder *b, nir_deref_instr *deref,
182 glsl_type_size_align_func size_align)
183 {
184 nir_deref_path path;
185 nir_deref_path_init(&path, deref, NULL);
186
187 assert(path.path[0]->deref_type == nir_deref_type_var);
188
189 nir_ssa_def *offset = nir_imm_int(b, 0);
190 for (nir_deref_instr **p = &path.path[1]; *p; p++) {
191 if ((*p)->deref_type == nir_deref_type_array) {
192 nir_ssa_def *index = nir_ssa_for_src(b, (*p)->arr.index, 1);
193 nir_ssa_def *stride =
194 nir_imm_int(b, type_get_array_stride((*p)->type, size_align));
195 offset = nir_iadd(b, offset, nir_imul(b, index, stride));
196 } else if ((*p)->deref_type == nir_deref_type_struct) {
197 /* p starts at path[1], so this is safe */
198 nir_deref_instr *parent = *(p - 1);
199 unsigned field_offset =
200 struct_type_get_field_offset(parent->type, size_align,
201 (*p)->strct.index);
202 nir_iadd(b, offset, nir_imm_int(b, field_offset));
203 } else {
204 unreachable("Unsupported deref type");
205 }
206 }
207
208 nir_deref_path_finish(&path);
209
210 return offset;
211 }
212
213 bool
214 nir_remove_dead_derefs_impl(nir_function_impl *impl)
215 {
216 bool progress = false;
217
218 nir_foreach_block(block, impl) {
219 nir_foreach_instr_safe(instr, block) {
220 if (instr->type == nir_instr_type_deref &&
221 nir_deref_instr_remove_if_unused(nir_instr_as_deref(instr)))
222 progress = true;
223 }
224 }
225
226 if (progress)
227 nir_metadata_preserve(impl, nir_metadata_block_index |
228 nir_metadata_dominance);
229
230 return progress;
231 }
232
233 bool
234 nir_remove_dead_derefs(nir_shader *shader)
235 {
236 bool progress = false;
237 nir_foreach_function(function, shader) {
238 if (function->impl && nir_remove_dead_derefs_impl(function->impl))
239 progress = true;
240 }
241
242 return progress;
243 }
244
245 void
246 nir_fixup_deref_modes(nir_shader *shader)
247 {
248 nir_foreach_function(function, shader) {
249 if (!function->impl)
250 continue;
251
252 nir_foreach_block(block, function->impl) {
253 nir_foreach_instr(instr, block) {
254 if (instr->type != nir_instr_type_deref)
255 continue;
256
257 nir_deref_instr *deref = nir_instr_as_deref(instr);
258
259 nir_variable_mode parent_mode;
260 if (deref->deref_type == nir_deref_type_var) {
261 parent_mode = deref->var->data.mode;
262 } else {
263 assert(deref->parent.is_ssa);
264 nir_deref_instr *parent =
265 nir_instr_as_deref(deref->parent.ssa->parent_instr);
266 parent_mode = parent->mode;
267 }
268
269 deref->mode = parent_mode;
270 }
271 }
272 }
273 }
274
275 nir_deref_compare_result
276 nir_compare_deref_paths(nir_deref_path *a_path,
277 nir_deref_path *b_path)
278 {
279 if (a_path->path[0]->var != b_path->path[0]->var)
280 return nir_derefs_do_not_alias;
281
282 /* Start off assuming they fully compare. We ignore equality for now. In
283 * the end, we'll determine that by containment.
284 */
285 nir_deref_compare_result result = nir_derefs_may_alias_bit |
286 nir_derefs_a_contains_b_bit |
287 nir_derefs_b_contains_a_bit;
288
289 nir_deref_instr **a_p = &a_path->path[1];
290 nir_deref_instr **b_p = &b_path->path[1];
291 while (*a_p != NULL && *b_p != NULL) {
292 nir_deref_instr *a_tail = *(a_p++);
293 nir_deref_instr *b_tail = *(b_p++);
294
295 if (a_tail == b_tail)
296 continue;
297
298 switch (a_tail->deref_type) {
299 case nir_deref_type_array:
300 case nir_deref_type_array_wildcard: {
301 assert(b_tail->deref_type == nir_deref_type_array ||
302 b_tail->deref_type == nir_deref_type_array_wildcard);
303
304 if (a_tail->deref_type == nir_deref_type_array_wildcard) {
305 if (b_tail->deref_type != nir_deref_type_array_wildcard)
306 result &= ~nir_derefs_b_contains_a_bit;
307 } else if (b_tail->deref_type == nir_deref_type_array_wildcard) {
308 if (a_tail->deref_type != nir_deref_type_array_wildcard)
309 result &= ~nir_derefs_a_contains_b_bit;
310 } else {
311 assert(a_tail->deref_type == nir_deref_type_array &&
312 b_tail->deref_type == nir_deref_type_array);
313 assert(a_tail->arr.index.is_ssa && b_tail->arr.index.is_ssa);
314
315 if (nir_src_is_const(a_tail->arr.index) &&
316 nir_src_is_const(b_tail->arr.index)) {
317 /* If they're both direct and have different offsets, they
318 * don't even alias much less anything else.
319 */
320 if (nir_src_as_uint(a_tail->arr.index) !=
321 nir_src_as_uint(b_tail->arr.index))
322 return nir_derefs_do_not_alias;
323 } else if (a_tail->arr.index.ssa == b_tail->arr.index.ssa) {
324 /* They're the same indirect, continue on */
325 } else {
326 /* They're not the same index so we can't prove anything about
327 * containment.
328 */
329 result &= ~(nir_derefs_a_contains_b_bit | nir_derefs_b_contains_a_bit);
330 }
331 }
332 break;
333 }
334
335 case nir_deref_type_struct: {
336 /* If they're different struct members, they don't even alias */
337 if (a_tail->strct.index != b_tail->strct.index)
338 return nir_derefs_do_not_alias;
339 break;
340 }
341
342 default:
343 unreachable("Invalid deref type");
344 }
345 }
346
347 /* If a is longer than b, then it can't contain b */
348 if (*a_p != NULL)
349 result &= ~nir_derefs_a_contains_b_bit;
350 if (*b_p != NULL)
351 result &= ~nir_derefs_b_contains_a_bit;
352
353 /* If a contains b and b contains a they must be equal. */
354 if ((result & nir_derefs_a_contains_b_bit) && (result & nir_derefs_b_contains_a_bit))
355 result |= nir_derefs_equal_bit;
356
357 return result;
358 }
359
360 nir_deref_compare_result
361 nir_compare_derefs(nir_deref_instr *a, nir_deref_instr *b)
362 {
363 if (a == b) {
364 return nir_derefs_equal_bit | nir_derefs_may_alias_bit |
365 nir_derefs_a_contains_b_bit | nir_derefs_b_contains_a_bit;
366 }
367
368 nir_deref_path a_path, b_path;
369 nir_deref_path_init(&a_path, a, NULL);
370 nir_deref_path_init(&b_path, b, NULL);
371 assert(a_path.path[0]->deref_type == nir_deref_type_var);
372 assert(b_path.path[0]->deref_type == nir_deref_type_var);
373
374 nir_deref_compare_result result = nir_compare_deref_paths(&a_path, &b_path);
375
376 nir_deref_path_finish(&a_path);
377 nir_deref_path_finish(&b_path);
378
379 return result;
380 }
381
382 struct rematerialize_deref_state {
383 bool progress;
384 nir_builder builder;
385 nir_block *block;
386 struct hash_table *cache;
387 };
388
389 static nir_deref_instr *
390 rematerialize_deref_in_block(nir_deref_instr *deref,
391 struct rematerialize_deref_state *state)
392 {
393 if (deref->instr.block == state->block)
394 return deref;
395
396 if (!state->cache) {
397 state->cache = _mesa_hash_table_create(NULL, _mesa_hash_pointer,
398 _mesa_key_pointer_equal);
399 }
400
401 struct hash_entry *cached = _mesa_hash_table_search(state->cache, deref);
402 if (cached)
403 return cached->data;
404
405 nir_builder *b = &state->builder;
406 nir_deref_instr *new_deref =
407 nir_deref_instr_create(b->shader, deref->deref_type);
408 new_deref->mode = deref->mode;
409 new_deref->type = deref->type;
410
411 if (deref->deref_type == nir_deref_type_var) {
412 new_deref->var = deref->var;
413 } else {
414 nir_deref_instr *parent = nir_src_as_deref(deref->parent);
415 if (parent) {
416 parent = rematerialize_deref_in_block(parent, state);
417 new_deref->parent = nir_src_for_ssa(&parent->dest.ssa);
418 } else {
419 nir_src_copy(&new_deref->parent, &deref->parent, new_deref);
420 }
421 }
422
423 switch (deref->deref_type) {
424 case nir_deref_type_var:
425 case nir_deref_type_array_wildcard:
426 case nir_deref_type_cast:
427 /* Nothing more to do */
428 break;
429
430 case nir_deref_type_array:
431 assert(!nir_src_as_deref(deref->arr.index));
432 nir_src_copy(&new_deref->arr.index, &deref->arr.index, new_deref);
433 break;
434
435 case nir_deref_type_struct:
436 new_deref->strct.index = deref->strct.index;
437 break;
438
439 default:
440 unreachable("Invalid deref instruction type");
441 }
442
443 nir_ssa_dest_init(&new_deref->instr, &new_deref->dest,
444 deref->dest.ssa.num_components,
445 deref->dest.ssa.bit_size,
446 deref->dest.ssa.name);
447 nir_builder_instr_insert(b, &new_deref->instr);
448
449 return new_deref;
450 }
451
452 static bool
453 rematerialize_deref_src(nir_src *src, void *_state)
454 {
455 struct rematerialize_deref_state *state = _state;
456
457 nir_deref_instr *deref = nir_src_as_deref(*src);
458 if (!deref)
459 return true;
460
461 nir_deref_instr *block_deref = rematerialize_deref_in_block(deref, state);
462 if (block_deref != deref) {
463 nir_instr_rewrite_src(src->parent_instr, src,
464 nir_src_for_ssa(&block_deref->dest.ssa));
465 nir_deref_instr_remove_if_unused(deref);
466 state->progress = true;
467 }
468
469 return true;
470 }
471
472 /** Re-materialize derefs in every block
473 *
474 * This pass re-materializes deref instructions in every block in which it is
475 * used. After this pass has been run, every use of a deref will be of a
476 * deref in the same block as the use. Also, all unused derefs will be
477 * deleted as a side-effect.
478 */
479 bool
480 nir_rematerialize_derefs_in_use_blocks_impl(nir_function_impl *impl)
481 {
482 struct rematerialize_deref_state state = { 0 };
483 nir_builder_init(&state.builder, impl);
484
485 nir_foreach_block(block, impl) {
486 state.block = block;
487
488 /* Start each block with a fresh cache */
489 if (state.cache)
490 _mesa_hash_table_clear(state.cache, NULL);
491
492 nir_foreach_instr_safe(instr, block) {
493 if (instr->type == nir_instr_type_deref) {
494 nir_deref_instr_remove_if_unused(nir_instr_as_deref(instr));
495 continue;
496 }
497
498 state.builder.cursor = nir_before_instr(instr);
499 nir_foreach_src(instr, rematerialize_deref_src, &state);
500 }
501
502 #ifndef NDEBUG
503 nir_if *following_if = nir_block_get_following_if(block);
504 if (following_if)
505 assert(!nir_src_as_deref(following_if->condition));
506 #endif
507 }
508
509 _mesa_hash_table_destroy(state.cache, NULL);
510
511 return state.progress;
512 }
513
514 static bool
515 is_trivial_deref_cast(nir_deref_instr *cast)
516 {
517 nir_deref_instr *parent = nir_src_as_deref(cast->parent);
518 if (!parent)
519 return false;
520
521 return cast->mode == parent->mode &&
522 cast->type == parent->type &&
523 cast->dest.ssa.num_components == parent->dest.ssa.num_components &&
524 cast->dest.ssa.bit_size == parent->dest.ssa.bit_size;
525 }
526
527 static bool
528 nir_opt_deref_impl(nir_function_impl *impl)
529 {
530 bool progress = false;
531
532 nir_foreach_block(block, impl) {
533 nir_foreach_instr_safe(instr, block) {
534 if (instr->type != nir_instr_type_deref)
535 continue;
536
537 nir_deref_instr *deref = nir_instr_as_deref(instr);
538 switch (deref->deref_type) {
539 case nir_deref_type_cast:
540 if (is_trivial_deref_cast(deref)) {
541 assert(deref->parent.is_ssa);
542 nir_ssa_def_rewrite_uses(&deref->dest.ssa,
543 nir_src_for_ssa(deref->parent.ssa));
544 nir_instr_remove(&deref->instr);
545 progress = true;
546 }
547 break;
548
549 default:
550 /* Do nothing */
551 break;
552 }
553 }
554 }
555
556 if (progress) {
557 nir_metadata_preserve(impl, nir_metadata_block_index |
558 nir_metadata_dominance);
559 }
560
561 return progress;
562 }
563
564 bool
565 nir_opt_deref(nir_shader *shader)
566 {
567 bool progress = false;
568
569 nir_foreach_function(func, shader) {
570 if (func->impl && nir_opt_deref_impl(func->impl))
571 progress = true;
572 }
573
574 return progress;
575 }