nir/load_store_vectorize: fix combining stores with aliasing loads between
[mesa.git] / src / compiler / nir / nir_opt_load_store_vectorize.c
1 /*
2 * Copyright © 2019 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 /**
25 * Although it's called a load/store "vectorization" pass, this also combines
26 * intersecting and identical loads/stores. It currently supports derefs, ubo,
27 * ssbo and push constant loads/stores.
28 *
29 * This doesn't handle copy_deref intrinsics and assumes that
30 * nir_lower_alu_to_scalar() has been called and that the IR is free from ALU
31 * modifiers. It also assumes that derefs have explicitly laid out types.
32 *
33 * After vectorization, the backend may want to call nir_lower_alu_to_scalar()
34 * and nir_lower_pack(). Also this creates cast instructions taking derefs as a
35 * source and some parts of NIR may not be able to handle that well.
36 *
37 * There are a few situations where this doesn't vectorize as well as it could:
38 * - It won't turn four consecutive vec3 loads into 3 vec4 loads.
39 * - It doesn't do global vectorization.
40 * Handling these cases probably wouldn't provide much benefit though.
41 */
42
43 #include "nir.h"
44 #include "nir_deref.h"
45 #include "nir_builder.h"
46 #include "nir_worklist.h"
47 #include "util/u_dynarray.h"
48
49 #include <stdlib.h>
50
51 struct intrinsic_info {
52 nir_variable_mode mode; /* 0 if the mode is obtained from the deref. */
53 nir_intrinsic_op op;
54 bool is_atomic;
55 /* Indices into nir_intrinsic::src[] or -1 if not applicable. */
56 int resource_src; /* resource (e.g. from vulkan_resource_index) */
57 int base_src; /* offset which it loads/stores from */
58 int deref_src; /* deref which is loads/stores from */
59 int value_src; /* the data it is storing */
60 };
61
62 static const struct intrinsic_info *
63 get_info(nir_intrinsic_op op) {
64 switch (op) {
65 #define INFO(mode, op, atomic, res, base, deref, val) \
66 case nir_intrinsic_##op: {\
67 static const struct intrinsic_info op##_info = {mode, nir_intrinsic_##op, atomic, res, base, deref, val};\
68 return &op##_info;\
69 }
70 #define LOAD(mode, op, res, base, deref) INFO(mode, load_##op, false, res, base, deref, -1)
71 #define STORE(mode, op, res, base, deref, val) INFO(mode, store_##op, false, res, base, deref, val)
72 #define ATOMIC(mode, type, op, res, base, deref, val) INFO(mode, type##_atomic_##op, true, res, base, deref, val)
73 LOAD(nir_var_mem_push_const, push_constant, -1, 0, -1)
74 LOAD(nir_var_mem_ubo, ubo, 0, 1, -1)
75 LOAD(nir_var_mem_ssbo, ssbo, 0, 1, -1)
76 STORE(nir_var_mem_ssbo, ssbo, 1, 2, -1, 0)
77 LOAD(0, deref, -1, -1, 0)
78 STORE(0, deref, -1, -1, 0, 1)
79 LOAD(nir_var_mem_shared, shared, -1, 0, -1)
80 STORE(nir_var_mem_shared, shared, -1, 1, -1, 0)
81 ATOMIC(nir_var_mem_ssbo, ssbo, add, 0, 1, -1, 2)
82 ATOMIC(nir_var_mem_ssbo, ssbo, imin, 0, 1, -1, 2)
83 ATOMIC(nir_var_mem_ssbo, ssbo, umin, 0, 1, -1, 2)
84 ATOMIC(nir_var_mem_ssbo, ssbo, imax, 0, 1, -1, 2)
85 ATOMIC(nir_var_mem_ssbo, ssbo, umax, 0, 1, -1, 2)
86 ATOMIC(nir_var_mem_ssbo, ssbo, and, 0, 1, -1, 2)
87 ATOMIC(nir_var_mem_ssbo, ssbo, or, 0, 1, -1, 2)
88 ATOMIC(nir_var_mem_ssbo, ssbo, xor, 0, 1, -1, 2)
89 ATOMIC(nir_var_mem_ssbo, ssbo, exchange, 0, 1, -1, 2)
90 ATOMIC(nir_var_mem_ssbo, ssbo, comp_swap, 0, 1, -1, 2)
91 ATOMIC(nir_var_mem_ssbo, ssbo, fadd, 0, 1, -1, 2)
92 ATOMIC(nir_var_mem_ssbo, ssbo, fmin, 0, 1, -1, 2)
93 ATOMIC(nir_var_mem_ssbo, ssbo, fmax, 0, 1, -1, 2)
94 ATOMIC(nir_var_mem_ssbo, ssbo, fcomp_swap, 0, 1, -1, 2)
95 ATOMIC(0, deref, add, -1, -1, 0, 1)
96 ATOMIC(0, deref, imin, -1, -1, 0, 1)
97 ATOMIC(0, deref, umin, -1, -1, 0, 1)
98 ATOMIC(0, deref, imax, -1, -1, 0, 1)
99 ATOMIC(0, deref, umax, -1, -1, 0, 1)
100 ATOMIC(0, deref, and, -1, -1, 0, 1)
101 ATOMIC(0, deref, or, -1, -1, 0, 1)
102 ATOMIC(0, deref, xor, -1, -1, 0, 1)
103 ATOMIC(0, deref, exchange, -1, -1, 0, 1)
104 ATOMIC(0, deref, comp_swap, -1, -1, 0, 1)
105 ATOMIC(0, deref, fadd, -1, -1, 0, 1)
106 ATOMIC(0, deref, fmin, -1, -1, 0, 1)
107 ATOMIC(0, deref, fmax, -1, -1, 0, 1)
108 ATOMIC(0, deref, fcomp_swap, -1, -1, 0, 1)
109 ATOMIC(nir_var_mem_shared, shared, add, 0, 1, -1, 2)
110 ATOMIC(nir_var_mem_shared, shared, imin, 0, 1, -1, 2)
111 ATOMIC(nir_var_mem_shared, shared, umin, 0, 1, -1, 2)
112 ATOMIC(nir_var_mem_shared, shared, imax, 0, 1, -1, 2)
113 ATOMIC(nir_var_mem_shared, shared, umax, 0, 1, -1, 2)
114 ATOMIC(nir_var_mem_shared, shared, and, 0, 1, -1, 2)
115 ATOMIC(nir_var_mem_shared, shared, or, 0, 1, -1, 2)
116 ATOMIC(nir_var_mem_shared, shared, xor, 0, 1, -1, 2)
117 ATOMIC(nir_var_mem_shared, shared, exchange, 0, 1, -1, 2)
118 ATOMIC(nir_var_mem_shared, shared, comp_swap, 0, 1, -1, 2)
119 ATOMIC(nir_var_mem_shared, shared, fadd, 0, 1, -1, 2)
120 ATOMIC(nir_var_mem_shared, shared, fmin, 0, 1, -1, 2)
121 ATOMIC(nir_var_mem_shared, shared, fmax, 0, 1, -1, 2)
122 ATOMIC(nir_var_mem_shared, shared, fcomp_swap, 0, 1, -1, 2)
123 default:
124 break;
125 #undef ATOMIC
126 #undef STORE
127 #undef LOAD
128 #undef INFO
129 }
130 return NULL;
131 }
132
133 /*
134 * Information used to compare memory operations.
135 * It canonically represents an offset as:
136 * `offset_defs[0]*offset_defs_mul[0] + offset_defs[1]*offset_defs_mul[1] + ...`
137 * "offset_defs" is sorted in ascenting order by the ssa definition's index.
138 * "resource" or "var" may be NULL.
139 */
140 struct entry_key {
141 nir_ssa_def *resource;
142 nir_variable *var;
143 unsigned offset_def_count;
144 nir_ssa_def **offset_defs;
145 uint64_t *offset_defs_mul;
146 };
147
148 /* Information on a single memory operation. */
149 struct entry {
150 struct list_head head;
151 unsigned index;
152
153 struct entry_key *key;
154 union {
155 uint64_t offset; /* sign-extended */
156 int64_t offset_signed;
157 };
158 uint32_t best_align;
159
160 nir_instr *instr;
161 nir_intrinsic_instr *intrin;
162 const struct intrinsic_info *info;
163 enum gl_access_qualifier access;
164 bool is_store;
165
166 nir_deref_instr *deref;
167 };
168
169 struct vectorize_ctx {
170 nir_variable_mode modes;
171 nir_should_vectorize_mem_func callback;
172 struct list_head entries[nir_num_variable_modes];
173 struct hash_table *loads[nir_num_variable_modes];
174 struct hash_table *stores[nir_num_variable_modes];
175 };
176
177 static uint32_t hash_entry_key(const void *key_)
178 {
179 /* this is careful to not include pointers in the hash calculation so that
180 * the order of the hash table walk is deterministic */
181 struct entry_key *key = (struct entry_key*)key_;
182
183 uint32_t hash = _mesa_fnv32_1a_offset_bias;
184 if (key->resource)
185 hash = _mesa_fnv32_1a_accumulate(hash, key->resource->index);
186 if (key->var) {
187 hash = _mesa_fnv32_1a_accumulate(hash, key->var->index);
188 unsigned mode = key->var->data.mode;
189 hash = _mesa_fnv32_1a_accumulate(hash, mode);
190 }
191
192 for (unsigned i = 0; i < key->offset_def_count; i++)
193 hash = _mesa_fnv32_1a_accumulate(hash, key->offset_defs[i]->index);
194
195 hash = _mesa_fnv32_1a_accumulate_block(
196 hash, key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
197
198 return hash;
199 }
200
201 static bool entry_key_equals(const void *a_, const void *b_)
202 {
203 struct entry_key *a = (struct entry_key*)a_;
204 struct entry_key *b = (struct entry_key*)b_;
205
206 if (a->var != b->var || a->resource != b->resource)
207 return false;
208
209 if (a->offset_def_count != b->offset_def_count)
210 return false;
211
212 size_t offset_def_size = a->offset_def_count * sizeof(nir_ssa_def *);
213 size_t offset_def_mul_size = a->offset_def_count * sizeof(uint64_t);
214 if (a->offset_def_count &&
215 (memcmp(a->offset_defs, b->offset_defs, offset_def_size) ||
216 memcmp(a->offset_defs_mul, b->offset_defs_mul, offset_def_mul_size)))
217 return false;
218
219 return true;
220 }
221
222 static void delete_entry_dynarray(struct hash_entry *entry)
223 {
224 struct util_dynarray *arr = (struct util_dynarray *)entry->data;
225 ralloc_free(arr);
226 }
227
228 static int sort_entries(const void *a_, const void *b_)
229 {
230 struct entry *a = *(struct entry*const*)a_;
231 struct entry *b = *(struct entry*const*)b_;
232
233 if (a->offset_signed > b->offset_signed)
234 return 1;
235 else if (a->offset_signed < b->offset_signed)
236 return -1;
237 else
238 return 0;
239 }
240
241 static unsigned
242 get_bit_size(struct entry *entry)
243 {
244 unsigned size = entry->is_store ?
245 entry->intrin->src[entry->info->value_src].ssa->bit_size :
246 entry->intrin->dest.ssa.bit_size;
247 return size == 1 ? 32u : size;
248 }
249
250 /* If "def" is from an alu instruction with the opcode "op" and one of it's
251 * sources is a constant, update "def" to be the non-constant source, fill "c"
252 * with the constant and return true. */
253 static bool
254 parse_alu(nir_ssa_def **def, nir_op op, uint64_t *c)
255 {
256 nir_ssa_scalar scalar;
257 scalar.def = *def;
258 scalar.comp = 0;
259
260 if (!nir_ssa_scalar_is_alu(scalar) || nir_ssa_scalar_alu_op(scalar) != op)
261 return false;
262
263 nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(scalar, 0);
264 nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(scalar, 1);
265 if (op != nir_op_ishl && nir_ssa_scalar_is_const(src0) && src1.comp == 0) {
266 *c = nir_ssa_scalar_as_uint(src0);
267 *def = src1.def;
268 } else if (nir_ssa_scalar_is_const(src1) && src0.comp == 0) {
269 *c = nir_ssa_scalar_as_uint(src1);
270 *def = src0.def;
271 } else {
272 return false;
273 }
274 return true;
275 }
276
277 /* Parses an offset expression such as "a * 16 + 4" and "(a * 16 + 4) * 64 + 32". */
278 static void
279 parse_offset(nir_ssa_def **base, uint64_t *base_mul, uint64_t *offset)
280 {
281 if ((*base)->parent_instr->type == nir_instr_type_load_const) {
282 *offset = nir_src_comp_as_uint(nir_src_for_ssa(*base), 0);
283 *base = NULL;
284 return;
285 }
286
287 uint64_t mul = 1;
288 uint64_t add = 0;
289 bool progress = false;
290 do {
291 uint64_t mul2 = 1, add2 = 0;
292
293 progress = parse_alu(base, nir_op_imul, &mul2);
294 mul *= mul2;
295
296 mul2 = 0;
297 progress |= parse_alu(base, nir_op_ishl, &mul2);
298 mul <<= mul2;
299
300 progress |= parse_alu(base, nir_op_iadd, &add2);
301 add += add2 * mul;
302 } while (progress);
303
304 *base_mul = mul;
305 *offset = add;
306 }
307
308 static unsigned
309 type_scalar_size_bytes(const struct glsl_type *type)
310 {
311 assert(glsl_type_is_vector_or_scalar(type) ||
312 glsl_type_is_matrix(type));
313 return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
314 }
315
316 static int
317 get_array_stride(const struct glsl_type *type)
318 {
319 unsigned explicit_stride = glsl_get_explicit_stride(type);
320 if ((glsl_type_is_matrix(type) &&
321 glsl_matrix_type_is_row_major(type)) ||
322 (glsl_type_is_vector(type) && explicit_stride == 0))
323 return type_scalar_size_bytes(type);
324 return explicit_stride;
325 }
326
327 static uint64_t
328 mask_sign_extend(uint64_t val, unsigned bit_size)
329 {
330 return (int64_t)(val << (64 - bit_size)) >> (64 - bit_size);
331 }
332
333 static unsigned
334 add_to_entry_key(nir_ssa_def **offset_defs, uint64_t *offset_defs_mul,
335 unsigned offset_def_count, nir_ssa_def *def, uint64_t mul)
336 {
337 mul = mask_sign_extend(mul, def->bit_size);
338
339 for (unsigned i = 0; i <= offset_def_count; i++) {
340 if (i == offset_def_count || def->index > offset_defs[i]->index) {
341 /* insert before i */
342 memmove(offset_defs + i + 1, offset_defs + i,
343 (offset_def_count - i) * sizeof(nir_ssa_def *));
344 memmove(offset_defs_mul + i + 1, offset_defs_mul + i,
345 (offset_def_count - i) * sizeof(uint64_t));
346 offset_defs[i] = def;
347 offset_defs_mul[i] = mul;
348 return 1;
349 } else if (def->index == offset_defs[i]->index) {
350 /* merge with offset_def at i */
351 offset_defs_mul[i] += mul;
352 return 0;
353 }
354 }
355 unreachable("Unreachable.");
356 return 0;
357 }
358
359 static struct entry_key *
360 create_entry_key_from_deref(void *mem_ctx,
361 struct vectorize_ctx *ctx,
362 nir_deref_path *path,
363 uint64_t *offset_base)
364 {
365 unsigned path_len = 0;
366 while (path->path[path_len])
367 path_len++;
368
369 nir_ssa_def *offset_defs_stack[32];
370 uint64_t offset_defs_mul_stack[32];
371 nir_ssa_def **offset_defs = offset_defs_stack;
372 uint64_t *offset_defs_mul = offset_defs_mul_stack;
373 if (path_len > 32) {
374 offset_defs = malloc(path_len * sizeof(nir_ssa_def *));
375 offset_defs_mul = malloc(path_len * sizeof(uint64_t));
376 }
377 unsigned offset_def_count = 0;
378
379 struct entry_key* key = ralloc(mem_ctx, struct entry_key);
380 key->resource = NULL;
381 key->var = NULL;
382 *offset_base = 0;
383
384 for (unsigned i = 0; i < path_len; i++) {
385 nir_deref_instr *parent = i ? path->path[i - 1] : NULL;
386 nir_deref_instr *deref = path->path[i];
387
388 switch (deref->deref_type) {
389 case nir_deref_type_var: {
390 assert(!parent);
391 key->var = deref->var;
392 break;
393 }
394 case nir_deref_type_array:
395 case nir_deref_type_ptr_as_array: {
396 assert(parent);
397 nir_ssa_def *index = deref->arr.index.ssa;
398 uint32_t stride;
399 if (deref->deref_type == nir_deref_type_ptr_as_array)
400 stride = nir_deref_instr_ptr_as_array_stride(deref);
401 else
402 stride = get_array_stride(parent->type);
403
404 nir_ssa_def *base = index;
405 uint64_t offset = 0, base_mul = 1;
406 parse_offset(&base, &base_mul, &offset);
407 offset = mask_sign_extend(offset, index->bit_size);
408
409 *offset_base += offset * stride;
410 if (base) {
411 offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
412 offset_def_count,
413 base, base_mul * stride);
414 }
415 break;
416 }
417 case nir_deref_type_struct: {
418 assert(parent);
419 int offset = glsl_get_struct_field_offset(parent->type, deref->strct.index);
420 *offset_base += offset;
421 break;
422 }
423 case nir_deref_type_cast: {
424 if (!parent)
425 key->resource = deref->parent.ssa;
426 break;
427 }
428 default:
429 unreachable("Unhandled deref type");
430 }
431 }
432
433 key->offset_def_count = offset_def_count;
434 key->offset_defs = ralloc_array(mem_ctx, nir_ssa_def *, offset_def_count);
435 key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, offset_def_count);
436 memcpy(key->offset_defs, offset_defs, offset_def_count * sizeof(nir_ssa_def *));
437 memcpy(key->offset_defs_mul, offset_defs_mul, offset_def_count * sizeof(uint64_t));
438
439 if (offset_defs != offset_defs_stack)
440 free(offset_defs);
441 if (offset_defs_mul != offset_defs_mul_stack)
442 free(offset_defs_mul);
443
444 return key;
445 }
446
447 static unsigned
448 parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
449 nir_ssa_def *base, uint64_t base_mul, uint64_t *offset)
450 {
451 uint64_t new_mul;
452 uint64_t new_offset;
453 parse_offset(&base, &new_mul, &new_offset);
454 *offset += new_offset * base_mul;
455
456 if (!base)
457 return 0;
458
459 base_mul *= new_mul;
460
461 assert(left >= 1);
462
463 if (left >= 2) {
464 nir_ssa_scalar scalar;
465 scalar.def = base;
466 scalar.comp = 0;
467 if (nir_ssa_scalar_is_alu(scalar) && nir_ssa_scalar_alu_op(scalar) == nir_op_iadd) {
468 nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(scalar, 0);
469 nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(scalar, 1);
470 if (src0.comp == 0 && src1.comp == 0) {
471 unsigned amount = parse_entry_key_from_offset(key, size, left - 1, src0.def, base_mul, offset);
472 amount += parse_entry_key_from_offset(key, size + amount, left - amount, src1.def, base_mul, offset);
473 return amount;
474 }
475 }
476 }
477
478 return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
479 }
480
481 static struct entry_key *
482 create_entry_key_from_offset(void *mem_ctx, nir_ssa_def *base, uint64_t base_mul, uint64_t *offset)
483 {
484 struct entry_key *key = ralloc(mem_ctx, struct entry_key);
485 key->resource = NULL;
486 key->var = NULL;
487 if (base) {
488 nir_ssa_def *offset_defs[32];
489 uint64_t offset_defs_mul[32];
490 key->offset_defs = offset_defs;
491 key->offset_defs_mul = offset_defs_mul;
492
493 key->offset_def_count = parse_entry_key_from_offset(key, 0, 32, base, base_mul, offset);
494
495 key->offset_defs = ralloc_array(mem_ctx, nir_ssa_def *, key->offset_def_count);
496 key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, key->offset_def_count);
497 memcpy(key->offset_defs, offset_defs, key->offset_def_count * sizeof(nir_ssa_def *));
498 memcpy(key->offset_defs_mul, offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
499 } else {
500 key->offset_def_count = 0;
501 key->offset_defs = NULL;
502 key->offset_defs_mul = NULL;
503 }
504 return key;
505 }
506
507 static nir_variable_mode
508 get_variable_mode(struct entry *entry)
509 {
510 if (entry->info->mode)
511 return entry->info->mode;
512 assert(entry->deref);
513 return entry->deref->mode;
514 }
515
516 static struct entry *
517 create_entry(struct vectorize_ctx *ctx,
518 const struct intrinsic_info *info,
519 nir_intrinsic_instr *intrin)
520 {
521 struct entry *entry = rzalloc(ctx, struct entry);
522 entry->intrin = intrin;
523 entry->instr = &intrin->instr;
524 entry->info = info;
525 entry->best_align = UINT32_MAX;
526 entry->is_store = entry->info->value_src >= 0;
527
528 if (entry->info->deref_src >= 0) {
529 entry->deref = nir_src_as_deref(intrin->src[entry->info->deref_src]);
530 nir_deref_path path;
531 nir_deref_path_init(&path, entry->deref, NULL);
532 entry->key = create_entry_key_from_deref(entry, ctx, &path, &entry->offset);
533 nir_deref_path_finish(&path);
534 } else {
535 nir_ssa_def *base = entry->info->base_src >= 0 ?
536 intrin->src[entry->info->base_src].ssa : NULL;
537 uint64_t offset = 0;
538 if (nir_intrinsic_infos[intrin->intrinsic].index_map[NIR_INTRINSIC_BASE])
539 offset += nir_intrinsic_base(intrin);
540 entry->key = create_entry_key_from_offset(entry, base, 1, &offset);
541 entry->offset = offset;
542
543 if (base)
544 entry->offset = mask_sign_extend(entry->offset, base->bit_size);
545 }
546
547 if (entry->info->resource_src >= 0)
548 entry->key->resource = intrin->src[entry->info->resource_src].ssa;
549
550 if (nir_intrinsic_infos[intrin->intrinsic].index_map[NIR_INTRINSIC_ACCESS])
551 entry->access = nir_intrinsic_access(intrin);
552 else if (entry->key->var)
553 entry->access = entry->key->var->data.access;
554
555 uint32_t restrict_modes = nir_var_shader_in | nir_var_shader_out;
556 restrict_modes |= nir_var_shader_temp | nir_var_function_temp;
557 restrict_modes |= nir_var_uniform | nir_var_mem_push_const;
558 restrict_modes |= nir_var_system_value | nir_var_mem_shared;
559 if (get_variable_mode(entry) & restrict_modes)
560 entry->access |= ACCESS_RESTRICT;
561
562 return entry;
563 }
564
565 static nir_deref_instr *
566 cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref_instr *deref)
567 {
568 if (glsl_get_components(deref->type) == num_components &&
569 type_scalar_size_bytes(deref->type)*8u == bit_size)
570 return deref;
571
572 enum glsl_base_type types[] = {
573 GLSL_TYPE_UINT8, GLSL_TYPE_UINT16, GLSL_TYPE_UINT, GLSL_TYPE_UINT64};
574 enum glsl_base_type base = types[ffs(bit_size / 8u) - 1u];
575 const struct glsl_type *type = glsl_vector_type(base, num_components);
576
577 if (deref->type == type)
578 return deref;
579
580 return nir_build_deref_cast(b, &deref->dest.ssa, deref->mode, type, 0);
581 }
582
583 /* Return true if the write mask "write_mask" of a store with "old_bit_size"
584 * bits per element can be represented for a store with "new_bit_size" bits per
585 * element. */
586 static bool
587 writemask_representable(unsigned write_mask, unsigned old_bit_size, unsigned new_bit_size)
588 {
589 while (write_mask) {
590 int start, count;
591 u_bit_scan_consecutive_range(&write_mask, &start, &count);
592 start *= old_bit_size;
593 count *= old_bit_size;
594 if (start % new_bit_size != 0)
595 return false;
596 if (count % new_bit_size != 0)
597 return false;
598 }
599 return true;
600 }
601
602 static uint64_t
603 gcd(uint64_t a, uint64_t b)
604 {
605 while (b) {
606 uint64_t old_b = b;
607 b = a % b;
608 a = old_b;
609 }
610 return a;
611 }
612
613 static uint32_t
614 get_best_align(struct entry *entry)
615 {
616 if (entry->best_align != UINT32_MAX)
617 return entry->best_align;
618
619 uint64_t best_align = entry->offset;
620 for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
621 if (!best_align)
622 best_align = entry->key->offset_defs_mul[i];
623 else if (entry->key->offset_defs_mul[i])
624 best_align = gcd(best_align, entry->key->offset_defs_mul[i]);
625 }
626
627 if (nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL])
628 best_align = MAX2(best_align, nir_intrinsic_align(entry->intrin));
629
630 /* ensure the result is a power of two that fits in a int32_t */
631 entry->best_align = gcd(best_align, 1u << 30);
632
633 return entry->best_align;
634 }
635
636 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
637 * of "low" and "high". */
638 static bool
639 new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
640 struct entry *low, struct entry *high, unsigned size)
641 {
642 if (size % new_bit_size != 0)
643 return false;
644
645 unsigned new_num_components = size / new_bit_size;
646 if (new_num_components > NIR_MAX_VEC_COMPONENTS)
647 return false;
648
649 unsigned high_offset = high->offset_signed - low->offset_signed;
650
651 /* check nir_extract_bits limitations */
652 unsigned common_bit_size = MIN2(get_bit_size(low), get_bit_size(high));
653 common_bit_size = MIN2(common_bit_size, new_bit_size);
654 if (high_offset > 0)
655 common_bit_size = MIN2(common_bit_size, (1u << (ffs(high_offset * 8) - 1)));
656 if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
657 return false;
658
659 if (!ctx->callback(get_best_align(low), new_bit_size, new_num_components,
660 high_offset, low->intrin, high->intrin))
661 return false;
662
663 if (low->is_store) {
664 unsigned low_size = low->intrin->num_components * get_bit_size(low);
665 unsigned high_size = high->intrin->num_components * get_bit_size(high);
666
667 if (low_size % new_bit_size != 0)
668 return false;
669 if (high_size % new_bit_size != 0)
670 return false;
671
672 unsigned write_mask = nir_intrinsic_write_mask(low->intrin);
673 if (!writemask_representable(write_mask, low_size, new_bit_size))
674 return false;
675
676 write_mask = nir_intrinsic_write_mask(high->intrin);
677 if (!writemask_representable(write_mask, high_size, new_bit_size))
678 return false;
679 }
680
681 return true;
682 }
683
684 /* Updates a write mask, "write_mask", so that it can be used with a
685 * "new_bit_size"-bit store instead of a "old_bit_size"-bit store. */
686 static uint32_t
687 update_writemask(unsigned write_mask, unsigned old_bit_size, unsigned new_bit_size)
688 {
689 uint32_t res = 0;
690 while (write_mask) {
691 int start, count;
692 u_bit_scan_consecutive_range(&write_mask, &start, &count);
693 start = start * old_bit_size / new_bit_size;
694 count = count * old_bit_size / new_bit_size;
695 res |= ((1 << count) - 1) << start;
696 }
697 return res;
698 }
699
700 static nir_deref_instr *subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
701 {
702 /* avoid adding another deref to the path */
703 if (deref->deref_type == nir_deref_type_ptr_as_array &&
704 nir_src_is_const(deref->arr.index) &&
705 offset % nir_deref_instr_ptr_as_array_stride(deref) == 0) {
706 unsigned stride = nir_deref_instr_ptr_as_array_stride(deref);
707 nir_ssa_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
708 deref->dest.ssa.bit_size);
709 return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
710 }
711
712 if (deref->deref_type == nir_deref_type_array &&
713 nir_src_is_const(deref->arr.index)) {
714 nir_deref_instr *parent = nir_deref_instr_parent(deref);
715 unsigned stride = glsl_get_explicit_stride(parent->type);
716 if (offset % stride == 0)
717 return nir_build_deref_array_imm(
718 b, parent, nir_src_as_int(deref->arr.index) - offset / stride);
719 }
720
721
722 deref = nir_build_deref_cast(b, &deref->dest.ssa, deref->mode,
723 glsl_scalar_type(GLSL_TYPE_UINT8), 1);
724 return nir_build_deref_ptr_as_array(
725 b, deref, nir_imm_intN_t(b, -offset, deref->dest.ssa.bit_size));
726 }
727
728 static bool update_align(struct entry *entry)
729 {
730 bool has_align_index =
731 nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
732 if (has_align_index) {
733 unsigned align = get_best_align(entry);
734 if (align != nir_intrinsic_align(entry->intrin)) {
735 nir_intrinsic_set_align(entry->intrin, align, 0);
736 return true;
737 }
738 }
739 return false;
740 }
741
742 static void
743 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
744 struct entry *low, struct entry *high,
745 struct entry *first, struct entry *second,
746 unsigned new_bit_size, unsigned new_num_components,
747 unsigned high_start)
748 {
749 unsigned low_bit_size = get_bit_size(low);
750 unsigned high_bit_size = get_bit_size(high);
751 bool low_bool = low->intrin->dest.ssa.bit_size == 1;
752 bool high_bool = high->intrin->dest.ssa.bit_size == 1;
753 nir_ssa_def *data = &first->intrin->dest.ssa;
754
755 b->cursor = nir_after_instr(first->instr);
756
757 /* update the load's destination size and extract data for each of the original loads */
758 data->num_components = new_num_components;
759 data->bit_size = new_bit_size;
760
761 nir_ssa_def *low_def = nir_extract_bits(
762 b, &data, 1, 0, low->intrin->num_components, low_bit_size);
763 nir_ssa_def *high_def = nir_extract_bits(
764 b, &data, 1, high_start, high->intrin->num_components, high_bit_size);
765
766 /* convert booleans */
767 low_def = low_bool ? nir_i2b(b, low_def) : nir_mov(b, low_def);
768 high_def = high_bool ? nir_i2b(b, high_def) : nir_mov(b, high_def);
769
770 /* update uses */
771 if (first == low) {
772 nir_ssa_def_rewrite_uses_after(&low->intrin->dest.ssa, nir_src_for_ssa(low_def),
773 high_def->parent_instr);
774 nir_ssa_def_rewrite_uses(&high->intrin->dest.ssa, nir_src_for_ssa(high_def));
775 } else {
776 nir_ssa_def_rewrite_uses(&low->intrin->dest.ssa, nir_src_for_ssa(low_def));
777 nir_ssa_def_rewrite_uses_after(&high->intrin->dest.ssa, nir_src_for_ssa(high_def),
778 high_def->parent_instr);
779 }
780
781 /* update the intrinsic */
782 first->intrin->num_components = new_num_components;
783
784 const struct intrinsic_info *info = first->info;
785
786 /* update the offset */
787 if (first != low && info->base_src >= 0) {
788 /* let nir_opt_algebraic() remove this addition. this doesn't have much
789 * issues with subtracting 16 from expressions like "(i + 1) * 16" because
790 * nir_opt_algebraic() turns them into "i * 16 + 16" */
791 b->cursor = nir_before_instr(first->instr);
792
793 nir_ssa_def *new_base = first->intrin->src[info->base_src].ssa;
794 new_base = nir_iadd(b, new_base, nir_imm_int(b, -(high_start / 8u)));
795
796 nir_instr_rewrite_src(first->instr, &first->intrin->src[info->base_src],
797 nir_src_for_ssa(new_base));
798 }
799
800 /* update the deref */
801 if (info->deref_src >= 0) {
802 b->cursor = nir_before_instr(first->instr);
803
804 nir_deref_instr *deref = nir_src_as_deref(first->intrin->src[info->deref_src]);
805 if (first != low && high_start != 0)
806 deref = subtract_deref(b, deref, high_start / 8u);
807 first->deref = cast_deref(b, new_num_components, new_bit_size, deref);
808
809 nir_instr_rewrite_src(first->instr, &first->intrin->src[info->deref_src],
810 nir_src_for_ssa(&first->deref->dest.ssa));
811 }
812
813 /* update base/align */
814 bool has_base_index =
815 nir_intrinsic_infos[first->intrin->intrinsic].index_map[NIR_INTRINSIC_BASE];
816
817 if (first != low && has_base_index)
818 nir_intrinsic_set_base(first->intrin, nir_intrinsic_base(low->intrin));
819
820 first->key = low->key;
821 first->offset = low->offset;
822 first->best_align = get_best_align(low);
823
824 update_align(first);
825
826 nir_instr_remove(second->instr);
827 }
828
829 static void
830 vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
831 struct entry *low, struct entry *high,
832 struct entry *first, struct entry *second,
833 unsigned new_bit_size, unsigned new_num_components,
834 unsigned high_start)
835 {
836 ASSERTED unsigned low_size = low->intrin->num_components * get_bit_size(low);
837 assert(low_size % new_bit_size == 0);
838
839 b->cursor = nir_before_instr(second->instr);
840
841 /* get new writemasks */
842 uint32_t low_write_mask = nir_intrinsic_write_mask(low->intrin);
843 uint32_t high_write_mask = nir_intrinsic_write_mask(high->intrin);
844 low_write_mask = update_writemask(low_write_mask, get_bit_size(low), new_bit_size);
845 high_write_mask = update_writemask(high_write_mask, get_bit_size(high), new_bit_size);
846 high_write_mask <<= high_start / new_bit_size;
847
848 uint32_t write_mask = low_write_mask | high_write_mask;
849
850 /* convert booleans */
851 nir_ssa_def *low_val = low->intrin->src[low->info->value_src].ssa;
852 nir_ssa_def *high_val = high->intrin->src[high->info->value_src].ssa;
853 low_val = low_val->bit_size == 1 ? nir_b2i(b, low_val, 32) : low_val;
854 high_val = high_val->bit_size == 1 ? nir_b2i(b, high_val, 32) : high_val;
855
856 /* combine the data */
857 nir_ssa_def *data_channels[NIR_MAX_VEC_COMPONENTS];
858 for (unsigned i = 0; i < new_num_components; i++) {
859 bool set_low = low_write_mask & (1 << i);
860 bool set_high = high_write_mask & (1 << i);
861
862 if (set_low && (!set_high || low == second)) {
863 unsigned offset = i * new_bit_size;
864 data_channels[i] = nir_extract_bits(b, &low_val, 1, offset, 1, new_bit_size);
865 } else if (set_high) {
866 assert(!set_low || high == second);
867 unsigned offset = i * new_bit_size - high_start;
868 data_channels[i] = nir_extract_bits(b, &high_val, 1, offset, 1, new_bit_size);
869 } else {
870 data_channels[i] = nir_ssa_undef(b, 1, new_bit_size);
871 }
872 }
873 nir_ssa_def *data = nir_vec(b, data_channels, new_num_components);
874
875 /* update the intrinsic */
876 nir_intrinsic_set_write_mask(second->intrin, write_mask);
877 second->intrin->num_components = data->num_components;
878
879 const struct intrinsic_info *info = second->info;
880 assert(info->value_src >= 0);
881 nir_instr_rewrite_src(second->instr, &second->intrin->src[info->value_src],
882 nir_src_for_ssa(data));
883
884 /* update the offset */
885 if (second != low && info->base_src >= 0)
886 nir_instr_rewrite_src(second->instr, &second->intrin->src[info->base_src],
887 low->intrin->src[info->base_src]);
888
889 /* update the deref */
890 if (info->deref_src >= 0) {
891 b->cursor = nir_before_instr(second->instr);
892 second->deref = cast_deref(b, new_num_components, new_bit_size,
893 nir_src_as_deref(low->intrin->src[info->deref_src]));
894 nir_instr_rewrite_src(second->instr, &second->intrin->src[info->deref_src],
895 nir_src_for_ssa(&second->deref->dest.ssa));
896 }
897
898 /* update base/align */
899 bool has_base_index =
900 nir_intrinsic_infos[second->intrin->intrinsic].index_map[NIR_INTRINSIC_BASE];
901
902 if (second != low && has_base_index)
903 nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
904
905 second->key = low->key;
906 second->offset = low->offset;
907 second->best_align = get_best_align(low);
908
909 update_align(second);
910
911 list_del(&first->head);
912 nir_instr_remove(first->instr);
913 }
914
915 /* Returns true if it can prove that "a" and "b" point to different resources. */
916 static bool
917 resources_different(nir_ssa_def *a, nir_ssa_def *b)
918 {
919 if (!a || !b)
920 return false;
921
922 if (a->parent_instr->type == nir_instr_type_load_const &&
923 b->parent_instr->type == nir_instr_type_load_const) {
924 return nir_src_as_uint(nir_src_for_ssa(a)) != nir_src_as_uint(nir_src_for_ssa(b));
925 }
926
927 if (a->parent_instr->type == nir_instr_type_intrinsic &&
928 b->parent_instr->type == nir_instr_type_intrinsic) {
929 nir_intrinsic_instr *aintrin = nir_instr_as_intrinsic(a->parent_instr);
930 nir_intrinsic_instr *bintrin = nir_instr_as_intrinsic(b->parent_instr);
931 if (aintrin->intrinsic == nir_intrinsic_vulkan_resource_index &&
932 bintrin->intrinsic == nir_intrinsic_vulkan_resource_index) {
933 return nir_intrinsic_desc_set(aintrin) != nir_intrinsic_desc_set(bintrin) ||
934 nir_intrinsic_binding(aintrin) != nir_intrinsic_binding(bintrin) ||
935 resources_different(aintrin->src[0].ssa, bintrin->src[0].ssa);
936 }
937 }
938
939 return false;
940 }
941
942 static int64_t
943 compare_entries(struct entry *a, struct entry *b)
944 {
945 if (!entry_key_equals(a->key, b->key))
946 return INT64_MAX;
947 return b->offset_signed - a->offset_signed;
948 }
949
950 static bool
951 may_alias(struct entry *a, struct entry *b)
952 {
953 assert(get_variable_mode(a) == get_variable_mode(b));
954
955 /* if the resources/variables are definitively different and both have
956 * ACCESS_RESTRICT, we can assume they do not alias. */
957 bool res_different = a->key->var != b->key->var ||
958 resources_different(a->key->resource, b->key->resource);
959 if (res_different && (a->access & ACCESS_RESTRICT) && (b->access & ACCESS_RESTRICT))
960 return false;
961
962 /* we can't compare offsets if the resources/variables might be different */
963 if (a->key->var != b->key->var || a->key->resource != b->key->resource)
964 return true;
965
966 /* use adjacency information */
967 /* TODO: we can look closer at the entry keys */
968 int64_t diff = compare_entries(a, b);
969 if (diff != INT64_MAX) {
970 /* with atomics, intrin->num_components can be 0 */
971 if (diff < 0)
972 return llabs(diff) < MAX2(b->intrin->num_components, 1u) * (get_bit_size(b) / 8u);
973 else
974 return diff < MAX2(a->intrin->num_components, 1u) * (get_bit_size(a) / 8u);
975 }
976
977 /* TODO: we can use deref information */
978
979 return true;
980 }
981
982 static bool
983 check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
984 {
985 nir_variable_mode mode = get_variable_mode(first);
986 if (mode & (nir_var_uniform | nir_var_system_value |
987 nir_var_mem_push_const | nir_var_mem_ubo))
988 return false;
989
990 unsigned mode_index = ffs(mode) - 1;
991 if (first->is_store) {
992 /* find first entry that aliases "first" */
993 list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
994 if (next == first)
995 continue;
996 if (next == second)
997 return false;
998 if (may_alias(first, next))
999 return true;
1000 }
1001 } else {
1002 /* find previous store that aliases this load */
1003 list_for_each_entry_from_rev(struct entry, prev, second, &ctx->entries[mode_index], head) {
1004 if (prev == second)
1005 continue;
1006 if (prev == first)
1007 return false;
1008 if (prev->is_store && may_alias(second, prev))
1009 return true;
1010 }
1011 }
1012
1013 return false;
1014 }
1015
1016 static bool
1017 is_strided_vector(const struct glsl_type *type)
1018 {
1019 if (glsl_type_is_vector(type)) {
1020 return glsl_get_explicit_stride(type) !=
1021 type_scalar_size_bytes(glsl_get_array_element(type));
1022 } else {
1023 return false;
1024 }
1025 }
1026
1027 static bool
1028 try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
1029 struct entry *low, struct entry *high,
1030 struct entry *first, struct entry *second)
1031 {
1032 if (check_for_aliasing(ctx, first, second))
1033 return false;
1034
1035 /* we can only vectorize non-volatile loads/stores of the same type and with
1036 * the same access */
1037 if (first->info != second->info || first->access != second->access ||
1038 (first->access & ACCESS_VOLATILE) || first->info->is_atomic)
1039 return false;
1040
1041 /* don't attempt to vectorize accesses of row-major matrix columns */
1042 if (first->deref) {
1043 const struct glsl_type *first_type = first->deref->type;
1044 const struct glsl_type *second_type = second->deref->type;
1045 if (is_strided_vector(first_type) || is_strided_vector(second_type))
1046 return false;
1047 }
1048
1049 /* gather information */
1050 uint64_t diff = high->offset_signed - low->offset_signed;
1051 unsigned low_bit_size = get_bit_size(low);
1052 unsigned high_bit_size = get_bit_size(high);
1053 unsigned low_size = low->intrin->num_components * low_bit_size;
1054 unsigned high_size = high->intrin->num_components * high_bit_size;
1055 unsigned new_size = MAX2(diff * 8u + high_size, low_size);
1056
1057 /* find a good bit size for the new load/store */
1058 unsigned new_bit_size = 0;
1059 if (new_bitsize_acceptable(ctx, low_bit_size, low, high, new_size)) {
1060 new_bit_size = low_bit_size;
1061 } else if (low_bit_size != high_bit_size &&
1062 new_bitsize_acceptable(ctx, high_bit_size, low, high, new_size)) {
1063 new_bit_size = high_bit_size;
1064 } else {
1065 new_bit_size = 64;
1066 for (; new_bit_size >= 8; new_bit_size /= 2) {
1067 /* don't repeat trying out bitsizes */
1068 if (new_bit_size == low_bit_size || new_bit_size == high_bit_size)
1069 continue;
1070 if (new_bitsize_acceptable(ctx, new_bit_size, low, high, new_size))
1071 break;
1072 }
1073 if (new_bit_size < 8)
1074 return false;
1075 }
1076 unsigned new_num_components = new_size / new_bit_size;
1077
1078 /* vectorize the loads/stores */
1079 nir_builder b;
1080 nir_builder_init(&b, impl);
1081
1082 if (first->is_store)
1083 vectorize_stores(&b, ctx, low, high, first, second,
1084 new_bit_size, new_num_components, diff * 8u);
1085 else
1086 vectorize_loads(&b, ctx, low, high, first, second,
1087 new_bit_size, new_num_components, diff * 8u);
1088
1089 return true;
1090 }
1091
1092 static bool
1093 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
1094 {
1095 if (!ht)
1096 return false;
1097
1098 bool progress = false;
1099 hash_table_foreach(ht, entry) {
1100 struct util_dynarray *arr = entry->data;
1101 if (!arr->size)
1102 continue;
1103
1104 qsort(util_dynarray_begin(arr),
1105 util_dynarray_num_elements(arr, struct entry *),
1106 sizeof(struct entry *), &sort_entries);
1107
1108 unsigned i = 0;
1109 for (; i < util_dynarray_num_elements(arr, struct entry*) - 1; i++) {
1110 struct entry *low = *util_dynarray_element(arr, struct entry *, i);
1111 struct entry *high = *util_dynarray_element(arr, struct entry *, i + 1);
1112
1113 uint64_t diff = high->offset_signed - low->offset_signed;
1114 if (diff > get_bit_size(low) / 8u * low->intrin->num_components) {
1115 progress |= update_align(low);
1116 continue;
1117 }
1118
1119 struct entry *first = low->index < high->index ? low : high;
1120 struct entry *second = low->index < high->index ? high : low;
1121
1122 if (try_vectorize(impl, ctx, low, high, first, second)) {
1123 *util_dynarray_element(arr, struct entry *, i) = NULL;
1124 *util_dynarray_element(arr, struct entry *, i + 1) = low->is_store ? second : first;
1125 progress = true;
1126 } else {
1127 progress |= update_align(low);
1128 }
1129 }
1130
1131 struct entry *last = *util_dynarray_element(arr, struct entry *, i);
1132 progress |= update_align(last);
1133 }
1134
1135 _mesa_hash_table_clear(ht, delete_entry_dynarray);
1136
1137 return progress;
1138 }
1139
1140 static bool
1141 handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *impl, nir_instr *instr)
1142 {
1143 unsigned modes = 0;
1144 bool acquire = true;
1145 bool release = true;
1146 if (instr->type == nir_instr_type_intrinsic) {
1147 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1148 switch (intrin->intrinsic) {
1149 case nir_intrinsic_group_memory_barrier:
1150 case nir_intrinsic_memory_barrier:
1151 modes = nir_var_mem_ssbo | nir_var_mem_shared | nir_var_mem_global;
1152 break;
1153 /* prevent speculative loads/stores */
1154 case nir_intrinsic_discard_if:
1155 case nir_intrinsic_discard:
1156 modes = nir_var_all;
1157 break;
1158 case nir_intrinsic_memory_barrier_buffer:
1159 modes = nir_var_mem_ssbo | nir_var_mem_global;
1160 break;
1161 case nir_intrinsic_memory_barrier_shared:
1162 modes = nir_var_mem_shared;
1163 break;
1164 case nir_intrinsic_scoped_memory_barrier:
1165 modes = nir_intrinsic_memory_modes(intrin);
1166 acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
1167 release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
1168 switch (nir_intrinsic_memory_scope(intrin)) {
1169 case NIR_SCOPE_INVOCATION:
1170 case NIR_SCOPE_SUBGROUP:
1171 /* a barier should never be required for correctness with these scopes */
1172 modes = 0;
1173 break;
1174 default:
1175 break;
1176 }
1177 break;
1178 default:
1179 return false;
1180 }
1181 } else if (instr->type == nir_instr_type_call) {
1182 modes = nir_var_all;
1183 } else {
1184 return false;
1185 }
1186
1187 while (modes) {
1188 unsigned mode_index = u_bit_scan(&modes);
1189
1190 if (acquire)
1191 *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
1192 if (release)
1193 *progress |= vectorize_entries(ctx, impl, ctx->stores[mode_index]);
1194 }
1195
1196 return true;
1197 }
1198
1199 static bool
1200 process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *block)
1201 {
1202 bool progress = false;
1203
1204 for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1205 list_inithead(&ctx->entries[i]);
1206 if (ctx->loads[i])
1207 _mesa_hash_table_clear(ctx->loads[i], delete_entry_dynarray);
1208 if (ctx->stores[i])
1209 _mesa_hash_table_clear(ctx->stores[i], delete_entry_dynarray);
1210 }
1211
1212 /* create entries */
1213 unsigned next_index = 0;
1214
1215 nir_foreach_instr_safe(instr, block) {
1216 if (handle_barrier(ctx, &progress, impl, instr))
1217 continue;
1218
1219 /* gather information */
1220 if (instr->type != nir_instr_type_intrinsic)
1221 continue;
1222 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1223
1224 const struct intrinsic_info *info = get_info(intrin->intrinsic);
1225 if (!info)
1226 continue;
1227
1228 nir_variable_mode mode = info->mode;
1229 if (!mode)
1230 mode = nir_src_as_deref(intrin->src[info->deref_src])->mode;
1231 if (!(mode & ctx->modes))
1232 continue;
1233 unsigned mode_index = ffs(mode) - 1;
1234
1235 /* create entry */
1236 struct entry *entry = create_entry(ctx, info, intrin);
1237 entry->index = next_index++;
1238
1239 list_addtail(&entry->head, &ctx->entries[mode_index]);
1240
1241 /* add the entry to a hash table */
1242
1243 struct hash_table *adj_ht = NULL;
1244 if (entry->is_store) {
1245 if (!ctx->stores[mode_index])
1246 ctx->stores[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1247 adj_ht = ctx->stores[mode_index];
1248 } else {
1249 if (!ctx->loads[mode_index])
1250 ctx->loads[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1251 adj_ht = ctx->loads[mode_index];
1252 }
1253
1254 uint32_t key_hash = hash_entry_key(entry->key);
1255 struct hash_entry *adj_entry = _mesa_hash_table_search_pre_hashed(adj_ht, key_hash, entry->key);
1256 struct util_dynarray *arr;
1257 if (adj_entry && adj_entry->data) {
1258 arr = (struct util_dynarray *)adj_entry->data;
1259 } else {
1260 arr = ralloc(ctx, struct util_dynarray);
1261 util_dynarray_init(arr, arr);
1262 _mesa_hash_table_insert_pre_hashed(adj_ht, key_hash, entry->key, arr);
1263 }
1264 util_dynarray_append(arr, struct entry *, entry);
1265 }
1266
1267 /* sort and combine entries */
1268 for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1269 progress |= vectorize_entries(ctx, impl, ctx->loads[i]);
1270 progress |= vectorize_entries(ctx, impl, ctx->stores[i]);
1271 }
1272
1273 return progress;
1274 }
1275
1276 bool
1277 nir_opt_load_store_vectorize(nir_shader *shader, nir_variable_mode modes,
1278 nir_should_vectorize_mem_func callback)
1279 {
1280 bool progress = false;
1281
1282 struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
1283 ctx->modes = modes;
1284 ctx->callback = callback;
1285
1286 nir_index_vars(shader, NULL, modes);
1287
1288 nir_foreach_function(function, shader) {
1289 if (function->impl) {
1290 if (modes & nir_var_function_temp)
1291 nir_index_vars(shader, function->impl, nir_var_function_temp);
1292
1293 nir_foreach_block(block, function->impl)
1294 progress |= process_block(function->impl, ctx, block);
1295
1296 nir_metadata_preserve(function->impl,
1297 nir_metadata_block_index |
1298 nir_metadata_dominance |
1299 nir_metadata_live_ssa_defs);
1300 }
1301 }
1302
1303 ralloc_free(ctx);
1304 return progress;
1305 }