nir/find_array_copies: Handle wildcards and overlapping copies
[mesa.git] / src / compiler / nir / nir_opt_find_array_copies.c
1 /*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27
28 struct match_node {
29 /* Note: these fields are only valid for leaf nodes */
30
31 unsigned next_array_idx;
32 int src_wildcard_idx;
33 nir_deref_path first_src_path;
34
35 /* The index of the first read of the source path that's part of the copy
36 * we're matching. If the last write to the source path is after this, we
37 * would get a different result from reading it at the end and we can't
38 * emit the copy.
39 */
40 unsigned first_src_read;
41
42 /* The last time there was a write to this node. */
43 unsigned last_overwritten;
44
45 /* The last time there was a write to this node which successfully advanced
46 * next_array_idx. This helps us catch any intervening aliased writes.
47 */
48 unsigned last_successful_write;
49
50 unsigned num_children;
51 struct match_node *children[];
52 };
53
54 struct match_state {
55 /* Map from nir_variable * -> match_node */
56 struct hash_table *table;
57
58 unsigned cur_instr;
59
60 nir_builder builder;
61
62 void *dead_ctx;
63 };
64
65 static struct match_node *
66 create_match_node(const struct glsl_type *type, struct match_state *state)
67 {
68 unsigned num_children = 0;
69 if (glsl_type_is_array_or_matrix(type)) {
70 /* One for wildcards */
71 num_children = glsl_get_length(type) + 1;
72 } else if (glsl_type_is_struct_or_ifc(type)) {
73 num_children = glsl_get_length(type);
74 }
75
76 struct match_node *node = rzalloc_size(state->dead_ctx,
77 sizeof(struct match_node) +
78 num_children * sizeof(struct match_node *));
79 node->num_children = num_children;
80 node->src_wildcard_idx = -1;
81 node->first_src_read = UINT32_MAX;
82 return node;
83 }
84
85 static struct match_node *
86 node_for_deref(nir_deref_instr *instr, struct match_node *parent,
87 struct match_state *state)
88 {
89 unsigned idx;
90 switch (instr->deref_type) {
91 case nir_deref_type_var: {
92 struct hash_entry *entry = _mesa_hash_table_search(state->table, instr->var);
93 if (entry) {
94 return entry->data;
95 } else {
96 struct match_node *node = create_match_node(instr->type, state);
97 _mesa_hash_table_insert(state->table, instr->var, node);
98 return node;
99 }
100 }
101
102 case nir_deref_type_array_wildcard:
103 idx = glsl_get_length(instr->type);
104 break;
105
106 case nir_deref_type_array:
107 if (nir_src_is_const(instr->arr.index)) {
108 idx = nir_src_as_uint(instr->arr.index);
109 } else {
110 idx = glsl_get_length(instr->type);
111 }
112 break;
113
114 case nir_deref_type_struct:
115 idx = instr->strct.index;
116 break;
117
118 default:
119 unreachable("bad deref type");
120 }
121
122 assert(idx < parent->num_children);
123 if (parent->children[idx]) {
124 return parent->children[idx];
125 } else {
126 struct match_node *node = create_match_node(instr->type, state);
127 parent->children[idx] = node;
128 return node;
129 }
130 }
131
132 static struct match_node *
133 node_for_wildcard(const struct glsl_type *type, struct match_node *parent,
134 struct match_state *state)
135 {
136 assert(glsl_type_is_array_or_matrix(type));
137 unsigned idx = glsl_get_length(type);
138
139 if (parent->children[idx]) {
140 return parent->children[idx];
141 } else {
142 struct match_node *node =
143 create_match_node(glsl_get_array_element(type), state);
144 parent->children[idx] = node;
145 return node;
146 }
147 }
148
149 static struct match_node *
150 node_for_path(nir_deref_path *path, struct match_state *state)
151 {
152 struct match_node *node = NULL;
153 for (nir_deref_instr **instr = path->path; *instr; instr++)
154 node = node_for_deref(*instr, node, state);
155
156 return node;
157 }
158
159 static struct match_node *
160 node_for_path_with_wildcard(nir_deref_path *path, unsigned wildcard_idx,
161 struct match_state *state)
162 {
163 struct match_node *node = NULL;
164 unsigned idx = 0;
165 for (nir_deref_instr **instr = path->path; *instr; instr++, idx++) {
166 if (idx == wildcard_idx)
167 node = node_for_wildcard((*(instr - 1))->type, node, state);
168 else
169 node = node_for_deref(*instr, node, state);
170 }
171
172 return node;
173 }
174
175 typedef void (*match_cb)(struct match_node *, struct match_state *);
176
177 static void
178 _foreach_aliasing(nir_deref_instr **deref, match_cb cb,
179 struct match_node *node, struct match_state *state)
180 {
181 if (*deref == NULL) {
182 cb(node, state);
183 return;
184 }
185
186 switch ((*deref)->deref_type) {
187 case nir_deref_type_struct: {
188 struct match_node *child = node->children[(*deref)->strct.index];
189 if (child)
190 _foreach_aliasing(deref + 1, cb, child, state);
191 return;
192 }
193
194 case nir_deref_type_array:
195 case nir_deref_type_array_wildcard: {
196 if ((*deref)->deref_type == nir_deref_type_array_wildcard ||
197 !nir_src_is_const((*deref)->arr.index)) {
198 /* This access may touch any index, so we have to visit all of
199 * them.
200 */
201 for (unsigned i = 0; i < node->num_children; i++) {
202 if (node->children[i])
203 _foreach_aliasing(deref + 1, cb, node->children[i], state);
204 }
205 } else {
206 /* Visit the wildcard entry if any */
207 if (node->children[node->num_children - 1]) {
208 _foreach_aliasing(deref + 1, cb,
209 node->children[node->num_children - 1], state);
210 }
211
212 unsigned index = nir_src_as_uint((*deref)->arr.index);
213 /* Check that the index is in-bounds */
214 if (index < node->num_children - 1 && node->children[index])
215 _foreach_aliasing(deref + 1, cb, node->children[index], state);
216 }
217 return;
218 }
219
220 default:
221 unreachable("bad deref type");
222 }
223 }
224
225 /* Given a deref path, find all the leaf deref nodes that alias it. */
226
227 static void
228 foreach_aliasing_node(nir_deref_path *path,
229 match_cb cb,
230 struct match_state *state)
231 {
232 assert(path->path[0]->deref_type == nir_deref_type_var);
233 struct hash_entry *entry = _mesa_hash_table_search(state->table,
234 path->path[0]->var);
235 if (entry)
236 _foreach_aliasing(&path->path[1], cb, entry->data, state);
237 }
238
239 static nir_deref_instr *
240 build_wildcard_deref(nir_builder *b, nir_deref_path *path,
241 unsigned wildcard_idx)
242 {
243 assert(path->path[wildcard_idx]->deref_type == nir_deref_type_array);
244
245 nir_deref_instr *tail =
246 nir_build_deref_array_wildcard(b, path->path[wildcard_idx - 1]);
247
248 for (unsigned i = wildcard_idx + 1; path->path[i]; i++)
249 tail = nir_build_deref_follower(b, tail, path->path[i]);
250
251 return tail;
252 }
253
254 static void
255 clobber(struct match_node *node, struct match_state *state)
256 {
257 node->last_overwritten = state->cur_instr;
258 }
259
260 static bool
261 try_match_deref(nir_deref_path *base_path, int *path_array_idx,
262 nir_deref_path *deref_path, int arr_idx)
263 {
264 for (int i = 0; ; i++) {
265 nir_deref_instr *b = base_path->path[i];
266 nir_deref_instr *d = deref_path->path[i];
267 /* They have to be the same length */
268 if ((b == NULL) != (d == NULL))
269 return false;
270
271 if (b == NULL)
272 break;
273
274 /* This can happen if one is a deref_array and the other a wildcard */
275 if (b->deref_type != d->deref_type)
276 return false;;
277
278 switch (b->deref_type) {
279 case nir_deref_type_var:
280 if (b->var != d->var)
281 return false;
282 continue;
283
284 case nir_deref_type_array:
285 assert(b->arr.index.is_ssa && d->arr.index.is_ssa);
286 const bool const_b_idx = nir_src_is_const(b->arr.index);
287 const bool const_d_idx = nir_src_is_const(d->arr.index);
288 const unsigned b_idx = const_b_idx ? nir_src_as_uint(b->arr.index) : 0;
289 const unsigned d_idx = const_d_idx ? nir_src_as_uint(d->arr.index) : 0;
290
291 /* If we don't have an index into the path yet or if this entry in
292 * the path is at the array index, see if this is a candidate. We're
293 * looking for an index which is zero in the base deref and arr_idx
294 * in the search deref.
295 */
296 if ((*path_array_idx < 0 || *path_array_idx == i) &&
297 const_b_idx && b_idx == 0 &&
298 const_d_idx && d_idx == arr_idx) {
299 *path_array_idx = i;
300 continue;
301 }
302
303 /* We're at the array index but not a candidate */
304 if (*path_array_idx == i)
305 return false;
306
307 /* If we're not the path array index, we must match exactly. We
308 * could probably just compare SSA values and trust in copy
309 * propagation but doing it ourselves means this pass can run a bit
310 * earlier.
311 */
312 if (b->arr.index.ssa == d->arr.index.ssa ||
313 (const_b_idx && const_d_idx && b_idx == d_idx))
314 continue;
315
316 return false;
317
318 case nir_deref_type_array_wildcard:
319 continue;
320
321 case nir_deref_type_struct:
322 if (b->strct.index != d->strct.index)
323 return false;
324 continue;
325
326 default:
327 unreachable("Invalid deref type in a path");
328 }
329 }
330
331 /* If we got here without failing, we've matched. However, it isn't an
332 * array match unless we found an altered array index.
333 */
334 return *path_array_idx > 0;
335 }
336
337 static void
338 handle_read(nir_deref_instr *src, struct match_state *state)
339 {
340 /* We only need to create an entry for sources that might be used to form
341 * an array copy. Hence no indirects or indexing into a vector.
342 */
343 if (nir_deref_instr_has_indirect(src) ||
344 nir_deref_instr_is_known_out_of_bounds(src) ||
345 (src->deref_type == nir_deref_type_array &&
346 glsl_type_is_vector(nir_src_as_deref(src->parent)->type)))
347 return;
348
349 nir_deref_path src_path;
350 nir_deref_path_init(&src_path, src, state->dead_ctx);
351
352 /* Create a node for this source if it doesn't exist. The point of this is
353 * to know which nodes aliasing a given store we actually need to care
354 * about, to avoid creating an excessive amount of nodes.
355 */
356 node_for_path(&src_path, state);
357 }
358
359 /* The core implementation, which is used for both copies and writes. Return
360 * true if a copy is created.
361 */
362 static bool
363 handle_write(nir_deref_instr *dst, nir_deref_instr *src,
364 unsigned write_index, unsigned read_index,
365 struct match_state *state)
366 {
367 nir_builder *b = &state->builder;
368
369 nir_deref_path dst_path;
370 nir_deref_path_init(&dst_path, dst, state->dead_ctx);
371
372 unsigned idx = 0;
373 for (nir_deref_instr **instr = dst_path.path; *instr; instr++, idx++) {
374 if ((*instr)->deref_type != nir_deref_type_array)
375 continue;
376
377 /* Get the entry where the index is replaced by a wildcard, so that we
378 * hopefully can keep matching an array copy.
379 */
380 struct match_node *dst_node =
381 node_for_path_with_wildcard(&dst_path, idx, state);
382
383 if (!src)
384 goto reset;
385
386 if (nir_src_as_uint((*instr)->arr.index) != dst_node->next_array_idx)
387 goto reset;
388
389 if (dst_node->next_array_idx == 0) {
390 /* At this point there may be multiple source indices which are zero,
391 * so we can't pin down the actual source index. Just store it and
392 * move on.
393 */
394 nir_deref_path_init(&dst_node->first_src_path, src, state->dead_ctx);
395 } else {
396 nir_deref_path src_path;
397 nir_deref_path_init(&src_path, src, state->dead_ctx);
398 bool result = try_match_deref(&dst_node->first_src_path,
399 &dst_node->src_wildcard_idx,
400 &src_path, dst_node->next_array_idx);
401 nir_deref_path_finish(&src_path);
402 if (!result)
403 goto reset;
404 }
405
406 /* Check if an aliasing write clobbered the array after the last normal
407 * write. For example, with a sequence like this:
408 *
409 * dst[0][*] = src[0][*];
410 * dst[0][0] = 0; // invalidates the array copy dst[*][*] = src[*][*]
411 * dst[1][*] = src[1][*];
412 *
413 * Note that the second write wouldn't reset the entry for dst[*][*]
414 * by itself, but it'll be caught by this check when processing the
415 * third copy.
416 */
417 if (dst_node->last_successful_write < dst_node->last_overwritten)
418 goto reset;
419
420 dst_node->last_successful_write = write_index;
421
422 /* In this case we've successfully processed an array element. Check if
423 * this is the last, so that we can emit an array copy.
424 */
425 dst_node->next_array_idx++;
426 dst_node->first_src_read = MIN2(dst_node->first_src_read, read_index);
427 if (dst_node->next_array_idx > 1 &&
428 dst_node->next_array_idx == glsl_get_length((*(instr - 1))->type)) {
429 /* Make sure that nothing was overwritten. */
430 struct match_node *src_node =
431 node_for_path_with_wildcard(&dst_node->first_src_path,
432 dst_node->src_wildcard_idx,
433 state);
434
435 if (src_node->last_overwritten <= dst_node->first_src_read) {
436 nir_copy_deref(b, build_wildcard_deref(b, &dst_path, idx),
437 build_wildcard_deref(b, &dst_node->first_src_path,
438 dst_node->src_wildcard_idx));
439 foreach_aliasing_node(&dst_path, clobber, state);
440 return true;
441 }
442 } else {
443 continue;
444 }
445
446 reset:
447 dst_node->next_array_idx = 0;
448 dst_node->src_wildcard_idx = -1;
449 dst_node->last_successful_write = 0;
450 dst_node->first_src_read = UINT32_MAX;
451 }
452
453 /* Mark everything aliasing dst_path as clobbered. This needs to happen
454 * last since in the loop above we need to know what last clobbered
455 * dst_node and this overwrites that.
456 */
457 foreach_aliasing_node(&dst_path, clobber, state);
458
459 return false;
460 }
461
462 static bool
463 opt_find_array_copies_block(nir_builder *b, nir_block *block,
464 struct match_state *state)
465 {
466 bool progress = false;
467
468 unsigned next_index = 0;
469
470 _mesa_hash_table_clear(state->table, NULL);
471
472 nir_foreach_instr(instr, block) {
473 if (instr->type != nir_instr_type_intrinsic)
474 continue;
475
476 /* Index the instructions before we do anything else. */
477 instr->index = next_index++;
478
479 /* Save the index of this instruction */
480 state->cur_instr = instr->index;
481
482 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
483
484 if (intrin->intrinsic == nir_intrinsic_load_deref) {
485 handle_read(nir_src_as_deref(intrin->src[0]), state);
486 continue;
487 }
488
489 if (intrin->intrinsic != nir_intrinsic_copy_deref &&
490 intrin->intrinsic != nir_intrinsic_store_deref)
491 continue;
492
493 nir_deref_instr *dst_deref = nir_src_as_deref(intrin->src[0]);
494
495 /* The destination must be local. If we see a non-local store, we
496 * continue on because it won't affect local stores or read-only
497 * variables.
498 */
499 if (dst_deref->mode != nir_var_function_temp)
500 continue;
501
502 /* If there are any known out-of-bounds writes, then we can just skip
503 * this write as it's undefined and won't contribute to building up an
504 * array copy anyways.
505 */
506 if (nir_deref_instr_is_known_out_of_bounds(dst_deref))
507 continue;
508
509 nir_deref_instr *src_deref;
510 unsigned load_index = 0;
511 if (intrin->intrinsic == nir_intrinsic_copy_deref) {
512 src_deref = nir_src_as_deref(intrin->src[1]);
513 load_index = intrin->instr.index;
514 } else {
515 assert(intrin->intrinsic == nir_intrinsic_store_deref);
516 nir_intrinsic_instr *load = nir_src_as_intrinsic(intrin->src[1]);
517 if (load == NULL || load->intrinsic != nir_intrinsic_load_deref) {
518 src_deref = NULL;
519 } else {
520 src_deref = nir_src_as_deref(load->src[0]);
521 load_index = load->instr.index;
522 }
523
524 if (nir_intrinsic_write_mask(intrin) !=
525 (1 << glsl_get_components(dst_deref->type)) - 1) {
526 src_deref = NULL;
527 }
528 }
529
530 /* The source must be either local or something that's guaranteed to be
531 * read-only.
532 */
533 const nir_variable_mode read_only_modes =
534 nir_var_shader_in | nir_var_uniform | nir_var_system_value;
535 if (src_deref &&
536 !(src_deref->mode & (nir_var_function_temp | read_only_modes))) {
537 src_deref = NULL;
538 }
539
540 /* There must be no indirects in the source or destination and no known
541 * out-of-bounds accesses in the source, and the copy must be fully
542 * qualified, or else we can't build up the array copy. We handled
543 * out-of-bounds accesses to the dest above.
544 */
545 if (src_deref &&
546 (nir_deref_instr_has_indirect(src_deref) ||
547 nir_deref_instr_is_known_out_of_bounds(src_deref) ||
548 nir_deref_instr_has_indirect(dst_deref) ||
549 !glsl_type_is_vector_or_scalar(src_deref->type))) {
550 src_deref = NULL;
551 }
552
553 state->builder.cursor = nir_after_instr(instr);
554 progress |= handle_write(dst_deref, src_deref, instr->index,
555 load_index, state);
556 }
557
558 return progress;
559 }
560
561 static bool
562 opt_find_array_copies_impl(nir_function_impl *impl)
563 {
564 nir_builder b;
565 nir_builder_init(&b, impl);
566
567 bool progress = false;
568
569 struct match_state s;
570 s.dead_ctx = ralloc_context(NULL);
571 s.table = _mesa_pointer_hash_table_create(s.dead_ctx);
572 nir_builder_init(&s.builder, impl);
573
574 nir_foreach_block(block, impl) {
575 if (opt_find_array_copies_block(&b, block, &s))
576 progress = true;
577 }
578
579 ralloc_free(s.dead_ctx);
580
581 if (progress) {
582 nir_metadata_preserve(impl, nir_metadata_block_index |
583 nir_metadata_dominance);
584 }
585
586 return progress;
587 }
588
589 /**
590 * This peephole optimization looks for a series of load/store_deref or
591 * copy_deref instructions that copy an array from one variable to another and
592 * turns it into a copy_deref that copies the entire array. The pattern it
593 * looks for is extremely specific but it's good enough to pick up on the
594 * input array copies in DXVK and should also be able to pick up the sequence
595 * generated by spirv_to_nir for a OpLoad of a large composite followed by
596 * OpStore.
597 *
598 * TODO: Support out-of-order copies.
599 */
600 bool
601 nir_opt_find_array_copies(nir_shader *shader)
602 {
603 bool progress = false;
604
605 nir_foreach_function(function, shader) {
606 if (function->impl && opt_find_array_copies_impl(function->impl))
607 progress = true;
608 }
609
610 return progress;
611 }