nir: Move nir_lower_mediump_outputs from ir3
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 /**
28 * \file nir_opt_intrinsics.c
29 */
30
31 static nir_intrinsic_instr *
32 lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
33 unsigned int component)
34 {
35 nir_ssa_def *comp;
36 if (component == 0)
37 comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
38 else
39 comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
40
41 nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
42 nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL);
43 intr->const_index[0] = intrin->const_index[0];
44 intr->const_index[1] = intrin->const_index[1];
45 intr->src[0] = nir_src_for_ssa(comp);
46 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
47 nir_src_copy(&intr->src[1], &intrin->src[1], intr);
48
49 intr->num_components = 1;
50 nir_builder_instr_insert(b, &intr->instr);
51 return intr;
52 }
53
54 static nir_ssa_def *
55 lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
56 {
57 assert(intrin->src[0].ssa->bit_size == 64);
58 nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
59 nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
60 return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa);
61 }
62
63 static nir_ssa_def *
64 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
65 {
66 /* We only use this on uvec4 types */
67 assert(value->num_components == 4 && value->bit_size == 32);
68
69 if (bit_size == 32) {
70 return nir_channel(b, value, 0);
71 } else {
72 assert(bit_size == 64);
73 return nir_pack_64_2x32_split(b, nir_channel(b, value, 0),
74 nir_channel(b, value, 1));
75 }
76 }
77
78 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
79 static nir_ssa_def *
80 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
81 unsigned num_components, unsigned bit_size)
82 {
83 assert(value->num_components == 1);
84 assert(value->bit_size == 32 || value->bit_size == 64);
85
86 nir_ssa_def *zero = nir_imm_int(b, 0);
87 if (num_components > 1) {
88 /* SPIR-V uses a uvec4 for ballot values */
89 assert(num_components == 4);
90 assert(bit_size == 32);
91
92 if (value->bit_size == 32) {
93 return nir_vec4(b, value, zero, zero, zero);
94 } else {
95 assert(value->bit_size == 64);
96 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
97 nir_unpack_64_2x32_split_y(b, value),
98 zero, zero);
99 }
100 } else {
101 /* GLSL uses a uint64_t for ballot values */
102 assert(num_components == 1);
103 assert(bit_size == 64);
104
105 if (value->bit_size == 32) {
106 return nir_pack_64_2x32_split(b, value, zero);
107 } else {
108 assert(value->bit_size == 64);
109 return value;
110 }
111 }
112 }
113
114 static nir_ssa_def *
115 lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
116 bool lower_to_32bit)
117 {
118 /* This is safe to call on scalar things but it would be silly */
119 assert(intrin->dest.ssa.num_components > 1);
120
121 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
122 intrin->num_components);
123 nir_ssa_def *reads[4];
124
125 for (unsigned i = 0; i < intrin->num_components; i++) {
126 nir_intrinsic_instr *chan_intrin =
127 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
128 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
129 1, intrin->dest.ssa.bit_size, NULL);
130 chan_intrin->num_components = 1;
131
132 /* value */
133 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
134 /* invocation */
135 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
136 assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
137 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
138 }
139
140 chan_intrin->const_index[0] = intrin->const_index[0];
141 chan_intrin->const_index[1] = intrin->const_index[1];
142
143 if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) {
144 reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin);
145 } else {
146 nir_builder_instr_insert(b, &chan_intrin->instr);
147 reads[i] = &chan_intrin->dest.ssa;
148 }
149 }
150
151 return nir_vec(b, reads, intrin->num_components);
152 }
153
154 static nir_ssa_def *
155 lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
156 {
157 assert(intrin->src[0].is_ssa);
158 nir_ssa_def *value = intrin->src[0].ssa;
159
160 nir_ssa_def *result = NULL;
161 for (unsigned i = 0; i < intrin->num_components; i++) {
162 nir_intrinsic_instr *chan_intrin =
163 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
164 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
165 1, intrin->dest.ssa.bit_size, NULL);
166 chan_intrin->num_components = 1;
167 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
168 nir_builder_instr_insert(b, &chan_intrin->instr);
169
170 if (result) {
171 result = nir_iand(b, result, &chan_intrin->dest.ssa);
172 } else {
173 result = &chan_intrin->dest.ssa;
174 }
175 }
176
177 return result;
178 }
179
180 static nir_ssa_def *
181 lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
182 const nir_lower_subgroups_options *options)
183 {
184 assert(intrin->src[0].is_ssa);
185 nir_ssa_def *value = intrin->src[0].ssa;
186
187 /* We have to implicitly lower to scalar */
188 nir_ssa_def *all_eq = NULL;
189 for (unsigned i = 0; i < intrin->num_components; i++) {
190 nir_intrinsic_instr *rfi =
191 nir_intrinsic_instr_create(b->shader,
192 nir_intrinsic_read_first_invocation);
193 nir_ssa_dest_init(&rfi->instr, &rfi->dest,
194 1, value->bit_size, NULL);
195 rfi->num_components = 1;
196 rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
197 nir_builder_instr_insert(b, &rfi->instr);
198
199 nir_ssa_def *is_eq;
200 if (intrin->intrinsic == nir_intrinsic_vote_feq) {
201 is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i));
202 } else {
203 is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i));
204 }
205
206 if (all_eq == NULL) {
207 all_eq = is_eq;
208 } else {
209 all_eq = nir_iand(b, all_eq, is_eq);
210 }
211 }
212
213 nir_intrinsic_instr *ballot =
214 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
215 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
216 1, options->ballot_bit_size, NULL);
217 ballot->num_components = 1;
218 ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq));
219 nir_builder_instr_insert(b, &ballot->instr);
220
221 return nir_ieq(b, &ballot->dest.ssa,
222 nir_imm_intN_t(b, 0, options->ballot_bit_size));
223 }
224
225 static nir_ssa_def *
226 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
227 bool lower_to_scalar, bool lower_to_32bit)
228 {
229 nir_ssa_def *index = nir_load_subgroup_invocation(b);
230 switch (intrin->intrinsic) {
231 case nir_intrinsic_shuffle_xor:
232 assert(intrin->src[1].is_ssa);
233 index = nir_ixor(b, index, intrin->src[1].ssa);
234 break;
235 case nir_intrinsic_shuffle_up:
236 assert(intrin->src[1].is_ssa);
237 index = nir_isub(b, index, intrin->src[1].ssa);
238 break;
239 case nir_intrinsic_shuffle_down:
240 assert(intrin->src[1].is_ssa);
241 index = nir_iadd(b, index, intrin->src[1].ssa);
242 break;
243 case nir_intrinsic_quad_broadcast:
244 assert(intrin->src[1].is_ssa);
245 index = nir_ior(b, nir_iand(b, index, nir_imm_int(b, ~0x3)),
246 intrin->src[1].ssa);
247 break;
248 case nir_intrinsic_quad_swap_horizontal:
249 /* For Quad operations, subgroups are divided into quads where
250 * (invocation % 4) is the index to a square arranged as follows:
251 *
252 * +---+---+
253 * | 0 | 1 |
254 * +---+---+
255 * | 2 | 3 |
256 * +---+---+
257 */
258 index = nir_ixor(b, index, nir_imm_int(b, 0x1));
259 break;
260 case nir_intrinsic_quad_swap_vertical:
261 index = nir_ixor(b, index, nir_imm_int(b, 0x2));
262 break;
263 case nir_intrinsic_quad_swap_diagonal:
264 index = nir_ixor(b, index, nir_imm_int(b, 0x3));
265 break;
266 default:
267 unreachable("Invalid intrinsic");
268 }
269
270 nir_intrinsic_instr *shuffle =
271 nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle);
272 shuffle->num_components = intrin->num_components;
273 nir_src_copy(&shuffle->src[0], &intrin->src[0], shuffle);
274 shuffle->src[1] = nir_src_for_ssa(index);
275 nir_ssa_dest_init(&shuffle->instr, &shuffle->dest,
276 intrin->dest.ssa.num_components,
277 intrin->dest.ssa.bit_size, NULL);
278
279 if (lower_to_scalar && shuffle->num_components > 1) {
280 return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
281 } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
282 return lower_subgroup_op_to_32bit(b, shuffle);
283 } else {
284 nir_builder_instr_insert(b, &shuffle->instr);
285 return &shuffle->dest.ssa;
286 }
287 }
288
289 static bool
290 lower_subgroups_filter(const nir_instr *instr, const void *_options)
291 {
292 return instr->type == nir_instr_type_intrinsic;
293 }
294
295 static nir_ssa_def *
296 build_subgroup_mask(nir_builder *b, unsigned bit_size,
297 const nir_lower_subgroups_options *options)
298 {
299 return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
300 nir_isub(b, nir_imm_int(b, bit_size),
301 nir_load_subgroup_size(b)));
302 }
303
304 static nir_ssa_def *
305 lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
306 const nir_lower_subgroups_options *options)
307 {
308 if (!options->lower_quad_broadcast_dynamic_to_const)
309 return lower_shuffle(b, intrin, options->lower_to_scalar, false);
310
311 nir_ssa_def *dst = NULL;
312
313 for (unsigned i = 0; i < 4; ++i) {
314 nir_intrinsic_instr *qbcst =
315 nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast);
316
317 qbcst->num_components = intrin->num_components;
318 qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i));
319 nir_src_copy(&qbcst->src[0], &intrin->src[0], qbcst);
320 nir_ssa_dest_init(&qbcst->instr, &qbcst->dest,
321 intrin->dest.ssa.num_components,
322 intrin->dest.ssa.bit_size, NULL);
323
324 nir_ssa_def *qbcst_dst = NULL;
325
326 if (options->lower_to_scalar && qbcst->num_components > 1) {
327 qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
328 } else {
329 nir_builder_instr_insert(b, &qbcst->instr);
330 qbcst_dst = &qbcst->dest.ssa;
331 }
332
333 if (i)
334 dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa,
335 nir_src_for_ssa(nir_imm_int(b, i)).ssa),
336 qbcst_dst, dst);
337 else
338 dst = qbcst_dst;
339 }
340
341 return dst;
342 }
343
344 static nir_ssa_def *
345 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
346 {
347 const nir_lower_subgroups_options *options = _options;
348
349 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
350 switch (intrin->intrinsic) {
351 case nir_intrinsic_vote_any:
352 case nir_intrinsic_vote_all:
353 if (options->lower_vote_trivial)
354 return nir_ssa_for_src(b, intrin->src[0], 1);
355 break;
356
357 case nir_intrinsic_vote_feq:
358 case nir_intrinsic_vote_ieq:
359 if (options->lower_vote_trivial)
360 return nir_imm_true(b);
361
362 if (options->lower_vote_eq_to_ballot)
363 return lower_vote_eq_to_ballot(b, intrin, options);
364
365 if (options->lower_to_scalar && intrin->num_components > 1)
366 return lower_vote_eq_to_scalar(b, intrin);
367 break;
368
369 case nir_intrinsic_load_subgroup_size:
370 if (options->subgroup_size)
371 return nir_imm_int(b, options->subgroup_size);
372 break;
373
374 case nir_intrinsic_read_invocation:
375 case nir_intrinsic_read_first_invocation:
376 if (options->lower_to_scalar && intrin->num_components > 1)
377 return lower_subgroup_op_to_scalar(b, intrin, false);
378 break;
379
380 case nir_intrinsic_load_subgroup_eq_mask:
381 case nir_intrinsic_load_subgroup_ge_mask:
382 case nir_intrinsic_load_subgroup_gt_mask:
383 case nir_intrinsic_load_subgroup_le_mask:
384 case nir_intrinsic_load_subgroup_lt_mask: {
385 if (!options->lower_subgroup_masks)
386 return NULL;
387
388 /* If either the result or the requested bit size is 64-bits then we
389 * know that we have 64-bit types and using them will probably be more
390 * efficient than messing around with 32-bit shifts and packing.
391 */
392 const unsigned bit_size = MAX2(options->ballot_bit_size,
393 intrin->dest.ssa.bit_size);
394
395 nir_ssa_def *count = nir_load_subgroup_invocation(b);
396 nir_ssa_def *val;
397 switch (intrin->intrinsic) {
398 case nir_intrinsic_load_subgroup_eq_mask:
399 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
400 break;
401 case nir_intrinsic_load_subgroup_ge_mask:
402 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
403 build_subgroup_mask(b, bit_size, options));
404 break;
405 case nir_intrinsic_load_subgroup_gt_mask:
406 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
407 build_subgroup_mask(b, bit_size, options));
408 break;
409 case nir_intrinsic_load_subgroup_le_mask:
410 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
411 break;
412 case nir_intrinsic_load_subgroup_lt_mask:
413 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
414 break;
415 default:
416 unreachable("you seriously can't tell this is unreachable?");
417 }
418
419 return uint_to_ballot_type(b, val,
420 intrin->dest.ssa.num_components,
421 intrin->dest.ssa.bit_size);
422 }
423
424 case nir_intrinsic_ballot: {
425 if (intrin->dest.ssa.num_components == 1 &&
426 intrin->dest.ssa.bit_size == options->ballot_bit_size)
427 return NULL;
428
429 nir_intrinsic_instr *ballot =
430 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
431 ballot->num_components = 1;
432 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
433 1, options->ballot_bit_size, NULL);
434 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
435 nir_builder_instr_insert(b, &ballot->instr);
436
437 return uint_to_ballot_type(b, &ballot->dest.ssa,
438 intrin->dest.ssa.num_components,
439 intrin->dest.ssa.bit_size);
440 }
441
442 case nir_intrinsic_ballot_bitfield_extract:
443 case nir_intrinsic_ballot_bit_count_reduce:
444 case nir_intrinsic_ballot_find_lsb:
445 case nir_intrinsic_ballot_find_msb: {
446 assert(intrin->src[0].is_ssa);
447 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
448 options->ballot_bit_size);
449
450 if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
451 intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
452 /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
453 *
454 * "Find the most significant bit set to 1 in Value, considering
455 * only the bits in Value required to represent all bits of the
456 * group’s invocations. If none of the considered bits is set to
457 * 1, the result is undefined."
458 *
459 * It has similar text for the other three. This means that, in case
460 * the subgroup size is less than 32, we have to mask off the unused
461 * bits. If the subgroup size is fixed and greater than or equal to
462 * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
463 * the iand.
464 *
465 * We only have to worry about this for BitCount and FindMSB because
466 * FindLSB counts from the bottom and BitfieldExtract selects
467 * individual bits. In either case, if run outside the range of
468 * valid bits, we hit the undefined results case and we can return
469 * anything we want.
470 */
471 int_val = nir_iand(b, int_val,
472 build_subgroup_mask(b, options->ballot_bit_size, options));
473 }
474
475 switch (intrin->intrinsic) {
476 case nir_intrinsic_ballot_bitfield_extract:
477 assert(intrin->src[1].is_ssa);
478 return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
479 intrin->src[1].ssa),
480 nir_imm_intN_t(b, 1, options->ballot_bit_size)));
481 case nir_intrinsic_ballot_bit_count_reduce:
482 return nir_bit_count(b, int_val);
483 case nir_intrinsic_ballot_find_lsb:
484 return nir_find_lsb(b, int_val);
485 case nir_intrinsic_ballot_find_msb:
486 return nir_ufind_msb(b, int_val);
487 default:
488 unreachable("you seriously can't tell this is unreachable?");
489 }
490 }
491
492 case nir_intrinsic_ballot_bit_count_exclusive:
493 case nir_intrinsic_ballot_bit_count_inclusive: {
494 nir_ssa_def *count = nir_load_subgroup_invocation(b);
495 nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size);
496 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
497 const unsigned bits = options->ballot_bit_size;
498 mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count));
499 } else {
500 mask = nir_inot(b, nir_ishl(b, mask, count));
501 }
502
503 assert(intrin->src[0].is_ssa);
504 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
505 options->ballot_bit_size);
506
507 return nir_bit_count(b, nir_iand(b, int_val, mask));
508 }
509
510 case nir_intrinsic_elect: {
511 nir_intrinsic_instr *first =
512 nir_intrinsic_instr_create(b->shader,
513 nir_intrinsic_first_invocation);
514 nir_ssa_dest_init(&first->instr, &first->dest, 1, 32, NULL);
515 nir_builder_instr_insert(b, &first->instr);
516
517 return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa);
518 }
519
520 case nir_intrinsic_shuffle:
521 if (options->lower_to_scalar && intrin->num_components > 1)
522 return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
523 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
524 return lower_subgroup_op_to_32bit(b, intrin);
525 break;
526
527 case nir_intrinsic_shuffle_xor:
528 case nir_intrinsic_shuffle_up:
529 case nir_intrinsic_shuffle_down:
530 if (options->lower_shuffle)
531 return lower_shuffle(b, intrin, options->lower_to_scalar, options->lower_shuffle_to_32bit);
532 else if (options->lower_to_scalar && intrin->num_components > 1)
533 return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
534 else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
535 return lower_subgroup_op_to_32bit(b, intrin);
536 break;
537
538 case nir_intrinsic_quad_broadcast:
539 case nir_intrinsic_quad_swap_horizontal:
540 case nir_intrinsic_quad_swap_vertical:
541 case nir_intrinsic_quad_swap_diagonal:
542 if (options->lower_quad ||
543 (options->lower_quad_broadcast_dynamic &&
544 intrin->intrinsic == nir_intrinsic_quad_broadcast &&
545 !nir_src_is_const(intrin->src[1])))
546 return lower_dynamic_quad_broadcast(b, intrin, options);
547 else if (options->lower_to_scalar && intrin->num_components > 1)
548 return lower_subgroup_op_to_scalar(b, intrin, false);
549 break;
550
551 case nir_intrinsic_reduce: {
552 nir_ssa_def *ret = NULL;
553 /* A cluster size greater than the subgroup size is implemention defined */
554 if (options->subgroup_size &&
555 nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
556 nir_intrinsic_set_cluster_size(intrin, 0);
557 ret = NIR_LOWER_INSTR_PROGRESS;
558 }
559 if (options->lower_to_scalar && intrin->num_components > 1)
560 ret = lower_subgroup_op_to_scalar(b, intrin, false);
561 return ret;
562 }
563 case nir_intrinsic_inclusive_scan:
564 case nir_intrinsic_exclusive_scan:
565 if (options->lower_to_scalar && intrin->num_components > 1)
566 return lower_subgroup_op_to_scalar(b, intrin, false);
567 break;
568
569 default:
570 break;
571 }
572
573 return NULL;
574 }
575
576 bool
577 nir_lower_subgroups(nir_shader *shader,
578 const nir_lower_subgroups_options *options)
579 {
580 return nir_shader_lower_instructions(shader,
581 lower_subgroups_filter,
582 lower_subgroups_instr,
583 (void *)options);
584 }