X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fcompiler%2Fspirv%2Fvtn_alu.c;h=5f9cc97fdfb0f337181dfc1456d964f22193542f;hb=768c275debb2ccf0e3093a893e19345f359b5c58;hp=ecf9cbc34d678ddaa2710298f688d6cafeffe44f;hpb=196e6b60b1e392c5e55c07a9f9b4e85dad52fb66;p=mesa.git diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index ecf9cbc34d6..5f9cc97fdfb 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -21,6 +21,7 @@ * IN THE SOFTWARE. */ +#include #include "vtn_private.h" /* @@ -142,10 +143,10 @@ mat_times_scalar(struct vtn_builder *b, { struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { - if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT) - dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); - else + if (glsl_base_type_is_integer(glsl_get_base_type(mat->type))) dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); + else + dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); } return dest; @@ -206,7 +207,7 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, } break; - default: unreachable("unknown matrix opcode"); + default: vtn_fail("unknown matrix opcode"); } } @@ -243,38 +244,52 @@ vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest, unsigned dest_bit_size = glsl_get_bit_size(dest->type); unsigned src_components = src->num_components; unsigned dest_components = glsl_get_vector_elements(dest->type); - assert(src_bit_size * src_components == dest_bit_size * dest_components); + vtn_assert(src_bit_size * src_components == dest_bit_size * dest_components); nir_ssa_def *dest_chan[4]; if (src_bit_size > dest_bit_size) { - assert(src_bit_size % dest_bit_size == 0); + vtn_assert(src_bit_size % dest_bit_size == 0); unsigned divisor = src_bit_size / dest_bit_size; for (unsigned comp = 0; comp < src_components; comp++) { - assert(src_bit_size == 64); - assert(dest_bit_size == 32); - nir_ssa_def *split = - nir_unpack_64_2x32(&b->nb, nir_channel(&b->nb, src, comp)); + nir_ssa_def *split; + if (src_bit_size == 64) { + assert(dest_bit_size == 32 || dest_bit_size == 16); + split = dest_bit_size == 32 ? + nir_unpack_64_2x32(&b->nb, nir_channel(&b->nb, src, comp)) : + nir_unpack_64_4x16(&b->nb, nir_channel(&b->nb, src, comp)); + } else { + vtn_assert(src_bit_size == 32); + vtn_assert(dest_bit_size == 16); + split = nir_unpack_32_2x16(&b->nb, nir_channel(&b->nb, src, comp)); + } for (unsigned i = 0; i < divisor; i++) dest_chan[divisor * comp + i] = nir_channel(&b->nb, split, i); } } else { - assert(dest_bit_size % src_bit_size == 0); + vtn_assert(dest_bit_size % src_bit_size == 0); unsigned divisor = dest_bit_size / src_bit_size; for (unsigned comp = 0; comp < dest_components; comp++) { unsigned channels = ((1 << divisor) - 1) << (comp * divisor); - nir_ssa_def *src_chan = - nir_channels(&b->nb, src, channels); - assert(dest_bit_size == 64); - assert(src_bit_size == 32); - dest_chan[comp] = nir_pack_64_2x32(&b->nb, src_chan); + nir_ssa_def *src_chan = nir_channels(&b->nb, src, channels); + if (dest_bit_size == 64) { + assert(src_bit_size == 32 || src_bit_size == 16); + dest_chan[comp] = src_bit_size == 32 ? + nir_pack_64_2x32(&b->nb, src_chan) : + nir_pack_64_4x16(&b->nb, src_chan); + } else { + vtn_assert(dest_bit_size == 32); + vtn_assert(src_bit_size == 16); + dest_chan[comp] = nir_pack_32_2x16(&b->nb, src_chan); + } } } dest->def = nir_vec(&b->nb, dest_chan, dest_components); } nir_op -vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap, - nir_alu_type src, nir_alu_type dst) +vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, + SpvOp opcode, bool *swap, + unsigned src_bit_size, unsigned dst_bit_size) { /* Indicates that the first two arguments should be swapped. This is * used for implementing greater-than and less-than-or-equal. @@ -354,9 +369,43 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap, case SpvOpConvertSToF: case SpvOpConvertUToF: case SpvOpSConvert: - case SpvOpFConvert: - return nir_type_conversion_op(src, dst); - + case SpvOpFConvert: { + nir_alu_type src_type; + nir_alu_type dst_type; + + switch (opcode) { + case SpvOpConvertFToS: + src_type = nir_type_float; + dst_type = nir_type_int; + break; + case SpvOpConvertFToU: + src_type = nir_type_float; + dst_type = nir_type_uint; + break; + case SpvOpFConvert: + src_type = dst_type = nir_type_float; + break; + case SpvOpConvertSToF: + src_type = nir_type_int; + dst_type = nir_type_float; + break; + case SpvOpSConvert: + src_type = dst_type = nir_type_int; + break; + case SpvOpConvertUToF: + src_type = nir_type_uint; + dst_type = nir_type_float; + break; + case SpvOpUConvert: + src_type = dst_type = nir_type_uint; + break; + default: + unreachable("Invalid opcode"); + } + src_type |= src_bit_size; + dst_type |= dst_bit_size; + return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef); + } /* Derivatives: */ case SpvOpDPdx: return nir_op_fddx; case SpvOpDPdy: return nir_op_fddy; @@ -366,7 +415,7 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap, case SpvOpDPdyCoarse: return nir_op_fddy_coarse; default: - unreachable("No NIR equivalent"); + vtn_fail("No NIR equivalent"); } } @@ -374,13 +423,34 @@ static void handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member, const struct vtn_decoration *dec, void *_void) { - assert(dec->scope == VTN_DEC_DECORATION); + vtn_assert(dec->scope == VTN_DEC_DECORATION); if (dec->decoration != SpvDecorationNoContraction) return; b->nb.exact = true; } +static void +handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member, + const struct vtn_decoration *dec, void *_out_rounding_mode) +{ + nir_rounding_mode *out_rounding_mode = _out_rounding_mode; + assert(dec->scope == VTN_DEC_DECORATION); + if (dec->decoration != SpvDecorationFPRoundingMode) + return; + switch (dec->literals[0]) { + case SpvFPRoundingModeRTE: + *out_rounding_mode = nir_rounding_mode_rtne; + break; + case SpvFPRoundingModeRTZ: + *out_rounding_mode = nir_rounding_mode_rtz; + break; + default: + unreachable("Not supported rounding mode"); + break; + } +} + void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -407,7 +477,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, val->ssa = vtn_create_ssa_value(b, type); nir_ssa_def *src[4] = { NULL, }; for (unsigned i = 0; i < num_inputs; i++) { - assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); + vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); src[i] = vtn_src[i]->def; } @@ -421,7 +491,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case 2: op = nir_op_bany_inequal2; break; case 3: op = nir_op_bany_inequal3; break; case 4: op = nir_op_bany_inequal4; break; - default: unreachable("invalid number of components"); + default: vtn_fail("invalid number of components"); } val->ssa->def = nir_build_alu(&b->nb, op, src[0], nir_imm_int(&b->nb, NIR_FALSE), @@ -438,7 +508,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case 2: op = nir_op_ball_iequal2; break; case 3: op = nir_op_ball_iequal3; break; case 4: op = nir_op_ball_iequal4; break; - default: unreachable("invalid number of components"); + default: vtn_fail("invalid number of components"); } val->ssa->def = nir_build_alu(&b->nb, op, src[0], nir_imm_int(&b->nb, NIR_TRUE), @@ -459,25 +529,25 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, break; case SpvOpIAddCarry: - assert(glsl_type_is_struct(val->ssa->type)); + vtn_assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); break; case SpvOpISubBorrow: - assert(glsl_type_is_struct(val->ssa->type)); + vtn_assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); break; case SpvOpUMulExtended: - assert(glsl_type_is_struct(val->ssa->type)); + vtn_assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); break; case SpvOpSMulExtended: - assert(glsl_type_is_struct(val->ssa->type)); + vtn_assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); break; @@ -507,10 +577,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, val->ssa->def = nir_fne(&b->nb, src[0], src[0]); break; - case SpvOpIsInf: - val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), - nir_imm_float(&b->nb, INFINITY)); + case SpvOpIsInf: { + nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size); + val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf); break; + } case SpvOpFUnordEqual: case SpvOpFUnordNotEqual: @@ -519,9 +590,10 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpFUnordLessThanEqual: case SpvOpFUnordGreaterThanEqual: { bool swap; - nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); - nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); + unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); + unsigned dst_bit_size = glsl_get_bit_size(type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, + src_bit_size, dst_bit_size); if (swap) { nir_ssa_def *tmp = src[0]; @@ -538,22 +610,18 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, break; } - case SpvOpFOrdEqual: - case SpvOpFOrdNotEqual: - case SpvOpFOrdLessThan: - case SpvOpFOrdGreaterThan: - case SpvOpFOrdLessThanEqual: - case SpvOpFOrdGreaterThanEqual: { + case SpvOpFOrdNotEqual: { + /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value + * from the ALU will probably already be false if the operands are not + * ordered so we don’t need to handle it specially. + */ bool swap; - nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); - nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); + unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); + unsigned dst_bit_size = glsl_get_bit_size(type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, + src_bit_size, dst_bit_size); - if (swap) { - nir_ssa_def *tmp = src[0]; - src[0] = src[1]; - src[1] = tmp; - } + assert(!swap); val->ssa->def = nir_iand(&b->nb, @@ -568,11 +636,59 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, vtn_handle_bitcast(b, val->ssa, src[0]); break; - default: { - bool swap; + case SpvOpFConvert: { nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); - nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type); + nir_rounding_mode rounding_mode = nir_rounding_mode_undef; + + vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode); + nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode); + + val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL); + break; + } + + case SpvOpBitFieldInsert: + case SpvOpBitFieldSExtract: + case SpvOpBitFieldUExtract: + case SpvOpShiftLeftLogical: + case SpvOpShiftRightArithmetic: + case SpvOpShiftRightLogical: { + bool swap; + unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type); + unsigned dst_bit_size = glsl_get_bit_size(type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, + src0_bit_size, dst_bit_size); + + assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl || + op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract || + op == nir_op_ibitfield_extract); + + for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { + unsigned src_bit_size = + nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]); + if (src_bit_size == 0) + continue; + if (src_bit_size != src[i]->bit_size) { + assert(src_bit_size == 32); + /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize + * supported by the NIR instructions. See discussion here: + * + * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html + */ + src[i] = nir_u2u32(&b->nb, src[i]); + } + } + val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); + break; + } + + default: { + bool swap; + unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); + unsigned dst_bit_size = glsl_get_bit_size(type); + nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, + src_bit_size, dst_bit_size); if (swap) { nir_ssa_def *tmp = src[0];