2 * Copyright © 2017 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25 #include "nir_builder.h"
28 * \file nir_opt_intrinsics.c
32 ballot_type_to_uint(nir_builder
*b
, nir_ssa_def
*value
, unsigned bit_size
)
34 /* We only use this on uvec4 types */
35 assert(value
->num_components
== 4 && value
->bit_size
== 32);
38 return nir_channel(b
, value
, 0);
40 assert(bit_size
== 64);
41 return nir_pack_64_2x32_split(b
, nir_channel(b
, value
, 0),
42 nir_channel(b
, value
, 1));
46 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
48 uint_to_ballot_type(nir_builder
*b
, nir_ssa_def
*value
,
49 unsigned num_components
, unsigned bit_size
)
51 assert(value
->num_components
== 1);
52 assert(value
->bit_size
== 32 || value
->bit_size
== 64);
54 nir_ssa_def
*zero
= nir_imm_int(b
, 0);
55 if (num_components
> 1) {
56 /* SPIR-V uses a uvec4 for ballot values */
57 assert(num_components
== 4);
58 assert(bit_size
== 32);
60 if (value
->bit_size
== 32) {
61 return nir_vec4(b
, value
, zero
, zero
, zero
);
63 assert(value
->bit_size
== 64);
64 return nir_vec4(b
, nir_unpack_64_2x32_split_x(b
, value
),
65 nir_unpack_64_2x32_split_y(b
, value
),
69 /* GLSL uses a uint64_t for ballot values */
70 assert(num_components
== 1);
71 assert(bit_size
== 64);
73 if (value
->bit_size
== 32) {
74 return nir_pack_64_2x32_split(b
, value
, zero
);
76 assert(value
->bit_size
== 64);
83 lower_read_invocation_to_scalar(nir_builder
*b
, nir_intrinsic_instr
*intrin
)
85 /* This is safe to call on scalar things but it would be silly */
86 assert(intrin
->dest
.ssa
.num_components
> 1);
88 nir_ssa_def
*value
= nir_ssa_for_src(b
, intrin
->src
[0],
89 intrin
->num_components
);
90 nir_ssa_def
*reads
[4];
92 for (unsigned i
= 0; i
< intrin
->num_components
; i
++) {
93 nir_intrinsic_instr
*chan_intrin
=
94 nir_intrinsic_instr_create(b
->shader
, intrin
->intrinsic
);
95 nir_ssa_dest_init(&chan_intrin
->instr
, &chan_intrin
->dest
,
96 1, intrin
->dest
.ssa
.bit_size
, NULL
);
97 chan_intrin
->num_components
= 1;
100 chan_intrin
->src
[0] = nir_src_for_ssa(nir_channel(b
, value
, i
));
102 if (intrin
->intrinsic
== nir_intrinsic_read_invocation
)
103 nir_src_copy(&chan_intrin
->src
[1], &intrin
->src
[1], chan_intrin
);
105 nir_builder_instr_insert(b
, &chan_intrin
->instr
);
107 reads
[i
] = &chan_intrin
->dest
.ssa
;
110 return nir_vec(b
, reads
, intrin
->num_components
);
114 lower_subgroups_intrin(nir_builder
*b
, nir_intrinsic_instr
*intrin
,
115 const nir_lower_subgroups_options
*options
)
117 switch (intrin
->intrinsic
) {
118 case nir_intrinsic_vote_any
:
119 case nir_intrinsic_vote_all
:
120 if (options
->lower_vote_trivial
)
121 return nir_ssa_for_src(b
, intrin
->src
[0], 1);
124 case nir_intrinsic_vote_feq
:
125 case nir_intrinsic_vote_ieq
:
126 if (options
->lower_vote_trivial
)
127 return nir_imm_int(b
, NIR_TRUE
);
130 case nir_intrinsic_load_subgroup_size
:
131 if (options
->subgroup_size
)
132 return nir_imm_int(b
, options
->subgroup_size
);
135 case nir_intrinsic_read_invocation
:
136 case nir_intrinsic_read_first_invocation
:
137 if (options
->lower_to_scalar
&& intrin
->num_components
> 1)
138 return lower_read_invocation_to_scalar(b
, intrin
);
141 case nir_intrinsic_load_subgroup_eq_mask
:
142 case nir_intrinsic_load_subgroup_ge_mask
:
143 case nir_intrinsic_load_subgroup_gt_mask
:
144 case nir_intrinsic_load_subgroup_le_mask
:
145 case nir_intrinsic_load_subgroup_lt_mask
: {
146 if (!options
->lower_subgroup_masks
)
149 /* If either the result or the requested bit size is 64-bits then we
150 * know that we have 64-bit types and using them will probably be more
151 * efficient than messing around with 32-bit shifts and packing.
153 const unsigned bit_size
= MAX2(options
->ballot_bit_size
,
154 intrin
->dest
.ssa
.bit_size
);
156 assert(options
->subgroup_size
<= 64);
157 uint64_t group_mask
= ~0ull >> (64 - options
->subgroup_size
);
159 nir_ssa_def
*count
= nir_load_subgroup_invocation(b
);
161 switch (intrin
->intrinsic
) {
162 case nir_intrinsic_load_subgroup_eq_mask
:
163 val
= nir_ishl(b
, nir_imm_intN_t(b
, 1ull, bit_size
), count
);
165 case nir_intrinsic_load_subgroup_ge_mask
:
166 val
= nir_iand(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~0ull, bit_size
), count
),
167 nir_imm_intN_t(b
, group_mask
, bit_size
));
169 case nir_intrinsic_load_subgroup_gt_mask
:
170 val
= nir_iand(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~1ull, bit_size
), count
),
171 nir_imm_intN_t(b
, group_mask
, bit_size
));
173 case nir_intrinsic_load_subgroup_le_mask
:
174 val
= nir_inot(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~1ull, bit_size
), count
));
176 case nir_intrinsic_load_subgroup_lt_mask
:
177 val
= nir_inot(b
, nir_ishl(b
, nir_imm_intN_t(b
, ~0ull, bit_size
), count
));
180 unreachable("you seriously can't tell this is unreachable?");
183 return uint_to_ballot_type(b
, val
,
184 intrin
->dest
.ssa
.num_components
,
185 intrin
->dest
.ssa
.bit_size
);
188 case nir_intrinsic_ballot
: {
189 if (intrin
->dest
.ssa
.num_components
== 1 &&
190 intrin
->dest
.ssa
.bit_size
== options
->ballot_bit_size
)
193 nir_intrinsic_instr
*ballot
=
194 nir_intrinsic_instr_create(b
->shader
, nir_intrinsic_ballot
);
195 ballot
->num_components
= 1;
196 nir_ssa_dest_init(&ballot
->instr
, &ballot
->dest
,
197 1, options
->ballot_bit_size
, NULL
);
198 nir_src_copy(&ballot
->src
[0], &intrin
->src
[0], ballot
);
199 nir_builder_instr_insert(b
, &ballot
->instr
);
201 return uint_to_ballot_type(b
, &ballot
->dest
.ssa
,
202 intrin
->dest
.ssa
.num_components
,
203 intrin
->dest
.ssa
.bit_size
);
206 case nir_intrinsic_ballot_bitfield_extract
:
207 case nir_intrinsic_ballot_bit_count_reduce
:
208 case nir_intrinsic_ballot_find_lsb
:
209 case nir_intrinsic_ballot_find_msb
: {
210 assert(intrin
->src
[0].is_ssa
);
211 nir_ssa_def
*int_val
= ballot_type_to_uint(b
, intrin
->src
[0].ssa
,
212 options
->ballot_bit_size
);
213 switch (intrin
->intrinsic
) {
214 case nir_intrinsic_ballot_bitfield_extract
:
215 assert(intrin
->src
[1].is_ssa
);
216 return nir_i2b(b
, nir_iand(b
, nir_ushr(b
, int_val
,
219 case nir_intrinsic_ballot_bit_count_reduce
:
220 return nir_bit_count(b
, int_val
);
221 case nir_intrinsic_ballot_find_lsb
:
222 return nir_find_lsb(b
, int_val
);
223 case nir_intrinsic_ballot_find_msb
:
224 return nir_ufind_msb(b
, int_val
);
226 unreachable("you seriously can't tell this is unreachable?");
230 case nir_intrinsic_ballot_bit_count_exclusive
:
231 case nir_intrinsic_ballot_bit_count_inclusive
: {
232 nir_ssa_def
*count
= nir_load_subgroup_invocation(b
);
233 nir_ssa_def
*mask
= nir_imm_intN_t(b
, ~0ull, options
->ballot_bit_size
);
234 if (intrin
->intrinsic
== nir_intrinsic_ballot_bit_count_inclusive
) {
235 const unsigned bits
= options
->ballot_bit_size
;
236 mask
= nir_ushr(b
, mask
, nir_isub(b
, nir_imm_int(b
, bits
- 1), count
));
238 mask
= nir_inot(b
, nir_ishl(b
, mask
, count
));
241 assert(intrin
->src
[0].is_ssa
);
242 nir_ssa_def
*int_val
= ballot_type_to_uint(b
, intrin
->src
[0].ssa
,
243 options
->ballot_bit_size
);
245 return nir_bit_count(b
, nir_iand(b
, int_val
, mask
));
248 case nir_intrinsic_elect
: {
249 nir_intrinsic_instr
*first
=
250 nir_intrinsic_instr_create(b
->shader
,
251 nir_intrinsic_first_invocation
);
252 nir_ssa_dest_init(&first
->instr
, &first
->dest
, 1, 32, NULL
);
253 nir_builder_instr_insert(b
, &first
->instr
);
255 return nir_ieq(b
, nir_load_subgroup_invocation(b
), &first
->dest
.ssa
);
266 lower_subgroups_impl(nir_function_impl
*impl
,
267 const nir_lower_subgroups_options
*options
)
270 nir_builder_init(&b
, impl
);
271 bool progress
= false;
273 nir_foreach_block(block
, impl
) {
274 nir_foreach_instr_safe(instr
, block
) {
275 if (instr
->type
!= nir_instr_type_intrinsic
)
278 nir_intrinsic_instr
*intrin
= nir_instr_as_intrinsic(instr
);
279 b
.cursor
= nir_before_instr(instr
);
281 nir_ssa_def
*lower
= lower_subgroups_intrin(&b
, intrin
, options
);
285 nir_ssa_def_rewrite_uses(&intrin
->dest
.ssa
, nir_src_for_ssa(lower
));
286 nir_instr_remove(instr
);
295 nir_lower_subgroups(nir_shader
*shader
,
296 const nir_lower_subgroups_options
*options
)
298 bool progress
= false;
300 nir_foreach_function(function
, shader
) {
304 if (lower_subgroups_impl(function
->impl
, options
)) {
306 nir_metadata_preserve(function
->impl
, nir_metadata_block_index
|
307 nir_metadata_dominance
);