2 * Copyright © 2017 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25 #include "nir_builder.h"
28 * \file nir_opt_intrinsics.c
32 ballot_type_to_uint(nir_builder
*b
, nir_ssa_def
*value
, unsigned bit_size
)
34 /* We only use this on uvec4 types */
35 assert(value
->num_components
== 4 && value
->bit_size
== 32);
38 return nir_channel(b
, value
, 0);
40 assert(bit_size
== 64);
41 return nir_pack_64_2x32_split(b
, nir_channel(b
, value
, 0),
42 nir_channel(b
, value
, 1));
46 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
48 uint_to_ballot_type(nir_builder
*b
, nir_ssa_def
*value
,
49 unsigned num_components
, unsigned bit_size
)
51 assert(value
->num_components
== 1);
52 assert(value
->bit_size
== 32 || value
->bit_size
== 64);
54 nir_ssa_def
*zero
= nir_imm_int(b
, 0);
55 if (num_components
> 1) {
56 /* SPIR-V uses a uvec4 for ballot values */
57 assert(num_components
== 4);
58 assert(bit_size
== 32);
60 if (value
->bit_size
== 32) {
61 return nir_vec4(b
, value
, zero
, zero
, zero
);
63 assert(value
->bit_size
== 64);
64 return nir_vec4(b
, nir_unpack_64_2x32_split_x(b
, value
),
65 nir_unpack_64_2x32_split_y(b
, value
),
69 /* GLSL uses a uint64_t for ballot values */
70 assert(num_components
== 1);
71 assert(bit_size
== 64);
73 if (value
->bit_size
== 32) {
74 return nir_pack_64_2x32_split(b
, value
, zero
);
76 assert(value
->bit_size
== 64);
83 lower_subgroup_op_to_scalar(nir_builder
*b
, nir_intrinsic_instr
*intrin
)
85 /* This is safe to call on scalar things but it would be silly */
86 assert(intrin
->dest
.ssa
.num_components
> 1);
88 nir_ssa_def
*value
= nir_ssa_for_src(b
, intrin
->src
[0],
89 intrin
->num_components
);
90 nir_ssa_def
*reads
[4];
92 for (unsigned i
= 0; i
< intrin
->num_components
; i
++) {
93 nir_intrinsic_instr
*chan_intrin
=
94 nir_intrinsic_instr_create(b
->shader
, intrin
->intrinsic
);
95 nir_ssa_dest_init(&chan_intrin
->instr
, &chan_intrin
->dest
,
96 1, intrin
->dest
.ssa
.bit_size
, NULL
);
97 chan_intrin
->num_components
= 1;
100 chan_intrin
->src
[0] = nir_src_for_ssa(nir_channel(b
, value
, i
));
102 if (nir_intrinsic_infos
[intrin
->intrinsic
].num_srcs
> 1) {
103 assert(nir_intrinsic_infos
[intrin
->intrinsic
].num_srcs
== 2);
104 nir_src_copy(&chan_intrin
->src
[1], &intrin
->src
[1], chan_intrin
);
107 chan_intrin
->const_index
[0] = intrin
->const_index
[0];
108 chan_intrin
->const_index
[1] = intrin
->const_index
[1];
110 nir_builder_instr_insert(b
, &chan_intrin
->instr
);
112 reads
[i
] = &chan_intrin
->dest
.ssa
;
115 return nir_vec(b
, reads
, intrin
->num_components
);
119 lower_vote_eq_to_scalar(nir_builder
*b
, nir_intrinsic_instr
*intrin
)
121 assert(intrin
->src
[0].is_ssa
);
122 nir_ssa_def
*value
= intrin
->src
[0].ssa
;
124 nir_ssa_def
*result
= NULL
;
125 for (unsigned i
= 0; i
< intrin
->num_components
; i
++) {
126 nir_intrinsic_instr
*chan_intrin
=
127 nir_intrinsic_instr_create(b
->shader
, intrin
->intrinsic
);
128 nir_ssa_dest_init(&chan_intrin
->instr
, &chan_intrin
->dest
,
129 1, intrin
->dest
.ssa
.bit_size
, NULL
);
130 chan_intrin
->num_components
= 1;
131 chan_intrin
->src
[0] = nir_src_for_ssa(nir_channel(b
, value
, i
));
132 nir_builder_instr_insert(b
, &chan_intrin
->instr
);
135 result
= nir_iand(b
, result
, &chan_intrin
->dest
.ssa
);
137 result
= &chan_intrin
->dest
.ssa
;
145 lower_vote_eq_to_ballot(nir_builder
*b
, nir_intrinsic_instr
*intrin
,
146 const nir_lower_subgroups_options
*options
)
148 assert(intrin
->src
[0].is_ssa
);
149 nir_ssa_def
*value
= intrin
->src
[0].ssa
;
151 /* We have to implicitly lower to scalar */
152 nir_ssa_def
*all_eq
= NULL
;
153 for (unsigned i
= 0; i
< intrin
->num_components
; i
++) {
154 nir_intrinsic_instr
*rfi
=
155 nir_intrinsic_instr_create(b
->shader
,
156 nir_intrinsic_read_first_invocation
);
157 nir_ssa_dest_init(&rfi
->instr
, &rfi
->dest
,
158 1, value
->bit_size
, NULL
);
159 rfi
->num_components
= 1;
160 rfi
->src
[0] = nir_src_for_ssa(nir_channel(b
, value
, i
));
161 nir_builder_instr_insert(b
, &rfi
->instr
);
164 if (intrin
->intrinsic
== nir_intrinsic_vote_feq
) {
165 is_eq
= nir_feq(b
, &rfi
->dest
.ssa
, nir_channel(b
, value
, i
));
167 is_eq
= nir_ieq(b
, &rfi
->dest
.ssa
, nir_channel(b
, value
, i
));
170 if (all_eq
== NULL
) {
173 all_eq
= nir_iand(b
, all_eq
, is_eq
);
177 nir_intrinsic_instr
*ballot
=
178 nir_intrinsic_instr_create(b
->shader
, nir_intrinsic_ballot
);
179 nir_ssa_dest_init(&ballot
->instr
, &ballot
->dest
,
180 1, options
->ballot_bit_size
, NULL
);
181 ballot
->num_components
= 1;
182 ballot
->src
[0] = nir_src_for_ssa(nir_inot(b
, all_eq
));
183 nir_builder_instr_insert(b
, &ballot
->instr
);
185 return nir_ieq(b
, &ballot
->dest
.ssa
,
186 nir_imm_intN_t(b
, 0, options
->ballot_bit_size
));
190 lower_shuffle(nir_builder
*b
, nir_intrinsic_instr
*intrin
,
191 bool lower_to_scalar
)
193 nir_ssa_def
*index
= nir_load_subgroup_invocation(b
);
194 switch (intrin
->intrinsic
) {
195 case nir_intrinsic_shuffle_xor
:
196 assert(intrin
->src
[1].is_ssa
);
197 index
= nir_ixor(b
, index
, intrin
->src
[1].ssa
);
199 case nir_intrinsic_shuffle_up
:
200 assert(intrin
->src
[1].is_ssa
);
201 index
= nir_isub(b
, index
, intrin
->src
[1].ssa
);
203 case nir_intrinsic_shuffle_down
:
204 assert(intrin
->src
[1].is_ssa
);
205 index
= nir_iadd(b
, index
, intrin
->src
[1].ssa
);
207 case nir_intrinsic_quad_broadcast
:
208 assert(intrin
->src
[1].is_ssa
);
209 index
= nir_ior(b
, nir_iand(b
, index
, nir_imm_int(b
, ~0x3)),
212 case nir_intrinsic_quad_swap_horizontal
:
213 /* For Quad operations, subgroups are divided into quads where
214 * (invocation % 4) is the index to a square arranged as follows:
222 index
= nir_ixor(b
, index
, nir_imm_int(b
, 0x1));
224 case nir_intrinsic_quad_swap_vertical
:
225 index
= nir_ixor(b
, index
, nir_imm_int(b
, 0x2));
227 case nir_intrinsic_quad_swap_diagonal
:
228 index
= nir_ixor(b
, index
, nir_imm_int(b
, 0x3));
231 unreachable("Invalid intrinsic");
234 nir_intrinsic_instr
*shuffle
=
235 nir_intrinsic_instr_create(b
->shader
, nir_intrinsic_shuffle
);
236 shuffle
->num_components
= intrin
->num_components
;
237 nir_src_copy(&shuffle
->src
[0], &intrin
->src
[0], shuffle
);
238 shuffle
->src
[1] = nir_src_for_ssa(index
);
239 nir_ssa_dest_init(&shuffle
->instr
, &shuffle
->dest
,
240 intrin
->dest
.ssa
.num_components
,
241 intrin
->dest
.ssa
.bit_size
, NULL
);
243 if (lower_to_scalar
&& shuffle
->num_components
> 1) {
244 return lower_subgroup_op_to_scalar(b
, shuffle
);
246 nir_builder_instr_insert(b
, &shuffle
->instr
);
247 return &shuffle
->dest
.ssa
;
252 lower_subgroups_intrin(nir_builder
*b
, nir_intrinsic_instr
*intrin
,
253 const nir_lower_subgroups_options
*options
)
255 switch (intrin
->intrinsic
) {
256 case nir_intrinsic_vote_any
:
257 case nir_intrinsic_vote_all
:
258 if (options
->lower_vote_trivial
)
259 return nir_ssa_for_src(b
, intrin
->src
[0], 1);
262 case nir_intrinsic_vote_feq
:
263 case nir_intrinsic_vote_ieq
:
264 if (options
->lower_vote_trivial
)
265 return nir_imm_int(b
, NIR_TRUE
);
267 if (options
->lower_vote_eq_to_ballot
)
268 return lower_vote_eq_to_ballot(b
, intrin
, options
);
270 if (options
->lower_to_scalar
&& intrin
->num_components
> 1)
271 return lower_vote_eq_to_scalar(b
, intrin
);
274 case nir_intrinsic_load_subgroup_size
:
275 if (options
->subgroup_size
)
276 return nir_imm_int(b
, options
->subgroup_size
);
279 case nir_intrinsic_read_invocation
:
280 case nir_intrinsic_read_first_invocation
:
281 if (options
->lower_to_scalar
&& intrin
->num_components
> 1)
282 return lower_subgroup_op_to_scalar(b
, intrin
);
285 case nir_intrinsic_load_subgroup_eq_mask
:
286 case nir_intrinsic_load_subgroup_ge_mask
:
287 case nir_intrinsic_load_subgroup_gt_mask
:
288 case nir_intrinsic_load_subgroup_le_mask
:
289 case nir_intrinsic_load_subgroup_lt_mask
: {
290 if (!options
->lower_subgroup_masks
)
293 /* If either the result or the requested bit size is 64-bits then we
294 * know that we have 64-bit types and using them will probably be more
295 * efficient than messing around with 32-bit shifts and packing.
297 const unsigned bit_size
= MAX2(options
->ballot_bit_size
,
298 intrin
->dest
.ssa
.bit_size
);
300 assert(options
->subgroup_size
<= 64);
301 uint64_t group_mask
= ~0ull >> (64 - options
->subgroup_size
);
303 nir_ssa_def
*count
= nir_load_subgroup_invocation(b
);
305 switch (intrin
->intrinsic
) {
306 case nir_intrinsic_load_subgroup_eq_mask
:
307 val
= nir_ishl(b
, nir_imm_intN_t(b
, 1ull, bit_size
), count
);
309 case nir_intrinsic_load_subgroup_ge_mask
:
310 val
= nir_iand(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~0ull, bit_size
), count
),
311 nir_imm_intN_t(b
, group_mask
, bit_size
));
313 case nir_intrinsic_load_subgroup_gt_mask
:
314 val
= nir_iand(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~1ull, bit_size
), count
),
315 nir_imm_intN_t(b
, group_mask
, bit_size
));
317 case nir_intrinsic_load_subgroup_le_mask
:
318 val
= nir_inot(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~1ull, bit_size
), count
));
320 case nir_intrinsic_load_subgroup_lt_mask
:
321 val
= nir_inot(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~0ull, bit_size
), count
));
324 unreachable("you seriously can't tell this is unreachable?");
327 return uint_to_ballot_type(b
, val
,
328 intrin
->dest
.ssa
.num_components
,
329 intrin
->dest
.ssa
.bit_size
);
332 case nir_intrinsic_ballot
: {
333 if (intrin
->dest
.ssa
.num_components
== 1 &&
334 intrin
->dest
.ssa
.bit_size
== options
->ballot_bit_size
)
337 nir_intrinsic_instr
*ballot
=
338 nir_intrinsic_instr_create(b
->shader
, nir_intrinsic_ballot
);
339 ballot
->num_components
= 1;
340 nir_ssa_dest_init(&ballot
->instr
, &ballot
->dest
,
341 1, options
->ballot_bit_size
, NULL
);
342 nir_src_copy(&ballot
->src
[0], &intrin
->src
[0], ballot
);
343 nir_builder_instr_insert(b
, &ballot
->instr
);
345 return uint_to_ballot_type(b
, &ballot
->dest
.ssa
,
346 intrin
->dest
.ssa
.num_components
,
347 intrin
->dest
.ssa
.bit_size
);
350 case nir_intrinsic_ballot_bitfield_extract
:
351 case nir_intrinsic_ballot_bit_count_reduce
:
352 case nir_intrinsic_ballot_find_lsb
:
353 case nir_intrinsic_ballot_find_msb
: {
354 assert(intrin
->src
[0].is_ssa
);
355 nir_ssa_def
*int_val
= ballot_type_to_uint(b
, intrin
->src
[0].ssa
,
356 options
->ballot_bit_size
);
357 switch (intrin
->intrinsic
) {
358 case nir_intrinsic_ballot_bitfield_extract
:
359 assert(intrin
->src
[1].is_ssa
);
360 return nir_i2b(b
, nir_iand(b
, nir_ushr(b
, int_val
,
363 case nir_intrinsic_ballot_bit_count_reduce
:
364 return nir_bit_count(b
, int_val
);
365 case nir_intrinsic_ballot_find_lsb
:
366 return nir_find_lsb(b
, int_val
);
367 case nir_intrinsic_ballot_find_msb
:
368 return nir_ufind_msb(b
, int_val
);
370 unreachable("you seriously can't tell this is unreachable?");
374 case nir_intrinsic_ballot_bit_count_exclusive
:
375 case nir_intrinsic_ballot_bit_count_inclusive
: {
376 nir_ssa_def
*count
= nir_load_subgroup_invocation(b
);
377 nir_ssa_def
*mask
= nir_imm_intN_t(b
, ~0ull, options
->ballot_bit_size
);
378 if (intrin
->intrinsic
== nir_intrinsic_ballot_bit_count_inclusive
) {
379 const unsigned bits
= options
->ballot_bit_size
;
380 mask
= nir_ushr(b
, mask
, nir_isub(b
, nir_imm_int(b
, bits
- 1), count
));
382 mask
= nir_inot(b
, nir_ishl(b
, mask
, count
));
385 assert(intrin
->src
[0].is_ssa
);
386 nir_ssa_def
*int_val
= ballot_type_to_uint(b
, intrin
->src
[0].ssa
,
387 options
->ballot_bit_size
);
389 return nir_bit_count(b
, nir_iand(b
, int_val
, mask
));
392 case nir_intrinsic_elect
: {
393 nir_intrinsic_instr
*first
=
394 nir_intrinsic_instr_create(b
->shader
,
395 nir_intrinsic_first_invocation
);
396 nir_ssa_dest_init(&first
->instr
, &first
->dest
, 1, 32, NULL
);
397 nir_builder_instr_insert(b
, &first
->instr
);
399 return nir_ieq(b
, nir_load_subgroup_invocation(b
), &first
->dest
.ssa
);
402 case nir_intrinsic_shuffle
:
403 if (options
->lower_to_scalar
&& intrin
->num_components
> 1)
404 return lower_subgroup_op_to_scalar(b
, intrin
);
407 case nir_intrinsic_shuffle_xor
:
408 case nir_intrinsic_shuffle_up
:
409 case nir_intrinsic_shuffle_down
:
410 if (options
->lower_shuffle
)
411 return lower_shuffle(b
, intrin
, options
->lower_to_scalar
);
412 else if (options
->lower_to_scalar
&& intrin
->num_components
> 1)
413 return lower_subgroup_op_to_scalar(b
, intrin
);
416 case nir_intrinsic_quad_broadcast
:
417 case nir_intrinsic_quad_swap_horizontal
:
418 case nir_intrinsic_quad_swap_vertical
:
419 case nir_intrinsic_quad_swap_diagonal
:
420 if (options
->lower_quad
)
421 return lower_shuffle(b
, intrin
, options
->lower_to_scalar
);
422 else if (options
->lower_to_scalar
&& intrin
->num_components
> 1)
423 return lower_subgroup_op_to_scalar(b
, intrin
);
426 case nir_intrinsic_reduce
:
427 case nir_intrinsic_inclusive_scan
:
428 case nir_intrinsic_exclusive_scan
:
429 if (options
->lower_to_scalar
&& intrin
->num_components
> 1)
430 return lower_subgroup_op_to_scalar(b
, intrin
);
441 lower_subgroups_impl(nir_function_impl
*impl
,
442 const nir_lower_subgroups_options
*options
)
445 nir_builder_init(&b
, impl
);
446 bool progress
= false;
448 nir_foreach_block(block
, impl
) {
449 nir_foreach_instr_safe(instr
, block
) {
450 if (instr
->type
!= nir_instr_type_intrinsic
)
453 nir_intrinsic_instr
*intrin
= nir_instr_as_intrinsic(instr
);
454 b
.cursor
= nir_before_instr(instr
);
456 nir_ssa_def
*lower
= lower_subgroups_intrin(&b
, intrin
, options
);
460 nir_ssa_def_rewrite_uses(&intrin
->dest
.ssa
, nir_src_for_ssa(lower
));
461 nir_instr_remove(instr
);
470 nir_lower_subgroups(nir_shader
*shader
,
471 const nir_lower_subgroups_options
*options
)
473 bool progress
= false;
475 nir_foreach_function(function
, shader
) {
479 if (lower_subgroups_impl(function
->impl
, options
)) {
481 nir_metadata_preserve(function
->impl
, nir_metadata_block_index
|
482 nir_metadata_dominance
);