Added few more stubs so that control reaches to DestroyDevice().
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 /**
28 * \file nir_opt_intrinsics.c
29 */
30
31 static nir_intrinsic_instr *
32 lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
33 unsigned int component)
34 {
35 nir_ssa_def *comp;
36 if (component == 0)
37 comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
38 else
39 comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
40
41 nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
42 nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL);
43 intr->const_index[0] = intrin->const_index[0];
44 intr->const_index[1] = intrin->const_index[1];
45 intr->src[0] = nir_src_for_ssa(comp);
46 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
47 nir_src_copy(&intr->src[1], &intrin->src[1], intr);
48
49 intr->num_components = 1;
50 nir_builder_instr_insert(b, &intr->instr);
51 return intr;
52 }
53
54 static nir_ssa_def *
55 lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
56 {
57 assert(intrin->src[0].ssa->bit_size == 64);
58 nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
59 nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
60 return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa);
61 }
62
63 static nir_ssa_def *
64 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
65 {
66 /* We only use this on uvec4 types */
67 assert(value->num_components == 4 && value->bit_size == 32);
68
69 if (bit_size == 32) {
70 return nir_channel(b, value, 0);
71 } else {
72 assert(bit_size == 64);
73 return nir_pack_64_2x32_split(b, nir_channel(b, value, 0),
74 nir_channel(b, value, 1));
75 }
76 }
77
78 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
79 static nir_ssa_def *
80 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
81 unsigned num_components, unsigned bit_size)
82 {
83 assert(value->num_components == 1);
84 assert(value->bit_size == 32 || value->bit_size == 64);
85
86 nir_ssa_def *zero = nir_imm_int(b, 0);
87 if (num_components > 1) {
88 /* SPIR-V uses a uvec4 for ballot values */
89 assert(num_components == 4);
90 assert(bit_size == 32);
91
92 if (value->bit_size == 32) {
93 return nir_vec4(b, value, zero, zero, zero);
94 } else {
95 assert(value->bit_size == 64);
96 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
97 nir_unpack_64_2x32_split_y(b, value),
98 zero, zero);
99 }
100 } else {
101 /* GLSL uses a uint64_t for ballot values */
102 assert(num_components == 1);
103 assert(bit_size == 64);
104
105 if (value->bit_size == 32) {
106 return nir_pack_64_2x32_split(b, value, zero);
107 } else {
108 assert(value->bit_size == 64);
109 return value;
110 }
111 }
112 }
113
114 static nir_ssa_def *
115 lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
116 bool lower_to_32bit)
117 {
118 /* This is safe to call on scalar things but it would be silly */
119 assert(intrin->dest.ssa.num_components > 1);
120
121 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
122 intrin->num_components);
123 nir_ssa_def *reads[4];
124
125 for (unsigned i = 0; i < intrin->num_components; i++) {
126 nir_intrinsic_instr *chan_intrin =
127 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
128 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
129 1, intrin->dest.ssa.bit_size, NULL);
130 chan_intrin->num_components = 1;
131
132 /* value */
133 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
134 /* invocation */
135 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
136 assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
137 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
138 }
139
140 chan_intrin->const_index[0] = intrin->const_index[0];
141 chan_intrin->const_index[1] = intrin->const_index[1];
142
143 if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) {
144 reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin);
145 } else {
146 nir_builder_instr_insert(b, &chan_intrin->instr);
147 reads[i] = &chan_intrin->dest.ssa;
148 }
149 }
150
151 return nir_vec(b, reads, intrin->num_components);
152 }
153
154 static nir_ssa_def *
155 lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
156 {
157 assert(intrin->src[0].is_ssa);
158 nir_ssa_def *value = intrin->src[0].ssa;
159
160 nir_ssa_def *result = NULL;
161 for (unsigned i = 0; i < intrin->num_components; i++) {
162 nir_intrinsic_instr *chan_intrin =
163 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
164 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
165 1, intrin->dest.ssa.bit_size, NULL);
166 chan_intrin->num_components = 1;
167 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
168 nir_builder_instr_insert(b, &chan_intrin->instr);
169
170 if (result) {
171 result = nir_iand(b, result, &chan_intrin->dest.ssa);
172 } else {
173 result = &chan_intrin->dest.ssa;
174 }
175 }
176
177 return result;
178 }
179
180 static nir_ssa_def *
181 lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
182 const nir_lower_subgroups_options *options)
183 {
184 assert(intrin->src[0].is_ssa);
185 nir_ssa_def *value = intrin->src[0].ssa;
186
187 /* We have to implicitly lower to scalar */
188 nir_ssa_def *all_eq = NULL;
189 for (unsigned i = 0; i < intrin->num_components; i++) {
190 nir_intrinsic_instr *rfi =
191 nir_intrinsic_instr_create(b->shader,
192 nir_intrinsic_read_first_invocation);
193 nir_ssa_dest_init(&rfi->instr, &rfi->dest,
194 1, value->bit_size, NULL);
195 rfi->num_components = 1;
196 rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
197 nir_builder_instr_insert(b, &rfi->instr);
198
199 nir_ssa_def *is_eq;
200 if (intrin->intrinsic == nir_intrinsic_vote_feq) {
201 is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i));
202 } else {
203 is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i));
204 }
205
206 if (all_eq == NULL) {
207 all_eq = is_eq;
208 } else {
209 all_eq = nir_iand(b, all_eq, is_eq);
210 }
211 }
212
213 nir_intrinsic_instr *ballot =
214 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
215 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
216 1, options->ballot_bit_size, NULL);
217 ballot->num_components = 1;
218 ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq));
219 nir_builder_instr_insert(b, &ballot->instr);
220
221 return nir_ieq(b, &ballot->dest.ssa,
222 nir_imm_intN_t(b, 0, options->ballot_bit_size));
223 }
224
225 static nir_ssa_def *
226 lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin,
227 const nir_lower_subgroups_options *options)
228 {
229 unsigned mask = nir_src_as_uint(intrin->src[1]);
230
231 if (mask >= 32)
232 return NULL;
233
234 nir_intrinsic_instr *swizzle = nir_intrinsic_instr_create(
235 b->shader, nir_intrinsic_masked_swizzle_amd);
236 swizzle->num_components = intrin->num_components;
237 nir_src_copy(&swizzle->src[0], &intrin->src[0], swizzle);
238 nir_intrinsic_set_swizzle_mask(swizzle, (mask << 10) | 0x1f);
239 nir_ssa_dest_init(&swizzle->instr, &swizzle->dest,
240 intrin->dest.ssa.num_components,
241 intrin->dest.ssa.bit_size, NULL);
242
243 if (options->lower_to_scalar && swizzle->num_components > 1) {
244 return lower_subgroup_op_to_scalar(b, swizzle, options->lower_shuffle_to_32bit);
245 } else if (options->lower_shuffle_to_32bit && swizzle->src[0].ssa->bit_size == 64) {
246 return lower_subgroup_op_to_32bit(b, swizzle);
247 } else {
248 nir_builder_instr_insert(b, &swizzle->instr);
249 return &swizzle->dest.ssa;
250 }
251 }
252
253 static nir_ssa_def *
254 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
255 const nir_lower_subgroups_options *options)
256 {
257 if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
258 options->lower_shuffle_to_swizzle_amd &&
259 nir_src_is_const(intrin->src[1])) {
260 nir_ssa_def *result =
261 lower_shuffle_to_swizzle(b, intrin, options);
262 if (result)
263 return result;
264 }
265
266 nir_ssa_def *index = nir_load_subgroup_invocation(b);
267 bool is_shuffle = false;
268 switch (intrin->intrinsic) {
269 case nir_intrinsic_shuffle_xor:
270 assert(intrin->src[1].is_ssa);
271 index = nir_ixor(b, index, intrin->src[1].ssa);
272 is_shuffle = true;
273 break;
274 case nir_intrinsic_shuffle_up:
275 assert(intrin->src[1].is_ssa);
276 index = nir_isub(b, index, intrin->src[1].ssa);
277 is_shuffle = true;
278 break;
279 case nir_intrinsic_shuffle_down:
280 assert(intrin->src[1].is_ssa);
281 index = nir_iadd(b, index, intrin->src[1].ssa);
282 is_shuffle = true;
283 break;
284 case nir_intrinsic_quad_broadcast:
285 assert(intrin->src[1].is_ssa);
286 index = nir_ior(b, nir_iand(b, index, nir_imm_int(b, ~0x3)),
287 intrin->src[1].ssa);
288 break;
289 case nir_intrinsic_quad_swap_horizontal:
290 /* For Quad operations, subgroups are divided into quads where
291 * (invocation % 4) is the index to a square arranged as follows:
292 *
293 * +---+---+
294 * | 0 | 1 |
295 * +---+---+
296 * | 2 | 3 |
297 * +---+---+
298 */
299 index = nir_ixor(b, index, nir_imm_int(b, 0x1));
300 break;
301 case nir_intrinsic_quad_swap_vertical:
302 index = nir_ixor(b, index, nir_imm_int(b, 0x2));
303 break;
304 case nir_intrinsic_quad_swap_diagonal:
305 index = nir_ixor(b, index, nir_imm_int(b, 0x3));
306 break;
307 default:
308 unreachable("Invalid intrinsic");
309 }
310
311 nir_intrinsic_instr *shuffle =
312 nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle);
313 shuffle->num_components = intrin->num_components;
314 nir_src_copy(&shuffle->src[0], &intrin->src[0], shuffle);
315 shuffle->src[1] = nir_src_for_ssa(index);
316 nir_ssa_dest_init(&shuffle->instr, &shuffle->dest,
317 intrin->dest.ssa.num_components,
318 intrin->dest.ssa.bit_size, NULL);
319
320 bool lower_to_32bit = options->lower_shuffle_to_32bit && is_shuffle;
321 if (options->lower_to_scalar && shuffle->num_components > 1) {
322 return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
323 } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
324 return lower_subgroup_op_to_32bit(b, shuffle);
325 } else {
326 nir_builder_instr_insert(b, &shuffle->instr);
327 return &shuffle->dest.ssa;
328 }
329 }
330
331 static bool
332 lower_subgroups_filter(const nir_instr *instr, const void *_options)
333 {
334 return instr->type == nir_instr_type_intrinsic;
335 }
336
337 static nir_ssa_def *
338 build_subgroup_mask(nir_builder *b, unsigned bit_size,
339 const nir_lower_subgroups_options *options)
340 {
341 return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
342 nir_isub(b, nir_imm_int(b, bit_size),
343 nir_load_subgroup_size(b)));
344 }
345
346 static nir_ssa_def *
347 lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
348 const nir_lower_subgroups_options *options)
349 {
350 if (!options->lower_quad_broadcast_dynamic_to_const)
351 return lower_shuffle(b, intrin, options);
352
353 nir_ssa_def *dst = NULL;
354
355 for (unsigned i = 0; i < 4; ++i) {
356 nir_intrinsic_instr *qbcst =
357 nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast);
358
359 qbcst->num_components = intrin->num_components;
360 qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i));
361 nir_src_copy(&qbcst->src[0], &intrin->src[0], qbcst);
362 nir_ssa_dest_init(&qbcst->instr, &qbcst->dest,
363 intrin->dest.ssa.num_components,
364 intrin->dest.ssa.bit_size, NULL);
365
366 nir_ssa_def *qbcst_dst = NULL;
367
368 if (options->lower_to_scalar && qbcst->num_components > 1) {
369 qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
370 } else {
371 nir_builder_instr_insert(b, &qbcst->instr);
372 qbcst_dst = &qbcst->dest.ssa;
373 }
374
375 if (i)
376 dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa,
377 nir_src_for_ssa(nir_imm_int(b, i)).ssa),
378 qbcst_dst, dst);
379 else
380 dst = qbcst_dst;
381 }
382
383 return dst;
384 }
385
386 static nir_ssa_def *
387 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
388 {
389 const nir_lower_subgroups_options *options = _options;
390
391 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
392 switch (intrin->intrinsic) {
393 case nir_intrinsic_vote_any:
394 case nir_intrinsic_vote_all:
395 if (options->lower_vote_trivial)
396 return nir_ssa_for_src(b, intrin->src[0], 1);
397 break;
398
399 case nir_intrinsic_vote_feq:
400 case nir_intrinsic_vote_ieq:
401 if (options->lower_vote_trivial)
402 return nir_imm_true(b);
403
404 if (options->lower_vote_eq_to_ballot)
405 return lower_vote_eq_to_ballot(b, intrin, options);
406
407 if (options->lower_to_scalar && intrin->num_components > 1)
408 return lower_vote_eq_to_scalar(b, intrin);
409 break;
410
411 case nir_intrinsic_load_subgroup_size:
412 if (options->subgroup_size)
413 return nir_imm_int(b, options->subgroup_size);
414 break;
415
416 case nir_intrinsic_read_invocation:
417 case nir_intrinsic_read_first_invocation:
418 if (options->lower_to_scalar && intrin->num_components > 1)
419 return lower_subgroup_op_to_scalar(b, intrin, false);
420 break;
421
422 case nir_intrinsic_load_subgroup_eq_mask:
423 case nir_intrinsic_load_subgroup_ge_mask:
424 case nir_intrinsic_load_subgroup_gt_mask:
425 case nir_intrinsic_load_subgroup_le_mask:
426 case nir_intrinsic_load_subgroup_lt_mask: {
427 if (!options->lower_subgroup_masks)
428 return NULL;
429
430 /* If either the result or the requested bit size is 64-bits then we
431 * know that we have 64-bit types and using them will probably be more
432 * efficient than messing around with 32-bit shifts and packing.
433 */
434 const unsigned bit_size = MAX2(options->ballot_bit_size,
435 intrin->dest.ssa.bit_size);
436
437 nir_ssa_def *count = nir_load_subgroup_invocation(b);
438 nir_ssa_def *val;
439 switch (intrin->intrinsic) {
440 case nir_intrinsic_load_subgroup_eq_mask:
441 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
442 break;
443 case nir_intrinsic_load_subgroup_ge_mask:
444 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
445 build_subgroup_mask(b, bit_size, options));
446 break;
447 case nir_intrinsic_load_subgroup_gt_mask:
448 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
449 build_subgroup_mask(b, bit_size, options));
450 break;
451 case nir_intrinsic_load_subgroup_le_mask:
452 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
453 break;
454 case nir_intrinsic_load_subgroup_lt_mask:
455 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
456 break;
457 default:
458 unreachable("you seriously can't tell this is unreachable?");
459 }
460
461 return uint_to_ballot_type(b, val,
462 intrin->dest.ssa.num_components,
463 intrin->dest.ssa.bit_size);
464 }
465
466 case nir_intrinsic_ballot: {
467 if (intrin->dest.ssa.num_components == 1 &&
468 intrin->dest.ssa.bit_size == options->ballot_bit_size)
469 return NULL;
470
471 nir_intrinsic_instr *ballot =
472 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
473 ballot->num_components = 1;
474 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
475 1, options->ballot_bit_size, NULL);
476 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
477 nir_builder_instr_insert(b, &ballot->instr);
478
479 return uint_to_ballot_type(b, &ballot->dest.ssa,
480 intrin->dest.ssa.num_components,
481 intrin->dest.ssa.bit_size);
482 }
483
484 case nir_intrinsic_ballot_bitfield_extract:
485 case nir_intrinsic_ballot_bit_count_reduce:
486 case nir_intrinsic_ballot_find_lsb:
487 case nir_intrinsic_ballot_find_msb: {
488 assert(intrin->src[0].is_ssa);
489 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
490 options->ballot_bit_size);
491
492 if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
493 intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
494 /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
495 *
496 * "Find the most significant bit set to 1 in Value, considering
497 * only the bits in Value required to represent all bits of the
498 * group’s invocations. If none of the considered bits is set to
499 * 1, the result is undefined."
500 *
501 * It has similar text for the other three. This means that, in case
502 * the subgroup size is less than 32, we have to mask off the unused
503 * bits. If the subgroup size is fixed and greater than or equal to
504 * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
505 * the iand.
506 *
507 * We only have to worry about this for BitCount and FindMSB because
508 * FindLSB counts from the bottom and BitfieldExtract selects
509 * individual bits. In either case, if run outside the range of
510 * valid bits, we hit the undefined results case and we can return
511 * anything we want.
512 */
513 int_val = nir_iand(b, int_val,
514 build_subgroup_mask(b, options->ballot_bit_size, options));
515 }
516
517 switch (intrin->intrinsic) {
518 case nir_intrinsic_ballot_bitfield_extract:
519 assert(intrin->src[1].is_ssa);
520 return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
521 intrin->src[1].ssa),
522 nir_imm_intN_t(b, 1, options->ballot_bit_size)));
523 case nir_intrinsic_ballot_bit_count_reduce:
524 return nir_bit_count(b, int_val);
525 case nir_intrinsic_ballot_find_lsb:
526 return nir_find_lsb(b, int_val);
527 case nir_intrinsic_ballot_find_msb:
528 return nir_ufind_msb(b, int_val);
529 default:
530 unreachable("you seriously can't tell this is unreachable?");
531 }
532 }
533
534 case nir_intrinsic_ballot_bit_count_exclusive:
535 case nir_intrinsic_ballot_bit_count_inclusive: {
536 nir_ssa_def *count = nir_load_subgroup_invocation(b);
537 nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size);
538 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
539 const unsigned bits = options->ballot_bit_size;
540 mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count));
541 } else {
542 mask = nir_inot(b, nir_ishl(b, mask, count));
543 }
544
545 assert(intrin->src[0].is_ssa);
546 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
547 options->ballot_bit_size);
548
549 return nir_bit_count(b, nir_iand(b, int_val, mask));
550 }
551
552 case nir_intrinsic_elect: {
553 nir_intrinsic_instr *first =
554 nir_intrinsic_instr_create(b->shader,
555 nir_intrinsic_first_invocation);
556 nir_ssa_dest_init(&first->instr, &first->dest, 1, 32, NULL);
557 nir_builder_instr_insert(b, &first->instr);
558
559 return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa);
560 }
561
562 case nir_intrinsic_shuffle:
563 if (options->lower_to_scalar && intrin->num_components > 1)
564 return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
565 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
566 return lower_subgroup_op_to_32bit(b, intrin);
567 break;
568 case nir_intrinsic_shuffle_xor:
569 case nir_intrinsic_shuffle_up:
570 case nir_intrinsic_shuffle_down:
571 if (options->lower_shuffle)
572 return lower_shuffle(b, intrin, options);
573 else if (options->lower_to_scalar && intrin->num_components > 1)
574 return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
575 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
576 return lower_subgroup_op_to_32bit(b, intrin);
577 break;
578
579 case nir_intrinsic_quad_broadcast:
580 case nir_intrinsic_quad_swap_horizontal:
581 case nir_intrinsic_quad_swap_vertical:
582 case nir_intrinsic_quad_swap_diagonal:
583 if (options->lower_quad ||
584 (options->lower_quad_broadcast_dynamic &&
585 intrin->intrinsic == nir_intrinsic_quad_broadcast &&
586 !nir_src_is_const(intrin->src[1])))
587 return lower_dynamic_quad_broadcast(b, intrin, options);
588 else if (options->lower_to_scalar && intrin->num_components > 1)
589 return lower_subgroup_op_to_scalar(b, intrin, false);
590 break;
591
592 case nir_intrinsic_reduce: {
593 nir_ssa_def *ret = NULL;
594 /* A cluster size greater than the subgroup size is implemention defined */
595 if (options->subgroup_size &&
596 nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
597 nir_intrinsic_set_cluster_size(intrin, 0);
598 ret = NIR_LOWER_INSTR_PROGRESS;
599 }
600 if (options->lower_to_scalar && intrin->num_components > 1)
601 ret = lower_subgroup_op_to_scalar(b, intrin, false);
602 return ret;
603 }
604 case nir_intrinsic_inclusive_scan:
605 case nir_intrinsic_exclusive_scan:
606 if (options->lower_to_scalar && intrin->num_components > 1)
607 return lower_subgroup_op_to_scalar(b, intrin, false);
608 break;
609
610 default:
611 break;
612 }
613
614 return NULL;
615 }
616
617 bool
618 nir_lower_subgroups(nir_shader *shader,
619 const nir_lower_subgroups_options *options)
620 {
621 return nir_shader_lower_instructions(shader,
622 lower_subgroups_filter,
623 lower_subgroups_instr,
624 (void *)options);
625 }