X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;ds=inline;f=src%2Fcompiler%2Fspirv%2Fvtn_alu.c;h=b403a25e51b20659bc22b13fc61f32a73355427e;hb=99fe3ef8ba400d9555a832d0feade58f5ca3d604;hp=fa8f259a006c105070945a7e985817545c763e25;hpb=58bcebd987b7c4e7d741f42699d34b8189ab9e79;p=mesa.git diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index fa8f259a006..b403a25e51b 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -23,6 +23,7 @@ #include #include "vtn_private.h" +#include "spirv_info.h" /* * Normally, column vectors in SPIR-V correspond to a single NIR SSA @@ -42,7 +43,7 @@ wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) return val; struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); - dest->type = val->type; + dest->type = glsl_get_bare_type(val->type); dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); dest->elems[0] = val; @@ -152,48 +153,46 @@ mat_times_scalar(struct vtn_builder *b, return dest; } -static void +static struct vtn_ssa_value * vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, - struct vtn_value *dest, struct vtn_ssa_value *src0, struct vtn_ssa_value *src1) { switch (opcode) { case SpvOpFNegate: { - dest->ssa = vtn_create_ssa_value(b, src0->type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); unsigned cols = glsl_get_matrix_columns(src0->type); for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); - break; + dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); + return dest; } case SpvOpFAdd: { - dest->ssa = vtn_create_ssa_value(b, src0->type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); unsigned cols = glsl_get_matrix_columns(src0->type); for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = + dest->elems[i]->def = nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def); - break; + return dest; } case SpvOpFSub: { - dest->ssa = vtn_create_ssa_value(b, src0->type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); unsigned cols = glsl_get_matrix_columns(src0->type); for (unsigned i = 0; i < cols; i++) - dest->ssa->elems[i]->def = + dest->elems[i]->def = nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def); - break; + return dest; } case SpvOpTranspose: - dest->ssa = vtn_ssa_transpose(b, src0); - break; + return vtn_ssa_transpose(b, src0); case SpvOpMatrixTimesScalar: if (src0->transposed) { - dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, - src1->def)); + return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, + src1->def)); } else { - dest->ssa = mat_times_scalar(b, src0, src1->def); + return mat_times_scalar(b, src0, src1->def); } break; @@ -201,13 +200,13 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpMatrixTimesVector: case SpvOpMatrixTimesMatrix: if (opcode == SpvOpVectorTimesMatrix) { - dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); + return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); } else { - dest->ssa = matrix_multiply(b, src0, src1); + return matrix_multiply(b, src0, src1); } break; - default: vtn_fail("unknown matrix opcode"); + default: vtn_fail_with_opcode("unknown matrix opcode", opcode); } } @@ -258,7 +257,21 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract; case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract; case SpvOpBitReverse: return nir_op_bitfield_reverse; - case SpvOpBitCount: return nir_op_bit_count; + + case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz; + /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */ + case SpvOpAbsISubINTEL: return nir_op_uabs_isub; + case SpvOpAbsUSubINTEL: return nir_op_uabs_usub; + case SpvOpIAddSatINTEL: return nir_op_iadd_sat; + case SpvOpUAddSatINTEL: return nir_op_uadd_sat; + case SpvOpIAverageINTEL: return nir_op_ihadd; + case SpvOpUAverageINTEL: return nir_op_uhadd; + case SpvOpIAverageRoundedINTEL: return nir_op_irhadd; + case SpvOpUAverageRoundedINTEL: return nir_op_urhadd; + case SpvOpISubSatINTEL: return nir_op_isub_sat; + case SpvOpUSubSatINTEL: return nir_op_usub_sat; + case SpvOpIMul32x16INTEL: return nir_op_imul_32x16; + case SpvOpUMul32x16INTEL: return nir_op_umul_32x16; /* The ordered / unordered operators need special implementation besides * the logical operator to use since they also need to check if operands are @@ -267,8 +280,9 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, case SpvOpFOrdEqual: return nir_op_feq; case SpvOpFUnordEqual: return nir_op_feq; case SpvOpINotEqual: return nir_op_ine; - case SpvOpFOrdNotEqual: return nir_op_fne; - case SpvOpFUnordNotEqual: return nir_op_fne; + case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */ + case SpvOpFOrdNotEqual: return nir_op_fneu; + case SpvOpFUnordNotEqual: return nir_op_fneu; case SpvOpULessThan: return nir_op_ult; case SpvOpSLessThan: return nir_op_ilt; case SpvOpFOrdLessThan: return nir_op_flt; @@ -339,6 +353,9 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, case SpvOpDPdxCoarse: return nir_op_fddx_coarse; case SpvOpDPdyCoarse: return nir_op_fddy_coarse; + case SpvOpIsNormal: return nir_op_fisnormal; + case SpvOpIsFinite: return nir_op_fisfinite; + default: vtn_fail("No NIR equivalent: %u", opcode); } @@ -363,7 +380,7 @@ handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member, assert(dec->scope == VTN_DEC_DECORATION); if (dec->decoration != SpvDecorationFPRoundingMode) return; - switch (dec->literals[0]) { + switch (dec->operands[0]) { case SpvFPRoundingModeRTE: *out_rounding_mode = nir_rounding_mode_rtne; break; @@ -376,15 +393,32 @@ handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member, } } +static void +handle_no_wrap(struct vtn_builder *b, struct vtn_value *val, int member, + const struct vtn_decoration *dec, void *_alu) +{ + nir_alu_instr *alu = _alu; + switch (dec->decoration) { + case SpvDecorationNoSignedWrap: + alu->no_signed_wrap = true; + break; + case SpvDecorationNoUnsignedWrap: + alu->no_unsigned_wrap = true; + break; + default: + /* Do nothing. */ + break; + } +} + void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) { - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - const struct glsl_type *type = - vtn_value(b, w[1], vtn_value_type_type)->type->type; + struct vtn_value *dest_val = vtn_untyped_value(b, w[2]); + const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; - vtn_foreach_decoration(b, val, handle_no_contraction, NULL); + vtn_foreach_decoration(b, dest_val, handle_no_contraction, NULL); /* Collect the various SSA sources */ const unsigned num_inputs = count - 3; @@ -394,12 +428,13 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, if (glsl_type_is_matrix(vtn_src[0]->type) || (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { - vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]); + vtn_push_ssa_value(b, w[2], + vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1])); b->nb.exact = b->exact; return; } - val->ssa = vtn_create_ssa_value(b, type); + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); nir_ssa_def *src[4] = { NULL, }; for (unsigned i = 0; i < num_inputs; i++) { vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); @@ -408,107 +443,91 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpAny: - if (src[0]->num_components == 1) { - val->ssa->def = nir_imov(&b->nb, src[0]); - } else { - nir_op op; - switch (src[0]->num_components) { - case 2: op = nir_op_bany_inequal2; break; - case 3: op = nir_op_bany_inequal3; break; - case 4: op = nir_op_bany_inequal4; break; - default: vtn_fail("invalid number of components"); - } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], - nir_imm_false(&b->nb), - NULL, NULL); - } + dest->def = nir_bany(&b->nb, src[0]); break; case SpvOpAll: - if (src[0]->num_components == 1) { - val->ssa->def = nir_imov(&b->nb, src[0]); - } else { - nir_op op; - switch (src[0]->num_components) { - case 2: op = nir_op_ball_iequal2; break; - case 3: op = nir_op_ball_iequal3; break; - case 4: op = nir_op_ball_iequal4; break; - default: vtn_fail("invalid number of components"); - } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], - nir_imm_true(&b->nb), - NULL, NULL); - } + dest->def = nir_ball(&b->nb, src[0]); break; case SpvOpOuterProduct: { for (unsigned i = 0; i < src[1]->num_components; i++) { - val->ssa->elems[i]->def = + dest->elems[i]->def = nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i)); } break; } case SpvOpDot: - val->ssa->def = nir_fdot(&b->nb, src[0], src[1]); + dest->def = nir_fdot(&b->nb, src[0], src[1]); break; case SpvOpIAddCarry: - vtn_assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); + dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); + dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); break; case SpvOpISubBorrow: - vtn_assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); + dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); + dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); break; case SpvOpUMulExtended: { - vtn_assert(glsl_type_is_struct(val->ssa->type)); + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]); - val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul); - val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul); + dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul); + dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul); break; } case SpvOpSMulExtended: { - vtn_assert(glsl_type_is_struct(val->ssa->type)); + vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]); - val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul); - val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul); + dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul); + dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul); break; } case SpvOpFwidth: - val->ssa->def = nir_fadd(&b->nb, + dest->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy(&b->nb, src[0]))); break; case SpvOpFwidthFine: - val->ssa->def = nir_fadd(&b->nb, + dest->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0]))); break; case SpvOpFwidthCoarse: - val->ssa->def = nir_fadd(&b->nb, + dest->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0]))); break; case SpvOpVectorTimesScalar: /* The builder will take care of splatting for us. */ - val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); + dest->def = nir_fmul(&b->nb, src[0], src[1]); break; case SpvOpIsNan: - val->ssa->def = nir_fne(&b->nb, src[0], src[0]); + dest->def = nir_fneu(&b->nb, src[0], src[0]); + break; + + case SpvOpOrdered: + dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]), + nir_feq(&b->nb, src[1], src[1])); + break; + + case SpvOpUnordered: + dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]), + nir_fneu(&b->nb, src[1], src[1])); break; case SpvOpIsInf: { nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size); - val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf); + dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf); break; } @@ -520,7 +539,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpFUnordGreaterThanEqual: { bool swap; unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, src_bit_size, dst_bit_size); @@ -530,15 +549,16 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, src[1] = tmp; } - val->ssa->def = + dest->def = nir_ior(&b->nb, nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), nir_ior(&b->nb, - nir_fne(&b->nb, src[0], src[0]), - nir_fne(&b->nb, src[1], src[1]))); + nir_fneu(&b->nb, src[0], src[0]), + nir_fneu(&b->nb, src[1], src[1]))); break; } + case SpvOpLessOrGreater: case SpvOpFOrdNotEqual: { /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value * from the ALU will probably already be false if the operands are not @@ -546,13 +566,13 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, */ bool swap; unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, src_bit_size, dst_bit_size); assert(!swap); - val->ssa->def = + dest->def = nir_iand(&b->nb, nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL), nir_iand(&b->nb, @@ -561,43 +581,15 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, break; } - case SpvOpBitcast: - /* From the definition of OpBitcast in the SPIR-V 1.2 spec: - * - * "If Result Type has the same number of components as Operand, they - * must also have the same component width, and results are computed - * per component. - * - * If Result Type has a different number of components than Operand, - * the total number of bits in Result Type must equal the total - * number of bits in Operand. Let L be the type, either Result Type - * or Operand’s type, that has the larger number of components. Let S - * be the other type, with the smaller number of components. The - * number of components in L must be an integer multiple of the - * number of components in S. The first component (that is, the only - * or lowest-numbered component) of S maps to the first components of - * L, and so on, up to the last component of S mapping to the last - * components of L. Within this mapping, any single component of S - * (mapping to multiple components of L) maps its lower-ordered bits - * to the lower-numbered components of L." - */ - vtn_fail_if(src[0]->num_components * src[0]->bit_size != - glsl_get_vector_elements(type) * glsl_get_bit_size(type), - "Source and destination of OpBitcast must have the same " - "total number of bits"); - val->ssa->def = nir_bitcast_vector(&b->nb, src[0], - glsl_get_bit_size(type)); - break; - case SpvOpFConvert: { nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type); - nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type); + nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(dest_type); nir_rounding_mode rounding_mode = nir_rounding_mode_undef; - vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode); + vtn_foreach_decoration(b, dest_val, handle_rounding_mode, &rounding_mode); nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode); - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL); + dest->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL); break; } @@ -609,7 +601,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, case SpvOpShiftRightLogical: { bool swap; unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, src0_bit_size, dst_bit_size); @@ -632,14 +624,33 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, src[i] = nir_u2u32(&b->nb, src[i]); } } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); + dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); + break; + } + + case SpvOpSignBitSet: + dest->def = nir_i2b(&b->nb, + nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1))); + break; + + case SpvOpUCountTrailingZerosINTEL: + dest->def = nir_umin(&b->nb, + nir_find_lsb(&b->nb, src[0]), + nir_imm_int(&b->nb, 32u)); + break; + + case SpvOpBitCount: { + /* bit_count always returns int32, but the SPIR-V opcode just says the return + * value needs to be big enough to store the number of bits. + */ + dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type)); break; } default: { bool swap; unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); - unsigned dst_bit_size = glsl_get_bit_size(type); + unsigned dst_bit_size = glsl_get_bit_size(dest_type); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, src_bit_size, dst_bit_size); @@ -660,10 +671,62 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, break; } - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); + dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); break; } /* default */ } + switch (opcode) { + case SpvOpIAdd: + case SpvOpIMul: + case SpvOpISub: + case SpvOpShiftLeftLogical: + case SpvOpSNegate: { + nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr); + vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu); + break; + } + default: + /* Do nothing. */ + break; + } + + vtn_push_ssa_value(b, w[2], dest); + b->nb.exact = b->exact; } + +void +vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) +{ + vtn_assert(count == 4); + /* From the definition of OpBitcast in the SPIR-V 1.2 spec: + * + * "If Result Type has the same number of components as Operand, they + * must also have the same component width, and results are computed per + * component. + * + * If Result Type has a different number of components than Operand, the + * total number of bits in Result Type must equal the total number of + * bits in Operand. Let L be the type, either Result Type or Operand’s + * type, that has the larger number of components. Let S be the other + * type, with the smaller number of components. The number of components + * in L must be an integer multiple of the number of components in S. + * The first component (that is, the only or lowest-numbered component) + * of S maps to the first components of L, and so on, up to the last + * component of S mapping to the last components of L. Within this + * mapping, any single component of S (mapping to multiple components of + * L) maps its lower-ordered bits to the lower-numbered components of L." + */ + + struct vtn_type *type = vtn_get_type(b, w[1]); + struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]); + + vtn_fail_if(src->num_components * src->bit_size != + glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type), + "Source and destination of OpBitcast must have the same " + "total number of bits"); + nir_ssa_def *val = + nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type)); + vtn_push_nir_ssa(b, w[2], val); +}