X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fcompiler%2Fspirv%2Fvtn_alu.c;h=b403a25e51b20659bc22b13fc61f32a73355427e;hb=938d6ceb8300b194a7cbaf640e2c899cbecc6c5a;hp=71e743cdd1e7b6f629b84a2a932c056b5e0df2ee;hpb=6e499572b9a7b33165b8438a85db37ae1ba0ce0e;p=mesa.git diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 71e743cdd1e..b403a25e51b 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -23,6 +23,7 @@ #include #include "vtn_private.h" +#include "spirv_info.h" /* * Normally, column vectors in SPIR-V correspond to a single NIR SSA @@ -42,7 +43,7 @@ wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) return val; struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); - dest->type = val->type; + dest->type = glsl_get_bare_type(val->type); dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); dest->elems[0] = val; @@ -152,48 +153,46 @@ mat_times_scalar(struct vtn_builder *b, return dest; } -static void +static struct vtn_ssa_value * vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, - struct vtn_value *dest, struct vtn_ssa_value *src0, struct vtn_ssa_value *src1) { switch (opcode) { case SpvOpFNegate: { - dest->ssa = vtn_create_ssa_value(b, src0->type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); unsigned cols = glsl_get_matrix_columns(src0->type); for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); - break; + dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); + return dest; } case SpvOpFAdd: { - dest->ssa = vtn_create_ssa_value(b, src0->type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); unsigned cols = glsl_get_matrix_columns(src0->type); for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = + dest->elems[i]->def = nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def); - break; + return dest; } case SpvOpFSub: { - dest->ssa = vtn_create_ssa_value(b, src0->type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); unsigned cols = glsl_get_matrix_columns(src0->type); for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = + dest->elems[i]->def = nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def); - break; + return dest; } case SpvOpTranspose: - dest->ssa = vtn_ssa_transpose(b, src0); - break; + return vtn_ssa_transpose(b, src0); case SpvOpMatrixTimesScalar: if (src0->transposed) { - dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, - src1->def)); + return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, + src1->def)); } else { - dest->ssa = mat_times_scalar(b, src0, src1->def); + return mat_times_scalar(b, src0, src1->def); } break; @@ -201,76 +200,14 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpMatrixTimesVector: case SpvOpMatrixTimesMatrix: if (opcode == SpvOpVectorTimesMatrix) { - dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); + return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); } else { - dest->ssa = matrix_multiply(b, src0, src1); + return matrix_multiply(b, src0, src1); } break; - default: vtn_fail("unknown matrix opcode"); - } -} - -static void -vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest, - struct nir_ssa_def *src) -{ - if (glsl_get_vector_elements(dest->type) == src->num_components) { - /* From the definition of OpBitcast in the SPIR-V 1.2 spec: - * - * "If Result Type has the same number of components as Operand, they - * must also have the same component width, and results are computed per - * component." - */ - dest->def = nir_imov(&b->nb, src); - return; - } - - /* From the definition of OpBitcast in the SPIR-V 1.2 spec: - * - * "If Result Type has a different number of components than Operand, the - * total number of bits in Result Type must equal the total number of bits - * in Operand. Let L be the type, either Result Type or Operand’s type, that - * has the larger number of components. Let S be the other type, with the - * smaller number of components. The number of components in L must be an - * integer multiple of the number of components in S. The first component - * (that is, the only or lowest-numbered component) of S maps to the first - * components of L, and so on, up to the last component of S mapping to the - * last components of L. Within this mapping, any single component of S - * (mapping to multiple components of L) maps its lower-ordered bits to the - * lower-numbered components of L." - */ - unsigned src_bit_size = src->bit_size; - unsigned dest_bit_size = glsl_get_bit_size(dest->type); - unsigned src_components = src->num_components; - unsigned dest_components = glsl_get_vector_elements(dest->type); - vtn_assert(src_bit_size * src_components == dest_bit_size * dest_components); - - nir_ssa_def *dest_chan[4]; - if (src_bit_size > dest_bit_size) { - vtn_assert(src_bit_size % dest_bit_size == 0); - unsigned divisor = src_bit_size / dest_bit_size; - for (unsigned comp = 0; comp < src_components; comp++) { - vtn_assert(src_bit_size == 64); - vtn_assert(dest_bit_size == 32); - nir_ssa_def *split = - nir_unpack_64_2x32(&b->nb, nir_channel(&b->nb, src, comp)); - for (unsigned i = 0; i < divisor; i++) - dest_chan[divisor * comp + i] = nir_channel(&b->nb, split, i); - } - } else { - vtn_assert(dest_bit_size % src_bit_size == 0); - unsigned divisor = dest_bit_size / src_bit_size; - for (unsigned comp = 0; comp < dest_components; comp++) { - unsigned channels = ((1 << divisor) - 1) << (comp * divisor); - nir_ssa_def *src_chan = - nir_channels(&b->nb, src, channels); - vtn_assert(dest_bit_size == 64); - vtn_assert(src_bit_size == 32); - dest_chan[comp] = nir_pack_64_2x32(&b->nb, src_chan); - } + default: vtn_fail_with_opcode("unknown matrix opcode", opcode); } - dest->def = nir_vec(&b->nb, dest_chan, dest_components); } nir_op @@ -320,7 +257,21 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract; case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract; case SpvOpBitReverse: return nir_op_bitfield_reverse; - case SpvOpBitCount: return nir_op_bit_count; + + case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz; + /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */ + case SpvOpAbsISubINTEL: return nir_op_uabs_isub; + case SpvOpAbsUSubINTEL: return nir_op_uabs_usub; + case SpvOpIAddSatINTEL: return nir_op_iadd_sat; + case SpvOpUAddSatINTEL: return nir_op_uadd_sat; + case SpvOpIAverageINTEL: return nir_op_ihadd; + case SpvOpUAverageINTEL: return nir_op_uhadd; + case SpvOpIAverageRoundedINTEL: return nir_op_irhadd; + case SpvOpUAverageRoundedINTEL: return nir_op_urhadd; + case SpvOpISubSatINTEL: return nir_op_isub_sat; + case SpvOpUSubSatINTEL: return nir_op_usub_sat; + case SpvOpIMul32x16INTEL: return nir_op_imul_32x16; + case SpvOpUMul32x16INTEL: return nir_op_umul_32x16; /* The ordered / unordered operators need special implementation besides * the logical operator to use since they also need to check if operands are @@ -329,8 +280,9 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, case SpvOpFOrdEqual: return nir_op_feq; case SpvOpFUnordEqual: return nir_op_feq; case SpvOpINotEqual: return nir_op_ine; - case SpvOpFOrdNotEqual: return nir_op_fne; - case SpvOpFUnordNotEqual: return nir_op_fne; + case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */ + case SpvOpFOrdNotEqual: return nir_op_fneu; + case SpvOpFUnordNotEqual: return nir_op_fneu; case SpvOpULessThan: return nir_op_ult; case SpvOpSLessThan: return nir_op_ilt; case SpvOpFOrdLessThan: return nir_op_flt; @@ -401,8 +353,11 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, case SpvOpDPdxCoarse: return nir_op_fddx_coarse; case SpvOpDPdyCoarse: return nir_op_fddy_coarse; + case SpvOpIsNormal: return nir_op_fisnormal; + case SpvOpIsFinite: return nir_op_fisfinite; + default: - vtn_fail("No NIR equivalent"); + vtn_fail("No NIR equivalent: %u", opcode); } } @@ -425,7 +380,7 @@ handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member, assert(dec->scope == VTN_DEC_DECORATION); if (dec->decoration != SpvDecorationFPRoundingMode) return; - switch (dec->literals[0]) { + switch (dec->operands[0]) { case SpvFPRoundingModeRTE: *out_rounding_mode = nir_rounding_mode_rtne; break; @@ -438,15 +393,32 @@ handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member, } } +static void +handle_no_wrap(struct vtn_builder *b, struct vtn_value *val, int member, + const struct vtn_decoration *dec, void *_alu) +{ + nir_alu_instr *alu = _alu; + switch (dec->decoration) { + case SpvDecorationNoSignedWrap: + alu->no_signed_wrap = true; + break; + case SpvDecorationNoUnsignedWrap: + alu->no_unsigned_wrap = true; + break; + default: + /* Do nothing. */ + break; + } +} + void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) { - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - const struct glsl_type *type = - vtn_value(b, w[1], vtn_value_type_type)->type->type; + struct vtn_value *dest_val = vtn_untyped_value(b, w[2]); + const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; - vtn_foreach_decoration(b, val, handle_no_contraction, NULL); + vtn_foreach_decoration(b, dest_val, handle_no_contraction, NULL); /* Collect the various SSA sources */ const unsigned num_inputs = count - 3; @@ -456,12 +428,13 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, if (glsl_type_is_matrix(vtn_src[0]->type) || (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { - vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]); - b->nb.exact = false; + vtn_push_ssa_value(b, w[2], + vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1])); + b->nb.exact = b->exact; return; } - val->ssa = vtn_create_ssa_value(b, type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); nir_ssa_def *src[4] = { NULL, }; for (unsigned i = 0; i < num_inputs; i++) { vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); @@ -470,103 +443,91 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpAny: - if (src[0]->num_components == 1) { - val->ssa->def = nir_imov(&b->nb, src[0]); - } else { - nir_op op; - switch (src[0]->num_components) { - case 2: op = nir_op_bany_inequal2; break; - case 3: op = nir_op_bany_inequal3; break; - case 4: op = nir_op_bany_inequal4; break; - default: vtn_fail("invalid number of components"); - } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], - nir_imm_int(&b->nb, NIR_FALSE), - NULL, NULL); - } + dest->def = nir_bany(&b->nb, src[0]); break; case SpvOpAll: - if (src[0]->num_components == 1) { - val->ssa->def = nir_imov(&b->nb, src[0]); - } else { - nir_op op; - switch (src[0]->num_components) { - case 2: op = nir_op_ball_iequal2; break; - case 3: op = nir_op_ball_iequal3; break; - case 4: op = nir_op_ball_iequal4; break; - default: vtn_fail("invalid number of components"); - } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], - nir_imm_int(&b->nb, NIR_TRUE), - NULL, NULL); - } + dest->def = nir_ball(&b->nb, src[0]); break; case SpvOpOuterProduct: { for (unsigned i = 0; i < src[1]->num_components; i++) { - val->ssa->elems[i]->def = + dest->elems[i]->def = nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i)); } break; } case SpvOpDot: - val->ssa->def = nir_fdot(&b->nb, src[0], src[1]); + dest->def = nir_fdot(&b->nb, src[0], src[1]); break; case SpvOpIAddCarry: - vtn_assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); + dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); + dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); break; case SpvOpISubBorrow: - vtn_assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); + dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); + dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); break; - case SpvOpUMulExtended: - vtn_assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); + case SpvOpUMulExtended: { + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); + nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]); + dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul); + dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul); break; + } - case SpvOpSMulExtended: - vtn_assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); + case SpvOpSMulExtended: { + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); + nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]); + dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul); + dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul); break; + } case SpvOpFwidth: - val->ssa->def = nir_fadd(&b->nb, + dest->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy(&b->nb, src[0]))); break; case SpvOpFwidthFine: - val->ssa->def = nir_fadd(&b->nb, + dest->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0]))); break; case SpvOpFwidthCoarse: - val->ssa->def = nir_fadd(&b->nb, + dest->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0]))); break; case SpvOpVectorTimesScalar: /* The builder will take care of splatting for us. */ - val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); + dest->def = nir_fmul(&b->nb, src[0], src[1]); break; case SpvOpIsNan: - val->ssa->def = nir_fne(&b->nb, src[0], src[0]); + dest->def = nir_fneu(&b->nb, src[0], src[0]); + break; + + case SpvOpOrdered: + dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]), + nir_feq(&b->nb, src[1], src[1])); + break; + + case SpvOpUnordered: + dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]), + nir_fneu(&b->nb, src[1], src[1])); break; case SpvOpIsInf: { nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size); - val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf); + dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf); break; } @@ -578,7 +539,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpFUnordGreaterThanEqual: { bool swap; unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, src_bit_size, dst_bit_size); @@ -588,34 +549,30 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, src[1] = tmp; } - val->ssa->def = + dest->def = nir_ior(&b->nb, nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), nir_ior(&b->nb, - nir_fne(&b->nb, src[0], src[0]), - nir_fne(&b->nb, src[1], src[1]))); + nir_fneu(&b->nb, src[0], src[0]), + nir_fneu(&b->nb, src[1], src[1]))); break; } - case SpvOpFOrdEqual: - case SpvOpFOrdNotEqual: - case SpvOpFOrdLessThan: - case SpvOpFOrdGreaterThan: - case SpvOpFOrdLessThanEqual: - case SpvOpFOrdGreaterThanEqual: { + case SpvOpLessOrGreater: + case SpvOpFOrdNotEqual: { + /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value + * from the ALU will probably already be false if the operands are not + * ordered so we don’t need to handle it specially. + */ bool swap; unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, src_bit_size, dst_bit_size); - if (swap) { - nir_ssa_def *tmp = src[0]; - src[0] = src[1]; - src[1] = tmp; - } + assert(!swap); - val->ssa->def = + dest->def = nir_iand(&b->nb, nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), nir_iand(&b->nb, @@ -624,26 +581,76 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, break; } - case SpvOpBitcast: - vtn_handle_bitcast(b, val->ssa, src[0]); - break; - case SpvOpFConvert: { nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); - nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); + nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(dest_type); nir_rounding_mode rounding_mode = nir_rounding_mode_undef; - vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode); + vtn_foreach_decoration(b, dest_val, handle_rounding_mode, &rounding_mode); nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode); - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL); + dest->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL); + break; + } + + case SpvOpBitFieldInsert: + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: + case SpvOpShiftLeftLogical: + case SpvOpShiftRightArithmetic: + case SpvOpShiftRightLogical: { + bool swap; + unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, + src0_bit_size, dst_bit_size); + + assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl || + op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract || + op == nir_op_ibitfield_extract); + + for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { + unsigned src_bit_size = + nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]); + if (src_bit_size == 0) + continue; + if (src_bit_size != src[i]->bit_size) { + assert(src_bit_size == 32); + /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize + * supported by the NIR instructions. See discussion here: + * + * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html + */ + src[i] = nir_u2u32(&b->nb, src[i]); + } + } + dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); + break; + } + + case SpvOpSignBitSet: + dest->def = nir_i2b(&b->nb, + nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1))); + break; + + case SpvOpUCountTrailingZerosINTEL: + dest->def = nir_umin(&b->nb, + nir_find_lsb(&b->nb, src[0]), + nir_imm_int(&b->nb, 32u)); + break; + + case SpvOpBitCount: { + /* bit_count always returns int32, but the SPIR-V opcode just says the return + * value needs to be big enough to store the number of bits. + */ + dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type)); break; } default: { bool swap; unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, src_bit_size, dst_bit_size); @@ -653,10 +660,73 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, src[1] = tmp; } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); + switch (op) { + case nir_op_ishl: + case nir_op_ishr: + case nir_op_ushr: + if (src[1]->bit_size != 32) + src[1] = nir_u2u32(&b->nb, src[1]); + break; + default: + break; + } + + dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); break; } /* default */ } - b->nb.exact = false; + switch (opcode) { + case SpvOpIAdd: + case SpvOpIMul: + case SpvOpISub: + case SpvOpShiftLeftLogical: + case SpvOpSNegate: { + nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr); + vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu); + break; + } + default: + /* Do nothing. */ + break; + } + + vtn_push_ssa_value(b, w[2], dest); + + b->nb.exact = b->exact; +} + +void +vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) +{ + vtn_assert(count == 4); + /* From the definition of OpBitcast in the SPIR-V 1.2 spec: + * + * "If Result Type has the same number of components as Operand, they + * must also have the same component width, and results are computed per + * component. + * + * If Result Type has a different number of components than Operand, the + * total number of bits in Result Type must equal the total number of + * bits in Operand. Let L be the type, either Result Type or Operand’s + * type, that has the larger number of components. Let S be the other + * type, with the smaller number of components. The number of components + * in L must be an integer multiple of the number of components in S. + * The first component (that is, the only or lowest-numbered component) + * of S maps to the first components of L, and so on, up to the last + * component of S mapping to the last components of L. Within this + * mapping, any single component of S (mapping to multiple components of + * L) maps its lower-ordered bits to the lower-numbered components of L." + */ + + struct vtn_type *type = vtn_get_type(b, w[1]); + struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]); + + vtn_fail_if(src->num_components * src->bit_size != + glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type), + "Source and destination of OpBitcast must have the same " + "total number of bits"); + nir_ssa_def *val = + nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type)); + vtn_push_nir_ssa(b, w[2], val); }