nir/spirv: Fix warning and add missing breaks.
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static void
27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *dst,
30 struct vtn_ssa_value *src0,
31 nir_ssa_def *index,
32 unsigned const_idx0,
33 unsigned const_idx1)
34 {
35 /* Some of the subgroup operations take an index. SPIR-V allows this to be
36 * any integer type. To make things simpler for drivers, we only support
37 * 32-bit indices.
38 */
39 if (index && index->bit_size != 32)
40 index = nir_u2u32(&b->nb, index);
41
42 vtn_assert(dst->type == src0->type);
43 if (!glsl_type_is_vector_or_scalar(dst->type)) {
44 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
45 vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
46 src0->elems[i], index,
47 const_idx0, const_idx1);
48 }
49 return;
50 }
51
52 nir_intrinsic_instr *intrin =
53 nir_intrinsic_instr_create(b->nb.shader, nir_op);
54 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
55 dst->type, NULL);
56 intrin->num_components = intrin->dest.ssa.num_components;
57
58 intrin->src[0] = nir_src_for_ssa(src0->def);
59 if (index)
60 intrin->src[1] = nir_src_for_ssa(index);
61
62 intrin->const_index[0] = const_idx0;
63 intrin->const_index[1] = const_idx1;
64
65 nir_builder_instr_insert(&b->nb, &intrin->instr);
66
67 dst->def = &intrin->dest.ssa;
68 }
69
70 void
71 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
72 const uint32_t *w, unsigned count)
73 {
74 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
75
76 val->ssa = vtn_create_ssa_value(b, val->type->type);
77
78 switch (opcode) {
79 case SpvOpGroupNonUniformElect: {
80 vtn_fail_if(val->type->type != glsl_bool_type(),
81 "OpGroupNonUniformElect must return a Bool");
82 nir_intrinsic_instr *elect =
83 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84 nir_ssa_dest_init(&elect->instr, &elect->dest, 1, 32, NULL);
85 nir_builder_instr_insert(&b->nb, &elect->instr);
86 val->ssa->def = &elect->dest.ssa;
87 break;
88 }
89
90 case SpvOpGroupNonUniformBallot: {
91 vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
92 "OpGroupNonUniformBallot must return a uvec4");
93 nir_intrinsic_instr *ballot =
94 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
95 ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
96 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
97 ballot->num_components = 4;
98 nir_builder_instr_insert(&b->nb, &ballot->instr);
99 val->ssa->def = &ballot->dest.ssa;
100 break;
101 }
102
103 case SpvOpGroupNonUniformInverseBallot: {
104 /* This one is just a BallotBitfieldExtract with subgroup invocation.
105 * We could add a NIR intrinsic but it's easier to just lower it on the
106 * spot.
107 */
108 nir_intrinsic_instr *intrin =
109 nir_intrinsic_instr_create(b->nb.shader,
110 nir_intrinsic_ballot_bitfield_extract);
111
112 intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
113 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
114
115 nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
116 nir_builder_instr_insert(&b->nb, &intrin->instr);
117
118 val->ssa->def = &intrin->dest.ssa;
119 break;
120 }
121
122 case SpvOpGroupNonUniformBallotBitExtract:
123 case SpvOpGroupNonUniformBallotBitCount:
124 case SpvOpGroupNonUniformBallotFindLSB:
125 case SpvOpGroupNonUniformBallotFindMSB: {
126 nir_ssa_def *src0, *src1 = NULL;
127 nir_intrinsic_op op;
128 switch (opcode) {
129 case SpvOpGroupNonUniformBallotBitExtract:
130 op = nir_intrinsic_ballot_bitfield_extract;
131 src0 = vtn_ssa_value(b, w[4])->def;
132 src1 = vtn_ssa_value(b, w[5])->def;
133 break;
134 case SpvOpGroupNonUniformBallotBitCount:
135 switch ((SpvGroupOperation)w[4]) {
136 case SpvGroupOperationReduce:
137 op = nir_intrinsic_ballot_bit_count_reduce;
138 break;
139 case SpvGroupOperationInclusiveScan:
140 op = nir_intrinsic_ballot_bit_count_inclusive;
141 break;
142 case SpvGroupOperationExclusiveScan:
143 op = nir_intrinsic_ballot_bit_count_exclusive;
144 break;
145 default:
146 unreachable("Invalid group operation");
147 }
148 src0 = vtn_ssa_value(b, w[5])->def;
149 break;
150 case SpvOpGroupNonUniformBallotFindLSB:
151 op = nir_intrinsic_ballot_find_lsb;
152 src0 = vtn_ssa_value(b, w[4])->def;
153 break;
154 case SpvOpGroupNonUniformBallotFindMSB:
155 op = nir_intrinsic_ballot_find_msb;
156 src0 = vtn_ssa_value(b, w[4])->def;
157 break;
158 default:
159 unreachable("Unhandled opcode");
160 }
161
162 nir_intrinsic_instr *intrin =
163 nir_intrinsic_instr_create(b->nb.shader, op);
164
165 intrin->src[0] = nir_src_for_ssa(src0);
166 if (src1)
167 intrin->src[1] = nir_src_for_ssa(src1);
168
169 nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
170 nir_builder_instr_insert(&b->nb, &intrin->instr);
171
172 val->ssa->def = &intrin->dest.ssa;
173 break;
174 }
175
176 case SpvOpGroupNonUniformBroadcastFirst:
177 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
178 val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
179 break;
180
181 case SpvOpGroupNonUniformBroadcast:
182 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
183 val->ssa, vtn_ssa_value(b, w[4]),
184 vtn_ssa_value(b, w[5])->def, 0, 0);
185 break;
186
187 case SpvOpGroupNonUniformAll:
188 case SpvOpGroupNonUniformAny:
189 case SpvOpGroupNonUniformAllEqual: {
190 vtn_fail_if(val->type->type != glsl_bool_type(),
191 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
192 nir_intrinsic_op op;
193 switch (opcode) {
194 case SpvOpGroupNonUniformAll:
195 op = nir_intrinsic_vote_all;
196 break;
197 case SpvOpGroupNonUniformAny:
198 op = nir_intrinsic_vote_any;
199 break;
200 case SpvOpGroupNonUniformAllEqual: {
201 switch (glsl_get_base_type(val->type->type)) {
202 case GLSL_TYPE_FLOAT:
203 case GLSL_TYPE_DOUBLE:
204 op = nir_intrinsic_vote_feq;
205 break;
206 case GLSL_TYPE_UINT:
207 case GLSL_TYPE_INT:
208 case GLSL_TYPE_UINT64:
209 case GLSL_TYPE_INT64:
210 case GLSL_TYPE_BOOL:
211 op = nir_intrinsic_vote_ieq;
212 break;
213 default:
214 unreachable("Unhandled type");
215 }
216 break;
217 }
218 default:
219 unreachable("Unhandled opcode");
220 }
221
222 nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def;
223
224 nir_intrinsic_instr *intrin =
225 nir_intrinsic_instr_create(b->nb.shader, op);
226 intrin->num_components = src0->num_components;
227 intrin->src[0] = nir_src_for_ssa(src0);
228 nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
229 nir_builder_instr_insert(&b->nb, &intrin->instr);
230
231 val->ssa->def = &intrin->dest.ssa;
232 break;
233 }
234
235 case SpvOpGroupNonUniformShuffle:
236 case SpvOpGroupNonUniformShuffleXor:
237 case SpvOpGroupNonUniformShuffleUp:
238 case SpvOpGroupNonUniformShuffleDown: {
239 nir_intrinsic_op op;
240 switch (opcode) {
241 case SpvOpGroupNonUniformShuffle:
242 op = nir_intrinsic_shuffle;
243 break;
244 case SpvOpGroupNonUniformShuffleXor:
245 op = nir_intrinsic_shuffle_xor;
246 break;
247 case SpvOpGroupNonUniformShuffleUp:
248 op = nir_intrinsic_shuffle_up;
249 break;
250 case SpvOpGroupNonUniformShuffleDown:
251 op = nir_intrinsic_shuffle_down;
252 break;
253 default:
254 unreachable("Invalid opcode");
255 }
256 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
257 vtn_ssa_value(b, w[5])->def, 0, 0);
258 break;
259 }
260
261 case SpvOpGroupNonUniformQuadBroadcast:
262 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
263 val->ssa, vtn_ssa_value(b, w[4]),
264 vtn_ssa_value(b, w[5])->def, 0, 0);
265 break;
266
267 case SpvOpGroupNonUniformQuadSwap: {
268 unsigned direction = vtn_constant_value(b, w[5])->values[0].u32[0];
269 nir_intrinsic_op op;
270 switch (direction) {
271 case 0:
272 op = nir_intrinsic_quad_swap_horizontal;
273 break;
274 case 1:
275 op = nir_intrinsic_quad_swap_vertical;
276 break;
277 case 2:
278 op = nir_intrinsic_quad_swap_diagonal;
279 break;
280 default:
281 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
282 }
283 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
284 NULL, 0, 0);
285 break;
286 }
287
288 case SpvOpGroupNonUniformIAdd:
289 case SpvOpGroupNonUniformFAdd:
290 case SpvOpGroupNonUniformIMul:
291 case SpvOpGroupNonUniformFMul:
292 case SpvOpGroupNonUniformSMin:
293 case SpvOpGroupNonUniformUMin:
294 case SpvOpGroupNonUniformFMin:
295 case SpvOpGroupNonUniformSMax:
296 case SpvOpGroupNonUniformUMax:
297 case SpvOpGroupNonUniformFMax:
298 case SpvOpGroupNonUniformBitwiseAnd:
299 case SpvOpGroupNonUniformBitwiseOr:
300 case SpvOpGroupNonUniformBitwiseXor:
301 case SpvOpGroupNonUniformLogicalAnd:
302 case SpvOpGroupNonUniformLogicalOr:
303 case SpvOpGroupNonUniformLogicalXor: {
304 nir_op reduction_op;
305 switch (opcode) {
306 case SpvOpGroupNonUniformIAdd:
307 reduction_op = nir_op_iadd;
308 break;
309 case SpvOpGroupNonUniformFAdd:
310 reduction_op = nir_op_fadd;
311 break;
312 case SpvOpGroupNonUniformIMul:
313 reduction_op = nir_op_imul;
314 break;
315 case SpvOpGroupNonUniformFMul:
316 reduction_op = nir_op_fmul;
317 break;
318 case SpvOpGroupNonUniformSMin:
319 reduction_op = nir_op_imin;
320 break;
321 case SpvOpGroupNonUniformUMin:
322 reduction_op = nir_op_umin;
323 break;
324 case SpvOpGroupNonUniformFMin:
325 reduction_op = nir_op_fmin;
326 break;
327 case SpvOpGroupNonUniformSMax:
328 reduction_op = nir_op_imax;
329 break;
330 case SpvOpGroupNonUniformUMax:
331 reduction_op = nir_op_umax;
332 break;
333 case SpvOpGroupNonUniformFMax:
334 reduction_op = nir_op_fmax;
335 break;
336 case SpvOpGroupNonUniformBitwiseAnd:
337 case SpvOpGroupNonUniformLogicalAnd:
338 reduction_op = nir_op_iand;
339 break;
340 case SpvOpGroupNonUniformBitwiseOr:
341 case SpvOpGroupNonUniformLogicalOr:
342 reduction_op = nir_op_ior;
343 break;
344 case SpvOpGroupNonUniformBitwiseXor:
345 case SpvOpGroupNonUniformLogicalXor:
346 reduction_op = nir_op_ixor;
347 break;
348 default:
349 unreachable("Invalid reduction operation");
350 }
351
352 nir_intrinsic_op op;
353 unsigned cluster_size = 0;
354 switch ((SpvGroupOperation)w[4]) {
355 case SpvGroupOperationReduce:
356 op = nir_intrinsic_reduce;
357 break;
358 case SpvGroupOperationInclusiveScan:
359 op = nir_intrinsic_inclusive_scan;
360 break;
361 case SpvGroupOperationExclusiveScan:
362 op = nir_intrinsic_exclusive_scan;
363 break;
364 case SpvGroupOperationClusteredReduce:
365 op = nir_intrinsic_reduce;
366 assert(count == 7);
367 cluster_size = vtn_constant_value(b, w[6])->values[0].u32[0];
368 break;
369 default:
370 unreachable("Invalid group operation");
371 }
372
373 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
374 NULL, reduction_op, cluster_size);
375 break;
376 }
377
378 default:
379 unreachable("Invalid SPIR-V opcode");
380 }
381 }