nir/spirv: cast shift operand to u32
[mesa.git] / src / compiler / spirv / vtn_alu.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26
27 /*
28 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
29 * definition. But for matrix multiplies, we want to do one routine for
30 * multiplying a matrix by a matrix and then pretend that vectors are matrices
31 * with one column. So we "wrap" these things, and unwrap the result before we
32 * send it off.
33 */
34
35 static struct vtn_ssa_value *
36 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
37 {
38 if (val == NULL)
39 return NULL;
40
41 if (glsl_type_is_matrix(val->type))
42 return val;
43
44 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
45 dest->type = val->type;
46 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
47 dest->elems[0] = val;
48
49 return dest;
50 }
51
52 static struct vtn_ssa_value *
53 unwrap_matrix(struct vtn_ssa_value *val)
54 {
55 if (glsl_type_is_matrix(val->type))
56 return val;
57
58 return val->elems[0];
59 }
60
61 static struct vtn_ssa_value *
62 matrix_multiply(struct vtn_builder *b,
63 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
64 {
65
66 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
67 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
68 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
69 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
70
71 unsigned src0_rows = glsl_get_vector_elements(src0->type);
72 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
73 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
74
75 const struct glsl_type *dest_type;
76 if (src1_columns > 1) {
77 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
78 src0_rows, src1_columns);
79 } else {
80 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
81 }
82 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
83
84 dest = wrap_matrix(b, dest);
85
86 bool transpose_result = false;
87 if (src0_transpose && src1_transpose) {
88 /* transpose(A) * transpose(B) = transpose(B * A) */
89 src1 = src0_transpose;
90 src0 = src1_transpose;
91 src0_transpose = NULL;
92 src1_transpose = NULL;
93 transpose_result = true;
94 }
95
96 if (src0_transpose && !src1_transpose &&
97 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
98 /* We already have the rows of src0 and the columns of src1 available,
99 * so we can just take the dot product of each row with each column to
100 * get the result.
101 */
102
103 for (unsigned i = 0; i < src1_columns; i++) {
104 nir_ssa_def *vec_src[4];
105 for (unsigned j = 0; j < src0_rows; j++) {
106 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
107 src1->elems[i]->def);
108 }
109 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
110 }
111 } else {
112 /* We don't handle the case where src1 is transposed but not src0, since
113 * the general case only uses individual components of src1 so the
114 * optimizer should chew through the transpose we emitted for src1.
115 */
116
117 for (unsigned i = 0; i < src1_columns; i++) {
118 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
119 dest->elems[i]->def =
120 nir_fmul(&b->nb, src0->elems[0]->def,
121 nir_channel(&b->nb, src1->elems[i]->def, 0));
122 for (unsigned j = 1; j < src0_columns; j++) {
123 dest->elems[i]->def =
124 nir_fadd(&b->nb, dest->elems[i]->def,
125 nir_fmul(&b->nb, src0->elems[j]->def,
126 nir_channel(&b->nb, src1->elems[i]->def, j)));
127 }
128 }
129 }
130
131 dest = unwrap_matrix(dest);
132
133 if (transpose_result)
134 dest = vtn_ssa_transpose(b, dest);
135
136 return dest;
137 }
138
139 static struct vtn_ssa_value *
140 mat_times_scalar(struct vtn_builder *b,
141 struct vtn_ssa_value *mat,
142 nir_ssa_def *scalar)
143 {
144 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
145 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
146 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
147 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
148 else
149 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
150 }
151
152 return dest;
153 }
154
155 static void
156 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
157 struct vtn_value *dest,
158 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159 {
160 switch (opcode) {
161 case SpvOpFNegate: {
162 dest->ssa = vtn_create_ssa_value(b, src0->type);
163 unsigned cols = glsl_get_matrix_columns(src0->type);
164 for (unsigned i = 0; i < cols; i++)
165 dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166 break;
167 }
168
169 case SpvOpFAdd: {
170 dest->ssa = vtn_create_ssa_value(b, src0->type);
171 unsigned cols = glsl_get_matrix_columns(src0->type);
172 for (unsigned i = 0; i < cols; i++)
173 dest->ssa->elems[i]->def =
174 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175 break;
176 }
177
178 case SpvOpFSub: {
179 dest->ssa = vtn_create_ssa_value(b, src0->type);
180 unsigned cols = glsl_get_matrix_columns(src0->type);
181 for (unsigned i = 0; i < cols; i++)
182 dest->ssa->elems[i]->def =
183 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184 break;
185 }
186
187 case SpvOpTranspose:
188 dest->ssa = vtn_ssa_transpose(b, src0);
189 break;
190
191 case SpvOpMatrixTimesScalar:
192 if (src0->transposed) {
193 dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
194 src1->def));
195 } else {
196 dest->ssa = mat_times_scalar(b, src0, src1->def);
197 }
198 break;
199
200 case SpvOpVectorTimesMatrix:
201 case SpvOpMatrixTimesVector:
202 case SpvOpMatrixTimesMatrix:
203 if (opcode == SpvOpVectorTimesMatrix) {
204 dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
205 } else {
206 dest->ssa = matrix_multiply(b, src0, src1);
207 }
208 break;
209
210 default: vtn_fail("unknown matrix opcode");
211 }
212 }
213
214 static void
215 vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest,
216 struct nir_ssa_def *src)
217 {
218 if (glsl_get_vector_elements(dest->type) == src->num_components) {
219 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
220 *
221 * "If Result Type has the same number of components as Operand, they
222 * must also have the same component width, and results are computed per
223 * component."
224 */
225 dest->def = nir_imov(&b->nb, src);
226 return;
227 }
228
229 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
230 *
231 * "If Result Type has a different number of components than Operand, the
232 * total number of bits in Result Type must equal the total number of bits
233 * in Operand. Let L be the type, either Result Type or Operand’s type, that
234 * has the larger number of components. Let S be the other type, with the
235 * smaller number of components. The number of components in L must be an
236 * integer multiple of the number of components in S. The first component
237 * (that is, the only or lowest-numbered component) of S maps to the first
238 * components of L, and so on, up to the last component of S mapping to the
239 * last components of L. Within this mapping, any single component of S
240 * (mapping to multiple components of L) maps its lower-ordered bits to the
241 * lower-numbered components of L."
242 */
243 unsigned src_bit_size = src->bit_size;
244 unsigned dest_bit_size = glsl_get_bit_size(dest->type);
245 unsigned src_components = src->num_components;
246 unsigned dest_components = glsl_get_vector_elements(dest->type);
247 vtn_assert(src_bit_size * src_components == dest_bit_size * dest_components);
248
249 nir_ssa_def *dest_chan[NIR_MAX_VEC_COMPONENTS];
250 if (src_bit_size > dest_bit_size) {
251 vtn_assert(src_bit_size % dest_bit_size == 0);
252 unsigned divisor = src_bit_size / dest_bit_size;
253 for (unsigned comp = 0; comp < src_components; comp++) {
254 nir_ssa_def *split;
255 if (src_bit_size == 64) {
256 assert(dest_bit_size == 32 || dest_bit_size == 16);
257 split = dest_bit_size == 32 ?
258 nir_unpack_64_2x32(&b->nb, nir_channel(&b->nb, src, comp)) :
259 nir_unpack_64_4x16(&b->nb, nir_channel(&b->nb, src, comp));
260 } else {
261 vtn_assert(src_bit_size == 32);
262 vtn_assert(dest_bit_size == 16);
263 split = nir_unpack_32_2x16(&b->nb, nir_channel(&b->nb, src, comp));
264 }
265 for (unsigned i = 0; i < divisor; i++)
266 dest_chan[divisor * comp + i] = nir_channel(&b->nb, split, i);
267 }
268 } else {
269 vtn_assert(dest_bit_size % src_bit_size == 0);
270 unsigned divisor = dest_bit_size / src_bit_size;
271 for (unsigned comp = 0; comp < dest_components; comp++) {
272 unsigned channels = ((1 << divisor) - 1) << (comp * divisor);
273 nir_ssa_def *src_chan = nir_channels(&b->nb, src, channels);
274 if (dest_bit_size == 64) {
275 assert(src_bit_size == 32 || src_bit_size == 16);
276 dest_chan[comp] = src_bit_size == 32 ?
277 nir_pack_64_2x32(&b->nb, src_chan) :
278 nir_pack_64_4x16(&b->nb, src_chan);
279 } else {
280 vtn_assert(dest_bit_size == 32);
281 vtn_assert(src_bit_size == 16);
282 dest_chan[comp] = nir_pack_32_2x16(&b->nb, src_chan);
283 }
284 }
285 }
286 dest->def = nir_vec(&b->nb, dest_chan, dest_components);
287 }
288
289 nir_op
290 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
291 SpvOp opcode, bool *swap,
292 unsigned src_bit_size, unsigned dst_bit_size)
293 {
294 /* Indicates that the first two arguments should be swapped. This is
295 * used for implementing greater-than and less-than-or-equal.
296 */
297 *swap = false;
298
299 switch (opcode) {
300 case SpvOpSNegate: return nir_op_ineg;
301 case SpvOpFNegate: return nir_op_fneg;
302 case SpvOpNot: return nir_op_inot;
303 case SpvOpIAdd: return nir_op_iadd;
304 case SpvOpFAdd: return nir_op_fadd;
305 case SpvOpISub: return nir_op_isub;
306 case SpvOpFSub: return nir_op_fsub;
307 case SpvOpIMul: return nir_op_imul;
308 case SpvOpFMul: return nir_op_fmul;
309 case SpvOpUDiv: return nir_op_udiv;
310 case SpvOpSDiv: return nir_op_idiv;
311 case SpvOpFDiv: return nir_op_fdiv;
312 case SpvOpUMod: return nir_op_umod;
313 case SpvOpSMod: return nir_op_imod;
314 case SpvOpFMod: return nir_op_fmod;
315 case SpvOpSRem: return nir_op_irem;
316 case SpvOpFRem: return nir_op_frem;
317
318 case SpvOpShiftRightLogical: return nir_op_ushr;
319 case SpvOpShiftRightArithmetic: return nir_op_ishr;
320 case SpvOpShiftLeftLogical: return nir_op_ishl;
321 case SpvOpLogicalOr: return nir_op_ior;
322 case SpvOpLogicalEqual: return nir_op_ieq;
323 case SpvOpLogicalNotEqual: return nir_op_ine;
324 case SpvOpLogicalAnd: return nir_op_iand;
325 case SpvOpLogicalNot: return nir_op_inot;
326 case SpvOpBitwiseOr: return nir_op_ior;
327 case SpvOpBitwiseXor: return nir_op_ixor;
328 case SpvOpBitwiseAnd: return nir_op_iand;
329 case SpvOpSelect: return nir_op_bcsel;
330 case SpvOpIEqual: return nir_op_ieq;
331
332 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
333 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
334 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
335 case SpvOpBitReverse: return nir_op_bitfield_reverse;
336 case SpvOpBitCount: return nir_op_bit_count;
337
338 /* The ordered / unordered operators need special implementation besides
339 * the logical operator to use since they also need to check if operands are
340 * ordered.
341 */
342 case SpvOpFOrdEqual: return nir_op_feq;
343 case SpvOpFUnordEqual: return nir_op_feq;
344 case SpvOpINotEqual: return nir_op_ine;
345 case SpvOpFOrdNotEqual: return nir_op_fne;
346 case SpvOpFUnordNotEqual: return nir_op_fne;
347 case SpvOpULessThan: return nir_op_ult;
348 case SpvOpSLessThan: return nir_op_ilt;
349 case SpvOpFOrdLessThan: return nir_op_flt;
350 case SpvOpFUnordLessThan: return nir_op_flt;
351 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
352 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
353 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt;
354 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt;
355 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
356 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
357 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge;
358 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge;
359 case SpvOpUGreaterThanEqual: return nir_op_uge;
360 case SpvOpSGreaterThanEqual: return nir_op_ige;
361 case SpvOpFOrdGreaterThanEqual: return nir_op_fge;
362 case SpvOpFUnordGreaterThanEqual: return nir_op_fge;
363
364 /* Conversions: */
365 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
366 case SpvOpUConvert:
367 case SpvOpConvertFToU:
368 case SpvOpConvertFToS:
369 case SpvOpConvertSToF:
370 case SpvOpConvertUToF:
371 case SpvOpSConvert:
372 case SpvOpFConvert: {
373 nir_alu_type src_type;
374 nir_alu_type dst_type;
375
376 switch (opcode) {
377 case SpvOpConvertFToS:
378 src_type = nir_type_float;
379 dst_type = nir_type_int;
380 break;
381 case SpvOpConvertFToU:
382 src_type = nir_type_float;
383 dst_type = nir_type_uint;
384 break;
385 case SpvOpFConvert:
386 src_type = dst_type = nir_type_float;
387 break;
388 case SpvOpConvertSToF:
389 src_type = nir_type_int;
390 dst_type = nir_type_float;
391 break;
392 case SpvOpSConvert:
393 src_type = dst_type = nir_type_int;
394 break;
395 case SpvOpConvertUToF:
396 src_type = nir_type_uint;
397 dst_type = nir_type_float;
398 break;
399 case SpvOpUConvert:
400 src_type = dst_type = nir_type_uint;
401 break;
402 default:
403 unreachable("Invalid opcode");
404 }
405 src_type |= src_bit_size;
406 dst_type |= dst_bit_size;
407 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
408 }
409 /* Derivatives: */
410 case SpvOpDPdx: return nir_op_fddx;
411 case SpvOpDPdy: return nir_op_fddy;
412 case SpvOpDPdxFine: return nir_op_fddx_fine;
413 case SpvOpDPdyFine: return nir_op_fddy_fine;
414 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
415 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
416
417 default:
418 vtn_fail("No NIR equivalent: %u", opcode);
419 }
420 }
421
422 static void
423 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
424 const struct vtn_decoration *dec, void *_void)
425 {
426 vtn_assert(dec->scope == VTN_DEC_DECORATION);
427 if (dec->decoration != SpvDecorationNoContraction)
428 return;
429
430 b->nb.exact = true;
431 }
432
433 static void
434 handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
435 const struct vtn_decoration *dec, void *_out_rounding_mode)
436 {
437 nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
438 assert(dec->scope == VTN_DEC_DECORATION);
439 if (dec->decoration != SpvDecorationFPRoundingMode)
440 return;
441 switch (dec->literals[0]) {
442 case SpvFPRoundingModeRTE:
443 *out_rounding_mode = nir_rounding_mode_rtne;
444 break;
445 case SpvFPRoundingModeRTZ:
446 *out_rounding_mode = nir_rounding_mode_rtz;
447 break;
448 default:
449 unreachable("Not supported rounding mode");
450 break;
451 }
452 }
453
454 void
455 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
456 const uint32_t *w, unsigned count)
457 {
458 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
459 const struct glsl_type *type =
460 vtn_value(b, w[1], vtn_value_type_type)->type->type;
461
462 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
463
464 /* Collect the various SSA sources */
465 const unsigned num_inputs = count - 3;
466 struct vtn_ssa_value *vtn_src[4] = { NULL, };
467 for (unsigned i = 0; i < num_inputs; i++)
468 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
469
470 if (glsl_type_is_matrix(vtn_src[0]->type) ||
471 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
472 vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
473 b->nb.exact = false;
474 return;
475 }
476
477 val->ssa = vtn_create_ssa_value(b, type);
478 nir_ssa_def *src[4] = { NULL, };
479 for (unsigned i = 0; i < num_inputs; i++) {
480 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
481 src[i] = vtn_src[i]->def;
482 }
483
484 switch (opcode) {
485 case SpvOpAny:
486 if (src[0]->num_components == 1) {
487 val->ssa->def = nir_imov(&b->nb, src[0]);
488 } else {
489 nir_op op;
490 switch (src[0]->num_components) {
491 case 2: op = nir_op_bany_inequal2; break;
492 case 3: op = nir_op_bany_inequal3; break;
493 case 4: op = nir_op_bany_inequal4; break;
494 default: vtn_fail("invalid number of components");
495 }
496 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
497 nir_imm_false(&b->nb),
498 NULL, NULL);
499 }
500 break;
501
502 case SpvOpAll:
503 if (src[0]->num_components == 1) {
504 val->ssa->def = nir_imov(&b->nb, src[0]);
505 } else {
506 nir_op op;
507 switch (src[0]->num_components) {
508 case 2: op = nir_op_ball_iequal2; break;
509 case 3: op = nir_op_ball_iequal3; break;
510 case 4: op = nir_op_ball_iequal4; break;
511 default: vtn_fail("invalid number of components");
512 }
513 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
514 nir_imm_true(&b->nb),
515 NULL, NULL);
516 }
517 break;
518
519 case SpvOpOuterProduct: {
520 for (unsigned i = 0; i < src[1]->num_components; i++) {
521 val->ssa->elems[i]->def =
522 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
523 }
524 break;
525 }
526
527 case SpvOpDot:
528 val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
529 break;
530
531 case SpvOpIAddCarry:
532 vtn_assert(glsl_type_is_struct(val->ssa->type));
533 val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
534 val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
535 break;
536
537 case SpvOpISubBorrow:
538 vtn_assert(glsl_type_is_struct(val->ssa->type));
539 val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
540 val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
541 break;
542
543 case SpvOpUMulExtended:
544 vtn_assert(glsl_type_is_struct(val->ssa->type));
545 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
546 val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
547 break;
548
549 case SpvOpSMulExtended:
550 vtn_assert(glsl_type_is_struct(val->ssa->type));
551 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
552 val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
553 break;
554
555 case SpvOpFwidth:
556 val->ssa->def = nir_fadd(&b->nb,
557 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
558 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
559 break;
560 case SpvOpFwidthFine:
561 val->ssa->def = nir_fadd(&b->nb,
562 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
563 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
564 break;
565 case SpvOpFwidthCoarse:
566 val->ssa->def = nir_fadd(&b->nb,
567 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
568 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
569 break;
570
571 case SpvOpVectorTimesScalar:
572 /* The builder will take care of splatting for us. */
573 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
574 break;
575
576 case SpvOpIsNan:
577 val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
578 break;
579
580 case SpvOpIsInf: {
581 nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
582 val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
583 break;
584 }
585
586 case SpvOpFUnordEqual:
587 case SpvOpFUnordNotEqual:
588 case SpvOpFUnordLessThan:
589 case SpvOpFUnordGreaterThan:
590 case SpvOpFUnordLessThanEqual:
591 case SpvOpFUnordGreaterThanEqual: {
592 bool swap;
593 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
594 unsigned dst_bit_size = glsl_get_bit_size(type);
595 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
596 src_bit_size, dst_bit_size);
597
598 if (swap) {
599 nir_ssa_def *tmp = src[0];
600 src[0] = src[1];
601 src[1] = tmp;
602 }
603
604 val->ssa->def =
605 nir_ior(&b->nb,
606 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
607 nir_ior(&b->nb,
608 nir_fne(&b->nb, src[0], src[0]),
609 nir_fne(&b->nb, src[1], src[1])));
610 break;
611 }
612
613 case SpvOpFOrdNotEqual: {
614 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
615 * from the ALU will probably already be false if the operands are not
616 * ordered so we don’t need to handle it specially.
617 */
618 bool swap;
619 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
620 unsigned dst_bit_size = glsl_get_bit_size(type);
621 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
622 src_bit_size, dst_bit_size);
623
624 assert(!swap);
625
626 val->ssa->def =
627 nir_iand(&b->nb,
628 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
629 nir_iand(&b->nb,
630 nir_feq(&b->nb, src[0], src[0]),
631 nir_feq(&b->nb, src[1], src[1])));
632 break;
633 }
634
635 case SpvOpBitcast:
636 vtn_handle_bitcast(b, val->ssa, src[0]);
637 break;
638
639 case SpvOpFConvert: {
640 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
641 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
642 nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
643
644 vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
645 nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
646
647 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
648 break;
649 }
650
651 case SpvOpBitFieldInsert:
652 case SpvOpBitFieldSExtract:
653 case SpvOpBitFieldUExtract:
654 case SpvOpShiftLeftLogical:
655 case SpvOpShiftRightArithmetic:
656 case SpvOpShiftRightLogical: {
657 bool swap;
658 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
659 unsigned dst_bit_size = glsl_get_bit_size(type);
660 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
661 src0_bit_size, dst_bit_size);
662
663 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
664 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
665 op == nir_op_ibitfield_extract);
666
667 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
668 unsigned src_bit_size =
669 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
670 if (src_bit_size == 0)
671 continue;
672 if (src_bit_size != src[i]->bit_size) {
673 assert(src_bit_size == 32);
674 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
675 * supported by the NIR instructions. See discussion here:
676 *
677 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
678 */
679 src[i] = nir_u2u32(&b->nb, src[i]);
680 }
681 }
682 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
683 break;
684 }
685
686 default: {
687 bool swap;
688 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
689 unsigned dst_bit_size = glsl_get_bit_size(type);
690 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
691 src_bit_size, dst_bit_size);
692
693 if (swap) {
694 nir_ssa_def *tmp = src[0];
695 src[0] = src[1];
696 src[1] = tmp;
697 }
698
699 switch (op) {
700 case nir_op_ishl:
701 case nir_op_ishr:
702 case nir_op_ushr:
703 if (src[1]->bit_size != 32)
704 src[1] = nir_u2u32(&b->nb, src[1]);
705 break;
706 default:
707 break;
708 }
709
710 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
711 break;
712 } /* default */
713 }
714
715 b->nb.exact = false;
716 }