spirv: Replace vtn_constant_value with vtn_constant_uint
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static void
27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *dst,
30 struct vtn_ssa_value *src0,
31 nir_ssa_def *index,
32 unsigned const_idx0,
33 unsigned const_idx1)
34 {
35 /* Some of the subgroup operations take an index. SPIR-V allows this to be
36 * any integer type. To make things simpler for drivers, we only support
37 * 32-bit indices.
38 */
39 if (index && index->bit_size != 32)
40 index = nir_u2u32(&b->nb, index);
41
42 vtn_assert(dst->type == src0->type);
43 if (!glsl_type_is_vector_or_scalar(dst->type)) {
44 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
45 vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
46 src0->elems[i], index,
47 const_idx0, const_idx1);
48 }
49 return;
50 }
51
52 nir_intrinsic_instr *intrin =
53 nir_intrinsic_instr_create(b->nb.shader, nir_op);
54 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
55 dst->type, NULL);
56 intrin->num_components = intrin->dest.ssa.num_components;
57
58 intrin->src[0] = nir_src_for_ssa(src0->def);
59 if (index)
60 intrin->src[1] = nir_src_for_ssa(index);
61
62 intrin->const_index[0] = const_idx0;
63 intrin->const_index[1] = const_idx1;
64
65 nir_builder_instr_insert(&b->nb, &intrin->instr);
66
67 dst->def = &intrin->dest.ssa;
68 }
69
70 void
71 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
72 const uint32_t *w, unsigned count)
73 {
74 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
75
76 val->ssa = vtn_create_ssa_value(b, val->type->type);
77
78 switch (opcode) {
79 case SpvOpGroupNonUniformElect: {
80 vtn_fail_if(val->type->type != glsl_bool_type(),
81 "OpGroupNonUniformElect must return a Bool");
82 nir_intrinsic_instr *elect =
83 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
84 nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
85 val->type->type, NULL);
86 nir_builder_instr_insert(&b->nb, &elect->instr);
87 val->ssa->def = &elect->dest.ssa;
88 break;
89 }
90
91 case SpvOpGroupNonUniformBallot: {
92 vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
93 "OpGroupNonUniformBallot must return a uvec4");
94 nir_intrinsic_instr *ballot =
95 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
96 ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
97 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
98 ballot->num_components = 4;
99 nir_builder_instr_insert(&b->nb, &ballot->instr);
100 val->ssa->def = &ballot->dest.ssa;
101 break;
102 }
103
104 case SpvOpGroupNonUniformInverseBallot: {
105 /* This one is just a BallotBitfieldExtract with subgroup invocation.
106 * We could add a NIR intrinsic but it's easier to just lower it on the
107 * spot.
108 */
109 nir_intrinsic_instr *intrin =
110 nir_intrinsic_instr_create(b->nb.shader,
111 nir_intrinsic_ballot_bitfield_extract);
112
113 intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
114 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
115
116 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
117 val->type->type, NULL);
118 nir_builder_instr_insert(&b->nb, &intrin->instr);
119
120 val->ssa->def = &intrin->dest.ssa;
121 break;
122 }
123
124 case SpvOpGroupNonUniformBallotBitExtract:
125 case SpvOpGroupNonUniformBallotBitCount:
126 case SpvOpGroupNonUniformBallotFindLSB:
127 case SpvOpGroupNonUniformBallotFindMSB: {
128 nir_ssa_def *src0, *src1 = NULL;
129 nir_intrinsic_op op;
130 switch (opcode) {
131 case SpvOpGroupNonUniformBallotBitExtract:
132 op = nir_intrinsic_ballot_bitfield_extract;
133 src0 = vtn_ssa_value(b, w[4])->def;
134 src1 = vtn_ssa_value(b, w[5])->def;
135 break;
136 case SpvOpGroupNonUniformBallotBitCount:
137 switch ((SpvGroupOperation)w[4]) {
138 case SpvGroupOperationReduce:
139 op = nir_intrinsic_ballot_bit_count_reduce;
140 break;
141 case SpvGroupOperationInclusiveScan:
142 op = nir_intrinsic_ballot_bit_count_inclusive;
143 break;
144 case SpvGroupOperationExclusiveScan:
145 op = nir_intrinsic_ballot_bit_count_exclusive;
146 break;
147 default:
148 unreachable("Invalid group operation");
149 }
150 src0 = vtn_ssa_value(b, w[5])->def;
151 break;
152 case SpvOpGroupNonUniformBallotFindLSB:
153 op = nir_intrinsic_ballot_find_lsb;
154 src0 = vtn_ssa_value(b, w[4])->def;
155 break;
156 case SpvOpGroupNonUniformBallotFindMSB:
157 op = nir_intrinsic_ballot_find_msb;
158 src0 = vtn_ssa_value(b, w[4])->def;
159 break;
160 default:
161 unreachable("Unhandled opcode");
162 }
163
164 nir_intrinsic_instr *intrin =
165 nir_intrinsic_instr_create(b->nb.shader, op);
166
167 intrin->src[0] = nir_src_for_ssa(src0);
168 if (src1)
169 intrin->src[1] = nir_src_for_ssa(src1);
170
171 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
172 val->type->type, NULL);
173 nir_builder_instr_insert(&b->nb, &intrin->instr);
174
175 val->ssa->def = &intrin->dest.ssa;
176 break;
177 }
178
179 case SpvOpGroupNonUniformBroadcastFirst:
180 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
181 val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0);
182 break;
183
184 case SpvOpGroupNonUniformBroadcast:
185 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
186 val->ssa, vtn_ssa_value(b, w[4]),
187 vtn_ssa_value(b, w[5])->def, 0, 0);
188 break;
189
190 case SpvOpGroupNonUniformAll:
191 case SpvOpGroupNonUniformAny:
192 case SpvOpGroupNonUniformAllEqual: {
193 vtn_fail_if(val->type->type != glsl_bool_type(),
194 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
195 nir_intrinsic_op op;
196 switch (opcode) {
197 case SpvOpGroupNonUniformAll:
198 op = nir_intrinsic_vote_all;
199 break;
200 case SpvOpGroupNonUniformAny:
201 op = nir_intrinsic_vote_any;
202 break;
203 case SpvOpGroupNonUniformAllEqual: {
204 switch (glsl_get_base_type(val->type->type)) {
205 case GLSL_TYPE_FLOAT:
206 case GLSL_TYPE_DOUBLE:
207 op = nir_intrinsic_vote_feq;
208 break;
209 case GLSL_TYPE_UINT:
210 case GLSL_TYPE_INT:
211 case GLSL_TYPE_UINT64:
212 case GLSL_TYPE_INT64:
213 case GLSL_TYPE_BOOL:
214 op = nir_intrinsic_vote_ieq;
215 break;
216 default:
217 unreachable("Unhandled type");
218 }
219 break;
220 }
221 default:
222 unreachable("Unhandled opcode");
223 }
224
225 nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def;
226
227 nir_intrinsic_instr *intrin =
228 nir_intrinsic_instr_create(b->nb.shader, op);
229 intrin->num_components = src0->num_components;
230 intrin->src[0] = nir_src_for_ssa(src0);
231 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
232 val->type->type, NULL);
233 nir_builder_instr_insert(&b->nb, &intrin->instr);
234
235 val->ssa->def = &intrin->dest.ssa;
236 break;
237 }
238
239 case SpvOpGroupNonUniformShuffle:
240 case SpvOpGroupNonUniformShuffleXor:
241 case SpvOpGroupNonUniformShuffleUp:
242 case SpvOpGroupNonUniformShuffleDown: {
243 nir_intrinsic_op op;
244 switch (opcode) {
245 case SpvOpGroupNonUniformShuffle:
246 op = nir_intrinsic_shuffle;
247 break;
248 case SpvOpGroupNonUniformShuffleXor:
249 op = nir_intrinsic_shuffle_xor;
250 break;
251 case SpvOpGroupNonUniformShuffleUp:
252 op = nir_intrinsic_shuffle_up;
253 break;
254 case SpvOpGroupNonUniformShuffleDown:
255 op = nir_intrinsic_shuffle_down;
256 break;
257 default:
258 unreachable("Invalid opcode");
259 }
260 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
261 vtn_ssa_value(b, w[5])->def, 0, 0);
262 break;
263 }
264
265 case SpvOpGroupNonUniformQuadBroadcast:
266 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
267 val->ssa, vtn_ssa_value(b, w[4]),
268 vtn_ssa_value(b, w[5])->def, 0, 0);
269 break;
270
271 case SpvOpGroupNonUniformQuadSwap: {
272 unsigned direction = vtn_constant_uint(b, w[5]);
273 nir_intrinsic_op op;
274 switch (direction) {
275 case 0:
276 op = nir_intrinsic_quad_swap_horizontal;
277 break;
278 case 1:
279 op = nir_intrinsic_quad_swap_vertical;
280 break;
281 case 2:
282 op = nir_intrinsic_quad_swap_diagonal;
283 break;
284 default:
285 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
286 }
287 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
288 NULL, 0, 0);
289 break;
290 }
291
292 case SpvOpGroupNonUniformIAdd:
293 case SpvOpGroupNonUniformFAdd:
294 case SpvOpGroupNonUniformIMul:
295 case SpvOpGroupNonUniformFMul:
296 case SpvOpGroupNonUniformSMin:
297 case SpvOpGroupNonUniformUMin:
298 case SpvOpGroupNonUniformFMin:
299 case SpvOpGroupNonUniformSMax:
300 case SpvOpGroupNonUniformUMax:
301 case SpvOpGroupNonUniformFMax:
302 case SpvOpGroupNonUniformBitwiseAnd:
303 case SpvOpGroupNonUniformBitwiseOr:
304 case SpvOpGroupNonUniformBitwiseXor:
305 case SpvOpGroupNonUniformLogicalAnd:
306 case SpvOpGroupNonUniformLogicalOr:
307 case SpvOpGroupNonUniformLogicalXor: {
308 nir_op reduction_op;
309 switch (opcode) {
310 case SpvOpGroupNonUniformIAdd:
311 reduction_op = nir_op_iadd;
312 break;
313 case SpvOpGroupNonUniformFAdd:
314 reduction_op = nir_op_fadd;
315 break;
316 case SpvOpGroupNonUniformIMul:
317 reduction_op = nir_op_imul;
318 break;
319 case SpvOpGroupNonUniformFMul:
320 reduction_op = nir_op_fmul;
321 break;
322 case SpvOpGroupNonUniformSMin:
323 reduction_op = nir_op_imin;
324 break;
325 case SpvOpGroupNonUniformUMin:
326 reduction_op = nir_op_umin;
327 break;
328 case SpvOpGroupNonUniformFMin:
329 reduction_op = nir_op_fmin;
330 break;
331 case SpvOpGroupNonUniformSMax:
332 reduction_op = nir_op_imax;
333 break;
334 case SpvOpGroupNonUniformUMax:
335 reduction_op = nir_op_umax;
336 break;
337 case SpvOpGroupNonUniformFMax:
338 reduction_op = nir_op_fmax;
339 break;
340 case SpvOpGroupNonUniformBitwiseAnd:
341 case SpvOpGroupNonUniformLogicalAnd:
342 reduction_op = nir_op_iand;
343 break;
344 case SpvOpGroupNonUniformBitwiseOr:
345 case SpvOpGroupNonUniformLogicalOr:
346 reduction_op = nir_op_ior;
347 break;
348 case SpvOpGroupNonUniformBitwiseXor:
349 case SpvOpGroupNonUniformLogicalXor:
350 reduction_op = nir_op_ixor;
351 break;
352 default:
353 unreachable("Invalid reduction operation");
354 }
355
356 nir_intrinsic_op op;
357 unsigned cluster_size = 0;
358 switch ((SpvGroupOperation)w[4]) {
359 case SpvGroupOperationReduce:
360 op = nir_intrinsic_reduce;
361 break;
362 case SpvGroupOperationInclusiveScan:
363 op = nir_intrinsic_inclusive_scan;
364 break;
365 case SpvGroupOperationExclusiveScan:
366 op = nir_intrinsic_exclusive_scan;
367 break;
368 case SpvGroupOperationClusteredReduce:
369 op = nir_intrinsic_reduce;
370 assert(count == 7);
371 cluster_size = vtn_constant_uint(b, w[6]);
372 break;
373 default:
374 unreachable("Invalid group operation");
375 }
376
377 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
378 NULL, reduction_op, cluster_size);
379 break;
380 }
381
382 default:
383 unreachable("Invalid SPIR-V opcode");
384 }
385 }