2 * Copyright © 2016 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 #include "vtn_private.h"
27 vtn_build_subgroup_instr(struct vtn_builder
*b
,
28 nir_intrinsic_op nir_op
,
29 struct vtn_ssa_value
*dst
,
30 struct vtn_ssa_value
*src0
,
33 /* Some of the subgroup operations take an index. SPIR-V allows this to be
34 * any integer type. To make things simpler for drivers, we only support
37 if (index
&& index
->bit_size
!= 32)
38 index
= nir_u2u32(&b
->nb
, index
);
40 vtn_assert(dst
->type
== src0
->type
);
41 if (!glsl_type_is_vector_or_scalar(dst
->type
)) {
42 for (unsigned i
= 0; i
< glsl_get_length(dst
->type
); i
++) {
43 vtn_build_subgroup_instr(b
, nir_op
, dst
->elems
[i
],
44 src0
->elems
[i
], index
);
49 nir_intrinsic_instr
*intrin
=
50 nir_intrinsic_instr_create(b
->nb
.shader
, nir_op
);
51 nir_ssa_dest_init_for_type(&intrin
->instr
, &intrin
->dest
,
53 intrin
->num_components
= intrin
->dest
.ssa
.num_components
;
55 intrin
->src
[0] = nir_src_for_ssa(src0
->def
);
57 intrin
->src
[1] = nir_src_for_ssa(index
);
59 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
61 dst
->def
= &intrin
->dest
.ssa
;
65 vtn_handle_subgroup(struct vtn_builder
*b
, SpvOp opcode
,
66 const uint32_t *w
, unsigned count
)
68 struct vtn_value
*val
= vtn_push_value(b
, w
[2], vtn_value_type_ssa
);
70 val
->ssa
= vtn_create_ssa_value(b
, val
->type
->type
);
73 case SpvOpGroupNonUniformElect
: {
74 vtn_fail_if(val
->type
->type
!= glsl_bool_type(),
75 "OpGroupNonUniformElect must return a Bool");
76 nir_intrinsic_instr
*elect
=
77 nir_intrinsic_instr_create(b
->nb
.shader
, nir_intrinsic_elect
);
78 nir_ssa_dest_init(&elect
->instr
, &elect
->dest
, 1, 32, NULL
);
79 nir_builder_instr_insert(&b
->nb
, &elect
->instr
);
80 val
->ssa
->def
= &elect
->dest
.ssa
;
84 case SpvOpGroupNonUniformBallot
: {
85 vtn_fail_if(val
->type
->type
!= glsl_vector_type(GLSL_TYPE_UINT
, 4),
86 "OpGroupNonUniformBallot must return a uvec4");
87 nir_intrinsic_instr
*ballot
=
88 nir_intrinsic_instr_create(b
->nb
.shader
, nir_intrinsic_ballot
);
89 ballot
->src
[0] = nir_src_for_ssa(vtn_ssa_value(b
, w
[4])->def
);
90 nir_ssa_dest_init(&ballot
->instr
, &ballot
->dest
, 4, 32, NULL
);
91 ballot
->num_components
= 4;
92 nir_builder_instr_insert(&b
->nb
, &ballot
->instr
);
93 val
->ssa
->def
= &ballot
->dest
.ssa
;
97 case SpvOpGroupNonUniformInverseBallot
: {
98 /* This one is just a BallotBitfieldExtract with subgroup invocation.
99 * We could add a NIR intrinsic but it's easier to just lower it on the
102 nir_intrinsic_instr
*intrin
=
103 nir_intrinsic_instr_create(b
->nb
.shader
,
104 nir_intrinsic_ballot_bitfield_extract
);
106 intrin
->src
[0] = nir_src_for_ssa(vtn_ssa_value(b
, w
[4])->def
);
107 intrin
->src
[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b
->nb
));
109 nir_ssa_dest_init(&intrin
->instr
, &intrin
->dest
, 1, 32, NULL
);
110 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
112 val
->ssa
->def
= &intrin
->dest
.ssa
;
116 case SpvOpGroupNonUniformBallotBitExtract
:
117 case SpvOpGroupNonUniformBallotBitCount
:
118 case SpvOpGroupNonUniformBallotFindLSB
:
119 case SpvOpGroupNonUniformBallotFindMSB
: {
120 nir_ssa_def
*src0
, *src1
= NULL
;
123 case SpvOpGroupNonUniformBallotBitExtract
:
124 op
= nir_intrinsic_ballot_bitfield_extract
;
125 src0
= vtn_ssa_value(b
, w
[4])->def
;
126 src1
= vtn_ssa_value(b
, w
[5])->def
;
128 case SpvOpGroupNonUniformBallotBitCount
:
129 switch ((SpvGroupOperation
)w
[4]) {
130 case SpvGroupOperationReduce
:
131 op
= nir_intrinsic_ballot_bit_count_reduce
;
133 case SpvGroupOperationInclusiveScan
:
134 op
= nir_intrinsic_ballot_bit_count_inclusive
;
136 case SpvGroupOperationExclusiveScan
:
137 op
= nir_intrinsic_ballot_bit_count_exclusive
;
140 unreachable("Invalid group operation");
142 src0
= vtn_ssa_value(b
, w
[5])->def
;
144 case SpvOpGroupNonUniformBallotFindLSB
:
145 op
= nir_intrinsic_ballot_find_lsb
;
146 src0
= vtn_ssa_value(b
, w
[4])->def
;
148 case SpvOpGroupNonUniformBallotFindMSB
:
149 op
= nir_intrinsic_ballot_find_msb
;
150 src0
= vtn_ssa_value(b
, w
[4])->def
;
153 unreachable("Unhandled opcode");
156 nir_intrinsic_instr
*intrin
=
157 nir_intrinsic_instr_create(b
->nb
.shader
, op
);
159 intrin
->src
[0] = nir_src_for_ssa(src0
);
161 intrin
->src
[1] = nir_src_for_ssa(src1
);
163 nir_ssa_dest_init(&intrin
->instr
, &intrin
->dest
, 1, 32, NULL
);
164 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
166 val
->ssa
->def
= &intrin
->dest
.ssa
;
170 case SpvOpGroupNonUniformBroadcastFirst
:
171 vtn_build_subgroup_instr(b
, nir_intrinsic_read_first_invocation
,
172 val
->ssa
, vtn_ssa_value(b
, w
[4]), NULL
);
175 case SpvOpGroupNonUniformBroadcast
:
176 vtn_build_subgroup_instr(b
, nir_intrinsic_read_invocation
,
177 val
->ssa
, vtn_ssa_value(b
, w
[4]),
178 vtn_ssa_value(b
, w
[5])->def
);
181 case SpvOpGroupNonUniformAll
:
182 case SpvOpGroupNonUniformAny
:
183 case SpvOpGroupNonUniformAllEqual
: {
184 vtn_fail_if(val
->type
->type
!= glsl_bool_type(),
185 "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
188 case SpvOpGroupNonUniformAll
:
189 op
= nir_intrinsic_vote_all
;
191 case SpvOpGroupNonUniformAny
:
192 op
= nir_intrinsic_vote_any
;
194 case SpvOpGroupNonUniformAllEqual
: {
195 switch (glsl_get_base_type(val
->type
->type
)) {
196 case GLSL_TYPE_FLOAT
:
197 case GLSL_TYPE_DOUBLE
:
198 op
= nir_intrinsic_vote_feq
;
202 case GLSL_TYPE_UINT64
:
203 case GLSL_TYPE_INT64
:
205 op
= nir_intrinsic_vote_ieq
;
208 unreachable("Unhandled type");
213 unreachable("Unhandled opcode");
216 nir_ssa_def
*src0
= vtn_ssa_value(b
, w
[4])->def
;
218 nir_intrinsic_instr
*intrin
=
219 nir_intrinsic_instr_create(b
->nb
.shader
, op
);
220 intrin
->num_components
= src0
->num_components
;
221 intrin
->src
[0] = nir_src_for_ssa(src0
);
222 nir_ssa_dest_init(&intrin
->instr
, &intrin
->dest
, 1, 32, NULL
);
223 nir_builder_instr_insert(&b
->nb
, &intrin
->instr
);
225 val
->ssa
->def
= &intrin
->dest
.ssa
;
229 case SpvOpGroupNonUniformShuffle
:
230 case SpvOpGroupNonUniformShuffleXor
:
231 case SpvOpGroupNonUniformShuffleUp
:
232 case SpvOpGroupNonUniformShuffleDown
: {
235 case SpvOpGroupNonUniformShuffle
:
236 op
= nir_intrinsic_shuffle
;
238 case SpvOpGroupNonUniformShuffleXor
:
239 op
= nir_intrinsic_shuffle_xor
;
241 case SpvOpGroupNonUniformShuffleUp
:
242 op
= nir_intrinsic_shuffle_up
;
244 case SpvOpGroupNonUniformShuffleDown
:
245 op
= nir_intrinsic_shuffle_down
;
248 unreachable("Invalid opcode");
250 vtn_build_subgroup_instr(b
, op
, val
->ssa
, vtn_ssa_value(b
, w
[4]),
251 vtn_ssa_value(b
, w
[5])->def
);
255 case SpvOpGroupNonUniformIAdd
:
256 case SpvOpGroupNonUniformFAdd
:
257 case SpvOpGroupNonUniformIMul
:
258 case SpvOpGroupNonUniformFMul
:
259 case SpvOpGroupNonUniformSMin
:
260 case SpvOpGroupNonUniformUMin
:
261 case SpvOpGroupNonUniformFMin
:
262 case SpvOpGroupNonUniformSMax
:
263 case SpvOpGroupNonUniformUMax
:
264 case SpvOpGroupNonUniformFMax
:
265 case SpvOpGroupNonUniformBitwiseAnd
:
266 case SpvOpGroupNonUniformBitwiseOr
:
267 case SpvOpGroupNonUniformBitwiseXor
:
268 case SpvOpGroupNonUniformLogicalAnd
:
269 case SpvOpGroupNonUniformLogicalOr
:
270 case SpvOpGroupNonUniformLogicalXor
:
271 case SpvOpGroupNonUniformQuadBroadcast
:
272 case SpvOpGroupNonUniformQuadSwap
:
274 unreachable("Invalid SPIR-V opcode");