nir: Add subgroup shuffle intrinsics and lowering
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
1 /*
2 * Copyright © 2017 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 /**
28 * \file nir_opt_intrinsics.c
29 */
30
31 static nir_ssa_def *
32 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
33 {
34 /* We only use this on uvec4 types */
35 assert(value->num_components == 4 && value->bit_size == 32);
36
37 if (bit_size == 32) {
38 return nir_channel(b, value, 0);
39 } else {
40 assert(bit_size == 64);
41 return nir_pack_64_2x32_split(b, nir_channel(b, value, 0),
42 nir_channel(b, value, 1));
43 }
44 }
45
46 /* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
47 static nir_ssa_def *
48 uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
49 unsigned num_components, unsigned bit_size)
50 {
51 assert(value->num_components == 1);
52 assert(value->bit_size == 32 || value->bit_size == 64);
53
54 nir_ssa_def *zero = nir_imm_int(b, 0);
55 if (num_components > 1) {
56 /* SPIR-V uses a uvec4 for ballot values */
57 assert(num_components == 4);
58 assert(bit_size == 32);
59
60 if (value->bit_size == 32) {
61 return nir_vec4(b, value, zero, zero, zero);
62 } else {
63 assert(value->bit_size == 64);
64 return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
65 nir_unpack_64_2x32_split_y(b, value),
66 zero, zero);
67 }
68 } else {
69 /* GLSL uses a uint64_t for ballot values */
70 assert(num_components == 1);
71 assert(bit_size == 64);
72
73 if (value->bit_size == 32) {
74 return nir_pack_64_2x32_split(b, value, zero);
75 } else {
76 assert(value->bit_size == 64);
77 return value;
78 }
79 }
80 }
81
82 static nir_ssa_def *
83 lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
84 {
85 /* This is safe to call on scalar things but it would be silly */
86 assert(intrin->dest.ssa.num_components > 1);
87
88 nir_ssa_def *value = nir_ssa_for_src(b, intrin->src[0],
89 intrin->num_components);
90 nir_ssa_def *reads[4];
91
92 for (unsigned i = 0; i < intrin->num_components; i++) {
93 nir_intrinsic_instr *chan_intrin =
94 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
95 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
96 1, intrin->dest.ssa.bit_size, NULL);
97 chan_intrin->num_components = 1;
98
99 /* value */
100 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
101 /* invocation */
102 if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
103 assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
104 nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
105 }
106
107 nir_builder_instr_insert(b, &chan_intrin->instr);
108
109 reads[i] = &chan_intrin->dest.ssa;
110 }
111
112 return nir_vec(b, reads, intrin->num_components);
113 }
114
115 static nir_ssa_def *
116 lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
117 {
118 assert(intrin->src[0].is_ssa);
119 nir_ssa_def *value = intrin->src[0].ssa;
120
121 nir_ssa_def *result = NULL;
122 for (unsigned i = 0; i < intrin->num_components; i++) {
123 nir_intrinsic_instr *chan_intrin =
124 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
125 nir_ssa_dest_init(&chan_intrin->instr, &chan_intrin->dest,
126 1, intrin->dest.ssa.bit_size, NULL);
127 chan_intrin->num_components = 1;
128 chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
129 nir_builder_instr_insert(b, &chan_intrin->instr);
130
131 if (result) {
132 result = nir_iand(b, result, &chan_intrin->dest.ssa);
133 } else {
134 result = &chan_intrin->dest.ssa;
135 }
136 }
137
138 return result;
139 }
140
141 static nir_ssa_def *
142 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
143 bool lower_to_scalar)
144 {
145 nir_ssa_def *index = nir_load_subgroup_invocation(b);
146 switch (intrin->intrinsic) {
147 case nir_intrinsic_shuffle_xor:
148 assert(intrin->src[1].is_ssa);
149 index = nir_ixor(b, index, intrin->src[1].ssa);
150 break;
151 case nir_intrinsic_shuffle_up:
152 assert(intrin->src[1].is_ssa);
153 index = nir_isub(b, index, intrin->src[1].ssa);
154 break;
155 case nir_intrinsic_shuffle_down:
156 assert(intrin->src[1].is_ssa);
157 index = nir_iadd(b, index, intrin->src[1].ssa);
158 break;
159 default:
160 unreachable("Invalid intrinsic");
161 }
162
163 nir_intrinsic_instr *shuffle =
164 nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle);
165 shuffle->num_components = intrin->num_components;
166 nir_src_copy(&shuffle->src[0], &intrin->src[0], shuffle);
167 shuffle->src[1] = nir_src_for_ssa(index);
168 nir_ssa_dest_init(&shuffle->instr, &shuffle->dest,
169 intrin->dest.ssa.num_components,
170 intrin->dest.ssa.bit_size, NULL);
171
172 if (lower_to_scalar && shuffle->num_components > 1) {
173 return lower_subgroup_op_to_scalar(b, shuffle);
174 } else {
175 nir_builder_instr_insert(b, &shuffle->instr);
176 return &shuffle->dest.ssa;
177 }
178 }
179
180 static nir_ssa_def *
181 lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
182 const nir_lower_subgroups_options *options)
183 {
184 switch (intrin->intrinsic) {
185 case nir_intrinsic_vote_any:
186 case nir_intrinsic_vote_all:
187 if (options->lower_vote_trivial)
188 return nir_ssa_for_src(b, intrin->src[0], 1);
189 break;
190
191 case nir_intrinsic_vote_feq:
192 case nir_intrinsic_vote_ieq:
193 if (options->lower_vote_trivial)
194 return nir_imm_int(b, NIR_TRUE);
195
196 if (options->lower_to_scalar && intrin->num_components > 1)
197 return lower_vote_eq_to_scalar(b, intrin);
198 break;
199
200 case nir_intrinsic_load_subgroup_size:
201 if (options->subgroup_size)
202 return nir_imm_int(b, options->subgroup_size);
203 break;
204
205 case nir_intrinsic_read_invocation:
206 case nir_intrinsic_read_first_invocation:
207 if (options->lower_to_scalar && intrin->num_components > 1)
208 return lower_subgroup_op_to_scalar(b, intrin);
209 break;
210
211 case nir_intrinsic_load_subgroup_eq_mask:
212 case nir_intrinsic_load_subgroup_ge_mask:
213 case nir_intrinsic_load_subgroup_gt_mask:
214 case nir_intrinsic_load_subgroup_le_mask:
215 case nir_intrinsic_load_subgroup_lt_mask: {
216 if (!options->lower_subgroup_masks)
217 return NULL;
218
219 /* If either the result or the requested bit size is 64-bits then we
220 * know that we have 64-bit types and using them will probably be more
221 * efficient than messing around with 32-bit shifts and packing.
222 */
223 const unsigned bit_size = MAX2(options->ballot_bit_size,
224 intrin->dest.ssa.bit_size);
225
226 assert(options->subgroup_size <= 64);
227 uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
228
229 nir_ssa_def *count = nir_load_subgroup_invocation(b);
230 nir_ssa_def *val;
231 switch (intrin->intrinsic) {
232 case nir_intrinsic_load_subgroup_eq_mask:
233 val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
234 break;
235 case nir_intrinsic_load_subgroup_ge_mask:
236 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
237 nir_imm_intN_t(b, group_mask, bit_size));
238 break;
239 case nir_intrinsic_load_subgroup_gt_mask:
240 val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
241 nir_imm_intN_t(b, group_mask, bit_size));
242 break;
243 case nir_intrinsic_load_subgroup_le_mask:
244 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
245 break;
246 case nir_intrinsic_load_subgroup_lt_mask:
247 val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
248 break;
249 default:
250 unreachable("you seriously can't tell this is unreachable?");
251 }
252
253 return uint_to_ballot_type(b, val,
254 intrin->dest.ssa.num_components,
255 intrin->dest.ssa.bit_size);
256 }
257
258 case nir_intrinsic_ballot: {
259 if (intrin->dest.ssa.num_components == 1 &&
260 intrin->dest.ssa.bit_size == options->ballot_bit_size)
261 return NULL;
262
263 nir_intrinsic_instr *ballot =
264 nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
265 ballot->num_components = 1;
266 nir_ssa_dest_init(&ballot->instr, &ballot->dest,
267 1, options->ballot_bit_size, NULL);
268 nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
269 nir_builder_instr_insert(b, &ballot->instr);
270
271 return uint_to_ballot_type(b, &ballot->dest.ssa,
272 intrin->dest.ssa.num_components,
273 intrin->dest.ssa.bit_size);
274 }
275
276 case nir_intrinsic_ballot_bitfield_extract:
277 case nir_intrinsic_ballot_bit_count_reduce:
278 case nir_intrinsic_ballot_find_lsb:
279 case nir_intrinsic_ballot_find_msb: {
280 assert(intrin->src[0].is_ssa);
281 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
282 options->ballot_bit_size);
283 switch (intrin->intrinsic) {
284 case nir_intrinsic_ballot_bitfield_extract:
285 assert(intrin->src[1].is_ssa);
286 return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
287 intrin->src[1].ssa),
288 nir_imm_int(b, 1)));
289 case nir_intrinsic_ballot_bit_count_reduce:
290 return nir_bit_count(b, int_val);
291 case nir_intrinsic_ballot_find_lsb:
292 return nir_find_lsb(b, int_val);
293 case nir_intrinsic_ballot_find_msb:
294 return nir_ufind_msb(b, int_val);
295 default:
296 unreachable("you seriously can't tell this is unreachable?");
297 }
298 }
299
300 case nir_intrinsic_ballot_bit_count_exclusive:
301 case nir_intrinsic_ballot_bit_count_inclusive: {
302 nir_ssa_def *count = nir_load_subgroup_invocation(b);
303 nir_ssa_def *mask = nir_imm_intN_t(b, ~0ull, options->ballot_bit_size);
304 if (intrin->intrinsic == nir_intrinsic_ballot_bit_count_inclusive) {
305 const unsigned bits = options->ballot_bit_size;
306 mask = nir_ushr(b, mask, nir_isub(b, nir_imm_int(b, bits - 1), count));
307 } else {
308 mask = nir_inot(b, nir_ishl(b, mask, count));
309 }
310
311 assert(intrin->src[0].is_ssa);
312 nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
313 options->ballot_bit_size);
314
315 return nir_bit_count(b, nir_iand(b, int_val, mask));
316 }
317
318 case nir_intrinsic_elect: {
319 nir_intrinsic_instr *first =
320 nir_intrinsic_instr_create(b->shader,
321 nir_intrinsic_first_invocation);
322 nir_ssa_dest_init(&first->instr, &first->dest, 1, 32, NULL);
323 nir_builder_instr_insert(b, &first->instr);
324
325 return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa);
326 }
327
328 case nir_intrinsic_shuffle:
329 if (options->lower_to_scalar && intrin->num_components > 1)
330 return lower_subgroup_op_to_scalar(b, intrin);
331 break;
332
333 case nir_intrinsic_shuffle_xor:
334 case nir_intrinsic_shuffle_up:
335 case nir_intrinsic_shuffle_down:
336 if (options->lower_shuffle)
337 return lower_shuffle(b, intrin, options->lower_to_scalar);
338 else if (options->lower_to_scalar && intrin->num_components > 1)
339 return lower_subgroup_op_to_scalar(b, intrin);
340 break;
341
342 default:
343 break;
344 }
345
346 return NULL;
347 }
348
349 static bool
350 lower_subgroups_impl(nir_function_impl *impl,
351 const nir_lower_subgroups_options *options)
352 {
353 nir_builder b;
354 nir_builder_init(&b, impl);
355 bool progress = false;
356
357 nir_foreach_block(block, impl) {
358 nir_foreach_instr_safe(instr, block) {
359 if (instr->type != nir_instr_type_intrinsic)
360 continue;
361
362 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
363 b.cursor = nir_before_instr(instr);
364
365 nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
366 if (!lower)
367 continue;
368
369 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
370 nir_instr_remove(instr);
371 progress = true;
372 }
373 }
374
375 return progress;
376 }
377
378 bool
379 nir_lower_subgroups(nir_shader *shader,
380 const nir_lower_subgroups_options *options)
381 {
382 bool progress = false;
383
384 nir_foreach_function(function, shader) {
385 if (!function->impl)
386 continue;
387
388 if (lower_subgroups_impl(function->impl, options)) {
389 progress = true;
390 nir_metadata_preserve(function->impl, nir_metadata_block_index |
391 nir_metadata_dominance);
392 }
393 }
394
395 return progress;
396 }