2 * Copyright © 2016 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 #include "vtn_private.h"
26 static struct vtn_ssa_value
*
27 vtn_build_subgroup_instr(struct vtn_builder
*b
,
28 nir_intrinsic_op nir_op
,
29 struct vtn_ssa_value
*src0
,
34 /* Some of the subgroup operations take an index. SPIR-V allows this to be
35 * any integer type. To make things simpler for drivers, we only support
38 if (index
&& index
->bit_size
!= 32)
39 index
= nir_u2u32(&b
->nb
, index
);
41 struct vtn_ssa_value
*dst
= vtn_create_ssa_value(b
, src0
->type
);
43 vtn_assert(dst
->type
== src0
->type
);
44 if (!glsl_type_is_vector_or_scalar(dst
->type
)) {
45 for (unsigned i
= 0; i
< glsl_get_length(dst
->type
); i
++) {
47 vtn_build_subgroup_instr(b
, nir_op
, src0
->elems
[i
], index
,
48 const_idx0
, const_idx1
);
53 nir_intrinsic_instr
*intrin
=
54 nir_intrinsic_instr_create(b
->nb
.shader
, nir_op
);
55 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
57 intrin
->num_components
= intrin
->dest
.ssa
.num_components
;
59 intrin
->src
[0] = nir_src_for_ssa(src0
->def
);
61 intrin
->src
[1] = nir_src_for_ssa(index
);
63 intrin
->const_index
[0] = const_idx0
;
64 intrin
->const_index
[1] = const_idx1
;
66 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
68 dst
->def
= &intrin
->dest
.ssa
;
74 vtn_handle_subgroup(struct vtn_builder
*b
, SpvOp opcode
,
75 const uint32_t *w
, unsigned count
)
77 struct vtn_type
*dest_type
= vtn_get_type(b
, w
[1]);
80 case SpvOpGroupNonUniformElect
: {
81 vtn_fail_if(dest_type
->type
!= glsl_bool_type(),
82 "OpGroupNonUniformElect must return a Bool");
83 nir_intrinsic_instr
*elect
=
84 nir_intrinsic_instr_create(b
->nb
.shader
, nir_intrinsic_elect
);
85 nir_ssa_dest_init_for_type(&elect
->instr
, &elect
->dest
,
86 dest_type
->type
, NULL
);
87 nir_builder_instr_insert(&b
->nb
, &elect
->instr
);
88 vtn_push_nir_ssa(b
, w
[2], &elect
->dest
.ssa
);
92 case SpvOpGroupNonUniformBallot
:
93 case SpvOpSubgroupBallotKHR
: {
94 bool has_scope
= (opcode
!= SpvOpSubgroupBallotKHR
);
95 vtn_fail_if(dest_type
->type
!= glsl_vector_type(GLSL_TYPE_UINT
, 4),
96 "OpGroupNonUniformBallot must return a uvec4");
97 nir_intrinsic_instr
*ballot
=
98 nir_intrinsic_instr_create(b
->nb
.shader
, nir_intrinsic_ballot
);
99 ballot
->src
[0] = nir_src_for_ssa(vtn_get_nir_ssa(b
, w
[3 + has_scope
]));
100 nir_ssa_dest_init(&ballot
->instr
, &ballot
->dest
, 4, 32, NULL
);
101 ballot
->num_components
= 4;
102 nir_builder_instr_insert(&b
->nb
, &ballot
->instr
);
103 vtn_push_nir_ssa(b
, w
[2], &ballot
->dest
.ssa
);
107 case SpvOpGroupNonUniformInverseBallot
: {
108 /* This one is just a BallotBitfieldExtract with subgroup invocation.
109 * We could add a NIR intrinsic but it's easier to just lower it on the
112 nir_intrinsic_instr
*intrin
=
113 nir_intrinsic_instr_create(b
->nb
.shader
,
114 nir_intrinsic_ballot_bitfield_extract
);
116 intrin
->src
[0] = nir_src_for_ssa(vtn_get_nir_ssa(b
, w
[4]));
117 intrin
->src
[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b
->nb
));
119 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
120 dest_type
->type
, NULL
);
121 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
123 vtn_push_nir_ssa(b
, w
[2], &intrin
->dest
.ssa
);
127 case SpvOpGroupNonUniformBallotBitExtract
:
128 case SpvOpGroupNonUniformBallotBitCount
:
129 case SpvOpGroupNonUniformBallotFindLSB
:
130 case SpvOpGroupNonUniformBallotFindMSB
: {
131 nir_ssa_def
*src0
, *src1
= NULL
;
134 case SpvOpGroupNonUniformBallotBitExtract
:
135 op
= nir_intrinsic_ballot_bitfield_extract
;
136 src0
= vtn_get_nir_ssa(b
, w
[4]);
137 src1
= vtn_get_nir_ssa(b
, w
[5]);
139 case SpvOpGroupNonUniformBallotBitCount
:
140 switch ((SpvGroupOperation
)w
[4]) {
141 case SpvGroupOperationReduce
:
142 op
= nir_intrinsic_ballot_bit_count_reduce
;
144 case SpvGroupOperationInclusiveScan
:
145 op
= nir_intrinsic_ballot_bit_count_inclusive
;
147 case SpvGroupOperationExclusiveScan
:
148 op
= nir_intrinsic_ballot_bit_count_exclusive
;
151 unreachable("Invalid group operation");
153 src0
= vtn_get_nir_ssa(b
, w
[5]);
155 case SpvOpGroupNonUniformBallotFindLSB
:
156 op
= nir_intrinsic_ballot_find_lsb
;
157 src0
= vtn_get_nir_ssa(b
, w
[4]);
159 case SpvOpGroupNonUniformBallotFindMSB
:
160 op
= nir_intrinsic_ballot_find_msb
;
161 src0
= vtn_get_nir_ssa(b
, w
[4]);
164 unreachable("Unhandled opcode");
167 nir_intrinsic_instr
*intrin
=
168 nir_intrinsic_instr_create(b
->nb
.shader
, op
);
170 intrin
->src
[0] = nir_src_for_ssa(src0
);
172 intrin
->src
[1] = nir_src_for_ssa(src1
);
174 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
175 dest_type
->type
, NULL
);
176 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
178 vtn_push_nir_ssa(b
, w
[2], &intrin
->dest
.ssa
);
182 case SpvOpGroupNonUniformBroadcastFirst
:
183 case SpvOpSubgroupFirstInvocationKHR
: {
184 bool has_scope
= (opcode
!= SpvOpSubgroupFirstInvocationKHR
);
185 vtn_push_ssa_value(b
, w
[2],
186 vtn_build_subgroup_instr(b
, nir_intrinsic_read_first_invocation
,
187 vtn_ssa_value(b
, w
[3 + has_scope
]),
192 case SpvOpGroupNonUniformBroadcast
:
193 case SpvOpGroupBroadcast
:
194 case SpvOpSubgroupReadInvocationKHR
: {
195 bool has_scope
= (opcode
!= SpvOpSubgroupReadInvocationKHR
);
196 vtn_push_ssa_value(b
, w
[2],
197 vtn_build_subgroup_instr(b
, nir_intrinsic_read_invocation
,
198 vtn_ssa_value(b
, w
[3 + has_scope
]),
199 vtn_get_nir_ssa(b
, w
[4 + has_scope
]), 0, 0));
203 case SpvOpGroupNonUniformAll
:
204 case SpvOpGroupNonUniformAny
:
205 case SpvOpGroupNonUniformAllEqual
:
208 case SpvOpSubgroupAllKHR
:
209 case SpvOpSubgroupAnyKHR
:
210 case SpvOpSubgroupAllEqualKHR
: {
211 vtn_fail_if(dest_type
->type
!= glsl_bool_type(),
212 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
215 case SpvOpGroupNonUniformAll
:
217 case SpvOpSubgroupAllKHR
:
218 op
= nir_intrinsic_vote_all
;
220 case SpvOpGroupNonUniformAny
:
222 case SpvOpSubgroupAnyKHR
:
223 op
= nir_intrinsic_vote_any
;
225 case SpvOpSubgroupAllEqualKHR
:
226 op
= nir_intrinsic_vote_ieq
;
228 case SpvOpGroupNonUniformAllEqual
:
229 switch (glsl_get_base_type(vtn_ssa_value(b
, w
[4])->type
)) {
230 case GLSL_TYPE_FLOAT
:
231 case GLSL_TYPE_FLOAT16
:
232 case GLSL_TYPE_DOUBLE
:
233 op
= nir_intrinsic_vote_feq
;
237 case GLSL_TYPE_UINT8
:
239 case GLSL_TYPE_UINT16
:
240 case GLSL_TYPE_INT16
:
241 case GLSL_TYPE_UINT64
:
242 case GLSL_TYPE_INT64
:
244 op
= nir_intrinsic_vote_ieq
;
247 unreachable("Unhandled type");
251 unreachable("Unhandled opcode");
255 if (opcode
== SpvOpGroupNonUniformAll
|| opcode
== SpvOpGroupAll
||
256 opcode
== SpvOpGroupNonUniformAny
|| opcode
== SpvOpGroupAny
||
257 opcode
== SpvOpGroupNonUniformAllEqual
) {
258 src0
= vtn_get_nir_ssa(b
, w
[4]);
260 src0
= vtn_get_nir_ssa(b
, w
[3]);
262 nir_intrinsic_instr
*intrin
=
263 nir_intrinsic_instr_create(b
->nb
.shader
, op
);
264 if (nir_intrinsic_infos
[op
].src_components
[0] == 0)
265 intrin
->num_components
= src0
->num_components
;
266 intrin
->src
[0] = nir_src_for_ssa(src0
);
267 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
268 dest_type
->type
, NULL
);
269 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
271 vtn_push_nir_ssa(b
, w
[2], &intrin
->dest
.ssa
);
275 case SpvOpGroupNonUniformShuffle
:
276 case SpvOpGroupNonUniformShuffleXor
:
277 case SpvOpGroupNonUniformShuffleUp
:
278 case SpvOpGroupNonUniformShuffleDown
: {
281 case SpvOpGroupNonUniformShuffle
:
282 op
= nir_intrinsic_shuffle
;
284 case SpvOpGroupNonUniformShuffleXor
:
285 op
= nir_intrinsic_shuffle_xor
;
287 case SpvOpGroupNonUniformShuffleUp
:
288 op
= nir_intrinsic_shuffle_up
;
290 case SpvOpGroupNonUniformShuffleDown
:
291 op
= nir_intrinsic_shuffle_down
;
294 unreachable("Invalid opcode");
296 vtn_push_ssa_value(b
, w
[2],
297 vtn_build_subgroup_instr(b
, op
, vtn_ssa_value(b
, w
[4]),
298 vtn_get_nir_ssa(b
, w
[5]), 0, 0));
302 case SpvOpGroupNonUniformQuadBroadcast
:
303 vtn_push_ssa_value(b
, w
[2],
304 vtn_build_subgroup_instr(b
, nir_intrinsic_quad_broadcast
,
305 vtn_ssa_value(b
, w
[4]),
306 vtn_get_nir_ssa(b
, w
[5]), 0, 0));
309 case SpvOpGroupNonUniformQuadSwap
: {
310 unsigned direction
= vtn_constant_uint(b
, w
[5]);
314 op
= nir_intrinsic_quad_swap_horizontal
;
317 op
= nir_intrinsic_quad_swap_vertical
;
320 op
= nir_intrinsic_quad_swap_diagonal
;
323 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
325 vtn_push_ssa_value(b
, w
[2],
326 vtn_build_subgroup_instr(b
, op
, vtn_ssa_value(b
, w
[4]), NULL
, 0, 0));
330 case SpvOpGroupNonUniformIAdd
:
331 case SpvOpGroupNonUniformFAdd
:
332 case SpvOpGroupNonUniformIMul
:
333 case SpvOpGroupNonUniformFMul
:
334 case SpvOpGroupNonUniformSMin
:
335 case SpvOpGroupNonUniformUMin
:
336 case SpvOpGroupNonUniformFMin
:
337 case SpvOpGroupNonUniformSMax
:
338 case SpvOpGroupNonUniformUMax
:
339 case SpvOpGroupNonUniformFMax
:
340 case SpvOpGroupNonUniformBitwiseAnd
:
341 case SpvOpGroupNonUniformBitwiseOr
:
342 case SpvOpGroupNonUniformBitwiseXor
:
343 case SpvOpGroupNonUniformLogicalAnd
:
344 case SpvOpGroupNonUniformLogicalOr
:
345 case SpvOpGroupNonUniformLogicalXor
:
354 case SpvOpGroupIAddNonUniformAMD
:
355 case SpvOpGroupFAddNonUniformAMD
:
356 case SpvOpGroupFMinNonUniformAMD
:
357 case SpvOpGroupUMinNonUniformAMD
:
358 case SpvOpGroupSMinNonUniformAMD
:
359 case SpvOpGroupFMaxNonUniformAMD
:
360 case SpvOpGroupUMaxNonUniformAMD
:
361 case SpvOpGroupSMaxNonUniformAMD
: {
364 case SpvOpGroupNonUniformIAdd
:
366 case SpvOpGroupIAddNonUniformAMD
:
367 reduction_op
= nir_op_iadd
;
369 case SpvOpGroupNonUniformFAdd
:
371 case SpvOpGroupFAddNonUniformAMD
:
372 reduction_op
= nir_op_fadd
;
374 case SpvOpGroupNonUniformIMul
:
375 reduction_op
= nir_op_imul
;
377 case SpvOpGroupNonUniformFMul
:
378 reduction_op
= nir_op_fmul
;
380 case SpvOpGroupNonUniformSMin
:
382 case SpvOpGroupSMinNonUniformAMD
:
383 reduction_op
= nir_op_imin
;
385 case SpvOpGroupNonUniformUMin
:
387 case SpvOpGroupUMinNonUniformAMD
:
388 reduction_op
= nir_op_umin
;
390 case SpvOpGroupNonUniformFMin
:
392 case SpvOpGroupFMinNonUniformAMD
:
393 reduction_op
= nir_op_fmin
;
395 case SpvOpGroupNonUniformSMax
:
397 case SpvOpGroupSMaxNonUniformAMD
:
398 reduction_op
= nir_op_imax
;
400 case SpvOpGroupNonUniformUMax
:
402 case SpvOpGroupUMaxNonUniformAMD
:
403 reduction_op
= nir_op_umax
;
405 case SpvOpGroupNonUniformFMax
:
407 case SpvOpGroupFMaxNonUniformAMD
:
408 reduction_op
= nir_op_fmax
;
410 case SpvOpGroupNonUniformBitwiseAnd
:
411 case SpvOpGroupNonUniformLogicalAnd
:
412 reduction_op
= nir_op_iand
;
414 case SpvOpGroupNonUniformBitwiseOr
:
415 case SpvOpGroupNonUniformLogicalOr
:
416 reduction_op
= nir_op_ior
;
418 case SpvOpGroupNonUniformBitwiseXor
:
419 case SpvOpGroupNonUniformLogicalXor
:
420 reduction_op
= nir_op_ixor
;
423 unreachable("Invalid reduction operation");
427 unsigned cluster_size
= 0;
428 switch ((SpvGroupOperation
)w
[4]) {
429 case SpvGroupOperationReduce
:
430 op
= nir_intrinsic_reduce
;
432 case SpvGroupOperationInclusiveScan
:
433 op
= nir_intrinsic_inclusive_scan
;
435 case SpvGroupOperationExclusiveScan
:
436 op
= nir_intrinsic_exclusive_scan
;
438 case SpvGroupOperationClusteredReduce
:
439 op
= nir_intrinsic_reduce
;
441 cluster_size
= vtn_constant_uint(b
, w
[6]);
444 unreachable("Invalid group operation");
447 vtn_push_ssa_value(b
, w
[2],
448 vtn_build_subgroup_instr(b
, op
, vtn_ssa_value(b
, w
[5]), NULL
,
449 reduction_op
, cluster_size
));
454 unreachable("Invalid SPIR-V opcode");