spirv: Update headers and grammar json
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static void
27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *dst,
30 struct vtn_ssa_value *src0,
31 nir_ssa_def *index,
32 unsigned const_idx0,
33 unsigned const_idx1)
34 {
35 /* Some of the subgroup operations take an index. SPIR-V allows this to be
36 * any integer type. To make things simpler for drivers, we only support
37 * 32-bit indices.
38 */
39 if (index && index->bit_size != 32)
40 index = nir_u2u32(&b->nb, index);
41
42 vtn_assert(dst->type == src0->type);
43 if (!glsl_type_is_vector_or_scalar(dst->type)) {
44 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
45 vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
46 src0->elems[i], index,
47 const_idx0, const_idx1);
48 }
49 return;
50 }
51
52 nir_intrinsic_instr *intrin =
53 nir_intrinsic_instr_create(b->nb.shader, nir_op);
54 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
55 dst->type, NULL);
56 intrin->num_components = intrin->dest.ssa.num_components;
57
58 intrin->src[0] = nir_src_for_ssa(src0->def);
59 if (index)
60 intrin->src[1] = nir_src_for_ssa(index);
61
62 intrin->const_index[0] = const_idx0;
63 intrin->const_index[1] = const_idx1;
64
65 nir_builder_instr_insert(&b->nb, &intrin->instr);
66
67 dst->def = &intrin->dest.ssa;
68 }
69
70 void
71 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
72 const uint32_t *w, unsigned count)
73 {
74 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
75
76 val->ssa = vtn_create_ssa_value(b, val->type->type);
77
78 switch (opcode) {
79 case SpvOpGroupNonUniformElect: {
80 vtn_fail_if(val->type->type != glsl_bool_type(),
81 "OpGroupNonUniformElect must return a Bool");
82 nir_intrinsic_instr *elect =
83 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84 nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
85 val->type->type, NULL);
86 nir_builder_instr_insert(&b->nb, &elect->instr);
87 val->ssa->def = &elect->dest.ssa;
88 break;
89 }
90
91 case SpvOpGroupNonUniformBallot: ++w; /* fallthrough */
92 case SpvOpSubgroupBallotKHR: {
93 vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
94 "OpGroupNonUniformBallot must return a uvec4");
95 nir_intrinsic_instr *ballot =
96 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
97 ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def);
98 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
99 ballot->num_components = 4;
100 nir_builder_instr_insert(&b->nb, &ballot->instr);
101 val->ssa->def = &ballot->dest.ssa;
102 break;
103 }
104
105 case SpvOpGroupNonUniformInverseBallot: {
106 /* This one is just a BallotBitfieldExtract with subgroup invocation.
107 * We could add a NIR intrinsic but it's easier to just lower it on the
108 * spot.
109 */
110 nir_intrinsic_instr *intrin =
111 nir_intrinsic_instr_create(b->nb.shader,
112 nir_intrinsic_ballot_bitfield_extract);
113
114 intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
115 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
116
117 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
118 val->type->type, NULL);
119 nir_builder_instr_insert(&b->nb, &intrin->instr);
120
121 val->ssa->def = &intrin->dest.ssa;
122 break;
123 }
124
125 case SpvOpGroupNonUniformBallotBitExtract:
126 case SpvOpGroupNonUniformBallotBitCount:
127 case SpvOpGroupNonUniformBallotFindLSB:
128 case SpvOpGroupNonUniformBallotFindMSB: {
129 nir_ssa_def *src0, *src1 = NULL;
130 nir_intrinsic_op op;
131 switch (opcode) {
132 case SpvOpGroupNonUniformBallotBitExtract:
133 op = nir_intrinsic_ballot_bitfield_extract;
134 src0 = vtn_ssa_value(b, w[4])->def;
135 src1 = vtn_ssa_value(b, w[5])->def;
136 break;
137 case SpvOpGroupNonUniformBallotBitCount:
138 switch ((SpvGroupOperation)w[4]) {
139 case SpvGroupOperationReduce:
140 op = nir_intrinsic_ballot_bit_count_reduce;
141 break;
142 case SpvGroupOperationInclusiveScan:
143 op = nir_intrinsic_ballot_bit_count_inclusive;
144 break;
145 case SpvGroupOperationExclusiveScan:
146 op = nir_intrinsic_ballot_bit_count_exclusive;
147 break;
148 default:
149 unreachable("Invalid group operation");
150 }
151 src0 = vtn_ssa_value(b, w[5])->def;
152 break;
153 case SpvOpGroupNonUniformBallotFindLSB:
154 op = nir_intrinsic_ballot_find_lsb;
155 src0 = vtn_ssa_value(b, w[4])->def;
156 break;
157 case SpvOpGroupNonUniformBallotFindMSB:
158 op = nir_intrinsic_ballot_find_msb;
159 src0 = vtn_ssa_value(b, w[4])->def;
160 break;
161 default:
162 unreachable("Unhandled opcode");
163 }
164
165 nir_intrinsic_instr *intrin =
166 nir_intrinsic_instr_create(b->nb.shader, op);
167
168 intrin->src[0] = nir_src_for_ssa(src0);
169 if (src1)
170 intrin->src[1] = nir_src_for_ssa(src1);
171
172 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
173 val->type->type, NULL);
174 nir_builder_instr_insert(&b->nb, &intrin->instr);
175
176 val->ssa->def = &intrin->dest.ssa;
177 break;
178 }
179
180 case SpvOpGroupNonUniformBroadcastFirst: ++w; /* fallthrough */
181 case SpvOpSubgroupFirstInvocationKHR:
182 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
183 val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
184 break;
185
186 case SpvOpGroupNonUniformBroadcast:
187 case SpvOpGroupBroadcast: ++w; /* fallthrough */
188 case SpvOpSubgroupReadInvocationKHR:
189 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
190 val->ssa, vtn_ssa_value(b, w[3]),
191 vtn_ssa_value(b, w[4])->def, 0, 0);
192 break;
193
194 case SpvOpGroupNonUniformAll:
195 case SpvOpGroupNonUniformAny:
196 case SpvOpGroupNonUniformAllEqual:
197 case SpvOpGroupAll:
198 case SpvOpGroupAny:
199 case SpvOpSubgroupAllKHR:
200 case SpvOpSubgroupAnyKHR:
201 case SpvOpSubgroupAllEqualKHR: {
202 vtn_fail_if(val->type->type != glsl_bool_type(),
203 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
204 nir_intrinsic_op op;
205 switch (opcode) {
206 case SpvOpGroupNonUniformAll:
207 case SpvOpGroupAll:
208 case SpvOpSubgroupAllKHR:
209 op = nir_intrinsic_vote_all;
210 break;
211 case SpvOpGroupNonUniformAny:
212 case SpvOpGroupAny:
213 case SpvOpSubgroupAnyKHR:
214 op = nir_intrinsic_vote_any;
215 break;
216 case SpvOpSubgroupAllEqualKHR:
217 op = nir_intrinsic_vote_ieq;
218 break;
219 case SpvOpGroupNonUniformAllEqual:
220 switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
221 case GLSL_TYPE_FLOAT:
222 case GLSL_TYPE_FLOAT16:
223 case GLSL_TYPE_DOUBLE:
224 op = nir_intrinsic_vote_feq;
225 break;
226 case GLSL_TYPE_UINT:
227 case GLSL_TYPE_INT:
228 case GLSL_TYPE_UINT8:
229 case GLSL_TYPE_INT8:
230 case GLSL_TYPE_UINT16:
231 case GLSL_TYPE_INT16:
232 case GLSL_TYPE_UINT64:
233 case GLSL_TYPE_INT64:
234 case GLSL_TYPE_BOOL:
235 op = nir_intrinsic_vote_ieq;
236 break;
237 default:
238 unreachable("Unhandled type");
239 }
240 break;
241 default:
242 unreachable("Unhandled opcode");
243 }
244
245 nir_ssa_def *src0;
246 if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
247 opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
248 opcode == SpvOpGroupNonUniformAllEqual) {
249 src0 = vtn_ssa_value(b, w[4])->def;
250 } else {
251 src0 = vtn_ssa_value(b, w[3])->def;
252 }
253 nir_intrinsic_instr *intrin =
254 nir_intrinsic_instr_create(b->nb.shader, op);
255 if (nir_intrinsic_infos[op].src_components[0] == 0)
256 intrin->num_components = src0->num_components;
257 intrin->src[0] = nir_src_for_ssa(src0);
258 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
259 val->type->type, NULL);
260 nir_builder_instr_insert(&b->nb, &intrin->instr);
261
262 val->ssa->def = &intrin->dest.ssa;
263 break;
264 }
265
266 case SpvOpGroupNonUniformShuffle:
267 case SpvOpGroupNonUniformShuffleXor:
268 case SpvOpGroupNonUniformShuffleUp:
269 case SpvOpGroupNonUniformShuffleDown: {
270 nir_intrinsic_op op;
271 switch (opcode) {
272 case SpvOpGroupNonUniformShuffle:
273 op = nir_intrinsic_shuffle;
274 break;
275 case SpvOpGroupNonUniformShuffleXor:
276 op = nir_intrinsic_shuffle_xor;
277 break;
278 case SpvOpGroupNonUniformShuffleUp:
279 op = nir_intrinsic_shuffle_up;
280 break;
281 case SpvOpGroupNonUniformShuffleDown:
282 op = nir_intrinsic_shuffle_down;
283 break;
284 default:
285 unreachable("Invalid opcode");
286 }
287 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
288 vtn_ssa_value(b, w[5])->def, 0, 0);
289 break;
290 }
291
292 case SpvOpGroupNonUniformQuadBroadcast:
293 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
294 val->ssa, vtn_ssa_value(b, w[4]),
295 vtn_ssa_value(b, w[5])->def, 0, 0);
296 break;
297
298 case SpvOpGroupNonUniformQuadSwap: {
299 unsigned direction = vtn_constant_uint(b, w[5]);
300 nir_intrinsic_op op;
301 switch (direction) {
302 case 0:
303 op = nir_intrinsic_quad_swap_horizontal;
304 break;
305 case 1:
306 op = nir_intrinsic_quad_swap_vertical;
307 break;
308 case 2:
309 op = nir_intrinsic_quad_swap_diagonal;
310 break;
311 default:
312 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
313 }
314 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
315 NULL, 0, 0);
316 break;
317 }
318
319 case SpvOpGroupNonUniformIAdd:
320 case SpvOpGroupNonUniformFAdd:
321 case SpvOpGroupNonUniformIMul:
322 case SpvOpGroupNonUniformFMul:
323 case SpvOpGroupNonUniformSMin:
324 case SpvOpGroupNonUniformUMin:
325 case SpvOpGroupNonUniformFMin:
326 case SpvOpGroupNonUniformSMax:
327 case SpvOpGroupNonUniformUMax:
328 case SpvOpGroupNonUniformFMax:
329 case SpvOpGroupNonUniformBitwiseAnd:
330 case SpvOpGroupNonUniformBitwiseOr:
331 case SpvOpGroupNonUniformBitwiseXor:
332 case SpvOpGroupNonUniformLogicalAnd:
333 case SpvOpGroupNonUniformLogicalOr:
334 case SpvOpGroupNonUniformLogicalXor:
335 case SpvOpGroupIAdd:
336 case SpvOpGroupFAdd:
337 case SpvOpGroupFMin:
338 case SpvOpGroupUMin:
339 case SpvOpGroupSMin:
340 case SpvOpGroupFMax:
341 case SpvOpGroupUMax:
342 case SpvOpGroupSMax:
343 case SpvOpGroupIAddNonUniformAMD:
344 case SpvOpGroupFAddNonUniformAMD:
345 case SpvOpGroupFMinNonUniformAMD:
346 case SpvOpGroupUMinNonUniformAMD:
347 case SpvOpGroupSMinNonUniformAMD:
348 case SpvOpGroupFMaxNonUniformAMD:
349 case SpvOpGroupUMaxNonUniformAMD:
350 case SpvOpGroupSMaxNonUniformAMD: {
351 nir_op reduction_op;
352 switch (opcode) {
353 case SpvOpGroupNonUniformIAdd:
354 case SpvOpGroupIAdd:
355 case SpvOpGroupIAddNonUniformAMD:
356 reduction_op = nir_op_iadd;
357 break;
358 case SpvOpGroupNonUniformFAdd:
359 case SpvOpGroupFAdd:
360 case SpvOpGroupFAddNonUniformAMD:
361 reduction_op = nir_op_fadd;
362 break;
363 case SpvOpGroupNonUniformIMul:
364 reduction_op = nir_op_imul;
365 break;
366 case SpvOpGroupNonUniformFMul:
367 reduction_op = nir_op_fmul;
368 break;
369 case SpvOpGroupNonUniformSMin:
370 case SpvOpGroupSMin:
371 case SpvOpGroupSMinNonUniformAMD:
372 reduction_op = nir_op_imin;
373 break;
374 case SpvOpGroupNonUniformUMin:
375 case SpvOpGroupUMin:
376 case SpvOpGroupUMinNonUniformAMD:
377 reduction_op = nir_op_umin;
378 break;
379 case SpvOpGroupNonUniformFMin:
380 case SpvOpGroupFMin:
381 case SpvOpGroupFMinNonUniformAMD:
382 reduction_op = nir_op_fmin;
383 break;
384 case SpvOpGroupNonUniformSMax:
385 case SpvOpGroupSMax:
386 case SpvOpGroupSMaxNonUniformAMD:
387 reduction_op = nir_op_imax;
388 break;
389 case SpvOpGroupNonUniformUMax:
390 case SpvOpGroupUMax:
391 case SpvOpGroupUMaxNonUniformAMD:
392 reduction_op = nir_op_umax;
393 break;
394 case SpvOpGroupNonUniformFMax:
395 case SpvOpGroupFMax:
396 case SpvOpGroupFMaxNonUniformAMD:
397 reduction_op = nir_op_fmax;
398 break;
399 case SpvOpGroupNonUniformBitwiseAnd:
400 case SpvOpGroupNonUniformLogicalAnd:
401 reduction_op = nir_op_iand;
402 break;
403 case SpvOpGroupNonUniformBitwiseOr:
404 case SpvOpGroupNonUniformLogicalOr:
405 reduction_op = nir_op_ior;
406 break;
407 case SpvOpGroupNonUniformBitwiseXor:
408 case SpvOpGroupNonUniformLogicalXor:
409 reduction_op = nir_op_ixor;
410 break;
411 default:
412 unreachable("Invalid reduction operation");
413 }
414
415 nir_intrinsic_op op;
416 unsigned cluster_size = 0;
417 switch ((SpvGroupOperation)w[4]) {
418 case SpvGroupOperationReduce:
419 op = nir_intrinsic_reduce;
420 break;
421 case SpvGroupOperationInclusiveScan:
422 op = nir_intrinsic_inclusive_scan;
423 break;
424 case SpvGroupOperationExclusiveScan:
425 op = nir_intrinsic_exclusive_scan;
426 break;
427 case SpvGroupOperationClusteredReduce:
428 op = nir_intrinsic_reduce;
429 assert(count == 7);
430 cluster_size = vtn_constant_uint(b, w[6]);
431 break;
432 default:
433 unreachable("Invalid group operation");
434 }
435
436 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
437 NULL, reduction_op, cluster_size);
438 break;
439 }
440
441 default:
442 unreachable("Invalid SPIR-V opcode");
443 }
444 }