From fbb9e1f008f8059c373ae9f130be139e0d4e1ae9 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Fri, 25 Mar 2016 15:30:46 -0700 Subject: [PATCH] spirv/alu: Add support for the NoContraction decoration --- src/compiler/nir/spirv/vtn_alu.c | 69 ++++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/src/compiler/nir/spirv/vtn_alu.c b/src/compiler/nir/spirv/vtn_alu.c index 450bc158be9..c9526f1d818 100644 --- a/src/compiler/nir/spirv/vtn_alu.c +++ b/src/compiler/nir/spirv/vtn_alu.c @@ -305,6 +305,17 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap) } } +static void +handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member, + const struct vtn_decoration *dec, void *_void) +{ + assert(dec->scope == VTN_DEC_DECORATION); + if (dec->decoration != SpvDecorationNoContraction) + return; + + b->nb.exact = true; +} + void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -313,15 +324,39 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const struct glsl_type *type = vtn_value(b, w[1], vtn_value_type_type)->type->type; + vtn_foreach_decoration(b, val, handle_no_contraction, NULL); + /* Collect the various SSA sources */ const unsigned num_inputs = count - 3; struct vtn_ssa_value *vtn_src[4] = { NULL, }; - for (unsigned i = 0; i < num_inputs; i++) + for (unsigned i = 0; i < num_inputs; i++) { vtn_src[i] = vtn_ssa_value(b, w[i + 3]); + /* The way SPIR-V defines the NoContraction decoration is rediculous. + * It expressly says in the SPIR-V spec: + * + * "For example, if applied to an OpFMul, that multiply can’t be + * combined with an addition to yield a fused multiply-add + * operation." + * + * Technically, this means we would have to either rewrite NIR with + * another silly "don't fuse me" flag or we would have to propagate + * the NoContraction decoration to all consumers of a value which + * would make it far more infectious than anyone intended. + * + * Instead, we take a short-cut by simply looking at the sources and + * see if any of them have it. That should be good enough. + * + * See also issue #17 on the SPIR-V gitlab + */ + vtn_foreach_decoration(b, vtn_untyped_value(b, w[i + 3]), + handle_no_contraction, NULL); + } + if (glsl_type_is_matrix(vtn_src[0]->type) || (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]); + b->nb.exact = false; return; } @@ -347,7 +382,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, nir_imm_int(&b->nb, NIR_FALSE), NULL, NULL); } - return; + break; case SpvOpAll: if (src[0]->num_components == 1) { @@ -363,73 +398,73 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, nir_imm_int(&b->nb, NIR_TRUE), NULL, NULL); } - return; + break; case SpvOpOuterProduct: { for (unsigned i = 0; i < src[1]->num_components; i++) { val->ssa->elems[i]->def = nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i)); } - return; + break; } case SpvOpDot: val->ssa->def = nir_fdot(&b->nb, src[0], src[1]); - return; + break; case SpvOpIAddCarry: assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); - return; + break; case SpvOpISubBorrow: assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); - return; + break; case SpvOpUMulExtended: assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); - return; + break; case SpvOpSMulExtended: assert(glsl_type_is_struct(val->ssa->type)); val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); - return; + break; case SpvOpFwidth: val->ssa->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy(&b->nb, src[0]))); - return; + break; case SpvOpFwidthFine: val->ssa->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0]))); - return; + break; case SpvOpFwidthCoarse: val->ssa->def = nir_fadd(&b->nb, nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0]))); - return; + break; case SpvOpVectorTimesScalar: /* The builder will take care of splatting for us. */ val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); - return; + break; case SpvOpIsNan: val->ssa->def = nir_fne(&b->nb, src[0], src[0]); - return; + break; case SpvOpIsInf: val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]), nir_imm_float(&b->nb, INFINITY)); - return; + break; default: { bool swap; @@ -442,7 +477,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, } val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); - return; + break; } /* default */ } + + b->nb.exact = false; } -- 2.30.2