nir/vtn: Handle LessOrGreater deprecated opcode
[mesa.git] / src / compiler / spirv / vtn_alu.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27
28 /*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36 static struct vtn_ssa_value *
37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39 if (val == NULL)
40 return NULL;
41
42 if (glsl_type_is_matrix(val->type))
43 return val;
44
45 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46 dest->type = glsl_get_bare_type(val->type);
47 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48 dest->elems[0] = val;
49
50 return dest;
51 }
52
53 static struct vtn_ssa_value *
54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56 if (glsl_type_is_matrix(val->type))
57 return val;
58
59 return val->elems[0];
60 }
61
62 static struct vtn_ssa_value *
63 matrix_multiply(struct vtn_builder *b,
64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66
67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72 unsigned src0_rows = glsl_get_vector_elements(src0->type);
73 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76 const struct glsl_type *dest_type;
77 if (src1_columns > 1) {
78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79 src0_rows, src1_columns);
80 } else {
81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 }
83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85 dest = wrap_matrix(b, dest);
86
87 bool transpose_result = false;
88 if (src0_transpose && src1_transpose) {
89 /* transpose(A) * transpose(B) = transpose(B * A) */
90 src1 = src0_transpose;
91 src0 = src1_transpose;
92 src0_transpose = NULL;
93 src1_transpose = NULL;
94 transpose_result = true;
95 }
96
97 if (src0_transpose && !src1_transpose &&
98 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99 /* We already have the rows of src0 and the columns of src1 available,
100 * so we can just take the dot product of each row with each column to
101 * get the result.
102 */
103
104 for (unsigned i = 0; i < src1_columns; i++) {
105 nir_ssa_def *vec_src[4];
106 for (unsigned j = 0; j < src0_rows; j++) {
107 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108 src1->elems[i]->def);
109 }
110 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111 }
112 } else {
113 /* We don't handle the case where src1 is transposed but not src0, since
114 * the general case only uses individual components of src1 so the
115 * optimizer should chew through the transpose we emitted for src1.
116 */
117
118 for (unsigned i = 0; i < src1_columns; i++) {
119 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120 dest->elems[i]->def =
121 nir_fmul(&b->nb, src0->elems[0]->def,
122 nir_channel(&b->nb, src1->elems[i]->def, 0));
123 for (unsigned j = 1; j < src0_columns; j++) {
124 dest->elems[i]->def =
125 nir_fadd(&b->nb, dest->elems[i]->def,
126 nir_fmul(&b->nb, src0->elems[j]->def,
127 nir_channel(&b->nb, src1->elems[i]->def, j)));
128 }
129 }
130 }
131
132 dest = unwrap_matrix(dest);
133
134 if (transpose_result)
135 dest = vtn_ssa_transpose(b, dest);
136
137 return dest;
138 }
139
140 static struct vtn_ssa_value *
141 mat_times_scalar(struct vtn_builder *b,
142 struct vtn_ssa_value *mat,
143 nir_ssa_def *scalar)
144 {
145 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 else
150 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151 }
152
153 return dest;
154 }
155
156 static struct vtn_ssa_value *
157 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159 {
160 switch (opcode) {
161 case SpvOpFNegate: {
162 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
163 unsigned cols = glsl_get_matrix_columns(src0->type);
164 for (unsigned i = 0; i < cols; i++)
165 dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166 return dest;
167 }
168
169 case SpvOpFAdd: {
170 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
171 unsigned cols = glsl_get_matrix_columns(src0->type);
172 for (unsigned i = 0; i < cols; i++)
173 dest->elems[i]->def =
174 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175 return dest;
176 }
177
178 case SpvOpFSub: {
179 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
180 unsigned cols = glsl_get_matrix_columns(src0->type);
181 for (unsigned i = 0; i < cols; i++)
182 dest->elems[i]->def =
183 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184 return dest;
185 }
186
187 case SpvOpTranspose:
188 return vtn_ssa_transpose(b, src0);
189
190 case SpvOpMatrixTimesScalar:
191 if (src0->transposed) {
192 return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193 src1->def));
194 } else {
195 return mat_times_scalar(b, src0, src1->def);
196 }
197 break;
198
199 case SpvOpVectorTimesMatrix:
200 case SpvOpMatrixTimesVector:
201 case SpvOpMatrixTimesMatrix:
202 if (opcode == SpvOpVectorTimesMatrix) {
203 return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204 } else {
205 return matrix_multiply(b, src0, src1);
206 }
207 break;
208
209 default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
210 }
211 }
212
213 nir_op
214 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
215 SpvOp opcode, bool *swap,
216 unsigned src_bit_size, unsigned dst_bit_size)
217 {
218 /* Indicates that the first two arguments should be swapped. This is
219 * used for implementing greater-than and less-than-or-equal.
220 */
221 *swap = false;
222
223 switch (opcode) {
224 case SpvOpSNegate: return nir_op_ineg;
225 case SpvOpFNegate: return nir_op_fneg;
226 case SpvOpNot: return nir_op_inot;
227 case SpvOpIAdd: return nir_op_iadd;
228 case SpvOpFAdd: return nir_op_fadd;
229 case SpvOpISub: return nir_op_isub;
230 case SpvOpFSub: return nir_op_fsub;
231 case SpvOpIMul: return nir_op_imul;
232 case SpvOpFMul: return nir_op_fmul;
233 case SpvOpUDiv: return nir_op_udiv;
234 case SpvOpSDiv: return nir_op_idiv;
235 case SpvOpFDiv: return nir_op_fdiv;
236 case SpvOpUMod: return nir_op_umod;
237 case SpvOpSMod: return nir_op_imod;
238 case SpvOpFMod: return nir_op_fmod;
239 case SpvOpSRem: return nir_op_irem;
240 case SpvOpFRem: return nir_op_frem;
241
242 case SpvOpShiftRightLogical: return nir_op_ushr;
243 case SpvOpShiftRightArithmetic: return nir_op_ishr;
244 case SpvOpShiftLeftLogical: return nir_op_ishl;
245 case SpvOpLogicalOr: return nir_op_ior;
246 case SpvOpLogicalEqual: return nir_op_ieq;
247 case SpvOpLogicalNotEqual: return nir_op_ine;
248 case SpvOpLogicalAnd: return nir_op_iand;
249 case SpvOpLogicalNot: return nir_op_inot;
250 case SpvOpBitwiseOr: return nir_op_ior;
251 case SpvOpBitwiseXor: return nir_op_ixor;
252 case SpvOpBitwiseAnd: return nir_op_iand;
253 case SpvOpSelect: return nir_op_bcsel;
254 case SpvOpIEqual: return nir_op_ieq;
255
256 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
257 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
258 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
259 case SpvOpBitReverse: return nir_op_bitfield_reverse;
260 case SpvOpBitCount: return nir_op_bit_count;
261
262 case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
263 /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
264 case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
265 case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
266 case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
267 case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
268 case SpvOpIAverageINTEL: return nir_op_ihadd;
269 case SpvOpUAverageINTEL: return nir_op_uhadd;
270 case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
271 case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
272 case SpvOpISubSatINTEL: return nir_op_isub_sat;
273 case SpvOpUSubSatINTEL: return nir_op_usub_sat;
274 case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
275 case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
276
277 /* The ordered / unordered operators need special implementation besides
278 * the logical operator to use since they also need to check if operands are
279 * ordered.
280 */
281 case SpvOpFOrdEqual: return nir_op_feq;
282 case SpvOpFUnordEqual: return nir_op_feq;
283 case SpvOpINotEqual: return nir_op_ine;
284 case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
285 case SpvOpFOrdNotEqual: return nir_op_fne;
286 case SpvOpFUnordNotEqual: return nir_op_fne;
287 case SpvOpULessThan: return nir_op_ult;
288 case SpvOpSLessThan: return nir_op_ilt;
289 case SpvOpFOrdLessThan: return nir_op_flt;
290 case SpvOpFUnordLessThan: return nir_op_flt;
291 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
292 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
293 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt;
294 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt;
295 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
296 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
297 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge;
298 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge;
299 case SpvOpUGreaterThanEqual: return nir_op_uge;
300 case SpvOpSGreaterThanEqual: return nir_op_ige;
301 case SpvOpFOrdGreaterThanEqual: return nir_op_fge;
302 case SpvOpFUnordGreaterThanEqual: return nir_op_fge;
303
304 /* Conversions: */
305 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
306 case SpvOpUConvert:
307 case SpvOpConvertFToU:
308 case SpvOpConvertFToS:
309 case SpvOpConvertSToF:
310 case SpvOpConvertUToF:
311 case SpvOpSConvert:
312 case SpvOpFConvert: {
313 nir_alu_type src_type;
314 nir_alu_type dst_type;
315
316 switch (opcode) {
317 case SpvOpConvertFToS:
318 src_type = nir_type_float;
319 dst_type = nir_type_int;
320 break;
321 case SpvOpConvertFToU:
322 src_type = nir_type_float;
323 dst_type = nir_type_uint;
324 break;
325 case SpvOpFConvert:
326 src_type = dst_type = nir_type_float;
327 break;
328 case SpvOpConvertSToF:
329 src_type = nir_type_int;
330 dst_type = nir_type_float;
331 break;
332 case SpvOpSConvert:
333 src_type = dst_type = nir_type_int;
334 break;
335 case SpvOpConvertUToF:
336 src_type = nir_type_uint;
337 dst_type = nir_type_float;
338 break;
339 case SpvOpUConvert:
340 src_type = dst_type = nir_type_uint;
341 break;
342 default:
343 unreachable("Invalid opcode");
344 }
345 src_type |= src_bit_size;
346 dst_type |= dst_bit_size;
347 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
348 }
349 /* Derivatives: */
350 case SpvOpDPdx: return nir_op_fddx;
351 case SpvOpDPdy: return nir_op_fddy;
352 case SpvOpDPdxFine: return nir_op_fddx_fine;
353 case SpvOpDPdyFine: return nir_op_fddy_fine;
354 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
355 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
356
357 case SpvOpIsNormal: return nir_op_fisnormal;
358 case SpvOpIsFinite: return nir_op_fisfinite;
359
360 default:
361 vtn_fail("No NIR equivalent: %u", opcode);
362 }
363 }
364
365 static void
366 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
367 const struct vtn_decoration *dec, void *_void)
368 {
369 vtn_assert(dec->scope == VTN_DEC_DECORATION);
370 if (dec->decoration != SpvDecorationNoContraction)
371 return;
372
373 b->nb.exact = true;
374 }
375
376 static void
377 handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
378 const struct vtn_decoration *dec, void *_out_rounding_mode)
379 {
380 nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
381 assert(dec->scope == VTN_DEC_DECORATION);
382 if (dec->decoration != SpvDecorationFPRoundingMode)
383 return;
384 switch (dec->operands[0]) {
385 case SpvFPRoundingModeRTE:
386 *out_rounding_mode = nir_rounding_mode_rtne;
387 break;
388 case SpvFPRoundingModeRTZ:
389 *out_rounding_mode = nir_rounding_mode_rtz;
390 break;
391 default:
392 unreachable("Not supported rounding mode");
393 break;
394 }
395 }
396
397 static void
398 handle_no_wrap(struct vtn_builder *b, struct vtn_value *val, int member,
399 const struct vtn_decoration *dec, void *_alu)
400 {
401 nir_alu_instr *alu = _alu;
402 switch (dec->decoration) {
403 case SpvDecorationNoSignedWrap:
404 alu->no_signed_wrap = true;
405 break;
406 case SpvDecorationNoUnsignedWrap:
407 alu->no_unsigned_wrap = true;
408 break;
409 default:
410 /* Do nothing. */
411 break;
412 }
413 }
414
415 void
416 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
417 const uint32_t *w, unsigned count)
418 {
419 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
420 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
421
422 vtn_foreach_decoration(b, dest_val, handle_no_contraction, NULL);
423
424 /* Collect the various SSA sources */
425 const unsigned num_inputs = count - 3;
426 struct vtn_ssa_value *vtn_src[4] = { NULL, };
427 for (unsigned i = 0; i < num_inputs; i++)
428 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
429
430 if (glsl_type_is_matrix(vtn_src[0]->type) ||
431 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
432 vtn_push_ssa_value(b, w[2],
433 vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
434 b->nb.exact = b->exact;
435 return;
436 }
437
438 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
439 nir_ssa_def *src[4] = { NULL, };
440 for (unsigned i = 0; i < num_inputs; i++) {
441 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
442 src[i] = vtn_src[i]->def;
443 }
444
445 switch (opcode) {
446 case SpvOpAny:
447 dest->def = nir_bany(&b->nb, src[0]);
448 break;
449
450 case SpvOpAll:
451 dest->def = nir_ball(&b->nb, src[0]);
452 break;
453
454 case SpvOpOuterProduct: {
455 for (unsigned i = 0; i < src[1]->num_components; i++) {
456 dest->elems[i]->def =
457 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
458 }
459 break;
460 }
461
462 case SpvOpDot:
463 dest->def = nir_fdot(&b->nb, src[0], src[1]);
464 break;
465
466 case SpvOpIAddCarry:
467 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
468 dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
469 dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
470 break;
471
472 case SpvOpISubBorrow:
473 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
474 dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
475 dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
476 break;
477
478 case SpvOpUMulExtended: {
479 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
480 nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
481 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
482 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
483 break;
484 }
485
486 case SpvOpSMulExtended: {
487 vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
488 nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
489 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
490 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
491 break;
492 }
493
494 case SpvOpFwidth:
495 dest->def = nir_fadd(&b->nb,
496 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
497 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
498 break;
499 case SpvOpFwidthFine:
500 dest->def = nir_fadd(&b->nb,
501 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
502 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
503 break;
504 case SpvOpFwidthCoarse:
505 dest->def = nir_fadd(&b->nb,
506 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
507 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
508 break;
509
510 case SpvOpVectorTimesScalar:
511 /* The builder will take care of splatting for us. */
512 dest->def = nir_fmul(&b->nb, src[0], src[1]);
513 break;
514
515 case SpvOpIsNan:
516 dest->def = nir_fne(&b->nb, src[0], src[0]);
517 break;
518
519 case SpvOpIsInf: {
520 nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
521 dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
522 break;
523 }
524
525 case SpvOpFUnordEqual:
526 case SpvOpFUnordNotEqual:
527 case SpvOpFUnordLessThan:
528 case SpvOpFUnordGreaterThan:
529 case SpvOpFUnordLessThanEqual:
530 case SpvOpFUnordGreaterThanEqual: {
531 bool swap;
532 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
533 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
534 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
535 src_bit_size, dst_bit_size);
536
537 if (swap) {
538 nir_ssa_def *tmp = src[0];
539 src[0] = src[1];
540 src[1] = tmp;
541 }
542
543 dest->def =
544 nir_ior(&b->nb,
545 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
546 nir_ior(&b->nb,
547 nir_fne(&b->nb, src[0], src[0]),
548 nir_fne(&b->nb, src[1], src[1])));
549 break;
550 }
551
552 case SpvOpLessOrGreater:
553 case SpvOpFOrdNotEqual: {
554 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
555 * from the ALU will probably already be false if the operands are not
556 * ordered so we don’t need to handle it specially.
557 */
558 bool swap;
559 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
560 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
561 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
562 src_bit_size, dst_bit_size);
563
564 assert(!swap);
565
566 dest->def =
567 nir_iand(&b->nb,
568 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
569 nir_iand(&b->nb,
570 nir_feq(&b->nb, src[0], src[0]),
571 nir_feq(&b->nb, src[1], src[1])));
572 break;
573 }
574
575 case SpvOpFConvert: {
576 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
577 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(dest_type);
578 nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
579
580 vtn_foreach_decoration(b, dest_val, handle_rounding_mode, &rounding_mode);
581 nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
582
583 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
584 break;
585 }
586
587 case SpvOpBitFieldInsert:
588 case SpvOpBitFieldSExtract:
589 case SpvOpBitFieldUExtract:
590 case SpvOpShiftLeftLogical:
591 case SpvOpShiftRightArithmetic:
592 case SpvOpShiftRightLogical: {
593 bool swap;
594 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
595 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
596 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
597 src0_bit_size, dst_bit_size);
598
599 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
600 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
601 op == nir_op_ibitfield_extract);
602
603 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
604 unsigned src_bit_size =
605 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
606 if (src_bit_size == 0)
607 continue;
608 if (src_bit_size != src[i]->bit_size) {
609 assert(src_bit_size == 32);
610 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
611 * supported by the NIR instructions. See discussion here:
612 *
613 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
614 */
615 src[i] = nir_u2u32(&b->nb, src[i]);
616 }
617 }
618 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
619 break;
620 }
621
622 case SpvOpSignBitSet:
623 dest->def = nir_i2b(&b->nb,
624 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
625 break;
626
627 case SpvOpUCountTrailingZerosINTEL:
628 dest->def = nir_umin(&b->nb,
629 nir_find_lsb(&b->nb, src[0]),
630 nir_imm_int(&b->nb, 32u));
631 break;
632
633 default: {
634 bool swap;
635 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
636 unsigned dst_bit_size = glsl_get_bit_size(dest_type);
637 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
638 src_bit_size, dst_bit_size);
639
640 if (swap) {
641 nir_ssa_def *tmp = src[0];
642 src[0] = src[1];
643 src[1] = tmp;
644 }
645
646 switch (op) {
647 case nir_op_ishl:
648 case nir_op_ishr:
649 case nir_op_ushr:
650 if (src[1]->bit_size != 32)
651 src[1] = nir_u2u32(&b->nb, src[1]);
652 break;
653 default:
654 break;
655 }
656
657 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
658 break;
659 } /* default */
660 }
661
662 switch (opcode) {
663 case SpvOpIAdd:
664 case SpvOpIMul:
665 case SpvOpISub:
666 case SpvOpShiftLeftLogical:
667 case SpvOpSNegate: {
668 nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
669 vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
670 break;
671 }
672 default:
673 /* Do nothing. */
674 break;
675 }
676
677 vtn_push_ssa_value(b, w[2], dest);
678
679 b->nb.exact = b->exact;
680 }
681
682 void
683 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
684 {
685 vtn_assert(count == 4);
686 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
687 *
688 * "If Result Type has the same number of components as Operand, they
689 * must also have the same component width, and results are computed per
690 * component.
691 *
692 * If Result Type has a different number of components than Operand, the
693 * total number of bits in Result Type must equal the total number of
694 * bits in Operand. Let L be the type, either Result Type or Operand’s
695 * type, that has the larger number of components. Let S be the other
696 * type, with the smaller number of components. The number of components
697 * in L must be an integer multiple of the number of components in S.
698 * The first component (that is, the only or lowest-numbered component)
699 * of S maps to the first components of L, and so on, up to the last
700 * component of S mapping to the last components of L. Within this
701 * mapping, any single component of S (mapping to multiple components of
702 * L) maps its lower-ordered bits to the lower-numbered components of L."
703 */
704
705 struct vtn_type *type = vtn_get_type(b, w[1]);
706 struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
707
708 vtn_fail_if(src->num_components * src->bit_size !=
709 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
710 "Source and destination of OpBitcast must have the same "
711 "total number of bits");
712 nir_ssa_def *val =
713 nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
714 vtn_push_nir_ssa(b, w[2], val);
715 }