2 * Copyright © 2016 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 #include "vtn_private.h"
27 vtn_build_subgroup_instr(struct vtn_builder
*b
,
28 nir_intrinsic_op nir_op
,
29 struct vtn_ssa_value
*dst
,
30 struct vtn_ssa_value
*src0
,
35 /* Some of the subgroup operations take an index. SPIR-V allows this to be
36 * any integer type. To make things simpler for drivers, we only support
39 if (index
&& index
->bit_size
!= 32)
40 index
= nir_u2u32(&b
->nb
, index
);
42 vtn_assert(dst
->type
== src0
->type
);
43 if (!glsl_type_is_vector_or_scalar(dst
->type
)) {
44 for (unsigned i
= 0; i
< glsl_get_length(dst
->type
); i
++) {
45 vtn_build_subgroup_instr(b
, nir_op
, dst
->elems
[i
],
46 src0
->elems
[i
], index
,
47 const_idx0
, const_idx1
);
52 nir_intrinsic_instr
*intrin
=
53 nir_intrinsic_instr_create(b
->nb
.shader
, nir_op
);
54 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
56 intrin
->num_components
= intrin
->dest
.ssa
.num_components
;
58 intrin
->src
[0] = nir_src_for_ssa(src0
->def
);
60 intrin
->src
[1] = nir_src_for_ssa(index
);
62 intrin
->const_index
[0] = const_idx0
;
63 intrin
->const_index
[1] = const_idx1
;
65 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
67 dst
->def
= &intrin
->dest
.ssa
;
71 vtn_handle_subgroup(struct vtn_builder
*b
, SpvOp opcode
,
72 const uint32_t *w
, unsigned count
)
74 struct vtn_value
*val
= vtn_push_value(b
, w
[2], vtn_value_type_ssa
);
76 val
->ssa
= vtn_create_ssa_value(b
, val
->type
->type
);
79 case SpvOpGroupNonUniformElect
: {
80 vtn_fail_if(val
->type
->type
!= glsl_bool_type(),
81 "OpGroupNonUniformElect must return a Bool");
82 nir_intrinsic_instr
*elect
=
83 nir_intrinsic_instr_create(b
->nb
.shader
, nir_intrinsic_elect
);
84 nir_ssa_dest_init_for_type(&elect
->instr
, &elect
->dest
,
85 val
->type
->type
, NULL
);
86 nir_builder_instr_insert(&b
->nb
, &elect
->instr
);
87 val
->ssa
->def
= &elect
->dest
.ssa
;
91 case SpvOpGroupNonUniformBallot
: ++w
; /* fallthrough */
92 case SpvOpSubgroupBallotKHR
: {
93 vtn_fail_if(val
->type
->type
!= glsl_vector_type(GLSL_TYPE_UINT
, 4),
94 "OpGroupNonUniformBallot must return a uvec4");
95 nir_intrinsic_instr
*ballot
=
96 nir_intrinsic_instr_create(b
->nb
.shader
, nir_intrinsic_ballot
);
97 ballot
->src
[0] = nir_src_for_ssa(vtn_ssa_value(b
, w
[3])->def
);
98 nir_ssa_dest_init(&ballot
->instr
, &ballot
->dest
, 4, 32, NULL
);
99 ballot
->num_components
= 4;
100 nir_builder_instr_insert(&b
->nb
, &ballot
->instr
);
101 val
->ssa
->def
= &ballot
->dest
.ssa
;
105 case SpvOpGroupNonUniformInverseBallot
: {
106 /* This one is just a BallotBitfieldExtract with subgroup invocation.
107 * We could add a NIR intrinsic but it's easier to just lower it on the
110 nir_intrinsic_instr
*intrin
=
111 nir_intrinsic_instr_create(b
->nb
.shader
,
112 nir_intrinsic_ballot_bitfield_extract
);
114 intrin
->src
[0] = nir_src_for_ssa(vtn_ssa_value(b
, w
[4])->def
);
115 intrin
->src
[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b
->nb
));
117 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
118 val
->type
->type
, NULL
);
119 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
121 val
->ssa
->def
= &intrin
->dest
.ssa
;
125 case SpvOpGroupNonUniformBallotBitExtract
:
126 case SpvOpGroupNonUniformBallotBitCount
:
127 case SpvOpGroupNonUniformBallotFindLSB
:
128 case SpvOpGroupNonUniformBallotFindMSB
: {
129 nir_ssa_def
*src0
, *src1
= NULL
;
132 case SpvOpGroupNonUniformBallotBitExtract
:
133 op
= nir_intrinsic_ballot_bitfield_extract
;
134 src0
= vtn_ssa_value(b
, w
[4])->def
;
135 src1
= vtn_ssa_value(b
, w
[5])->def
;
137 case SpvOpGroupNonUniformBallotBitCount
:
138 switch ((SpvGroupOperation
)w
[4]) {
139 case SpvGroupOperationReduce
:
140 op
= nir_intrinsic_ballot_bit_count_reduce
;
142 case SpvGroupOperationInclusiveScan
:
143 op
= nir_intrinsic_ballot_bit_count_inclusive
;
145 case SpvGroupOperationExclusiveScan
:
146 op
= nir_intrinsic_ballot_bit_count_exclusive
;
149 unreachable("Invalid group operation");
151 src0
= vtn_ssa_value(b
, w
[5])->def
;
153 case SpvOpGroupNonUniformBallotFindLSB
:
154 op
= nir_intrinsic_ballot_find_lsb
;
155 src0
= vtn_ssa_value(b
, w
[4])->def
;
157 case SpvOpGroupNonUniformBallotFindMSB
:
158 op
= nir_intrinsic_ballot_find_msb
;
159 src0
= vtn_ssa_value(b
, w
[4])->def
;
162 unreachable("Unhandled opcode");
165 nir_intrinsic_instr
*intrin
=
166 nir_intrinsic_instr_create(b
->nb
.shader
, op
);
168 intrin
->src
[0] = nir_src_for_ssa(src0
);
170 intrin
->src
[1] = nir_src_for_ssa(src1
);
172 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
173 val
->type
->type
, NULL
);
174 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
176 val
->ssa
->def
= &intrin
->dest
.ssa
;
180 case SpvOpGroupNonUniformBroadcastFirst
: ++w
; /* fallthrough */
181 case SpvOpSubgroupFirstInvocationKHR
:
182 vtn_build_subgroup_instr(b
, nir_intrinsic_read_first_invocation
,
183 val
->ssa
, vtn_ssa_value(b
, w
[3]), NULL
, 0, 0);
186 case SpvOpGroupNonUniformBroadcast
:
187 case SpvOpGroupBroadcast
: ++w
; /* fallthrough */
188 case SpvOpSubgroupReadInvocationKHR
:
189 vtn_build_subgroup_instr(b
, nir_intrinsic_read_invocation
,
190 val
->ssa
, vtn_ssa_value(b
, w
[3]),
191 vtn_ssa_value(b
, w
[4])->def
, 0, 0);
194 case SpvOpGroupNonUniformAll
:
195 case SpvOpGroupNonUniformAny
:
196 case SpvOpGroupNonUniformAllEqual
:
199 case SpvOpSubgroupAllKHR
:
200 case SpvOpSubgroupAnyKHR
:
201 case SpvOpSubgroupAllEqualKHR
: {
202 vtn_fail_if(val
->type
->type
!= glsl_bool_type(),
203 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
206 case SpvOpGroupNonUniformAll
:
208 case SpvOpSubgroupAllKHR
:
209 op
= nir_intrinsic_vote_all
;
211 case SpvOpGroupNonUniformAny
:
213 case SpvOpSubgroupAnyKHR
:
214 op
= nir_intrinsic_vote_any
;
216 case SpvOpSubgroupAllEqualKHR
:
217 op
= nir_intrinsic_vote_ieq
;
219 case SpvOpGroupNonUniformAllEqual
:
220 switch (glsl_get_base_type(vtn_ssa_value(b
, w
[4])->type
)) {
221 case GLSL_TYPE_FLOAT
:
222 case GLSL_TYPE_FLOAT16
:
223 case GLSL_TYPE_DOUBLE
:
224 op
= nir_intrinsic_vote_feq
;
228 case GLSL_TYPE_UINT8
:
230 case GLSL_TYPE_UINT16
:
231 case GLSL_TYPE_INT16
:
232 case GLSL_TYPE_UINT64
:
233 case GLSL_TYPE_INT64
:
235 op
= nir_intrinsic_vote_ieq
;
238 unreachable("Unhandled type");
242 unreachable("Unhandled opcode");
246 if (opcode
== SpvOpGroupNonUniformAll
|| opcode
== SpvOpGroupAll
||
247 opcode
== SpvOpGroupNonUniformAny
|| opcode
== SpvOpGroupAny
||
248 opcode
== SpvOpGroupNonUniformAllEqual
) {
249 src0
= vtn_ssa_value(b
, w
[4])->def
;
251 src0
= vtn_ssa_value(b
, w
[3])->def
;
253 nir_intrinsic_instr
*intrin
=
254 nir_intrinsic_instr_create(b
->nb
.shader
, op
);
255 if (nir_intrinsic_infos
[op
].src_components
[0] == 0)
256 intrin
->num_components
= src0
->num_components
;
257 intrin
->src
[0] = nir_src_for_ssa(src0
);
258 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
259 val
->type
->type
, NULL
);
260 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
262 val
->ssa
->def
= &intrin
->dest
.ssa
;
266 case SpvOpGroupNonUniformShuffle
:
267 case SpvOpGroupNonUniformShuffleXor
:
268 case SpvOpGroupNonUniformShuffleUp
:
269 case SpvOpGroupNonUniformShuffleDown
: {
272 case SpvOpGroupNonUniformShuffle
:
273 op
= nir_intrinsic_shuffle
;
275 case SpvOpGroupNonUniformShuffleXor
:
276 op
= nir_intrinsic_shuffle_xor
;
278 case SpvOpGroupNonUniformShuffleUp
:
279 op
= nir_intrinsic_shuffle_up
;
281 case SpvOpGroupNonUniformShuffleDown
:
282 op
= nir_intrinsic_shuffle_down
;
285 unreachable("Invalid opcode");
287 vtn_build_subgroup_instr(b
, op
, val
->ssa
, vtn_ssa_value(b
, w
[4]),
288 vtn_ssa_value(b
, w
[5])->def
, 0, 0);
292 case SpvOpGroupNonUniformQuadBroadcast
:
293 vtn_build_subgroup_instr(b
, nir_intrinsic_quad_broadcast
,
294 val
->ssa
, vtn_ssa_value(b
, w
[4]),
295 vtn_ssa_value(b
, w
[5])->def
, 0, 0);
298 case SpvOpGroupNonUniformQuadSwap
: {
299 unsigned direction
= vtn_constant_uint(b
, w
[5]);
303 op
= nir_intrinsic_quad_swap_horizontal
;
306 op
= nir_intrinsic_quad_swap_vertical
;
309 op
= nir_intrinsic_quad_swap_diagonal
;
312 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
314 vtn_build_subgroup_instr(b
, op
, val
->ssa
, vtn_ssa_value(b
, w
[4]),
319 case SpvOpGroupNonUniformIAdd
:
320 case SpvOpGroupNonUniformFAdd
:
321 case SpvOpGroupNonUniformIMul
:
322 case SpvOpGroupNonUniformFMul
:
323 case SpvOpGroupNonUniformSMin
:
324 case SpvOpGroupNonUniformUMin
:
325 case SpvOpGroupNonUniformFMin
:
326 case SpvOpGroupNonUniformSMax
:
327 case SpvOpGroupNonUniformUMax
:
328 case SpvOpGroupNonUniformFMax
:
329 case SpvOpGroupNonUniformBitwiseAnd
:
330 case SpvOpGroupNonUniformBitwiseOr
:
331 case SpvOpGroupNonUniformBitwiseXor
:
332 case SpvOpGroupNonUniformLogicalAnd
:
333 case SpvOpGroupNonUniformLogicalOr
:
334 case SpvOpGroupNonUniformLogicalXor
:
343 case SpvOpGroupIAddNonUniformAMD
:
344 case SpvOpGroupFAddNonUniformAMD
:
345 case SpvOpGroupFMinNonUniformAMD
:
346 case SpvOpGroupUMinNonUniformAMD
:
347 case SpvOpGroupSMinNonUniformAMD
:
348 case SpvOpGroupFMaxNonUniformAMD
:
349 case SpvOpGroupUMaxNonUniformAMD
:
350 case SpvOpGroupSMaxNonUniformAMD
: {
353 case SpvOpGroupNonUniformIAdd
:
355 case SpvOpGroupIAddNonUniformAMD
:
356 reduction_op
= nir_op_iadd
;
358 case SpvOpGroupNonUniformFAdd
:
360 case SpvOpGroupFAddNonUniformAMD
:
361 reduction_op
= nir_op_fadd
;
363 case SpvOpGroupNonUniformIMul
:
364 reduction_op
= nir_op_imul
;
366 case SpvOpGroupNonUniformFMul
:
367 reduction_op
= nir_op_fmul
;
369 case SpvOpGroupNonUniformSMin
:
371 case SpvOpGroupSMinNonUniformAMD
:
372 reduction_op
= nir_op_imin
;
374 case SpvOpGroupNonUniformUMin
:
376 case SpvOpGroupUMinNonUniformAMD
:
377 reduction_op
= nir_op_umin
;
379 case SpvOpGroupNonUniformFMin
:
381 case SpvOpGroupFMinNonUniformAMD
:
382 reduction_op
= nir_op_fmin
;
384 case SpvOpGroupNonUniformSMax
:
386 case SpvOpGroupSMaxNonUniformAMD
:
387 reduction_op
= nir_op_imax
;
389 case SpvOpGroupNonUniformUMax
:
391 case SpvOpGroupUMaxNonUniformAMD
:
392 reduction_op
= nir_op_umax
;
394 case SpvOpGroupNonUniformFMax
:
396 case SpvOpGroupFMaxNonUniformAMD
:
397 reduction_op
= nir_op_fmax
;
399 case SpvOpGroupNonUniformBitwiseAnd
:
400 case SpvOpGroupNonUniformLogicalAnd
:
401 reduction_op
= nir_op_iand
;
403 case SpvOpGroupNonUniformBitwiseOr
:
404 case SpvOpGroupNonUniformLogicalOr
:
405 reduction_op
= nir_op_ior
;
407 case SpvOpGroupNonUniformBitwiseXor
:
408 case SpvOpGroupNonUniformLogicalXor
:
409 reduction_op
= nir_op_ixor
;
412 unreachable("Invalid reduction operation");
416 unsigned cluster_size
= 0;
417 switch ((SpvGroupOperation
)w
[4]) {
418 case SpvGroupOperationReduce
:
419 op
= nir_intrinsic_reduce
;
421 case SpvGroupOperationInclusiveScan
:
422 op
= nir_intrinsic_inclusive_scan
;
424 case SpvGroupOperationExclusiveScan
:
425 op
= nir_intrinsic_exclusive_scan
;
427 case SpvGroupOperationClusteredReduce
:
428 op
= nir_intrinsic_reduce
;
430 cluster_size
= vtn_constant_uint(b
, w
[6]);
433 unreachable("Invalid group operation");
436 vtn_build_subgroup_instr(b
, op
, val
->ssa
, vtn_ssa_value(b
, w
[5]),
437 NULL
, reduction_op
, cluster_size
);
442 unreachable("Invalid SPIR-V opcode");