2 * Copyright © 2016 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25 #include "vtn_private.h"
28 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
29 * definition. But for matrix multiplies, we want to do one routine for
30 * multiplying a matrix by a matrix and then pretend that vectors are matrices
31 * with one column. So we "wrap" these things, and unwrap the result before we
35 static struct vtn_ssa_value
*
36 wrap_matrix(struct vtn_builder
*b
, struct vtn_ssa_value
*val
)
41 if (glsl_type_is_matrix(val
->type
))
44 struct vtn_ssa_value
*dest
= rzalloc(b
, struct vtn_ssa_value
);
45 dest
->type
= val
->type
;
46 dest
->elems
= ralloc_array(b
, struct vtn_ssa_value
*, 1);
52 static struct vtn_ssa_value
*
53 unwrap_matrix(struct vtn_ssa_value
*val
)
55 if (glsl_type_is_matrix(val
->type
))
61 static struct vtn_ssa_value
*
62 matrix_multiply(struct vtn_builder
*b
,
63 struct vtn_ssa_value
*_src0
, struct vtn_ssa_value
*_src1
)
66 struct vtn_ssa_value
*src0
= wrap_matrix(b
, _src0
);
67 struct vtn_ssa_value
*src1
= wrap_matrix(b
, _src1
);
68 struct vtn_ssa_value
*src0_transpose
= wrap_matrix(b
, _src0
->transposed
);
69 struct vtn_ssa_value
*src1_transpose
= wrap_matrix(b
, _src1
->transposed
);
71 unsigned src0_rows
= glsl_get_vector_elements(src0
->type
);
72 unsigned src0_columns
= glsl_get_matrix_columns(src0
->type
);
73 unsigned src1_columns
= glsl_get_matrix_columns(src1
->type
);
75 const struct glsl_type
*dest_type
;
76 if (src1_columns
> 1) {
77 dest_type
= glsl_matrix_type(glsl_get_base_type(src0
->type
),
78 src0_rows
, src1_columns
);
80 dest_type
= glsl_vector_type(glsl_get_base_type(src0
->type
), src0_rows
);
82 struct vtn_ssa_value
*dest
= vtn_create_ssa_value(b
, dest_type
);
84 dest
= wrap_matrix(b
, dest
);
86 bool transpose_result
= false;
87 if (src0_transpose
&& src1_transpose
) {
88 /* transpose(A) * transpose(B) = transpose(B * A) */
89 src1
= src0_transpose
;
90 src0
= src1_transpose
;
91 src0_transpose
= NULL
;
92 src1_transpose
= NULL
;
93 transpose_result
= true;
96 if (src0_transpose
&& !src1_transpose
&&
97 glsl_get_base_type(src0
->type
) == GLSL_TYPE_FLOAT
) {
98 /* We already have the rows of src0 and the columns of src1 available,
99 * so we can just take the dot product of each row with each column to
103 for (unsigned i
= 0; i
< src1_columns
; i
++) {
104 nir_ssa_def
*vec_src
[4];
105 for (unsigned j
= 0; j
< src0_rows
; j
++) {
106 vec_src
[j
] = nir_fdot(&b
->nb
, src0_transpose
->elems
[j
]->def
,
107 src1
->elems
[i
]->def
);
109 dest
->elems
[i
]->def
= nir_vec(&b
->nb
, vec_src
, src0_rows
);
112 /* We don't handle the case where src1 is transposed but not src0, since
113 * the general case only uses individual components of src1 so the
114 * optimizer should chew through the transpose we emitted for src1.
117 for (unsigned i
= 0; i
< src1_columns
; i
++) {
118 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
119 dest
->elems
[i
]->def
=
120 nir_fmul(&b
->nb
, src0
->elems
[0]->def
,
121 nir_channel(&b
->nb
, src1
->elems
[i
]->def
, 0));
122 for (unsigned j
= 1; j
< src0_columns
; j
++) {
123 dest
->elems
[i
]->def
=
124 nir_fadd(&b
->nb
, dest
->elems
[i
]->def
,
125 nir_fmul(&b
->nb
, src0
->elems
[j
]->def
,
126 nir_channel(&b
->nb
, src1
->elems
[i
]->def
, j
)));
131 dest
= unwrap_matrix(dest
);
133 if (transpose_result
)
134 dest
= vtn_ssa_transpose(b
, dest
);
139 static struct vtn_ssa_value
*
140 mat_times_scalar(struct vtn_builder
*b
,
141 struct vtn_ssa_value
*mat
,
144 struct vtn_ssa_value
*dest
= vtn_create_ssa_value(b
, mat
->type
);
145 for (unsigned i
= 0; i
< glsl_get_matrix_columns(mat
->type
); i
++) {
146 if (glsl_base_type_is_integer(glsl_get_base_type(mat
->type
)))
147 dest
->elems
[i
]->def
= nir_imul(&b
->nb
, mat
->elems
[i
]->def
, scalar
);
149 dest
->elems
[i
]->def
= nir_fmul(&b
->nb
, mat
->elems
[i
]->def
, scalar
);
156 vtn_handle_matrix_alu(struct vtn_builder
*b
, SpvOp opcode
,
157 struct vtn_value
*dest
,
158 struct vtn_ssa_value
*src0
, struct vtn_ssa_value
*src1
)
162 dest
->ssa
= vtn_create_ssa_value(b
, src0
->type
);
163 unsigned cols
= glsl_get_matrix_columns(src0
->type
);
164 for (unsigned i
= 0; i
< cols
; i
++)
165 dest
->ssa
->elems
[i
]->def
= nir_fneg(&b
->nb
, src0
->elems
[i
]->def
);
170 dest
->ssa
= vtn_create_ssa_value(b
, src0
->type
);
171 unsigned cols
= glsl_get_matrix_columns(src0
->type
);
172 for (unsigned i
= 0; i
< cols
; i
++)
173 dest
->ssa
->elems
[i
]->def
=
174 nir_fadd(&b
->nb
, src0
->elems
[i
]->def
, src1
->elems
[i
]->def
);
179 dest
->ssa
= vtn_create_ssa_value(b
, src0
->type
);
180 unsigned cols
= glsl_get_matrix_columns(src0
->type
);
181 for (unsigned i
= 0; i
< cols
; i
++)
182 dest
->ssa
->elems
[i
]->def
=
183 nir_fsub(&b
->nb
, src0
->elems
[i
]->def
, src1
->elems
[i
]->def
);
188 dest
->ssa
= vtn_ssa_transpose(b
, src0
);
191 case SpvOpMatrixTimesScalar
:
192 if (src0
->transposed
) {
193 dest
->ssa
= vtn_ssa_transpose(b
, mat_times_scalar(b
, src0
->transposed
,
196 dest
->ssa
= mat_times_scalar(b
, src0
, src1
->def
);
200 case SpvOpVectorTimesMatrix
:
201 case SpvOpMatrixTimesVector
:
202 case SpvOpMatrixTimesMatrix
:
203 if (opcode
== SpvOpVectorTimesMatrix
) {
204 dest
->ssa
= matrix_multiply(b
, vtn_ssa_transpose(b
, src1
), src0
);
206 dest
->ssa
= matrix_multiply(b
, src0
, src1
);
210 default: vtn_fail("unknown matrix opcode");
215 vtn_handle_bitcast(struct vtn_builder
*b
, struct vtn_ssa_value
*dest
,
216 struct nir_ssa_def
*src
)
218 if (glsl_get_vector_elements(dest
->type
) == src
->num_components
) {
219 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
221 * "If Result Type has the same number of components as Operand, they
222 * must also have the same component width, and results are computed per
225 dest
->def
= nir_imov(&b
->nb
, src
);
229 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
231 * "If Result Type has a different number of components than Operand, the
232 * total number of bits in Result Type must equal the total number of bits
233 * in Operand. Let L be the type, either Result Type or Operand’s type, that
234 * has the larger number of components. Let S be the other type, with the
235 * smaller number of components. The number of components in L must be an
236 * integer multiple of the number of components in S. The first component
237 * (that is, the only or lowest-numbered component) of S maps to the first
238 * components of L, and so on, up to the last component of S mapping to the
239 * last components of L. Within this mapping, any single component of S
240 * (mapping to multiple components of L) maps its lower-ordered bits to the
241 * lower-numbered components of L."
243 unsigned src_bit_size
= src
->bit_size
;
244 unsigned dest_bit_size
= glsl_get_bit_size(dest
->type
);
245 unsigned src_components
= src
->num_components
;
246 unsigned dest_components
= glsl_get_vector_elements(dest
->type
);
247 vtn_assert(src_bit_size
* src_components
== dest_bit_size
* dest_components
);
249 nir_ssa_def
*dest_chan
[NIR_MAX_VEC_COMPONENTS
];
250 if (src_bit_size
> dest_bit_size
) {
251 vtn_assert(src_bit_size
% dest_bit_size
== 0);
252 unsigned divisor
= src_bit_size
/ dest_bit_size
;
253 for (unsigned comp
= 0; comp
< src_components
; comp
++) {
255 if (src_bit_size
== 64) {
256 assert(dest_bit_size
== 32 || dest_bit_size
== 16);
257 split
= dest_bit_size
== 32 ?
258 nir_unpack_64_2x32(&b
->nb
, nir_channel(&b
->nb
, src
, comp
)) :
259 nir_unpack_64_4x16(&b
->nb
, nir_channel(&b
->nb
, src
, comp
));
261 vtn_assert(src_bit_size
== 32);
262 vtn_assert(dest_bit_size
== 16);
263 split
= nir_unpack_32_2x16(&b
->nb
, nir_channel(&b
->nb
, src
, comp
));
265 for (unsigned i
= 0; i
< divisor
; i
++)
266 dest_chan
[divisor
* comp
+ i
] = nir_channel(&b
->nb
, split
, i
);
269 vtn_assert(dest_bit_size
% src_bit_size
== 0);
270 unsigned divisor
= dest_bit_size
/ src_bit_size
;
271 for (unsigned comp
= 0; comp
< dest_components
; comp
++) {
272 unsigned channels
= ((1 << divisor
) - 1) << (comp
* divisor
);
273 nir_ssa_def
*src_chan
= nir_channels(&b
->nb
, src
, channels
);
274 if (dest_bit_size
== 64) {
275 assert(src_bit_size
== 32 || src_bit_size
== 16);
276 dest_chan
[comp
] = src_bit_size
== 32 ?
277 nir_pack_64_2x32(&b
->nb
, src_chan
) :
278 nir_pack_64_4x16(&b
->nb
, src_chan
);
280 vtn_assert(dest_bit_size
== 32);
281 vtn_assert(src_bit_size
== 16);
282 dest_chan
[comp
] = nir_pack_32_2x16(&b
->nb
, src_chan
);
286 dest
->def
= nir_vec(&b
->nb
, dest_chan
, dest_components
);
290 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder
*b
,
291 SpvOp opcode
, bool *swap
,
292 unsigned src_bit_size
, unsigned dst_bit_size
)
294 /* Indicates that the first two arguments should be swapped. This is
295 * used for implementing greater-than and less-than-or-equal.
300 case SpvOpSNegate
: return nir_op_ineg
;
301 case SpvOpFNegate
: return nir_op_fneg
;
302 case SpvOpNot
: return nir_op_inot
;
303 case SpvOpIAdd
: return nir_op_iadd
;
304 case SpvOpFAdd
: return nir_op_fadd
;
305 case SpvOpISub
: return nir_op_isub
;
306 case SpvOpFSub
: return nir_op_fsub
;
307 case SpvOpIMul
: return nir_op_imul
;
308 case SpvOpFMul
: return nir_op_fmul
;
309 case SpvOpUDiv
: return nir_op_udiv
;
310 case SpvOpSDiv
: return nir_op_idiv
;
311 case SpvOpFDiv
: return nir_op_fdiv
;
312 case SpvOpUMod
: return nir_op_umod
;
313 case SpvOpSMod
: return nir_op_imod
;
314 case SpvOpFMod
: return nir_op_fmod
;
315 case SpvOpSRem
: return nir_op_irem
;
316 case SpvOpFRem
: return nir_op_frem
;
318 case SpvOpShiftRightLogical
: return nir_op_ushr
;
319 case SpvOpShiftRightArithmetic
: return nir_op_ishr
;
320 case SpvOpShiftLeftLogical
: return nir_op_ishl
;
321 case SpvOpLogicalOr
: return nir_op_ior
;
322 case SpvOpLogicalEqual
: return nir_op_ieq
;
323 case SpvOpLogicalNotEqual
: return nir_op_ine
;
324 case SpvOpLogicalAnd
: return nir_op_iand
;
325 case SpvOpLogicalNot
: return nir_op_inot
;
326 case SpvOpBitwiseOr
: return nir_op_ior
;
327 case SpvOpBitwiseXor
: return nir_op_ixor
;
328 case SpvOpBitwiseAnd
: return nir_op_iand
;
329 case SpvOpSelect
: return nir_op_bcsel
;
330 case SpvOpIEqual
: return nir_op_ieq
;
332 case SpvOpBitFieldInsert
: return nir_op_bitfield_insert
;
333 case SpvOpBitFieldSExtract
: return nir_op_ibitfield_extract
;
334 case SpvOpBitFieldUExtract
: return nir_op_ubitfield_extract
;
335 case SpvOpBitReverse
: return nir_op_bitfield_reverse
;
336 case SpvOpBitCount
: return nir_op_bit_count
;
338 /* The ordered / unordered operators need special implementation besides
339 * the logical operator to use since they also need to check if operands are
342 case SpvOpFOrdEqual
: return nir_op_feq
;
343 case SpvOpFUnordEqual
: return nir_op_feq
;
344 case SpvOpINotEqual
: return nir_op_ine
;
345 case SpvOpFOrdNotEqual
: return nir_op_fne
;
346 case SpvOpFUnordNotEqual
: return nir_op_fne
;
347 case SpvOpULessThan
: return nir_op_ult
;
348 case SpvOpSLessThan
: return nir_op_ilt
;
349 case SpvOpFOrdLessThan
: return nir_op_flt
;
350 case SpvOpFUnordLessThan
: return nir_op_flt
;
351 case SpvOpUGreaterThan
: *swap
= true; return nir_op_ult
;
352 case SpvOpSGreaterThan
: *swap
= true; return nir_op_ilt
;
353 case SpvOpFOrdGreaterThan
: *swap
= true; return nir_op_flt
;
354 case SpvOpFUnordGreaterThan
: *swap
= true; return nir_op_flt
;
355 case SpvOpULessThanEqual
: *swap
= true; return nir_op_uge
;
356 case SpvOpSLessThanEqual
: *swap
= true; return nir_op_ige
;
357 case SpvOpFOrdLessThanEqual
: *swap
= true; return nir_op_fge
;
358 case SpvOpFUnordLessThanEqual
: *swap
= true; return nir_op_fge
;
359 case SpvOpUGreaterThanEqual
: return nir_op_uge
;
360 case SpvOpSGreaterThanEqual
: return nir_op_ige
;
361 case SpvOpFOrdGreaterThanEqual
: return nir_op_fge
;
362 case SpvOpFUnordGreaterThanEqual
: return nir_op_fge
;
365 case SpvOpQuantizeToF16
: return nir_op_fquantize2f16
;
367 case SpvOpConvertFToU
:
368 case SpvOpConvertFToS
:
369 case SpvOpConvertSToF
:
370 case SpvOpConvertUToF
:
372 case SpvOpFConvert
: {
373 nir_alu_type src_type
;
374 nir_alu_type dst_type
;
377 case SpvOpConvertFToS
:
378 src_type
= nir_type_float
;
379 dst_type
= nir_type_int
;
381 case SpvOpConvertFToU
:
382 src_type
= nir_type_float
;
383 dst_type
= nir_type_uint
;
386 src_type
= dst_type
= nir_type_float
;
388 case SpvOpConvertSToF
:
389 src_type
= nir_type_int
;
390 dst_type
= nir_type_float
;
393 src_type
= dst_type
= nir_type_int
;
395 case SpvOpConvertUToF
:
396 src_type
= nir_type_uint
;
397 dst_type
= nir_type_float
;
400 src_type
= dst_type
= nir_type_uint
;
403 unreachable("Invalid opcode");
405 src_type
|= src_bit_size
;
406 dst_type
|= dst_bit_size
;
407 return nir_type_conversion_op(src_type
, dst_type
, nir_rounding_mode_undef
);
410 case SpvOpDPdx
: return nir_op_fddx
;
411 case SpvOpDPdy
: return nir_op_fddy
;
412 case SpvOpDPdxFine
: return nir_op_fddx_fine
;
413 case SpvOpDPdyFine
: return nir_op_fddy_fine
;
414 case SpvOpDPdxCoarse
: return nir_op_fddx_coarse
;
415 case SpvOpDPdyCoarse
: return nir_op_fddy_coarse
;
418 vtn_fail("No NIR equivalent: %u", opcode
);
423 handle_no_contraction(struct vtn_builder
*b
, struct vtn_value
*val
, int member
,
424 const struct vtn_decoration
*dec
, void *_void
)
426 vtn_assert(dec
->scope
== VTN_DEC_DECORATION
);
427 if (dec
->decoration
!= SpvDecorationNoContraction
)
434 handle_rounding_mode(struct vtn_builder
*b
, struct vtn_value
*val
, int member
,
435 const struct vtn_decoration
*dec
, void *_out_rounding_mode
)
437 nir_rounding_mode
*out_rounding_mode
= _out_rounding_mode
;
438 assert(dec
->scope
== VTN_DEC_DECORATION
);
439 if (dec
->decoration
!= SpvDecorationFPRoundingMode
)
441 switch (dec
->literals
[0]) {
442 case SpvFPRoundingModeRTE
:
443 *out_rounding_mode
= nir_rounding_mode_rtne
;
445 case SpvFPRoundingModeRTZ
:
446 *out_rounding_mode
= nir_rounding_mode_rtz
;
449 unreachable("Not supported rounding mode");
455 vtn_handle_alu(struct vtn_builder
*b
, SpvOp opcode
,
456 const uint32_t *w
, unsigned count
)
458 struct vtn_value
*val
= vtn_push_value(b
, w
[2], vtn_value_type_ssa
);
459 const struct glsl_type
*type
=
460 vtn_value(b
, w
[1], vtn_value_type_type
)->type
->type
;
462 vtn_foreach_decoration(b
, val
, handle_no_contraction
, NULL
);
464 /* Collect the various SSA sources */
465 const unsigned num_inputs
= count
- 3;
466 struct vtn_ssa_value
*vtn_src
[4] = { NULL
, };
467 for (unsigned i
= 0; i
< num_inputs
; i
++)
468 vtn_src
[i
] = vtn_ssa_value(b
, w
[i
+ 3]);
470 if (glsl_type_is_matrix(vtn_src
[0]->type
) ||
471 (num_inputs
>= 2 && glsl_type_is_matrix(vtn_src
[1]->type
))) {
472 vtn_handle_matrix_alu(b
, opcode
, val
, vtn_src
[0], vtn_src
[1]);
477 val
->ssa
= vtn_create_ssa_value(b
, type
);
478 nir_ssa_def
*src
[4] = { NULL
, };
479 for (unsigned i
= 0; i
< num_inputs
; i
++) {
480 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src
[i
]->type
));
481 src
[i
] = vtn_src
[i
]->def
;
486 if (src
[0]->num_components
== 1) {
487 val
->ssa
->def
= nir_imov(&b
->nb
, src
[0]);
490 switch (src
[0]->num_components
) {
491 case 2: op
= nir_op_bany_inequal2
; break;
492 case 3: op
= nir_op_bany_inequal3
; break;
493 case 4: op
= nir_op_bany_inequal4
; break;
494 default: vtn_fail("invalid number of components");
496 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0],
497 nir_imm_false(&b
->nb
),
503 if (src
[0]->num_components
== 1) {
504 val
->ssa
->def
= nir_imov(&b
->nb
, src
[0]);
507 switch (src
[0]->num_components
) {
508 case 2: op
= nir_op_ball_iequal2
; break;
509 case 3: op
= nir_op_ball_iequal3
; break;
510 case 4: op
= nir_op_ball_iequal4
; break;
511 default: vtn_fail("invalid number of components");
513 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0],
514 nir_imm_true(&b
->nb
),
519 case SpvOpOuterProduct
: {
520 for (unsigned i
= 0; i
< src
[1]->num_components
; i
++) {
521 val
->ssa
->elems
[i
]->def
=
522 nir_fmul(&b
->nb
, src
[0], nir_channel(&b
->nb
, src
[1], i
));
528 val
->ssa
->def
= nir_fdot(&b
->nb
, src
[0], src
[1]);
532 vtn_assert(glsl_type_is_struct(val
->ssa
->type
));
533 val
->ssa
->elems
[0]->def
= nir_iadd(&b
->nb
, src
[0], src
[1]);
534 val
->ssa
->elems
[1]->def
= nir_uadd_carry(&b
->nb
, src
[0], src
[1]);
537 case SpvOpISubBorrow
:
538 vtn_assert(glsl_type_is_struct(val
->ssa
->type
));
539 val
->ssa
->elems
[0]->def
= nir_isub(&b
->nb
, src
[0], src
[1]);
540 val
->ssa
->elems
[1]->def
= nir_usub_borrow(&b
->nb
, src
[0], src
[1]);
543 case SpvOpUMulExtended
:
544 vtn_assert(glsl_type_is_struct(val
->ssa
->type
));
545 val
->ssa
->elems
[0]->def
= nir_imul(&b
->nb
, src
[0], src
[1]);
546 val
->ssa
->elems
[1]->def
= nir_umul_high(&b
->nb
, src
[0], src
[1]);
549 case SpvOpSMulExtended
:
550 vtn_assert(glsl_type_is_struct(val
->ssa
->type
));
551 val
->ssa
->elems
[0]->def
= nir_imul(&b
->nb
, src
[0], src
[1]);
552 val
->ssa
->elems
[1]->def
= nir_imul_high(&b
->nb
, src
[0], src
[1]);
556 val
->ssa
->def
= nir_fadd(&b
->nb
,
557 nir_fabs(&b
->nb
, nir_fddx(&b
->nb
, src
[0])),
558 nir_fabs(&b
->nb
, nir_fddy(&b
->nb
, src
[0])));
560 case SpvOpFwidthFine
:
561 val
->ssa
->def
= nir_fadd(&b
->nb
,
562 nir_fabs(&b
->nb
, nir_fddx_fine(&b
->nb
, src
[0])),
563 nir_fabs(&b
->nb
, nir_fddy_fine(&b
->nb
, src
[0])));
565 case SpvOpFwidthCoarse
:
566 val
->ssa
->def
= nir_fadd(&b
->nb
,
567 nir_fabs(&b
->nb
, nir_fddx_coarse(&b
->nb
, src
[0])),
568 nir_fabs(&b
->nb
, nir_fddy_coarse(&b
->nb
, src
[0])));
571 case SpvOpVectorTimesScalar
:
572 /* The builder will take care of splatting for us. */
573 val
->ssa
->def
= nir_fmul(&b
->nb
, src
[0], src
[1]);
577 val
->ssa
->def
= nir_fne(&b
->nb
, src
[0], src
[0]);
581 nir_ssa_def
*inf
= nir_imm_floatN_t(&b
->nb
, INFINITY
, src
[0]->bit_size
);
582 val
->ssa
->def
= nir_ieq(&b
->nb
, nir_fabs(&b
->nb
, src
[0]), inf
);
586 case SpvOpFUnordEqual
:
587 case SpvOpFUnordNotEqual
:
588 case SpvOpFUnordLessThan
:
589 case SpvOpFUnordGreaterThan
:
590 case SpvOpFUnordLessThanEqual
:
591 case SpvOpFUnordGreaterThanEqual
: {
593 unsigned src_bit_size
= glsl_get_bit_size(vtn_src
[0]->type
);
594 unsigned dst_bit_size
= glsl_get_bit_size(type
);
595 nir_op op
= vtn_nir_alu_op_for_spirv_opcode(b
, opcode
, &swap
,
596 src_bit_size
, dst_bit_size
);
599 nir_ssa_def
*tmp
= src
[0];
606 nir_build_alu(&b
->nb
, op
, src
[0], src
[1], NULL
, NULL
),
608 nir_fne(&b
->nb
, src
[0], src
[0]),
609 nir_fne(&b
->nb
, src
[1], src
[1])));
613 case SpvOpFOrdNotEqual
: {
614 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
615 * from the ALU will probably already be false if the operands are not
616 * ordered so we don’t need to handle it specially.
619 unsigned src_bit_size
= glsl_get_bit_size(vtn_src
[0]->type
);
620 unsigned dst_bit_size
= glsl_get_bit_size(type
);
621 nir_op op
= vtn_nir_alu_op_for_spirv_opcode(b
, opcode
, &swap
,
622 src_bit_size
, dst_bit_size
);
628 nir_build_alu(&b
->nb
, op
, src
[0], src
[1], NULL
, NULL
),
630 nir_feq(&b
->nb
, src
[0], src
[0]),
631 nir_feq(&b
->nb
, src
[1], src
[1])));
636 vtn_handle_bitcast(b
, val
->ssa
, src
[0]);
639 case SpvOpFConvert
: {
640 nir_alu_type src_alu_type
= nir_get_nir_type_for_glsl_type(vtn_src
[0]->type
);
641 nir_alu_type dst_alu_type
= nir_get_nir_type_for_glsl_type(type
);
642 nir_rounding_mode rounding_mode
= nir_rounding_mode_undef
;
644 vtn_foreach_decoration(b
, val
, handle_rounding_mode
, &rounding_mode
);
645 nir_op op
= nir_type_conversion_op(src_alu_type
, dst_alu_type
, rounding_mode
);
647 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0], src
[1], NULL
, NULL
);
651 case SpvOpBitFieldInsert
:
652 case SpvOpBitFieldSExtract
:
653 case SpvOpBitFieldUExtract
:
654 case SpvOpShiftLeftLogical
:
655 case SpvOpShiftRightArithmetic
:
656 case SpvOpShiftRightLogical
: {
658 unsigned src0_bit_size
= glsl_get_bit_size(vtn_src
[0]->type
);
659 unsigned dst_bit_size
= glsl_get_bit_size(type
);
660 nir_op op
= vtn_nir_alu_op_for_spirv_opcode(b
, opcode
, &swap
,
661 src0_bit_size
, dst_bit_size
);
663 assert (op
== nir_op_ushr
|| op
== nir_op_ishr
|| op
== nir_op_ishl
||
664 op
== nir_op_bitfield_insert
|| op
== nir_op_ubitfield_extract
||
665 op
== nir_op_ibitfield_extract
);
667 for (unsigned i
= 0; i
< nir_op_infos
[op
].num_inputs
; i
++) {
668 unsigned src_bit_size
=
669 nir_alu_type_get_type_size(nir_op_infos
[op
].input_types
[i
]);
670 if (src_bit_size
== 0)
672 if (src_bit_size
!= src
[i
]->bit_size
) {
673 assert(src_bit_size
== 32);
674 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
675 * supported by the NIR instructions. See discussion here:
677 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
679 src
[i
] = nir_u2u32(&b
->nb
, src
[i
]);
682 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0], src
[1], src
[2], src
[3]);
688 unsigned src_bit_size
= glsl_get_bit_size(vtn_src
[0]->type
);
689 unsigned dst_bit_size
= glsl_get_bit_size(type
);
690 nir_op op
= vtn_nir_alu_op_for_spirv_opcode(b
, opcode
, &swap
,
691 src_bit_size
, dst_bit_size
);
694 nir_ssa_def
*tmp
= src
[0];
703 if (src
[1]->bit_size
!= 32)
704 src
[1] = nir_u2u32(&b
->nb
, src
[1]);
710 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0], src
[1], src
[2], src
[3]);