spirv/alu: Add support for the NoContraction decoration
authorJason Ekstrand <jason.ekstrand@intel.com>
Fri, 25 Mar 2016 22:30:46 +0000 (15:30 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Sat, 26 Mar 2016 04:35:41 +0000 (21:35 -0700)
src/compiler/nir/spirv/vtn_alu.c

index 450bc158be9fa874e3dfbd90f9a0932feddf55fb..c9526f1d8186acdd33bf7427caf2466279354d08 100644 (file)
@@ -305,6 +305,17 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
    }
 }
 
+static void
+handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
+                      const struct vtn_decoration *dec, void *_void)
+{
+   assert(dec->scope == VTN_DEC_DECORATION);
+   if (dec->decoration != SpvDecorationNoContraction)
+      return;
+
+   b->nb.exact = true;
+}
+
 void
 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                const uint32_t *w, unsigned count)
@@ -313,15 +324,39 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    const struct glsl_type *type =
       vtn_value(b, w[1], vtn_value_type_type)->type->type;
 
+   vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
+
    /* Collect the various SSA sources */
    const unsigned num_inputs = count - 3;
    struct vtn_ssa_value *vtn_src[4] = { NULL, };
-   for (unsigned i = 0; i < num_inputs; i++)
+   for (unsigned i = 0; i < num_inputs; i++) {
       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
 
+      /* The way SPIR-V defines the NoContraction decoration is rediculous.
+       * It expressly says in the SPIR-V spec:
+       *
+       *    "For example, if applied to an OpFMul, that multiply can’t be
+       *    combined with an addition to yield a fused multiply-add
+       *    operation."
+       *
+       * Technically, this means we would have to either rewrite NIR with
+       * another silly "don't fuse me" flag or we would have to propagate
+       * the NoContraction decoration to all consumers of a value which
+       * would make it far more infectious than anyone intended.
+       *
+       * Instead, we take a short-cut by simply looking at the sources and
+       * see if any of them have it.  That should be good enough.
+       *
+       * See also issue #17 on the SPIR-V gitlab
+       */
+      vtn_foreach_decoration(b, vtn_untyped_value(b, w[i + 3]),
+                             handle_no_contraction, NULL);
+   }
+
    if (glsl_type_is_matrix(vtn_src[0]->type) ||
        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
       vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
+      b->nb.exact = false;
       return;
    }
 
@@ -347,7 +382,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                                        nir_imm_int(&b->nb, NIR_FALSE),
                                        NULL, NULL);
       }
-      return;
+      break;
 
    case SpvOpAll:
       if (src[0]->num_components == 1) {
@@ -363,73 +398,73 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                                        nir_imm_int(&b->nb, NIR_TRUE),
                                        NULL, NULL);
       }
-      return;
+      break;
 
    case SpvOpOuterProduct: {
       for (unsigned i = 0; i < src[1]->num_components; i++) {
          val->ssa->elems[i]->def =
             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
       }
-      return;
+      break;
    }
 
    case SpvOpDot:
       val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
-      return;
+      break;
 
    case SpvOpIAddCarry:
       assert(glsl_type_is_struct(val->ssa->type));
       val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
       val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
-      return;
+      break;
 
    case SpvOpISubBorrow:
       assert(glsl_type_is_struct(val->ssa->type));
       val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
       val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
-      return;
+      break;
 
    case SpvOpUMulExtended:
       assert(glsl_type_is_struct(val->ssa->type));
       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
       val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
-      return;
+      break;
 
    case SpvOpSMulExtended:
       assert(glsl_type_is_struct(val->ssa->type));
       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
       val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
-      return;
+      break;
 
    case SpvOpFwidth:
       val->ssa->def = nir_fadd(&b->nb,
                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
-      return;
+      break;
    case SpvOpFwidthFine:
       val->ssa->def = nir_fadd(&b->nb,
                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
-      return;
+      break;
    case SpvOpFwidthCoarse:
       val->ssa->def = nir_fadd(&b->nb,
                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
-      return;
+      break;
 
    case SpvOpVectorTimesScalar:
       /* The builder will take care of splatting for us. */
       val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
-      return;
+      break;
 
    case SpvOpIsNan:
       val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
-      return;
+      break;
 
    case SpvOpIsInf:
       val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
                                       nir_imm_float(&b->nb, INFINITY));
-      return;
+      break;
 
    default: {
       bool swap;
@@ -442,7 +477,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       }
 
       val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
-      return;
+      break;
    } /* default */
    }
+
+   b->nb.exact = false;
 }