spirv: handle SpvOpUConvert in proper place.
[mesa.git] / src / compiler / spirv / vtn_alu.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 /*
27 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
28 * definition. But for matrix multiplies, we want to do one routine for
29 * multiplying a matrix by a matrix and then pretend that vectors are matrices
30 * with one column. So we "wrap" these things, and unwrap the result before we
31 * send it off.
32 */
33
34 static struct vtn_ssa_value *
35 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
36 {
37 if (val == NULL)
38 return NULL;
39
40 if (glsl_type_is_matrix(val->type))
41 return val;
42
43 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
44 dest->type = val->type;
45 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
46 dest->elems[0] = val;
47
48 return dest;
49 }
50
51 static struct vtn_ssa_value *
52 unwrap_matrix(struct vtn_ssa_value *val)
53 {
54 if (glsl_type_is_matrix(val->type))
55 return val;
56
57 return val->elems[0];
58 }
59
60 static struct vtn_ssa_value *
61 matrix_multiply(struct vtn_builder *b,
62 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
63 {
64
65 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
66 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
67 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
68 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
69
70 unsigned src0_rows = glsl_get_vector_elements(src0->type);
71 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
72 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
73
74 const struct glsl_type *dest_type;
75 if (src1_columns > 1) {
76 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
77 src0_rows, src1_columns);
78 } else {
79 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
80 }
81 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
82
83 dest = wrap_matrix(b, dest);
84
85 bool transpose_result = false;
86 if (src0_transpose && src1_transpose) {
87 /* transpose(A) * transpose(B) = transpose(B * A) */
88 src1 = src0_transpose;
89 src0 = src1_transpose;
90 src0_transpose = NULL;
91 src1_transpose = NULL;
92 transpose_result = true;
93 }
94
95 if (src0_transpose && !src1_transpose &&
96 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
97 /* We already have the rows of src0 and the columns of src1 available,
98 * so we can just take the dot product of each row with each column to
99 * get the result.
100 */
101
102 for (unsigned i = 0; i < src1_columns; i++) {
103 nir_ssa_def *vec_src[4];
104 for (unsigned j = 0; j < src0_rows; j++) {
105 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
106 src1->elems[i]->def);
107 }
108 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
109 }
110 } else {
111 /* We don't handle the case where src1 is transposed but not src0, since
112 * the general case only uses individual components of src1 so the
113 * optimizer should chew through the transpose we emitted for src1.
114 */
115
116 for (unsigned i = 0; i < src1_columns; i++) {
117 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
118 dest->elems[i]->def =
119 nir_fmul(&b->nb, src0->elems[0]->def,
120 nir_channel(&b->nb, src1->elems[i]->def, 0));
121 for (unsigned j = 1; j < src0_columns; j++) {
122 dest->elems[i]->def =
123 nir_fadd(&b->nb, dest->elems[i]->def,
124 nir_fmul(&b->nb, src0->elems[j]->def,
125 nir_channel(&b->nb, src1->elems[i]->def, j)));
126 }
127 }
128 }
129
130 dest = unwrap_matrix(dest);
131
132 if (transpose_result)
133 dest = vtn_ssa_transpose(b, dest);
134
135 return dest;
136 }
137
138 static struct vtn_ssa_value *
139 mat_times_scalar(struct vtn_builder *b,
140 struct vtn_ssa_value *mat,
141 nir_ssa_def *scalar)
142 {
143 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
144 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
145 if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
146 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
147 else
148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 }
150
151 return dest;
152 }
153
154 static void
155 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
156 struct vtn_value *dest,
157 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
158 {
159 switch (opcode) {
160 case SpvOpFNegate: {
161 dest->ssa = vtn_create_ssa_value(b, src0->type);
162 unsigned cols = glsl_get_matrix_columns(src0->type);
163 for (unsigned i = 0; i < cols; i++)
164 dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
165 break;
166 }
167
168 case SpvOpFAdd: {
169 dest->ssa = vtn_create_ssa_value(b, src0->type);
170 unsigned cols = glsl_get_matrix_columns(src0->type);
171 for (unsigned i = 0; i < cols; i++)
172 dest->ssa->elems[i]->def =
173 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
174 break;
175 }
176
177 case SpvOpFSub: {
178 dest->ssa = vtn_create_ssa_value(b, src0->type);
179 unsigned cols = glsl_get_matrix_columns(src0->type);
180 for (unsigned i = 0; i < cols; i++)
181 dest->ssa->elems[i]->def =
182 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
183 break;
184 }
185
186 case SpvOpTranspose:
187 dest->ssa = vtn_ssa_transpose(b, src0);
188 break;
189
190 case SpvOpMatrixTimesScalar:
191 if (src0->transposed) {
192 dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193 src1->def));
194 } else {
195 dest->ssa = mat_times_scalar(b, src0, src1->def);
196 }
197 break;
198
199 case SpvOpVectorTimesMatrix:
200 case SpvOpMatrixTimesVector:
201 case SpvOpMatrixTimesMatrix:
202 if (opcode == SpvOpVectorTimesMatrix) {
203 dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204 } else {
205 dest->ssa = matrix_multiply(b, src0, src1);
206 }
207 break;
208
209 default: unreachable("unknown matrix opcode");
210 }
211 }
212
213 nir_op
214 vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap,
215 nir_alu_type src, nir_alu_type dst)
216 {
217 /* Indicates that the first two arguments should be swapped. This is
218 * used for implementing greater-than and less-than-or-equal.
219 */
220 *swap = false;
221
222 switch (opcode) {
223 case SpvOpSNegate: return nir_op_ineg;
224 case SpvOpFNegate: return nir_op_fneg;
225 case SpvOpNot: return nir_op_inot;
226 case SpvOpIAdd: return nir_op_iadd;
227 case SpvOpFAdd: return nir_op_fadd;
228 case SpvOpISub: return nir_op_isub;
229 case SpvOpFSub: return nir_op_fsub;
230 case SpvOpIMul: return nir_op_imul;
231 case SpvOpFMul: return nir_op_fmul;
232 case SpvOpUDiv: return nir_op_udiv;
233 case SpvOpSDiv: return nir_op_idiv;
234 case SpvOpFDiv: return nir_op_fdiv;
235 case SpvOpUMod: return nir_op_umod;
236 case SpvOpSMod: return nir_op_imod;
237 case SpvOpFMod: return nir_op_fmod;
238 case SpvOpSRem: return nir_op_irem;
239 case SpvOpFRem: return nir_op_frem;
240
241 case SpvOpShiftRightLogical: return nir_op_ushr;
242 case SpvOpShiftRightArithmetic: return nir_op_ishr;
243 case SpvOpShiftLeftLogical: return nir_op_ishl;
244 case SpvOpLogicalOr: return nir_op_ior;
245 case SpvOpLogicalEqual: return nir_op_ieq;
246 case SpvOpLogicalNotEqual: return nir_op_ine;
247 case SpvOpLogicalAnd: return nir_op_iand;
248 case SpvOpLogicalNot: return nir_op_inot;
249 case SpvOpBitwiseOr: return nir_op_ior;
250 case SpvOpBitwiseXor: return nir_op_ixor;
251 case SpvOpBitwiseAnd: return nir_op_iand;
252 case SpvOpSelect: return nir_op_bcsel;
253 case SpvOpIEqual: return nir_op_ieq;
254
255 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
256 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
257 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
258 case SpvOpBitReverse: return nir_op_bitfield_reverse;
259 case SpvOpBitCount: return nir_op_bit_count;
260
261 /* The ordered / unordered operators need special implementation besides
262 * the logical operator to use since they also need to check if operands are
263 * ordered.
264 */
265 case SpvOpFOrdEqual: return nir_op_feq;
266 case SpvOpFUnordEqual: return nir_op_feq;
267 case SpvOpINotEqual: return nir_op_ine;
268 case SpvOpFOrdNotEqual: return nir_op_fne;
269 case SpvOpFUnordNotEqual: return nir_op_fne;
270 case SpvOpULessThan: return nir_op_ult;
271 case SpvOpSLessThan: return nir_op_ilt;
272 case SpvOpFOrdLessThan: return nir_op_flt;
273 case SpvOpFUnordLessThan: return nir_op_flt;
274 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
275 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
276 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt;
277 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt;
278 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
279 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
280 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge;
281 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge;
282 case SpvOpUGreaterThanEqual: return nir_op_uge;
283 case SpvOpSGreaterThanEqual: return nir_op_ige;
284 case SpvOpFOrdGreaterThanEqual: return nir_op_fge;
285 case SpvOpFUnordGreaterThanEqual: return nir_op_fge;
286
287 /* Conversions: */
288 case SpvOpBitcast: return nir_op_imov;
289 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
290 case SpvOpUConvert:
291 case SpvOpConvertFToU:
292 case SpvOpConvertFToS:
293 case SpvOpConvertSToF:
294 case SpvOpConvertUToF:
295 case SpvOpSConvert:
296 case SpvOpFConvert:
297 return nir_type_conversion_op(src, dst);
298
299 /* Derivatives: */
300 case SpvOpDPdx: return nir_op_fddx;
301 case SpvOpDPdy: return nir_op_fddy;
302 case SpvOpDPdxFine: return nir_op_fddx_fine;
303 case SpvOpDPdyFine: return nir_op_fddy_fine;
304 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
305 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
306
307 default:
308 unreachable("No NIR equivalent");
309 }
310 }
311
312 static void
313 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
314 const struct vtn_decoration *dec, void *_void)
315 {
316 assert(dec->scope == VTN_DEC_DECORATION);
317 if (dec->decoration != SpvDecorationNoContraction)
318 return;
319
320 b->nb.exact = true;
321 }
322
323 void
324 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
325 const uint32_t *w, unsigned count)
326 {
327 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
328 const struct glsl_type *type =
329 vtn_value(b, w[1], vtn_value_type_type)->type->type;
330
331 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
332
333 /* Collect the various SSA sources */
334 const unsigned num_inputs = count - 3;
335 struct vtn_ssa_value *vtn_src[4] = { NULL, };
336 for (unsigned i = 0; i < num_inputs; i++)
337 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
338
339 if (glsl_type_is_matrix(vtn_src[0]->type) ||
340 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
341 vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
342 b->nb.exact = false;
343 return;
344 }
345
346 val->ssa = vtn_create_ssa_value(b, type);
347 nir_ssa_def *src[4] = { NULL, };
348 for (unsigned i = 0; i < num_inputs; i++) {
349 assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
350 src[i] = vtn_src[i]->def;
351 }
352
353 switch (opcode) {
354 case SpvOpAny:
355 if (src[0]->num_components == 1) {
356 val->ssa->def = nir_imov(&b->nb, src[0]);
357 } else {
358 nir_op op;
359 switch (src[0]->num_components) {
360 case 2: op = nir_op_bany_inequal2; break;
361 case 3: op = nir_op_bany_inequal3; break;
362 case 4: op = nir_op_bany_inequal4; break;
363 default: unreachable("invalid number of components");
364 }
365 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
366 nir_imm_int(&b->nb, NIR_FALSE),
367 NULL, NULL);
368 }
369 break;
370
371 case SpvOpAll:
372 if (src[0]->num_components == 1) {
373 val->ssa->def = nir_imov(&b->nb, src[0]);
374 } else {
375 nir_op op;
376 switch (src[0]->num_components) {
377 case 2: op = nir_op_ball_iequal2; break;
378 case 3: op = nir_op_ball_iequal3; break;
379 case 4: op = nir_op_ball_iequal4; break;
380 default: unreachable("invalid number of components");
381 }
382 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
383 nir_imm_int(&b->nb, NIR_TRUE),
384 NULL, NULL);
385 }
386 break;
387
388 case SpvOpOuterProduct: {
389 for (unsigned i = 0; i < src[1]->num_components; i++) {
390 val->ssa->elems[i]->def =
391 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
392 }
393 break;
394 }
395
396 case SpvOpDot:
397 val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
398 break;
399
400 case SpvOpIAddCarry:
401 assert(glsl_type_is_struct(val->ssa->type));
402 val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
403 val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
404 break;
405
406 case SpvOpISubBorrow:
407 assert(glsl_type_is_struct(val->ssa->type));
408 val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
409 val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
410 break;
411
412 case SpvOpUMulExtended:
413 assert(glsl_type_is_struct(val->ssa->type));
414 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
415 val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
416 break;
417
418 case SpvOpSMulExtended:
419 assert(glsl_type_is_struct(val->ssa->type));
420 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
421 val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
422 break;
423
424 case SpvOpFwidth:
425 val->ssa->def = nir_fadd(&b->nb,
426 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
427 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
428 break;
429 case SpvOpFwidthFine:
430 val->ssa->def = nir_fadd(&b->nb,
431 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
432 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
433 break;
434 case SpvOpFwidthCoarse:
435 val->ssa->def = nir_fadd(&b->nb,
436 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
437 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
438 break;
439
440 case SpvOpVectorTimesScalar:
441 /* The builder will take care of splatting for us. */
442 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
443 break;
444
445 case SpvOpIsNan:
446 val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
447 break;
448
449 case SpvOpIsInf:
450 val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
451 nir_imm_float(&b->nb, INFINITY));
452 break;
453
454 case SpvOpFUnordEqual:
455 case SpvOpFUnordNotEqual:
456 case SpvOpFUnordLessThan:
457 case SpvOpFUnordGreaterThan:
458 case SpvOpFUnordLessThanEqual:
459 case SpvOpFUnordGreaterThanEqual: {
460 bool swap;
461 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
462 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
463 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
464
465 if (swap) {
466 nir_ssa_def *tmp = src[0];
467 src[0] = src[1];
468 src[1] = tmp;
469 }
470
471 val->ssa->def =
472 nir_ior(&b->nb,
473 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
474 nir_ior(&b->nb,
475 nir_fne(&b->nb, src[0], src[0]),
476 nir_fne(&b->nb, src[1], src[1])));
477 break;
478 }
479
480 case SpvOpFOrdEqual:
481 case SpvOpFOrdNotEqual:
482 case SpvOpFOrdLessThan:
483 case SpvOpFOrdGreaterThan:
484 case SpvOpFOrdLessThanEqual:
485 case SpvOpFOrdGreaterThanEqual: {
486 bool swap;
487 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
488 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
489 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
490
491 if (swap) {
492 nir_ssa_def *tmp = src[0];
493 src[0] = src[1];
494 src[1] = tmp;
495 }
496
497 val->ssa->def =
498 nir_iand(&b->nb,
499 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
500 nir_iand(&b->nb,
501 nir_feq(&b->nb, src[0], src[0]),
502 nir_feq(&b->nb, src[1], src[1])));
503 break;
504 }
505
506 default: {
507 bool swap;
508 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
509 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
510 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
511
512 if (swap) {
513 nir_ssa_def *tmp = src[0];
514 src[0] = src[1];
515 src[1] = tmp;
516 }
517
518 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
519 break;
520 } /* default */
521 }
522
523 b->nb.exact = false;
524 }