nir/spirv: implement ordered / unordered floating point comparisons properly
[mesa.git] / src / compiler / spirv / vtn_alu.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "vtn_private.h"
25
26 /*
27 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
28 * definition. But for matrix multiplies, we want to do one routine for
29 * multiplying a matrix by a matrix and then pretend that vectors are matrices
30 * with one column. So we "wrap" these things, and unwrap the result before we
31 * send it off.
32 */
33
34 static struct vtn_ssa_value *
35 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
36 {
37 if (val == NULL)
38 return NULL;
39
40 if (glsl_type_is_matrix(val->type))
41 return val;
42
43 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
44 dest->type = val->type;
45 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
46 dest->elems[0] = val;
47
48 return dest;
49 }
50
51 static struct vtn_ssa_value *
52 unwrap_matrix(struct vtn_ssa_value *val)
53 {
54 if (glsl_type_is_matrix(val->type))
55 return val;
56
57 return val->elems[0];
58 }
59
60 static struct vtn_ssa_value *
61 matrix_multiply(struct vtn_builder *b,
62 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
63 {
64
65 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
66 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
67 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
68 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
69
70 unsigned src0_rows = glsl_get_vector_elements(src0->type);
71 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
72 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
73
74 const struct glsl_type *dest_type;
75 if (src1_columns > 1) {
76 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
77 src0_rows, src1_columns);
78 } else {
79 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
80 }
81 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
82
83 dest = wrap_matrix(b, dest);
84
85 bool transpose_result = false;
86 if (src0_transpose && src1_transpose) {
87 /* transpose(A) * transpose(B) = transpose(B * A) */
88 src1 = src0_transpose;
89 src0 = src1_transpose;
90 src0_transpose = NULL;
91 src1_transpose = NULL;
92 transpose_result = true;
93 }
94
95 if (src0_transpose && !src1_transpose &&
96 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
97 /* We already have the rows of src0 and the columns of src1 available,
98 * so we can just take the dot product of each row with each column to
99 * get the result.
100 */
101
102 for (unsigned i = 0; i < src1_columns; i++) {
103 nir_ssa_def *vec_src[4];
104 for (unsigned j = 0; j < src0_rows; j++) {
105 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
106 src1->elems[i]->def);
107 }
108 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
109 }
110 } else {
111 /* We don't handle the case where src1 is transposed but not src0, since
112 * the general case only uses individual components of src1 so the
113 * optimizer should chew through the transpose we emitted for src1.
114 */
115
116 for (unsigned i = 0; i < src1_columns; i++) {
117 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
118 dest->elems[i]->def =
119 nir_fmul(&b->nb, src0->elems[0]->def,
120 nir_channel(&b->nb, src1->elems[i]->def, 0));
121 for (unsigned j = 1; j < src0_columns; j++) {
122 dest->elems[i]->def =
123 nir_fadd(&b->nb, dest->elems[i]->def,
124 nir_fmul(&b->nb, src0->elems[j]->def,
125 nir_channel(&b->nb, src1->elems[i]->def, j)));
126 }
127 }
128 }
129
130 dest = unwrap_matrix(dest);
131
132 if (transpose_result)
133 dest = vtn_ssa_transpose(b, dest);
134
135 return dest;
136 }
137
138 static struct vtn_ssa_value *
139 mat_times_scalar(struct vtn_builder *b,
140 struct vtn_ssa_value *mat,
141 nir_ssa_def *scalar)
142 {
143 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
144 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
145 if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT)
146 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
147 else
148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 }
150
151 return dest;
152 }
153
154 static void
155 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
156 struct vtn_value *dest,
157 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
158 {
159 switch (opcode) {
160 case SpvOpFNegate: {
161 dest->ssa = vtn_create_ssa_value(b, src0->type);
162 unsigned cols = glsl_get_matrix_columns(src0->type);
163 for (unsigned i = 0; i < cols; i++)
164 dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
165 break;
166 }
167
168 case SpvOpFAdd: {
169 dest->ssa = vtn_create_ssa_value(b, src0->type);
170 unsigned cols = glsl_get_matrix_columns(src0->type);
171 for (unsigned i = 0; i < cols; i++)
172 dest->ssa->elems[i]->def =
173 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
174 break;
175 }
176
177 case SpvOpFSub: {
178 dest->ssa = vtn_create_ssa_value(b, src0->type);
179 unsigned cols = glsl_get_matrix_columns(src0->type);
180 for (unsigned i = 0; i < cols; i++)
181 dest->ssa->elems[i]->def =
182 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
183 break;
184 }
185
186 case SpvOpTranspose:
187 dest->ssa = vtn_ssa_transpose(b, src0);
188 break;
189
190 case SpvOpMatrixTimesScalar:
191 if (src0->transposed) {
192 dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193 src1->def));
194 } else {
195 dest->ssa = mat_times_scalar(b, src0, src1->def);
196 }
197 break;
198
199 case SpvOpVectorTimesMatrix:
200 case SpvOpMatrixTimesVector:
201 case SpvOpMatrixTimesMatrix:
202 if (opcode == SpvOpVectorTimesMatrix) {
203 dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204 } else {
205 dest->ssa = matrix_multiply(b, src0, src1);
206 }
207 break;
208
209 default: unreachable("unknown matrix opcode");
210 }
211 }
212
213 nir_op
214 vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
215 {
216 /* Indicates that the first two arguments should be swapped. This is
217 * used for implementing greater-than and less-than-or-equal.
218 */
219 *swap = false;
220
221 switch (opcode) {
222 case SpvOpSNegate: return nir_op_ineg;
223 case SpvOpFNegate: return nir_op_fneg;
224 case SpvOpNot: return nir_op_inot;
225 case SpvOpIAdd: return nir_op_iadd;
226 case SpvOpFAdd: return nir_op_fadd;
227 case SpvOpISub: return nir_op_isub;
228 case SpvOpFSub: return nir_op_fsub;
229 case SpvOpIMul: return nir_op_imul;
230 case SpvOpFMul: return nir_op_fmul;
231 case SpvOpUDiv: return nir_op_udiv;
232 case SpvOpSDiv: return nir_op_idiv;
233 case SpvOpFDiv: return nir_op_fdiv;
234 case SpvOpUMod: return nir_op_umod;
235 case SpvOpSMod: return nir_op_imod;
236 case SpvOpFMod: return nir_op_fmod;
237 case SpvOpSRem: return nir_op_irem;
238 case SpvOpFRem: return nir_op_frem;
239
240 case SpvOpShiftRightLogical: return nir_op_ushr;
241 case SpvOpShiftRightArithmetic: return nir_op_ishr;
242 case SpvOpShiftLeftLogical: return nir_op_ishl;
243 case SpvOpLogicalOr: return nir_op_ior;
244 case SpvOpLogicalEqual: return nir_op_ieq;
245 case SpvOpLogicalNotEqual: return nir_op_ine;
246 case SpvOpLogicalAnd: return nir_op_iand;
247 case SpvOpLogicalNot: return nir_op_inot;
248 case SpvOpBitwiseOr: return nir_op_ior;
249 case SpvOpBitwiseXor: return nir_op_ixor;
250 case SpvOpBitwiseAnd: return nir_op_iand;
251 case SpvOpSelect: return nir_op_bcsel;
252 case SpvOpIEqual: return nir_op_ieq;
253
254 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
255 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
256 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
257 case SpvOpBitReverse: return nir_op_bitfield_reverse;
258 case SpvOpBitCount: return nir_op_bit_count;
259
260 /* The ordered / unordered operators need special implementation besides
261 * the logical operator to use since they also need to check if operands are
262 * ordered.
263 */
264 case SpvOpFOrdEqual: return nir_op_feq;
265 case SpvOpFUnordEqual: return nir_op_feq;
266 case SpvOpINotEqual: return nir_op_ine;
267 case SpvOpFOrdNotEqual: return nir_op_fne;
268 case SpvOpFUnordNotEqual: return nir_op_fne;
269 case SpvOpULessThan: return nir_op_ult;
270 case SpvOpSLessThan: return nir_op_ilt;
271 case SpvOpFOrdLessThan: return nir_op_flt;
272 case SpvOpFUnordLessThan: return nir_op_flt;
273 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
274 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
275 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt;
276 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt;
277 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
278 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
279 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge;
280 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge;
281 case SpvOpUGreaterThanEqual: return nir_op_uge;
282 case SpvOpSGreaterThanEqual: return nir_op_ige;
283 case SpvOpFOrdGreaterThanEqual: return nir_op_fge;
284 case SpvOpFUnordGreaterThanEqual: return nir_op_fge;
285
286 /* Conversions: */
287 case SpvOpConvertFToU: return nir_op_f2u;
288 case SpvOpConvertFToS: return nir_op_f2i;
289 case SpvOpConvertSToF: return nir_op_i2f;
290 case SpvOpConvertUToF: return nir_op_u2f;
291 case SpvOpBitcast: return nir_op_imov;
292 case SpvOpUConvert:
293 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
294 /* TODO: NIR is 32-bit only; these are no-ops. */
295 case SpvOpSConvert: return nir_op_imov;
296 case SpvOpFConvert: return nir_op_fmov;
297
298 /* Derivatives: */
299 case SpvOpDPdx: return nir_op_fddx;
300 case SpvOpDPdy: return nir_op_fddy;
301 case SpvOpDPdxFine: return nir_op_fddx_fine;
302 case SpvOpDPdyFine: return nir_op_fddy_fine;
303 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
304 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
305
306 default:
307 unreachable("No NIR equivalent");
308 }
309 }
310
311 static void
312 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
313 const struct vtn_decoration *dec, void *_void)
314 {
315 assert(dec->scope == VTN_DEC_DECORATION);
316 if (dec->decoration != SpvDecorationNoContraction)
317 return;
318
319 b->nb.exact = true;
320 }
321
322 void
323 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
324 const uint32_t *w, unsigned count)
325 {
326 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
327 const struct glsl_type *type =
328 vtn_value(b, w[1], vtn_value_type_type)->type->type;
329
330 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
331
332 /* Collect the various SSA sources */
333 const unsigned num_inputs = count - 3;
334 struct vtn_ssa_value *vtn_src[4] = { NULL, };
335 for (unsigned i = 0; i < num_inputs; i++)
336 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
337
338 if (glsl_type_is_matrix(vtn_src[0]->type) ||
339 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
340 vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
341 b->nb.exact = false;
342 return;
343 }
344
345 val->ssa = vtn_create_ssa_value(b, type);
346 nir_ssa_def *src[4] = { NULL, };
347 for (unsigned i = 0; i < num_inputs; i++) {
348 assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
349 src[i] = vtn_src[i]->def;
350 }
351
352 switch (opcode) {
353 case SpvOpAny:
354 if (src[0]->num_components == 1) {
355 val->ssa->def = nir_imov(&b->nb, src[0]);
356 } else {
357 nir_op op;
358 switch (src[0]->num_components) {
359 case 2: op = nir_op_bany_inequal2; break;
360 case 3: op = nir_op_bany_inequal3; break;
361 case 4: op = nir_op_bany_inequal4; break;
362 default: unreachable("invalid number of components");
363 }
364 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
365 nir_imm_int(&b->nb, NIR_FALSE),
366 NULL, NULL);
367 }
368 break;
369
370 case SpvOpAll:
371 if (src[0]->num_components == 1) {
372 val->ssa->def = nir_imov(&b->nb, src[0]);
373 } else {
374 nir_op op;
375 switch (src[0]->num_components) {
376 case 2: op = nir_op_ball_iequal2; break;
377 case 3: op = nir_op_ball_iequal3; break;
378 case 4: op = nir_op_ball_iequal4; break;
379 default: unreachable("invalid number of components");
380 }
381 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
382 nir_imm_int(&b->nb, NIR_TRUE),
383 NULL, NULL);
384 }
385 break;
386
387 case SpvOpOuterProduct: {
388 for (unsigned i = 0; i < src[1]->num_components; i++) {
389 val->ssa->elems[i]->def =
390 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
391 }
392 break;
393 }
394
395 case SpvOpDot:
396 val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
397 break;
398
399 case SpvOpIAddCarry:
400 assert(glsl_type_is_struct(val->ssa->type));
401 val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
402 val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
403 break;
404
405 case SpvOpISubBorrow:
406 assert(glsl_type_is_struct(val->ssa->type));
407 val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
408 val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
409 break;
410
411 case SpvOpUMulExtended:
412 assert(glsl_type_is_struct(val->ssa->type));
413 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
414 val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
415 break;
416
417 case SpvOpSMulExtended:
418 assert(glsl_type_is_struct(val->ssa->type));
419 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
420 val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
421 break;
422
423 case SpvOpFwidth:
424 val->ssa->def = nir_fadd(&b->nb,
425 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
426 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
427 break;
428 case SpvOpFwidthFine:
429 val->ssa->def = nir_fadd(&b->nb,
430 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
431 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
432 break;
433 case SpvOpFwidthCoarse:
434 val->ssa->def = nir_fadd(&b->nb,
435 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
436 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
437 break;
438
439 case SpvOpVectorTimesScalar:
440 /* The builder will take care of splatting for us. */
441 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
442 break;
443
444 case SpvOpIsNan:
445 val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
446 break;
447
448 case SpvOpIsInf:
449 val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
450 nir_imm_float(&b->nb, INFINITY));
451 break;
452
453 case SpvOpFUnordEqual:
454 case SpvOpFUnordNotEqual:
455 case SpvOpFUnordLessThan:
456 case SpvOpFUnordGreaterThan:
457 case SpvOpFUnordLessThanEqual:
458 case SpvOpFUnordGreaterThanEqual: {
459 bool swap;
460 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap);
461
462 if (swap) {
463 nir_ssa_def *tmp = src[0];
464 src[0] = src[1];
465 src[1] = tmp;
466 }
467
468 val->ssa->def =
469 nir_ior(&b->nb,
470 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
471 nir_ior(&b->nb,
472 nir_fne(&b->nb, src[0], src[0]),
473 nir_fne(&b->nb, src[1], src[1])));
474 break;
475 }
476
477 case SpvOpFOrdEqual:
478 case SpvOpFOrdNotEqual:
479 case SpvOpFOrdLessThan:
480 case SpvOpFOrdGreaterThan:
481 case SpvOpFOrdLessThanEqual:
482 case SpvOpFOrdGreaterThanEqual: {
483 bool swap;
484 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap);
485
486 if (swap) {
487 nir_ssa_def *tmp = src[0];
488 src[0] = src[1];
489 src[1] = tmp;
490 }
491
492 val->ssa->def =
493 nir_iand(&b->nb,
494 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
495 nir_iand(&b->nb,
496 nir_feq(&b->nb, src[0], src[0]),
497 nir_feq(&b->nb, src[1], src[1])));
498 break;
499 }
500
501 default: {
502 bool swap;
503 nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap);
504
505 if (swap) {
506 nir_ssa_def *tmp = src[0];
507 src[0] = src[1];
508 src[1] = tmp;
509 }
510
511 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
512 break;
513 } /* default */
514 }
515
516 b->nb.exact = false;
517 }