nir: Add new SPIR-V ballot intrinsics and lowering
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 /**
28 * \file nir_opt_intrinsics.c
29 */
30
31 static nir_ssa_def *
32 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
33 {
34 /* We only use this on uvec4 types */
35 assert(value->num_components == 4 && value->bit_size == 32);
36
37 if (bit_size == 32) {
38 return nir_channel(b, value, 0);
39 } else {
40 assert(bit_size == 64);
41 return nir_pack_64_2x32_split(b, nir_channel(b, value, 0),
42 nir_channel(b, value, 1));
43 }
44 }
45
46 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
47 static nir_ssa_def *
48 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
49 unsigned num_components, unsigned bit_size)
50 {
51 assert(value->num_components == 1);
52 assert(value->bit_size == 32 || value->bit_size == 64);
53
54 nir_ssa_def *zero = nir_imm_int(b, 0);
55 if (num_components > 1) {
56 /* SPIR-V uses a uvec4 for ballot values */
57 assert(num_components == 4);
58 assert(bit_size == 32);
59
60 if (value->bit_size == 32) {
61 return nir_vec4(b, value, zero, zero, zero);
62 } else {
63 assert(value->bit_size == 64);
64 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
65 nir_unpack_64_2x32_split_y(b, value),
66 zero, zero);
67 }
68 } else {
69 /* GLSL uses a uint64_t for ballot values */
70 assert(num_components == 1);
71 assert(bit_size == 64);
72
73 if (value->bit_size == 32) {
74 return nir_pack_64_2x32_split(b, value, zero);
75 } else {
76 assert(value->bit_size == 64);
77 return value;
78 }
79 }
80 }
81
82 static nir_ssa_def *
83 lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
84 {
85 /* This is safe to call on scalar things but it would be silly */
86 assert(intrin->dest.ssa.num_components > 1);
87
88 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
89 intrin->num_components);
90 nir_ssa_def *reads[4];
91
92 for (unsigned i = 0; i < intrin->num_components; i++) {
93 nir_intrinsic_instr *chan_intrin =
94 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
95 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
96 1, intrin->dest.ssa.bit_size, NULL);
97 chan_intrin->num_components = 1;
98
99 /* value */
100 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
101 /* invocation */
102 if (intrin->intrinsic == nir_intrinsic_read_invocation)
103 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
104
105 nir_builder_instr_insert(b, &chan_intrin->instr);
106
107 reads[i] = &chan_intrin->dest.ssa;
108 }
109
110 return nir_vec(b, reads, intrin->num_components);
111 }
112
113 static nir_ssa_def *
114 lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
115 const nir_lower_subgroups_options *options)
116 {
117 switch (intrin->intrinsic) {
118 case nir_intrinsic_vote_any:
119 case nir_intrinsic_vote_all:
120 if (options->lower_vote_trivial)
121 return nir_ssa_for_src(b, intrin->src[0], 1);
122 break;
123
124 case nir_intrinsic_vote_eq:
125 if (options->lower_vote_trivial)
126 return nir_imm_int(b, NIR_TRUE);
127 break;
128
129 case nir_intrinsic_load_subgroup_size:
130 if (options->subgroup_size)
131 return nir_imm_int(b, options->subgroup_size);
132 break;
133
134 case nir_intrinsic_read_invocation:
135 case nir_intrinsic_read_first_invocation:
136 if (options->lower_to_scalar && intrin->num_components > 1)
137 return lower_read_invocation_to_scalar(b, intrin);
138 break;
139
140 case nir_intrinsic_load_subgroup_eq_mask:
141 case nir_intrinsic_load_subgroup_ge_mask:
142 case nir_intrinsic_load_subgroup_gt_mask:
143 case nir_intrinsic_load_subgroup_le_mask:
144 case nir_intrinsic_load_subgroup_lt_mask: {
145 if (!options->lower_subgroup_masks)
146 return NULL;
147
148 /* If either the result or the requested bit size is 64-bits then we
149 * know that we have 64-bit types and using them will probably be more
150 * efficient than messing around with 32-bit shifts and packing.
151 */
152 const unsigned bit_size = MAX2(options->ballot_bit_size,
153 intrin->dest.ssa.bit_size);
154
155 assert(options->subgroup_size <= 64);
156 uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
157
158 nir_ssa_def *count = nir_load_subgroup_invocation(b);
159 nir_ssa_def *val;
160 switch (intrin->intrinsic) {
161 case nir_intrinsic_load_subgroup_eq_mask:
162 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
163 break;
164 case nir_intrinsic_load_subgroup_ge_mask:
165 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
166 nir_imm_intN_t(b, group_mask, bit_size));
167 break;
168 case nir_intrinsic_load_subgroup_gt_mask:
169 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
170 nir_imm_intN_t(b, group_mask, bit_size));
171 break;
172 case nir_intrinsic_load_subgroup_le_mask:
173 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
174 break;
175 case nir_intrinsic_load_subgroup_lt_mask:
176 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
177 break;
178 default:
179 unreachable("you seriously can't tell this is unreachable?");
180 }
181
182 return uint_to_ballot_type(b, val,
183 intrin->dest.ssa.num_components,
184 intrin->dest.ssa.bit_size);
185 }
186
187 case nir_intrinsic_ballot: {
188 if (intrin->dest.ssa.num_components == 1 &&
189 intrin->dest.ssa.bit_size == options->ballot_bit_size)
190 return NULL;
191
192 nir_intrinsic_instr *ballot =
193 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
194 ballot->num_components = 1;
195 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
196 1, options->ballot_bit_size, NULL);
197 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
198 nir_builder_instr_insert(b, &ballot->instr);
199
200 return uint_to_ballot_type(b, &ballot->dest.ssa,
201 intrin->dest.ssa.num_components,
202 intrin->dest.ssa.bit_size);
203 }
204
205 case nir_intrinsic_ballot_bitfield_extract:
206 case nir_intrinsic_ballot_bit_count_reduce:
207 case nir_intrinsic_ballot_find_lsb:
208 case nir_intrinsic_ballot_find_msb: {
209 assert(intrin->src[0].is_ssa);
210 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
211 options->ballot_bit_size);
212 switch (intrin->intrinsic) {
213 case nir_intrinsic_ballot_bitfield_extract:
214 assert(intrin->src[1].is_ssa);
215 return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
216 intrin->src[1].ssa),
217 nir_imm_int(b, 1)));
218 case nir_intrinsic_ballot_bit_count_reduce:
219 return nir_bit_count(b, int_val);
220 case nir_intrinsic_ballot_find_lsb:
221 return nir_find_lsb(b, int_val);
222 case nir_intrinsic_ballot_find_msb:
223 return nir_ufind_msb(b, int_val);
224 default:
225 unreachable("you seriously can't tell this is unreachable?");
226 }
227 }
228
229 case nir_intrinsic_ballot_bit_count_exclusive:
230 case nir_intrinsic_ballot_bit_count_inclusive: {
231 nir_ssa_def *count = nir_load_subgroup_invocation(b);
232 nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size);
233 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
234 const unsigned bits = options->ballot_bit_size;
235 mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count));
236 } else {
237 mask = nir_inot(b, nir_ishl(b, mask, count));
238 }
239
240 assert(intrin->src[0].is_ssa);
241 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
242 options->ballot_bit_size);
243
244 return nir_bit_count(b, nir_iand(b, int_val, mask));
245 }
246
247 case nir_intrinsic_elect: {
248 nir_intrinsic_instr *first =
249 nir_intrinsic_instr_create(b->shader,
250 nir_intrinsic_first_invocation);
251 nir_ssa_dest_init(&first->instr, &first->dest, 1, 32, NULL);
252 nir_builder_instr_insert(b, &first->instr);
253
254 return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa);
255 }
256
257 default:
258 break;
259 }
260
261 return NULL;
262 }
263
264 static bool
265 lower_subgroups_impl(nir_function_impl *impl,
266 const nir_lower_subgroups_options *options)
267 {
268 nir_builder b;
269 nir_builder_init(&b, impl);
270 bool progress = false;
271
272 nir_foreach_block(block, impl) {
273 nir_foreach_instr_safe(instr, block) {
274 if (instr->type != nir_instr_type_intrinsic)
275 continue;
276
277 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
278 b.cursor = nir_before_instr(instr);
279
280 nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
281 if (!lower)
282 continue;
283
284 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
285 nir_instr_remove(instr);
286 progress = true;
287 }
288 }
289
290 return progress;
291 }
292
293 bool
294 nir_lower_subgroups(nir_shader *shader,
295 const nir_lower_subgroups_options *options)
296 {
297 bool progress = false;
298
299 nir_foreach_function(function, shader) {
300 if (!function->impl)
301 continue;
302
303 if (lower_subgroups_impl(function->impl, options)) {
304 progress = true;
305 nir_metadata_preserve(function->impl, nir_metadata_block_index |
306 nir_metadata_dominance);
307 }
308 }
309
310 return progress;
311 }