spirv: Add subgroup shuffle support
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 static void
27 vtn_build_subgroup_instr(struct vtn_builder *b,
28 nir_intrinsic_op nir_op,
29 struct vtn_ssa_value *dst,
30 struct vtn_ssa_value *src0,
31 nir_ssa_def *index)
32 {
33 /* Some of the subgroup operations take an index. SPIR-V allows this to be
34 * any integer type. To make things simpler for drivers, we only support
35 * 32-bit indices.
36 */
37 if (index && index->bit_size != 32)
38 index = nir_u2u32(&b->nb, index);
39
40 vtn_assert(dst->type == src0->type);
41 if (!glsl_type_is_vector_or_scalar(dst->type)) {
42 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
43 vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
44 src0->elems[i], index);
45 }
46 return;
47 }
48
49 nir_intrinsic_instr *intrin =
50 nir_intrinsic_instr_create(b->nb.shader, nir_op);
51 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
52 dst->type, NULL);
53 intrin->num_components = intrin->dest.ssa.num_components;
54
55 intrin->src[0] = nir_src_for_ssa(src0->def);
56 if (index)
57 intrin->src[1] = nir_src_for_ssa(index);
58
59 nir_builder_instr_insert(&b->nb, &intrin->instr);
60
61 dst->def = &intrin->dest.ssa;
62 }
63
64 void
65 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
66 const uint32_t *w, unsigned count)
67 {
68 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
69
70 val->ssa = vtn_create_ssa_value(b, val->type->type);
71
72 switch (opcode) {
73 case SpvOpGroupNonUniformElect: {
74 vtn_fail_if(val->type->type != glsl_bool_type(),
75 "OpGroupNonUniformElect must return a Bool");
76 nir_intrinsic_instr *elect =
77 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
78 nir_ssa_dest_init(&elect->instr, &elect->dest, 1, 32, NULL);
79 nir_builder_instr_insert(&b->nb, &elect->instr);
80 val->ssa->def = &elect->dest.ssa;
81 break;
82 }
83
84 case SpvOpGroupNonUniformBallot: {
85 vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
86 "OpGroupNonUniformBallot must return a uvec4");
87 nir_intrinsic_instr *ballot =
88 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
89 ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
90 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
91 ballot->num_components = 4;
92 nir_builder_instr_insert(&b->nb, &ballot->instr);
93 val->ssa->def = &ballot->dest.ssa;
94 break;
95 }
96
97 case SpvOpGroupNonUniformInverseBallot: {
98 /* This one is just a BallotBitfieldExtract with subgroup invocation.
99 * We could add a NIR intrinsic but it's easier to just lower it on the
100 * spot.
101 */
102 nir_intrinsic_instr *intrin =
103 nir_intrinsic_instr_create(b->nb.shader,
104 nir_intrinsic_ballot_bitfield_extract);
105
106 intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
107 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
108
109 nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
110 nir_builder_instr_insert(&b->nb, &intrin->instr);
111
112 val->ssa->def = &intrin->dest.ssa;
113 break;
114 }
115
116 case SpvOpGroupNonUniformBallotBitExtract:
117 case SpvOpGroupNonUniformBallotBitCount:
118 case SpvOpGroupNonUniformBallotFindLSB:
119 case SpvOpGroupNonUniformBallotFindMSB: {
120 nir_ssa_def *src0, *src1 = NULL;
121 nir_intrinsic_op op;
122 switch (opcode) {
123 case SpvOpGroupNonUniformBallotBitExtract:
124 op = nir_intrinsic_ballot_bitfield_extract;
125 src0 = vtn_ssa_value(b, w[4])->def;
126 src1 = vtn_ssa_value(b, w[5])->def;
127 break;
128 case SpvOpGroupNonUniformBallotBitCount:
129 switch ((SpvGroupOperation)w[4]) {
130 case SpvGroupOperationReduce:
131 op = nir_intrinsic_ballot_bit_count_reduce;
132 break;
133 case SpvGroupOperationInclusiveScan:
134 op = nir_intrinsic_ballot_bit_count_inclusive;
135 break;
136 case SpvGroupOperationExclusiveScan:
137 op = nir_intrinsic_ballot_bit_count_exclusive;
138 break;
139 default:
140 unreachable("Invalid group operation");
141 }
142 src0 = vtn_ssa_value(b, w[5])->def;
143 break;
144 case SpvOpGroupNonUniformBallotFindLSB:
145 op = nir_intrinsic_ballot_find_lsb;
146 src0 = vtn_ssa_value(b, w[4])->def;
147 break;
148 case SpvOpGroupNonUniformBallotFindMSB:
149 op = nir_intrinsic_ballot_find_msb;
150 src0 = vtn_ssa_value(b, w[4])->def;
151 break;
152 default:
153 unreachable("Unhandled opcode");
154 }
155
156 nir_intrinsic_instr *intrin =
157 nir_intrinsic_instr_create(b->nb.shader, op);
158
159 intrin->src[0] = nir_src_for_ssa(src0);
160 if (src1)
161 intrin->src[1] = nir_src_for_ssa(src1);
162
163 nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
164 nir_builder_instr_insert(&b->nb, &intrin->instr);
165
166 val->ssa->def = &intrin->dest.ssa;
167 break;
168 }
169
170 case SpvOpGroupNonUniformBroadcastFirst:
171 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
172 val->ssa, vtn_ssa_value(b, w[4]), NULL);
173 break;
174
175 case SpvOpGroupNonUniformBroadcast:
176 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
177 val->ssa, vtn_ssa_value(b, w[4]),
178 vtn_ssa_value(b, w[5])->def);
179 break;
180
181 case SpvOpGroupNonUniformAll:
182 case SpvOpGroupNonUniformAny:
183 case SpvOpGroupNonUniformAllEqual: {
184 vtn_fail_if(val->type->type != glsl_bool_type(),
185 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
186 nir_intrinsic_op op;
187 switch (opcode) {
188 case SpvOpGroupNonUniformAll:
189 op = nir_intrinsic_vote_all;
190 break;
191 case SpvOpGroupNonUniformAny:
192 op = nir_intrinsic_vote_any;
193 break;
194 case SpvOpGroupNonUniformAllEqual: {
195 switch (glsl_get_base_type(val->type->type)) {
196 case GLSL_TYPE_FLOAT:
197 case GLSL_TYPE_DOUBLE:
198 op = nir_intrinsic_vote_feq;
199 break;
200 case GLSL_TYPE_UINT:
201 case GLSL_TYPE_INT:
202 case GLSL_TYPE_UINT64:
203 case GLSL_TYPE_INT64:
204 case GLSL_TYPE_BOOL:
205 op = nir_intrinsic_vote_ieq;
206 break;
207 default:
208 unreachable("Unhandled type");
209 }
210 break;
211 }
212 default:
213 unreachable("Unhandled opcode");
214 }
215
216 nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def;
217
218 nir_intrinsic_instr *intrin =
219 nir_intrinsic_instr_create(b->nb.shader, op);
220 intrin->num_components = src0->num_components;
221 intrin->src[0] = nir_src_for_ssa(src0);
222 nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
223 nir_builder_instr_insert(&b->nb, &intrin->instr);
224
225 val->ssa->def = &intrin->dest.ssa;
226 break;
227 }
228
229 case SpvOpGroupNonUniformShuffle:
230 case SpvOpGroupNonUniformShuffleXor:
231 case SpvOpGroupNonUniformShuffleUp:
232 case SpvOpGroupNonUniformShuffleDown: {
233 nir_intrinsic_op op;
234 switch (opcode) {
235 case SpvOpGroupNonUniformShuffle:
236 op = nir_intrinsic_shuffle;
237 break;
238 case SpvOpGroupNonUniformShuffleXor:
239 op = nir_intrinsic_shuffle_xor;
240 break;
241 case SpvOpGroupNonUniformShuffleUp:
242 op = nir_intrinsic_shuffle_up;
243 break;
244 case SpvOpGroupNonUniformShuffleDown:
245 op = nir_intrinsic_shuffle_down;
246 break;
247 default:
248 unreachable("Invalid opcode");
249 }
250 vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
251 vtn_ssa_value(b, w[5])->def);
252 break;
253 }
254
255 case SpvOpGroupNonUniformIAdd:
256 case SpvOpGroupNonUniformFAdd:
257 case SpvOpGroupNonUniformIMul:
258 case SpvOpGroupNonUniformFMul:
259 case SpvOpGroupNonUniformSMin:
260 case SpvOpGroupNonUniformUMin:
261 case SpvOpGroupNonUniformFMin:
262 case SpvOpGroupNonUniformSMax:
263 case SpvOpGroupNonUniformUMax:
264 case SpvOpGroupNonUniformFMax:
265 case SpvOpGroupNonUniformBitwiseAnd:
266 case SpvOpGroupNonUniformBitwiseOr:
267 case SpvOpGroupNonUniformBitwiseXor:
268 case SpvOpGroupNonUniformLogicalAnd:
269 case SpvOpGroupNonUniformLogicalOr:
270 case SpvOpGroupNonUniformLogicalXor:
271 case SpvOpGroupNonUniformQuadBroadcast:
272 case SpvOpGroupNonUniformQuadSwap:
273 default:
274 unreachable("Invalid SPIR-V opcode");
275 }
276 }