nir: Drop imov/fmov in favor of one mov instruction
[mesa.git] / src / compiler / spirv / vtn_alu.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27
28 /*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36 static struct vtn_ssa_value *
37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39 if (val == NULL)
40 return NULL;
41
42 if (glsl_type_is_matrix(val->type))
43 return val;
44
45 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46 dest->type = val->type;
47 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48 dest->elems[0] = val;
49
50 return dest;
51 }
52
53 static struct vtn_ssa_value *
54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56 if (glsl_type_is_matrix(val->type))
57 return val;
58
59 return val->elems[0];
60 }
61
62 static struct vtn_ssa_value *
63 matrix_multiply(struct vtn_builder *b,
64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66
67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72 unsigned src0_rows = glsl_get_vector_elements(src0->type);
73 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76 const struct glsl_type *dest_type;
77 if (src1_columns > 1) {
78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79 src0_rows, src1_columns);
80 } else {
81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 }
83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85 dest = wrap_matrix(b, dest);
86
87 bool transpose_result = false;
88 if (src0_transpose && src1_transpose) {
89 /* transpose(A) * transpose(B) = transpose(B * A) */
90 src1 = src0_transpose;
91 src0 = src1_transpose;
92 src0_transpose = NULL;
93 src1_transpose = NULL;
94 transpose_result = true;
95 }
96
97 if (src0_transpose && !src1_transpose &&
98 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99 /* We already have the rows of src0 and the columns of src1 available,
100 * so we can just take the dot product of each row with each column to
101 * get the result.
102 */
103
104 for (unsigned i = 0; i < src1_columns; i++) {
105 nir_ssa_def *vec_src[4];
106 for (unsigned j = 0; j < src0_rows; j++) {
107 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108 src1->elems[i]->def);
109 }
110 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111 }
112 } else {
113 /* We don't handle the case where src1 is transposed but not src0, since
114 * the general case only uses individual components of src1 so the
115 * optimizer should chew through the transpose we emitted for src1.
116 */
117
118 for (unsigned i = 0; i < src1_columns; i++) {
119 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120 dest->elems[i]->def =
121 nir_fmul(&b->nb, src0->elems[0]->def,
122 nir_channel(&b->nb, src1->elems[i]->def, 0));
123 for (unsigned j = 1; j < src0_columns; j++) {
124 dest->elems[i]->def =
125 nir_fadd(&b->nb, dest->elems[i]->def,
126 nir_fmul(&b->nb, src0->elems[j]->def,
127 nir_channel(&b->nb, src1->elems[i]->def, j)));
128 }
129 }
130 }
131
132 dest = unwrap_matrix(dest);
133
134 if (transpose_result)
135 dest = vtn_ssa_transpose(b, dest);
136
137 return dest;
138 }
139
140 static struct vtn_ssa_value *
141 mat_times_scalar(struct vtn_builder *b,
142 struct vtn_ssa_value *mat,
143 nir_ssa_def *scalar)
144 {
145 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 else
150 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151 }
152
153 return dest;
154 }
155
156 static void
157 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158 struct vtn_value *dest,
159 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
160 {
161 switch (opcode) {
162 case SpvOpFNegate: {
163 dest->ssa = vtn_create_ssa_value(b, src0->type);
164 unsigned cols = glsl_get_matrix_columns(src0->type);
165 for (unsigned i = 0; i < cols; i++)
166 dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
167 break;
168 }
169
170 case SpvOpFAdd: {
171 dest->ssa = vtn_create_ssa_value(b, src0->type);
172 unsigned cols = glsl_get_matrix_columns(src0->type);
173 for (unsigned i = 0; i < cols; i++)
174 dest->ssa->elems[i]->def =
175 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
176 break;
177 }
178
179 case SpvOpFSub: {
180 dest->ssa = vtn_create_ssa_value(b, src0->type);
181 unsigned cols = glsl_get_matrix_columns(src0->type);
182 for (unsigned i = 0; i < cols; i++)
183 dest->ssa->elems[i]->def =
184 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
185 break;
186 }
187
188 case SpvOpTranspose:
189 dest->ssa = vtn_ssa_transpose(b, src0);
190 break;
191
192 case SpvOpMatrixTimesScalar:
193 if (src0->transposed) {
194 dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
195 src1->def));
196 } else {
197 dest->ssa = mat_times_scalar(b, src0, src1->def);
198 }
199 break;
200
201 case SpvOpVectorTimesMatrix:
202 case SpvOpMatrixTimesVector:
203 case SpvOpMatrixTimesMatrix:
204 if (opcode == SpvOpVectorTimesMatrix) {
205 dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
206 } else {
207 dest->ssa = matrix_multiply(b, src0, src1);
208 }
209 break;
210
211 default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
212 }
213 }
214
215 nir_op
216 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
217 SpvOp opcode, bool *swap,
218 unsigned src_bit_size, unsigned dst_bit_size)
219 {
220 /* Indicates that the first two arguments should be swapped. This is
221 * used for implementing greater-than and less-than-or-equal.
222 */
223 *swap = false;
224
225 switch (opcode) {
226 case SpvOpSNegate: return nir_op_ineg;
227 case SpvOpFNegate: return nir_op_fneg;
228 case SpvOpNot: return nir_op_inot;
229 case SpvOpIAdd: return nir_op_iadd;
230 case SpvOpFAdd: return nir_op_fadd;
231 case SpvOpISub: return nir_op_isub;
232 case SpvOpFSub: return nir_op_fsub;
233 case SpvOpIMul: return nir_op_imul;
234 case SpvOpFMul: return nir_op_fmul;
235 case SpvOpUDiv: return nir_op_udiv;
236 case SpvOpSDiv: return nir_op_idiv;
237 case SpvOpFDiv: return nir_op_fdiv;
238 case SpvOpUMod: return nir_op_umod;
239 case SpvOpSMod: return nir_op_imod;
240 case SpvOpFMod: return nir_op_fmod;
241 case SpvOpSRem: return nir_op_irem;
242 case SpvOpFRem: return nir_op_frem;
243
244 case SpvOpShiftRightLogical: return nir_op_ushr;
245 case SpvOpShiftRightArithmetic: return nir_op_ishr;
246 case SpvOpShiftLeftLogical: return nir_op_ishl;
247 case SpvOpLogicalOr: return nir_op_ior;
248 case SpvOpLogicalEqual: return nir_op_ieq;
249 case SpvOpLogicalNotEqual: return nir_op_ine;
250 case SpvOpLogicalAnd: return nir_op_iand;
251 case SpvOpLogicalNot: return nir_op_inot;
252 case SpvOpBitwiseOr: return nir_op_ior;
253 case SpvOpBitwiseXor: return nir_op_ixor;
254 case SpvOpBitwiseAnd: return nir_op_iand;
255 case SpvOpSelect: return nir_op_bcsel;
256 case SpvOpIEqual: return nir_op_ieq;
257
258 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
259 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
260 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
261 case SpvOpBitReverse: return nir_op_bitfield_reverse;
262 case SpvOpBitCount: return nir_op_bit_count;
263
264 /* The ordered / unordered operators need special implementation besides
265 * the logical operator to use since they also need to check if operands are
266 * ordered.
267 */
268 case SpvOpFOrdEqual: return nir_op_feq;
269 case SpvOpFUnordEqual: return nir_op_feq;
270 case SpvOpINotEqual: return nir_op_ine;
271 case SpvOpFOrdNotEqual: return nir_op_fne;
272 case SpvOpFUnordNotEqual: return nir_op_fne;
273 case SpvOpULessThan: return nir_op_ult;
274 case SpvOpSLessThan: return nir_op_ilt;
275 case SpvOpFOrdLessThan: return nir_op_flt;
276 case SpvOpFUnordLessThan: return nir_op_flt;
277 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
278 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
279 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt;
280 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt;
281 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
282 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
283 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge;
284 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge;
285 case SpvOpUGreaterThanEqual: return nir_op_uge;
286 case SpvOpSGreaterThanEqual: return nir_op_ige;
287 case SpvOpFOrdGreaterThanEqual: return nir_op_fge;
288 case SpvOpFUnordGreaterThanEqual: return nir_op_fge;
289
290 /* Conversions: */
291 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
292 case SpvOpUConvert:
293 case SpvOpConvertFToU:
294 case SpvOpConvertFToS:
295 case SpvOpConvertSToF:
296 case SpvOpConvertUToF:
297 case SpvOpSConvert:
298 case SpvOpFConvert: {
299 nir_alu_type src_type;
300 nir_alu_type dst_type;
301
302 switch (opcode) {
303 case SpvOpConvertFToS:
304 src_type = nir_type_float;
305 dst_type = nir_type_int;
306 break;
307 case SpvOpConvertFToU:
308 src_type = nir_type_float;
309 dst_type = nir_type_uint;
310 break;
311 case SpvOpFConvert:
312 src_type = dst_type = nir_type_float;
313 break;
314 case SpvOpConvertSToF:
315 src_type = nir_type_int;
316 dst_type = nir_type_float;
317 break;
318 case SpvOpSConvert:
319 src_type = dst_type = nir_type_int;
320 break;
321 case SpvOpConvertUToF:
322 src_type = nir_type_uint;
323 dst_type = nir_type_float;
324 break;
325 case SpvOpUConvert:
326 src_type = dst_type = nir_type_uint;
327 break;
328 default:
329 unreachable("Invalid opcode");
330 }
331 src_type |= src_bit_size;
332 dst_type |= dst_bit_size;
333 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
334 }
335 /* Derivatives: */
336 case SpvOpDPdx: return nir_op_fddx;
337 case SpvOpDPdy: return nir_op_fddy;
338 case SpvOpDPdxFine: return nir_op_fddx_fine;
339 case SpvOpDPdyFine: return nir_op_fddy_fine;
340 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
341 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
342
343 default:
344 vtn_fail("No NIR equivalent: %u", opcode);
345 }
346 }
347
348 static void
349 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
350 const struct vtn_decoration *dec, void *_void)
351 {
352 vtn_assert(dec->scope == VTN_DEC_DECORATION);
353 if (dec->decoration != SpvDecorationNoContraction)
354 return;
355
356 b->nb.exact = true;
357 }
358
359 static void
360 handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
361 const struct vtn_decoration *dec, void *_out_rounding_mode)
362 {
363 nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
364 assert(dec->scope == VTN_DEC_DECORATION);
365 if (dec->decoration != SpvDecorationFPRoundingMode)
366 return;
367 switch (dec->operands[0]) {
368 case SpvFPRoundingModeRTE:
369 *out_rounding_mode = nir_rounding_mode_rtne;
370 break;
371 case SpvFPRoundingModeRTZ:
372 *out_rounding_mode = nir_rounding_mode_rtz;
373 break;
374 default:
375 unreachable("Not supported rounding mode");
376 break;
377 }
378 }
379
380 void
381 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
382 const uint32_t *w, unsigned count)
383 {
384 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
385 const struct glsl_type *type =
386 vtn_value(b, w[1], vtn_value_type_type)->type->type;
387
388 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
389
390 /* Collect the various SSA sources */
391 const unsigned num_inputs = count - 3;
392 struct vtn_ssa_value *vtn_src[4] = { NULL, };
393 for (unsigned i = 0; i < num_inputs; i++)
394 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
395
396 if (glsl_type_is_matrix(vtn_src[0]->type) ||
397 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
398 vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
399 b->nb.exact = b->exact;
400 return;
401 }
402
403 val->ssa = vtn_create_ssa_value(b, type);
404 nir_ssa_def *src[4] = { NULL, };
405 for (unsigned i = 0; i < num_inputs; i++) {
406 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
407 src[i] = vtn_src[i]->def;
408 }
409
410 switch (opcode) {
411 case SpvOpAny:
412 if (src[0]->num_components == 1) {
413 val->ssa->def = nir_mov(&b->nb, src[0]);
414 } else {
415 nir_op op;
416 switch (src[0]->num_components) {
417 case 2: op = nir_op_bany_inequal2; break;
418 case 3: op = nir_op_bany_inequal3; break;
419 case 4: op = nir_op_bany_inequal4; break;
420 default: vtn_fail("invalid number of components");
421 }
422 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
423 nir_imm_false(&b->nb),
424 NULL, NULL);
425 }
426 break;
427
428 case SpvOpAll:
429 if (src[0]->num_components == 1) {
430 val->ssa->def = nir_mov(&b->nb, src[0]);
431 } else {
432 nir_op op;
433 switch (src[0]->num_components) {
434 case 2: op = nir_op_ball_iequal2; break;
435 case 3: op = nir_op_ball_iequal3; break;
436 case 4: op = nir_op_ball_iequal4; break;
437 default: vtn_fail("invalid number of components");
438 }
439 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
440 nir_imm_true(&b->nb),
441 NULL, NULL);
442 }
443 break;
444
445 case SpvOpOuterProduct: {
446 for (unsigned i = 0; i < src[1]->num_components; i++) {
447 val->ssa->elems[i]->def =
448 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
449 }
450 break;
451 }
452
453 case SpvOpDot:
454 val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
455 break;
456
457 case SpvOpIAddCarry:
458 vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
459 val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
460 val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
461 break;
462
463 case SpvOpISubBorrow:
464 vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
465 val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
466 val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
467 break;
468
469 case SpvOpUMulExtended: {
470 vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
471 nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
472 val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
473 val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
474 break;
475 }
476
477 case SpvOpSMulExtended: {
478 vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
479 nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
480 val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
481 val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
482 break;
483 }
484
485 case SpvOpFwidth:
486 val->ssa->def = nir_fadd(&b->nb,
487 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
488 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
489 break;
490 case SpvOpFwidthFine:
491 val->ssa->def = nir_fadd(&b->nb,
492 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
493 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
494 break;
495 case SpvOpFwidthCoarse:
496 val->ssa->def = nir_fadd(&b->nb,
497 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
498 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
499 break;
500
501 case SpvOpVectorTimesScalar:
502 /* The builder will take care of splatting for us. */
503 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
504 break;
505
506 case SpvOpIsNan:
507 val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
508 break;
509
510 case SpvOpIsInf: {
511 nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
512 val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
513 break;
514 }
515
516 case SpvOpFUnordEqual:
517 case SpvOpFUnordNotEqual:
518 case SpvOpFUnordLessThan:
519 case SpvOpFUnordGreaterThan:
520 case SpvOpFUnordLessThanEqual:
521 case SpvOpFUnordGreaterThanEqual: {
522 bool swap;
523 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
524 unsigned dst_bit_size = glsl_get_bit_size(type);
525 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
526 src_bit_size, dst_bit_size);
527
528 if (swap) {
529 nir_ssa_def *tmp = src[0];
530 src[0] = src[1];
531 src[1] = tmp;
532 }
533
534 val->ssa->def =
535 nir_ior(&b->nb,
536 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
537 nir_ior(&b->nb,
538 nir_fne(&b->nb, src[0], src[0]),
539 nir_fne(&b->nb, src[1], src[1])));
540 break;
541 }
542
543 case SpvOpFOrdNotEqual: {
544 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
545 * from the ALU will probably already be false if the operands are not
546 * ordered so we don’t need to handle it specially.
547 */
548 bool swap;
549 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
550 unsigned dst_bit_size = glsl_get_bit_size(type);
551 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
552 src_bit_size, dst_bit_size);
553
554 assert(!swap);
555
556 val->ssa->def =
557 nir_iand(&b->nb,
558 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
559 nir_iand(&b->nb,
560 nir_feq(&b->nb, src[0], src[0]),
561 nir_feq(&b->nb, src[1], src[1])));
562 break;
563 }
564
565 case SpvOpFConvert: {
566 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
567 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
568 nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
569
570 vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
571 nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
572
573 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
574 break;
575 }
576
577 case SpvOpBitFieldInsert:
578 case SpvOpBitFieldSExtract:
579 case SpvOpBitFieldUExtract:
580 case SpvOpShiftLeftLogical:
581 case SpvOpShiftRightArithmetic:
582 case SpvOpShiftRightLogical: {
583 bool swap;
584 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
585 unsigned dst_bit_size = glsl_get_bit_size(type);
586 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
587 src0_bit_size, dst_bit_size);
588
589 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
590 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
591 op == nir_op_ibitfield_extract);
592
593 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
594 unsigned src_bit_size =
595 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
596 if (src_bit_size == 0)
597 continue;
598 if (src_bit_size != src[i]->bit_size) {
599 assert(src_bit_size == 32);
600 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
601 * supported by the NIR instructions. See discussion here:
602 *
603 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
604 */
605 src[i] = nir_u2u32(&b->nb, src[i]);
606 }
607 }
608 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
609 break;
610 }
611
612 case SpvOpSignBitSet: {
613 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
614 if (src[0]->num_components == 1)
615 val->ssa->def =
616 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
617 else
618 val->ssa->def =
619 nir_ishr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
620
621 val->ssa->def = nir_i2b(&b->nb, val->ssa->def);
622 break;
623 }
624
625 default: {
626 bool swap;
627 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
628 unsigned dst_bit_size = glsl_get_bit_size(type);
629 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
630 src_bit_size, dst_bit_size);
631
632 if (swap) {
633 nir_ssa_def *tmp = src[0];
634 src[0] = src[1];
635 src[1] = tmp;
636 }
637
638 switch (op) {
639 case nir_op_ishl:
640 case nir_op_ishr:
641 case nir_op_ushr:
642 if (src[1]->bit_size != 32)
643 src[1] = nir_u2u32(&b->nb, src[1]);
644 break;
645 default:
646 break;
647 }
648
649 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
650 break;
651 } /* default */
652 }
653
654 b->nb.exact = b->exact;
655 }
656
657 void
658 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
659 {
660 vtn_assert(count == 4);
661 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
662 *
663 * "If Result Type has the same number of components as Operand, they
664 * must also have the same component width, and results are computed per
665 * component.
666 *
667 * If Result Type has a different number of components than Operand, the
668 * total number of bits in Result Type must equal the total number of
669 * bits in Operand. Let L be the type, either Result Type or Operand’s
670 * type, that has the larger number of components. Let S be the other
671 * type, with the smaller number of components. The number of components
672 * in L must be an integer multiple of the number of components in S.
673 * The first component (that is, the only or lowest-numbered component)
674 * of S maps to the first components of L, and so on, up to the last
675 * component of S mapping to the last components of L. Within this
676 * mapping, any single component of S (mapping to multiple components of
677 * L) maps its lower-ordered bits to the lower-numbered components of L."
678 */
679
680 struct vtn_type *type = vtn_value(b, w[1], vtn_value_type_type)->type;
681 struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]);
682 struct nir_ssa_def *src = vtn_src->def;
683 struct vtn_ssa_value *val = vtn_create_ssa_value(b, type->type);
684
685 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src->type));
686
687 vtn_fail_if(src->num_components * src->bit_size !=
688 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
689 "Source and destination of OpBitcast must have the same "
690 "total number of bits");
691 val->def = nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
692 vtn_push_ssa(b, w[2], type, val);
693 }