compiler: add INTERP_MODE_COLOR for radeonsi
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static struct vtn_ssa_value *
27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *src0,
30 nir_ssa_def *index,
31 unsigned const_idx0,
32 unsigned const_idx1)
33 {
34 /* Some of the subgroup operations take an index. SPIR-V allows this to be
35 * any integer type. To make things simpler for drivers, we only support
36 * 32-bit indices.
37 */
38 if (index && index->bit_size != 32)
39 index = nir_u2u32(&b->nb, index);
40
41 struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42
43 vtn_assert(dst->type == src0->type);
44 if (!glsl_type_is_vector_or_scalar(dst->type)) {
45 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46 dst->elems[0] =
47 vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48 const_idx0, const_idx1);
49 }
50 return dst;
51 }
52
53 nir_intrinsic_instr *intrin =
54 nir_intrinsic_instr_create(b->nb.shader, nir_op);
55 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
56 dst->type, NULL);
57 intrin->num_components = intrin->dest.ssa.num_components;
58
59 intrin->src[0] = nir_src_for_ssa(src0->def);
60 if (index)
61 intrin->src[1] = nir_src_for_ssa(index);
62
63 intrin->const_index[0] = const_idx0;
64 intrin->const_index[1] = const_idx1;
65
66 nir_builder_instr_insert(&b->nb, &intrin->instr);
67
68 dst->def = &intrin->dest.ssa;
69
70 return dst;
71 }
72
73 void
74 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
75 const uint32_t *w, unsigned count)
76 {
77 struct vtn_type *dest_type = vtn_get_type(b, w[1]);
78
79 switch (opcode) {
80 case SpvOpGroupNonUniformElect: {
81 vtn_fail_if(dest_type->type != glsl_bool_type(),
82 "OpGroupNonUniformElect must return a Bool");
83 nir_intrinsic_instr *elect =
84 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
85 nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
86 dest_type->type, NULL);
87 nir_builder_instr_insert(&b->nb, &elect->instr);
88 vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
89 break;
90 }
91
92 case SpvOpGroupNonUniformBallot:
93 case SpvOpSubgroupBallotKHR: {
94 bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
95 vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
96 "OpGroupNonUniformBallot must return a uvec4");
97 nir_intrinsic_instr *ballot =
98 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
99 ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
100 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
101 ballot->num_components = 4;
102 nir_builder_instr_insert(&b->nb, &ballot->instr);
103 vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
104 break;
105 }
106
107 case SpvOpGroupNonUniformInverseBallot: {
108 /* This one is just a BallotBitfieldExtract with subgroup invocation.
109 * We could add a NIR intrinsic but it's easier to just lower it on the
110 * spot.
111 */
112 nir_intrinsic_instr *intrin =
113 nir_intrinsic_instr_create(b->nb.shader,
114 nir_intrinsic_ballot_bitfield_extract);
115
116 intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4]));
117 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
118
119 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
120 dest_type->type, NULL);
121 nir_builder_instr_insert(&b->nb, &intrin->instr);
122
123 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
124 break;
125 }
126
127 case SpvOpGroupNonUniformBallotBitExtract:
128 case SpvOpGroupNonUniformBallotBitCount:
129 case SpvOpGroupNonUniformBallotFindLSB:
130 case SpvOpGroupNonUniformBallotFindMSB: {
131 nir_ssa_def *src0, *src1 = NULL;
132 nir_intrinsic_op op;
133 switch (opcode) {
134 case SpvOpGroupNonUniformBallotBitExtract:
135 op = nir_intrinsic_ballot_bitfield_extract;
136 src0 = vtn_get_nir_ssa(b, w[4]);
137 src1 = vtn_get_nir_ssa(b, w[5]);
138 break;
139 case SpvOpGroupNonUniformBallotBitCount:
140 switch ((SpvGroupOperation)w[4]) {
141 case SpvGroupOperationReduce:
142 op = nir_intrinsic_ballot_bit_count_reduce;
143 break;
144 case SpvGroupOperationInclusiveScan:
145 op = nir_intrinsic_ballot_bit_count_inclusive;
146 break;
147 case SpvGroupOperationExclusiveScan:
148 op = nir_intrinsic_ballot_bit_count_exclusive;
149 break;
150 default:
151 unreachable("Invalid group operation");
152 }
153 src0 = vtn_get_nir_ssa(b, w[5]);
154 break;
155 case SpvOpGroupNonUniformBallotFindLSB:
156 op = nir_intrinsic_ballot_find_lsb;
157 src0 = vtn_get_nir_ssa(b, w[4]);
158 break;
159 case SpvOpGroupNonUniformBallotFindMSB:
160 op = nir_intrinsic_ballot_find_msb;
161 src0 = vtn_get_nir_ssa(b, w[4]);
162 break;
163 default:
164 unreachable("Unhandled opcode");
165 }
166
167 nir_intrinsic_instr *intrin =
168 nir_intrinsic_instr_create(b->nb.shader, op);
169
170 intrin->src[0] = nir_src_for_ssa(src0);
171 if (src1)
172 intrin->src[1] = nir_src_for_ssa(src1);
173
174 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
175 dest_type->type, NULL);
176 nir_builder_instr_insert(&b->nb, &intrin->instr);
177
178 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
179 break;
180 }
181
182 case SpvOpGroupNonUniformBroadcastFirst:
183 case SpvOpSubgroupFirstInvocationKHR: {
184 bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
185 vtn_push_ssa_value(b, w[2],
186 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
187 vtn_ssa_value(b, w[3 + has_scope]),
188 NULL, 0, 0));
189 break;
190 }
191
192 case SpvOpGroupNonUniformBroadcast:
193 case SpvOpGroupBroadcast:
194 case SpvOpSubgroupReadInvocationKHR: {
195 bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
196 vtn_push_ssa_value(b, w[2],
197 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
198 vtn_ssa_value(b, w[3 + has_scope]),
199 vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
200 break;
201 }
202
203 case SpvOpGroupNonUniformAll:
204 case SpvOpGroupNonUniformAny:
205 case SpvOpGroupNonUniformAllEqual:
206 case SpvOpGroupAll:
207 case SpvOpGroupAny:
208 case SpvOpSubgroupAllKHR:
209 case SpvOpSubgroupAnyKHR:
210 case SpvOpSubgroupAllEqualKHR: {
211 vtn_fail_if(dest_type->type != glsl_bool_type(),
212 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
213 nir_intrinsic_op op;
214 switch (opcode) {
215 case SpvOpGroupNonUniformAll:
216 case SpvOpGroupAll:
217 case SpvOpSubgroupAllKHR:
218 op = nir_intrinsic_vote_all;
219 break;
220 case SpvOpGroupNonUniformAny:
221 case SpvOpGroupAny:
222 case SpvOpSubgroupAnyKHR:
223 op = nir_intrinsic_vote_any;
224 break;
225 case SpvOpSubgroupAllEqualKHR:
226 op = nir_intrinsic_vote_ieq;
227 break;
228 case SpvOpGroupNonUniformAllEqual:
229 switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
230 case GLSL_TYPE_FLOAT:
231 case GLSL_TYPE_FLOAT16:
232 case GLSL_TYPE_DOUBLE:
233 op = nir_intrinsic_vote_feq;
234 break;
235 case GLSL_TYPE_UINT:
236 case GLSL_TYPE_INT:
237 case GLSL_TYPE_UINT8:
238 case GLSL_TYPE_INT8:
239 case GLSL_TYPE_UINT16:
240 case GLSL_TYPE_INT16:
241 case GLSL_TYPE_UINT64:
242 case GLSL_TYPE_INT64:
243 case GLSL_TYPE_BOOL:
244 op = nir_intrinsic_vote_ieq;
245 break;
246 default:
247 unreachable("Unhandled type");
248 }
249 break;
250 default:
251 unreachable("Unhandled opcode");
252 }
253
254 nir_ssa_def *src0;
255 if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
256 opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
257 opcode == SpvOpGroupNonUniformAllEqual) {
258 src0 = vtn_get_nir_ssa(b, w[4]);
259 } else {
260 src0 = vtn_get_nir_ssa(b, w[3]);
261 }
262 nir_intrinsic_instr *intrin =
263 nir_intrinsic_instr_create(b->nb.shader, op);
264 if (nir_intrinsic_infos[op].src_components[0] == 0)
265 intrin->num_components = src0->num_components;
266 intrin->src[0] = nir_src_for_ssa(src0);
267 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
268 dest_type->type, NULL);
269 nir_builder_instr_insert(&b->nb, &intrin->instr);
270
271 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
272 break;
273 }
274
275 case SpvOpGroupNonUniformShuffle:
276 case SpvOpGroupNonUniformShuffleXor:
277 case SpvOpGroupNonUniformShuffleUp:
278 case SpvOpGroupNonUniformShuffleDown: {
279 nir_intrinsic_op op;
280 switch (opcode) {
281 case SpvOpGroupNonUniformShuffle:
282 op = nir_intrinsic_shuffle;
283 break;
284 case SpvOpGroupNonUniformShuffleXor:
285 op = nir_intrinsic_shuffle_xor;
286 break;
287 case SpvOpGroupNonUniformShuffleUp:
288 op = nir_intrinsic_shuffle_up;
289 break;
290 case SpvOpGroupNonUniformShuffleDown:
291 op = nir_intrinsic_shuffle_down;
292 break;
293 default:
294 unreachable("Invalid opcode");
295 }
296 vtn_push_ssa_value(b, w[2],
297 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
298 vtn_get_nir_ssa(b, w[5]), 0, 0));
299 break;
300 }
301
302 case SpvOpGroupNonUniformQuadBroadcast:
303 vtn_push_ssa_value(b, w[2],
304 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
305 vtn_ssa_value(b, w[4]),
306 vtn_get_nir_ssa(b, w[5]), 0, 0));
307 break;
308
309 case SpvOpGroupNonUniformQuadSwap: {
310 unsigned direction = vtn_constant_uint(b, w[5]);
311 nir_intrinsic_op op;
312 switch (direction) {
313 case 0:
314 op = nir_intrinsic_quad_swap_horizontal;
315 break;
316 case 1:
317 op = nir_intrinsic_quad_swap_vertical;
318 break;
319 case 2:
320 op = nir_intrinsic_quad_swap_diagonal;
321 break;
322 default:
323 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
324 }
325 vtn_push_ssa_value(b, w[2],
326 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
327 break;
328 }
329
330 case SpvOpGroupNonUniformIAdd:
331 case SpvOpGroupNonUniformFAdd:
332 case SpvOpGroupNonUniformIMul:
333 case SpvOpGroupNonUniformFMul:
334 case SpvOpGroupNonUniformSMin:
335 case SpvOpGroupNonUniformUMin:
336 case SpvOpGroupNonUniformFMin:
337 case SpvOpGroupNonUniformSMax:
338 case SpvOpGroupNonUniformUMax:
339 case SpvOpGroupNonUniformFMax:
340 case SpvOpGroupNonUniformBitwiseAnd:
341 case SpvOpGroupNonUniformBitwiseOr:
342 case SpvOpGroupNonUniformBitwiseXor:
343 case SpvOpGroupNonUniformLogicalAnd:
344 case SpvOpGroupNonUniformLogicalOr:
345 case SpvOpGroupNonUniformLogicalXor:
346 case SpvOpGroupIAdd:
347 case SpvOpGroupFAdd:
348 case SpvOpGroupFMin:
349 case SpvOpGroupUMin:
350 case SpvOpGroupSMin:
351 case SpvOpGroupFMax:
352 case SpvOpGroupUMax:
353 case SpvOpGroupSMax:
354 case SpvOpGroupIAddNonUniformAMD:
355 case SpvOpGroupFAddNonUniformAMD:
356 case SpvOpGroupFMinNonUniformAMD:
357 case SpvOpGroupUMinNonUniformAMD:
358 case SpvOpGroupSMinNonUniformAMD:
359 case SpvOpGroupFMaxNonUniformAMD:
360 case SpvOpGroupUMaxNonUniformAMD:
361 case SpvOpGroupSMaxNonUniformAMD: {
362 nir_op reduction_op;
363 switch (opcode) {
364 case SpvOpGroupNonUniformIAdd:
365 case SpvOpGroupIAdd:
366 case SpvOpGroupIAddNonUniformAMD:
367 reduction_op = nir_op_iadd;
368 break;
369 case SpvOpGroupNonUniformFAdd:
370 case SpvOpGroupFAdd:
371 case SpvOpGroupFAddNonUniformAMD:
372 reduction_op = nir_op_fadd;
373 break;
374 case SpvOpGroupNonUniformIMul:
375 reduction_op = nir_op_imul;
376 break;
377 case SpvOpGroupNonUniformFMul:
378 reduction_op = nir_op_fmul;
379 break;
380 case SpvOpGroupNonUniformSMin:
381 case SpvOpGroupSMin:
382 case SpvOpGroupSMinNonUniformAMD:
383 reduction_op = nir_op_imin;
384 break;
385 case SpvOpGroupNonUniformUMin:
386 case SpvOpGroupUMin:
387 case SpvOpGroupUMinNonUniformAMD:
388 reduction_op = nir_op_umin;
389 break;
390 case SpvOpGroupNonUniformFMin:
391 case SpvOpGroupFMin:
392 case SpvOpGroupFMinNonUniformAMD:
393 reduction_op = nir_op_fmin;
394 break;
395 case SpvOpGroupNonUniformSMax:
396 case SpvOpGroupSMax:
397 case SpvOpGroupSMaxNonUniformAMD:
398 reduction_op = nir_op_imax;
399 break;
400 case SpvOpGroupNonUniformUMax:
401 case SpvOpGroupUMax:
402 case SpvOpGroupUMaxNonUniformAMD:
403 reduction_op = nir_op_umax;
404 break;
405 case SpvOpGroupNonUniformFMax:
406 case SpvOpGroupFMax:
407 case SpvOpGroupFMaxNonUniformAMD:
408 reduction_op = nir_op_fmax;
409 break;
410 case SpvOpGroupNonUniformBitwiseAnd:
411 case SpvOpGroupNonUniformLogicalAnd:
412 reduction_op = nir_op_iand;
413 break;
414 case SpvOpGroupNonUniformBitwiseOr:
415 case SpvOpGroupNonUniformLogicalOr:
416 reduction_op = nir_op_ior;
417 break;
418 case SpvOpGroupNonUniformBitwiseXor:
419 case SpvOpGroupNonUniformLogicalXor:
420 reduction_op = nir_op_ixor;
421 break;
422 default:
423 unreachable("Invalid reduction operation");
424 }
425
426 nir_intrinsic_op op;
427 unsigned cluster_size = 0;
428 switch ((SpvGroupOperation)w[4]) {
429 case SpvGroupOperationReduce:
430 op = nir_intrinsic_reduce;
431 break;
432 case SpvGroupOperationInclusiveScan:
433 op = nir_intrinsic_inclusive_scan;
434 break;
435 case SpvGroupOperationExclusiveScan:
436 op = nir_intrinsic_exclusive_scan;
437 break;
438 case SpvGroupOperationClusteredReduce:
439 op = nir_intrinsic_reduce;
440 assert(count == 7);
441 cluster_size = vtn_constant_uint(b, w[6]);
442 break;
443 default:
444 unreachable("Invalid group operation");
445 }
446
447 vtn_push_ssa_value(b, w[2],
448 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
449 reduction_op, cluster_size));
450 break;
451 }
452
453 default:
454 unreachable("Invalid SPIR-V opcode");
455 }
456 }