spirv: Move the emit a 'return value' store logic into own function
[mesa.git] / src / compiler / spirv / vtn_glsl450.c
index 933cf1d407ef937bf450a4be96d4c2e3902a48e2..242f3db02aa4809fe98b1ff4b0650c377ac259e9 100644 (file)
@@ -311,8 +311,6 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
 {
    struct nir_builder *nb = &b->nb;
    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
-   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-   val->ssa = vtn_create_ssa_value(b, dest_type);
 
    /* Collect the various SSA sources */
    unsigned num_inputs = count - 5;
@@ -322,99 +320,100 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
       if (vtn_untyped_value(b, w[i + 5])->value_type == vtn_value_type_pointer)
          continue;
 
-      src[i] = vtn_ssa_value(b, w[i + 5])->def;
+      src[i] = vtn_get_nir_ssa(b, w[i + 5]);
    }
 
+   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
    switch (entrypoint) {
    case GLSLstd450Radians:
-      val->ssa->def = nir_radians(nb, src[0]);
-      return;
+      dest->def = nir_radians(nb, src[0]);
+      break;
    case GLSLstd450Degrees:
-      val->ssa->def = nir_degrees(nb, src[0]);
-      return;
+      dest->def = nir_degrees(nb, src[0]);
+      break;
    case GLSLstd450Tan:
-      val->ssa->def = nir_ftan(nb, src[0]);
-      return;
+      dest->def = nir_ftan(nb, src[0]);
+      break;
 
    case GLSLstd450Modf: {
       nir_ssa_def *sign = nir_fsign(nb, src[0]);
       nir_ssa_def *abs = nir_fabs(nb, src[0]);
-      val->ssa->def = nir_fmul(nb, sign, nir_ffract(nb, abs));
+      dest->def = nir_fmul(nb, sign, nir_ffract(nb, abs));
       nir_store_deref(nb, vtn_nir_deref(b, w[6]),
                       nir_fmul(nb, sign, nir_ffloor(nb, abs)), 0xf);
-      return;
+      break;
    }
 
    case GLSLstd450ModfStruct: {
       nir_ssa_def *sign = nir_fsign(nb, src[0]);
       nir_ssa_def *abs = nir_fabs(nb, src[0]);
-      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
-      val->ssa->elems[0]->def = nir_fmul(nb, sign, nir_ffract(nb, abs));
-      val->ssa->elems[1]->def = nir_fmul(nb, sign, nir_ffloor(nb, abs));
-      return;
+      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
+      dest->elems[0]->def = nir_fmul(nb, sign, nir_ffract(nb, abs));
+      dest->elems[1]->def = nir_fmul(nb, sign, nir_ffloor(nb, abs));
+      break;
    }
 
    case GLSLstd450Step:
-      val->ssa->def = nir_sge(nb, src[1], src[0]);
-      return;
+      dest->def = nir_sge(nb, src[1], src[0]);
+      break;
 
    case GLSLstd450Length:
-      val->ssa->def = nir_fast_length(nb, src[0]);
-      return;
+      dest->def = nir_fast_length(nb, src[0]);
+      break;
    case GLSLstd450Distance:
-      val->ssa->def = nir_fast_distance(nb, src[0], src[1]);
-      return;
+      dest->def = nir_fast_distance(nb, src[0], src[1]);
+      break;
    case GLSLstd450Normalize:
-      val->ssa->def = nir_fast_normalize(nb, src[0]);
-      return;
+      dest->def = nir_fast_normalize(nb, src[0]);
+      break;
 
    case GLSLstd450Exp:
-      val->ssa->def = nir_fexp(nb, src[0]);
-      return;
+      dest->def = nir_fexp(nb, src[0]);
+      break;
 
    case GLSLstd450Log:
-      val->ssa->def = nir_flog(nb, src[0]);
-      return;
+      dest->def = nir_flog(nb, src[0]);
+      break;
 
    case GLSLstd450FClamp:
-      val->ssa->def = nir_fclamp(nb, src[0], src[1], src[2]);
-      return;
+      dest->def = nir_fclamp(nb, src[0], src[1], src[2]);
+      break;
    case GLSLstd450NClamp:
       nb->exact = true;
-      val->ssa->def = nir_fclamp(nb, src[0], src[1], src[2]);
+      dest->def = nir_fclamp(nb, src[0], src[1], src[2]);
       nb->exact = false;
-      return;
+      break;
    case GLSLstd450UClamp:
-      val->ssa->def = nir_uclamp(nb, src[0], src[1], src[2]);
-      return;
+      dest->def = nir_uclamp(nb, src[0], src[1], src[2]);
+      break;
    case GLSLstd450SClamp:
-      val->ssa->def = nir_iclamp(nb, src[0], src[1], src[2]);
-      return;
+      dest->def = nir_iclamp(nb, src[0], src[1], src[2]);
+      break;
 
    case GLSLstd450Cross: {
-      val->ssa->def = nir_cross3(nb, src[0], src[1]);
-      return;
+      dest->def = nir_cross3(nb, src[0], src[1]);
+      break;
    }
 
    case GLSLstd450SmoothStep: {
-      val->ssa->def = nir_smoothstep(nb, src[0], src[1], src[2]);
-      return;
+      dest->def = nir_smoothstep(nb, src[0], src[1], src[2]);
+      break;
    }
 
    case GLSLstd450FaceForward:
-      val->ssa->def =
+      dest->def =
          nir_bcsel(nb, nir_flt(nb, nir_fdot(nb, src[2], src[1]),
                                    NIR_IMM_FP(nb, 0.0)),
                        src[0], nir_fneg(nb, src[0]));
-      return;
+      break;
 
    case GLSLstd450Reflect:
       /* I - 2 * dot(N, I) * N */
-      val->ssa->def =
+      dest->def =
          nir_fsub(nb, src[0], nir_fmul(nb, NIR_IMM_FP(nb, 2.0),
                               nir_fmul(nb, nir_fdot(nb, src[0], src[1]),
                                            src[1])));
-      return;
+      break;
 
    case GLSLstd450Refract: {
       nir_ssa_def *I = src[0];
@@ -446,25 +445,25 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
                       nir_fmul(nb, nir_fadd(nb, nir_fmul(nb, eta, n_dot_i),
                                                 nir_fsqrt(nb, k)), N));
       /* XXX: bcsel, or if statement? */
-      val->ssa->def = nir_bcsel(nb, nir_flt(nb, k, zero), zero, result);
-      return;
+      dest->def = nir_bcsel(nb, nir_flt(nb, k, zero), zero, result);
+      break;
    }
 
    case GLSLstd450Sinh:
       /* 0.5 * (e^x - e^(-x)) */
