2 * Copyright © 2016 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 #include "vtn_private.h"
27 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
28 * definition. But for matrix multiplies, we want to do one routine for
29 * multiplying a matrix by a matrix and then pretend that vectors are matrices
30 * with one column. So we "wrap" these things, and unwrap the result before we
34 static struct vtn_ssa_value
*
35 wrap_matrix(struct vtn_builder
*b
, struct vtn_ssa_value
*val
)
40 if (glsl_type_is_matrix(val
->type
))
43 struct vtn_ssa_value
*dest
= rzalloc(b
, struct vtn_ssa_value
);
44 dest
->type
= val
->type
;
45 dest
->elems
= ralloc_array(b
, struct vtn_ssa_value
*, 1);
51 static struct vtn_ssa_value
*
52 unwrap_matrix(struct vtn_ssa_value
*val
)
54 if (glsl_type_is_matrix(val
->type
))
60 static struct vtn_ssa_value
*
61 matrix_multiply(struct vtn_builder
*b
,
62 struct vtn_ssa_value
*_src0
, struct vtn_ssa_value
*_src1
)
65 struct vtn_ssa_value
*src0
= wrap_matrix(b
, _src0
);
66 struct vtn_ssa_value
*src1
= wrap_matrix(b
, _src1
);
67 struct vtn_ssa_value
*src0_transpose
= wrap_matrix(b
, _src0
->transposed
);
68 struct vtn_ssa_value
*src1_transpose
= wrap_matrix(b
, _src1
->transposed
);
70 unsigned src0_rows
= glsl_get_vector_elements(src0
->type
);
71 unsigned src0_columns
= glsl_get_matrix_columns(src0
->type
);
72 unsigned src1_columns
= glsl_get_matrix_columns(src1
->type
);
74 const struct glsl_type
*dest_type
;
75 if (src1_columns
> 1) {
76 dest_type
= glsl_matrix_type(glsl_get_base_type(src0
->type
),
77 src0_rows
, src1_columns
);
79 dest_type
= glsl_vector_type(glsl_get_base_type(src0
->type
), src0_rows
);
81 struct vtn_ssa_value
*dest
= vtn_create_ssa_value(b
, dest_type
);
83 dest
= wrap_matrix(b
, dest
);
85 bool transpose_result
= false;
86 if (src0_transpose
&& src1_transpose
) {
87 /* transpose(A) * transpose(B) = transpose(B * A) */
88 src1
= src0_transpose
;
89 src0
= src1_transpose
;
90 src0_transpose
= NULL
;
91 src1_transpose
= NULL
;
92 transpose_result
= true;
95 if (src0_transpose
&& !src1_transpose
&&
96 glsl_get_base_type(src0
->type
) == GLSL_TYPE_FLOAT
) {
97 /* We already have the rows of src0 and the columns of src1 available,
98 * so we can just take the dot product of each row with each column to
102 for (unsigned i
= 0; i
< src1_columns
; i
++) {
103 nir_ssa_def
*vec_src
[4];
104 for (unsigned j
= 0; j
< src0_rows
; j
++) {
105 vec_src
[j
] = nir_fdot(&b
->nb
, src0_transpose
->elems
[j
]->def
,
106 src1
->elems
[i
]->def
);
108 dest
->elems
[i
]->def
= nir_vec(&b
->nb
, vec_src
, src0_rows
);
111 /* We don't handle the case where src1 is transposed but not src0, since
112 * the general case only uses individual components of src1 so the
113 * optimizer should chew through the transpose we emitted for src1.
116 for (unsigned i
= 0; i
< src1_columns
; i
++) {
117 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
118 dest
->elems
[i
]->def
=
119 nir_fmul(&b
->nb
, src0
->elems
[0]->def
,
120 nir_channel(&b
->nb
, src1
->elems
[i
]->def
, 0));
121 for (unsigned j
= 1; j
< src0_columns
; j
++) {
122 dest
->elems
[i
]->def
=
123 nir_fadd(&b
->nb
, dest
->elems
[i
]->def
,
124 nir_fmul(&b
->nb
, src0
->elems
[j
]->def
,
125 nir_channel(&b
->nb
, src1
->elems
[i
]->def
, j
)));
130 dest
= unwrap_matrix(dest
);
132 if (transpose_result
)
133 dest
= vtn_ssa_transpose(b
, dest
);
138 static struct vtn_ssa_value
*
139 mat_times_scalar(struct vtn_builder
*b
,
140 struct vtn_ssa_value
*mat
,
143 struct vtn_ssa_value
*dest
= vtn_create_ssa_value(b
, mat
->type
);
144 for (unsigned i
= 0; i
< glsl_get_matrix_columns(mat
->type
); i
++) {
145 if (glsl_get_base_type(mat
->type
) == GLSL_TYPE_FLOAT
)
146 dest
->elems
[i
]->def
= nir_fmul(&b
->nb
, mat
->elems
[i
]->def
, scalar
);
148 dest
->elems
[i
]->def
= nir_imul(&b
->nb
, mat
->elems
[i
]->def
, scalar
);
155 vtn_handle_matrix_alu(struct vtn_builder
*b
, SpvOp opcode
,
156 struct vtn_value
*dest
,
157 struct vtn_ssa_value
*src0
, struct vtn_ssa_value
*src1
)
161 dest
->ssa
= vtn_create_ssa_value(b
, src0
->type
);
162 unsigned cols
= glsl_get_matrix_columns(src0
->type
);
163 for (unsigned i
= 0; i
< cols
; i
++)
164 dest
->ssa
->elems
[i
]->def
= nir_fneg(&b
->nb
, src0
->elems
[i
]->def
);
169 dest
->ssa
= vtn_create_ssa_value(b
, src0
->type
);
170 unsigned cols
= glsl_get_matrix_columns(src0
->type
);
171 for (unsigned i
= 0; i
< cols
; i
++)
172 dest
->ssa
->elems
[i
]->def
=
173 nir_fadd(&b
->nb
, src0
->elems
[i
]->def
, src1
->elems
[i
]->def
);
178 dest
->ssa
= vtn_create_ssa_value(b
, src0
->type
);
179 unsigned cols
= glsl_get_matrix_columns(src0
->type
);
180 for (unsigned i
= 0; i
< cols
; i
++)
181 dest
->ssa
->elems
[i
]->def
=
182 nir_fsub(&b
->nb
, src0
->elems
[i
]->def
, src1
->elems
[i
]->def
);
187 dest
->ssa
= vtn_ssa_transpose(b
, src0
);
190 case SpvOpMatrixTimesScalar
:
191 if (src0
->transposed
) {
192 dest
->ssa
= vtn_ssa_transpose(b
, mat_times_scalar(b
, src0
->transposed
,
195 dest
->ssa
= mat_times_scalar(b
, src0
, src1
->def
);
199 case SpvOpVectorTimesMatrix
:
200 case SpvOpMatrixTimesVector
:
201 case SpvOpMatrixTimesMatrix
:
202 if (opcode
== SpvOpVectorTimesMatrix
) {
203 dest
->ssa
= matrix_multiply(b
, vtn_ssa_transpose(b
, src1
), src0
);
205 dest
->ssa
= matrix_multiply(b
, src0
, src1
);
209 default: unreachable("unknown matrix opcode");
214 vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode
, bool *swap
,
215 nir_alu_type src
, nir_alu_type dst
)
217 /* Indicates that the first two arguments should be swapped. This is
218 * used for implementing greater-than and less-than-or-equal.
223 case SpvOpSNegate
: return nir_op_ineg
;
224 case SpvOpFNegate
: return nir_op_fneg
;
225 case SpvOpNot
: return nir_op_inot
;
226 case SpvOpIAdd
: return nir_op_iadd
;
227 case SpvOpFAdd
: return nir_op_fadd
;
228 case SpvOpISub
: return nir_op_isub
;
229 case SpvOpFSub
: return nir_op_fsub
;
230 case SpvOpIMul
: return nir_op_imul
;
231 case SpvOpFMul
: return nir_op_fmul
;
232 case SpvOpUDiv
: return nir_op_udiv
;
233 case SpvOpSDiv
: return nir_op_idiv
;
234 case SpvOpFDiv
: return nir_op_fdiv
;
235 case SpvOpUMod
: return nir_op_umod
;
236 case SpvOpSMod
: return nir_op_imod
;
237 case SpvOpFMod
: return nir_op_fmod
;
238 case SpvOpSRem
: return nir_op_irem
;
239 case SpvOpFRem
: return nir_op_frem
;
241 case SpvOpShiftRightLogical
: return nir_op_ushr
;
242 case SpvOpShiftRightArithmetic
: return nir_op_ishr
;
243 case SpvOpShiftLeftLogical
: return nir_op_ishl
;
244 case SpvOpLogicalOr
: return nir_op_ior
;
245 case SpvOpLogicalEqual
: return nir_op_ieq
;
246 case SpvOpLogicalNotEqual
: return nir_op_ine
;
247 case SpvOpLogicalAnd
: return nir_op_iand
;
248 case SpvOpLogicalNot
: return nir_op_inot
;
249 case SpvOpBitwiseOr
: return nir_op_ior
;
250 case SpvOpBitwiseXor
: return nir_op_ixor
;
251 case SpvOpBitwiseAnd
: return nir_op_iand
;
252 case SpvOpSelect
: return nir_op_bcsel
;
253 case SpvOpIEqual
: return nir_op_ieq
;
255 case SpvOpBitFieldInsert
: return nir_op_bitfield_insert
;
256 case SpvOpBitFieldSExtract
: return nir_op_ibitfield_extract
;
257 case SpvOpBitFieldUExtract
: return nir_op_ubitfield_extract
;
258 case SpvOpBitReverse
: return nir_op_bitfield_reverse
;
259 case SpvOpBitCount
: return nir_op_bit_count
;
261 /* The ordered / unordered operators need special implementation besides
262 * the logical operator to use since they also need to check if operands are
265 case SpvOpFOrdEqual
: return nir_op_feq
;
266 case SpvOpFUnordEqual
: return nir_op_feq
;
267 case SpvOpINotEqual
: return nir_op_ine
;
268 case SpvOpFOrdNotEqual
: return nir_op_fne
;
269 case SpvOpFUnordNotEqual
: return nir_op_fne
;
270 case SpvOpULessThan
: return nir_op_ult
;
271 case SpvOpSLessThan
: return nir_op_ilt
;
272 case SpvOpFOrdLessThan
: return nir_op_flt
;
273 case SpvOpFUnordLessThan
: return nir_op_flt
;
274 case SpvOpUGreaterThan
: *swap
= true; return nir_op_ult
;
275 case SpvOpSGreaterThan
: *swap
= true; return nir_op_ilt
;
276 case SpvOpFOrdGreaterThan
: *swap
= true; return nir_op_flt
;
277 case SpvOpFUnordGreaterThan
: *swap
= true; return nir_op_flt
;
278 case SpvOpULessThanEqual
: *swap
= true; return nir_op_uge
;
279 case SpvOpSLessThanEqual
: *swap
= true; return nir_op_ige
;
280 case SpvOpFOrdLessThanEqual
: *swap
= true; return nir_op_fge
;
281 case SpvOpFUnordLessThanEqual
: *swap
= true; return nir_op_fge
;
282 case SpvOpUGreaterThanEqual
: return nir_op_uge
;
283 case SpvOpSGreaterThanEqual
: return nir_op_ige
;
284 case SpvOpFOrdGreaterThanEqual
: return nir_op_fge
;
285 case SpvOpFUnordGreaterThanEqual
: return nir_op_fge
;
288 case SpvOpBitcast
: return nir_op_imov
;
289 case SpvOpQuantizeToF16
: return nir_op_fquantize2f16
;
291 case SpvOpConvertFToU
:
292 case SpvOpConvertFToS
:
293 case SpvOpConvertSToF
:
294 case SpvOpConvertUToF
:
297 return nir_type_conversion_op(src
, dst
);
300 case SpvOpDPdx
: return nir_op_fddx
;
301 case SpvOpDPdy
: return nir_op_fddy
;
302 case SpvOpDPdxFine
: return nir_op_fddx_fine
;
303 case SpvOpDPdyFine
: return nir_op_fddy_fine
;
304 case SpvOpDPdxCoarse
: return nir_op_fddx_coarse
;
305 case SpvOpDPdyCoarse
: return nir_op_fddy_coarse
;
308 unreachable("No NIR equivalent");
313 handle_no_contraction(struct vtn_builder
*b
, struct vtn_value
*val
, int member
,
314 const struct vtn_decoration
*dec
, void *_void
)
316 assert(dec
->scope
== VTN_DEC_DECORATION
);
317 if (dec
->decoration
!= SpvDecorationNoContraction
)
324 vtn_handle_alu(struct vtn_builder
*b
, SpvOp opcode
,
325 const uint32_t *w
, unsigned count
)
327 struct vtn_value
*val
= vtn_push_value(b
, w
[2], vtn_value_type_ssa
);
328 const struct glsl_type
*type
=
329 vtn_value(b
, w
[1], vtn_value_type_type
)->type
->type
;
331 vtn_foreach_decoration(b
, val
, handle_no_contraction
, NULL
);
333 /* Collect the various SSA sources */
334 const unsigned num_inputs
= count
- 3;
335 struct vtn_ssa_value
*vtn_src
[4] = { NULL
, };
336 for (unsigned i
= 0; i
< num_inputs
; i
++)
337 vtn_src
[i
] = vtn_ssa_value(b
, w
[i
+ 3]);
339 if (glsl_type_is_matrix(vtn_src
[0]->type
) ||
340 (num_inputs
>= 2 && glsl_type_is_matrix(vtn_src
[1]->type
))) {
341 vtn_handle_matrix_alu(b
, opcode
, val
, vtn_src
[0], vtn_src
[1]);
346 val
->ssa
= vtn_create_ssa_value(b
, type
);
347 nir_ssa_def
*src
[4] = { NULL
, };
348 for (unsigned i
= 0; i
< num_inputs
; i
++) {
349 assert(glsl_type_is_vector_or_scalar(vtn_src
[i
]->type
));
350 src
[i
] = vtn_src
[i
]->def
;
355 if (src
[0]->num_components
== 1) {
356 val
->ssa
->def
= nir_imov(&b
->nb
, src
[0]);
359 switch (src
[0]->num_components
) {
360 case 2: op
= nir_op_bany_inequal2
; break;
361 case 3: op
= nir_op_bany_inequal3
; break;
362 case 4: op
= nir_op_bany_inequal4
; break;
363 default: unreachable("invalid number of components");
365 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0],
366 nir_imm_int(&b
->nb
, NIR_FALSE
),
372 if (src
[0]->num_components
== 1) {
373 val
->ssa
->def
= nir_imov(&b
->nb
, src
[0]);
376 switch (src
[0]->num_components
) {
377 case 2: op
= nir_op_ball_iequal2
; break;
378 case 3: op
= nir_op_ball_iequal3
; break;
379 case 4: op
= nir_op_ball_iequal4
; break;
380 default: unreachable("invalid number of components");
382 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0],
383 nir_imm_int(&b
->nb
, NIR_TRUE
),
388 case SpvOpOuterProduct
: {
389 for (unsigned i
= 0; i
< src
[1]->num_components
; i
++) {
390 val
->ssa
->elems
[i
]->def
=
391 nir_fmul(&b
->nb
, src
[0], nir_channel(&b
->nb
, src
[1], i
));
397 val
->ssa
->def
= nir_fdot(&b
->nb
, src
[0], src
[1]);
401 assert(glsl_type_is_struct(val
->ssa
->type
));
402 val
->ssa
->elems
[0]->def
= nir_iadd(&b
->nb
, src
[0], src
[1]);
403 val
->ssa
->elems
[1]->def
= nir_uadd_carry(&b
->nb
, src
[0], src
[1]);
406 case SpvOpISubBorrow
:
407 assert(glsl_type_is_struct(val
->ssa
->type
));
408 val
->ssa
->elems
[0]->def
= nir_isub(&b
->nb
, src
[0], src
[1]);
409 val
->ssa
->elems
[1]->def
= nir_usub_borrow(&b
->nb
, src
[0], src
[1]);
412 case SpvOpUMulExtended
:
413 assert(glsl_type_is_struct(val
->ssa
->type
));
414 val
->ssa
->elems
[0]->def
= nir_imul(&b
->nb
, src
[0], src
[1]);
415 val
->ssa
->elems
[1]->def
= nir_umul_high(&b
->nb
, src
[0], src
[1]);
418 case SpvOpSMulExtended
:
419 assert(glsl_type_is_struct(val
->ssa
->type
));
420 val
->ssa
->elems
[0]->def
= nir_imul(&b
->nb
, src
[0], src
[1]);
421 val
->ssa
->elems
[1]->def
= nir_imul_high(&b
->nb
, src
[0], src
[1]);
425 val
->ssa
->def
= nir_fadd(&b
->nb
,
426 nir_fabs(&b
->nb
, nir_fddx(&b
->nb
, src
[0])),
427 nir_fabs(&b
->nb
, nir_fddy(&b
->nb
, src
[0])));
429 case SpvOpFwidthFine
:
430 val
->ssa
->def
= nir_fadd(&b
->nb
,
431 nir_fabs(&b
->nb
, nir_fddx_fine(&b
->nb
, src
[0])),
432 nir_fabs(&b
->nb
, nir_fddy_fine(&b
->nb
, src
[0])));
434 case SpvOpFwidthCoarse
:
435 val
->ssa
->def
= nir_fadd(&b
->nb
,
436 nir_fabs(&b
->nb
, nir_fddx_coarse(&b
->nb
, src
[0])),
437 nir_fabs(&b
->nb
, nir_fddy_coarse(&b
->nb
, src
[0])));
440 case SpvOpVectorTimesScalar
:
441 /* The builder will take care of splatting for us. */
442 val
->ssa
->def
= nir_fmul(&b
->nb
, src
[0], src
[1]);
446 val
->ssa
->def
= nir_fne(&b
->nb
, src
[0], src
[0]);
450 val
->ssa
->def
= nir_feq(&b
->nb
, nir_fabs(&b
->nb
, src
[0]),
451 nir_imm_float(&b
->nb
, INFINITY
));
454 case SpvOpFUnordEqual
:
455 case SpvOpFUnordNotEqual
:
456 case SpvOpFUnordLessThan
:
457 case SpvOpFUnordGreaterThan
:
458 case SpvOpFUnordLessThanEqual
:
459 case SpvOpFUnordGreaterThanEqual
: {
461 nir_alu_type src_alu_type
= nir_get_nir_type_for_glsl_type(vtn_src
[0]->type
);
462 nir_alu_type dst_alu_type
= nir_get_nir_type_for_glsl_type(type
);
463 nir_op op
= vtn_nir_alu_op_for_spirv_opcode(opcode
, &swap
, src_alu_type
, dst_alu_type
);
466 nir_ssa_def
*tmp
= src
[0];
473 nir_build_alu(&b
->nb
, op
, src
[0], src
[1], NULL
, NULL
),
475 nir_fne(&b
->nb
, src
[0], src
[0]),
476 nir_fne(&b
->nb
, src
[1], src
[1])));
481 case SpvOpFOrdNotEqual
:
482 case SpvOpFOrdLessThan
:
483 case SpvOpFOrdGreaterThan
:
484 case SpvOpFOrdLessThanEqual
:
485 case SpvOpFOrdGreaterThanEqual
: {
487 nir_alu_type src_alu_type
= nir_get_nir_type_for_glsl_type(vtn_src
[0]->type
);
488 nir_alu_type dst_alu_type
= nir_get_nir_type_for_glsl_type(type
);
489 nir_op op
= vtn_nir_alu_op_for_spirv_opcode(opcode
, &swap
, src_alu_type
, dst_alu_type
);
492 nir_ssa_def
*tmp
= src
[0];
499 nir_build_alu(&b
->nb
, op
, src
[0], src
[1], NULL
, NULL
),
501 nir_feq(&b
->nb
, src
[0], src
[0]),
502 nir_feq(&b
->nb
, src
[1], src
[1])));
508 nir_alu_type src_alu_type
= nir_get_nir_type_for_glsl_type(vtn_src
[0]->type
);
509 nir_alu_type dst_alu_type
= nir_get_nir_type_for_glsl_type(type
);
510 nir_op op
= vtn_nir_alu_op_for_spirv_opcode(opcode
, &swap
, src_alu_type
, dst_alu_type
);
513 nir_ssa_def
*tmp
= src
[0];
518 val
->ssa
->def
= nir_build_alu(&b
->nb
, op
, src
[0], src
[1], src
[2], src
[3]);