-      val->ssa->def =
+      dest->def =
          nir_fmul_imm(nb, nir_fsub(nb, nir_fexp(nb, src[0]),
                                        nir_fexp(nb, nir_fneg(nb, src[0]))),
                           0.5f);
-      return;
+      break;
 
    case GLSLstd450Cosh:
       /* 0.5 * (e^x + e^(-x)) */
-      val->ssa->def =
+      dest->def =
          nir_fmul_imm(nb, nir_fadd(nb, nir_fexp(nb, src[0]),
                                        nir_fexp(nb, nir_fneg(nb, src[0]))),
                           0.5f);
-      return;
+      break;
 
    case GLSLstd450Tanh: {
       /* tanh(x) := (e^x - e^(-x)) / (e^x + e^(-x))
@@ -480,64 +479,64 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
       nir_ssa_def *x = nir_fclamp(nb, src[0],
                                   nir_imm_floatN_t(nb, -clamped_x, bit_size),
                                   nir_imm_floatN_t(nb, clamped_x, bit_size));
-      val->ssa->def =
+      dest->def =
          nir_fdiv(nb, nir_fsub(nb, nir_fexp(nb, x),
                                nir_fexp(nb, nir_fneg(nb, x))),
                   nir_fadd(nb, nir_fexp(nb, x),
                            nir_fexp(nb, nir_fneg(nb, x))));
-      return;
+      break;
    }
 
    case GLSLstd450Asinh:
-      val->ssa->def = nir_fmul(nb, nir_fsign(nb, src[0]),
+      dest->def = nir_fmul(nb, nir_fsign(nb, src[0]),
          nir_flog(nb, nir_fadd(nb, nir_fabs(nb, src[0]),
                       nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]),
                                                     1.0f)))));
-      return;
+      break;
    case GLSLstd450Acosh:
-      val->ssa->def = nir_flog(nb, nir_fadd(nb, src[0],
+      dest->def = nir_flog(nb, nir_fadd(nb, src[0],
          nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]),
                                         -1.0f))));
-      return;
+      break;
    case GLSLstd450Atanh: {
       nir_ssa_def *one = nir_imm_floatN_t(nb, 1.0, src[0]->bit_size);
-      val->ssa->def =
+      dest->def =
          nir_fmul_imm(nb, nir_flog(nb, nir_fdiv(nb, nir_fadd(nb, src[0], one),
                                        nir_fsub(nb, one, src[0]))),
                           0.5f);
-      return;
+      break;
    }
 
    case GLSLstd450Asin:
-      val->ssa->def = build_asin(nb, src[0], 0.086566724, -0.03102955, true);
-      return;
+      dest->def = build_asin(nb, src[0], 0.086566724, -0.03102955, true);
+      break;
 
    case GLSLstd450Acos:
-      val->ssa->def =
+      dest->def =
          nir_fsub(nb, nir_imm_floatN_t(nb, M_PI_2f, src[0]->bit_size),
                       build_asin(nb, src[0], 0.08132463, -0.02363318, false));
-      return;
+      break;
 
    case GLSLstd450Atan:
-      val->ssa->def = nir_atan(nb, src[0]);
-      return;
+      dest->def = nir_atan(nb, src[0]);
+      break;
 
    case GLSLstd450Atan2:
-      val->ssa->def = nir_atan2(nb, src[0], src[1]);
-      return;
+      dest->def = nir_atan2(nb, src[0], src[1]);
+      break;
 
    case GLSLstd450Frexp: {
       nir_ssa_def *exponent = nir_frexp_exp(nb, src[0]);
-      val->ssa->def = nir_frexp_sig(nb, src[0]);
+      dest->def = nir_frexp_sig(nb, src[0]);
       nir_store_deref(nb, vtn_nir_deref(b, w[6]), exponent, 0xf);
-      return;
+      break;
    }
 
    case GLSLstd450FrexpStruct: {
-      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
-      val->ssa->elems[0]->def = nir_frexp_sig(nb, src[0]);
-      val->ssa->elems[1]->def = nir_frexp_exp(nb, src[0]);
-      return;
+      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
+      dest->elems[0]->def = nir_frexp_sig(nb, src[0]);
+      dest->elems[1]->def = nir_frexp_exp(nb, src[0]);
+      break;
    }
 
    default: {
@@ -546,11 +545,13 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
       bool exact;
       nir_op op = vtn_nir_alu_op_for_spirv_glsl_opcode(b, entrypoint, execution_mode, &exact);
       b->nb.exact = exact;
-      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], NULL);
+      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], NULL);
       b->nb.exact = false;
-      return;
+      break;
    }
    }
+
+   vtn_push_ssa_value(b, w[2], dest);
 }
 
 static void
@@ -598,7 +599,7 @@ handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode,
       break;
    case GLSLstd450InterpolateAtSample:
    case GLSLstd450InterpolateAtOffset:
-      intrin->src[1] = nir_src_for_ssa(vtn_ssa_value(b, w[6])->def);
+      intrin->src[1] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[6]));
       break;
    default:
       vtn_fail("Invalid opcode");
@@ -629,8 +630,7 @@ vtn_handle_glsl450_instruction(struct vtn_builder *b, SpvOp ext_opcode,
    }
 
    case GLSLstd450MatrixInverse: {
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-      val->ssa = matrix_inverse(b, vtn_ssa_value(b, w[5]));
+      vtn_push_ssa_value(b, w[2], matrix_inverse(b, vtn_ssa_value(b, w[5])));
       break;
    }