spirv: Add a vtn_push_nir_ssa helper
authorJason Ekstrand <jason@jlekstrand.net>
Wed, 27 May 2020 22:49:47 +0000 (17:49 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Fri, 24 Jul 2020 03:43:21 +0000 (22:43 -0500)
This makes it easy to write a simple NIR SSA value

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5278>

src/compiler/spirv/spirv_to_nir.c
src/compiler/spirv/vtn_alu.c
src/compiler/spirv/vtn_glsl450.c
src/compiler/spirv/vtn_opencl.c
src/compiler/spirv/vtn_private.h
src/compiler/spirv/vtn_variables.c

index 308102f4bf05a93b5e0916da6127c96a6866d53f..3940c2347f80ca87d86edf97841b951e3b9e83fe 100644 (file)
@@ -302,6 +302,21 @@ vtn_ssa_value(struct vtn_builder *b, uint32_t value_id)
    }
 }
 
+struct vtn_value *
+vtn_push_nir_ssa(struct vtn_builder *b, uint32_t value_id, nir_ssa_def *def)
+{
+   /* Types for all SPIR-V SSA values are set as part of a pre-pass so the
+    * type will be valid by the time we get here.
+    */
+   struct vtn_type *type = vtn_get_value_type(b, value_id);
+   vtn_fail_if(def->num_components != glsl_get_vector_elements(type->type) ||
+               def->bit_size != glsl_get_bit_size(type->type),
+               "Mismatch between NIR and SPIR-V type.");
+   struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, type->type);
+   ssa->def = def;
+   return vtn_push_ssa(b, value_id, type, ssa);
+}
+
 static char *
 vtn_string_literal(struct vtn_builder *b, const uint32_t *words,
                    unsigned word_count, unsigned *words_used)
@@ -2672,11 +2687,9 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
       }
    }
 
-   struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, ret_type->type);
-   ssa->def = &instr->dest.ssa;
-   vtn_push_ssa(b, w[2], ret_type, ssa);
-
    nir_builder_instr_insert(&b->nb, &instr->instr);
+
+   vtn_push_nir_ssa(b, w[2], &instr->dest.ssa);
 }
 
 static void
@@ -3026,9 +3039,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
       if (nir_intrinsic_dest_components(intrin) != dest_components)
          result = nir_channels(&b->nb, result, (1 << dest_components) - 1);
 
-      struct vtn_value *val =
-         vtn_push_ssa(b, w[2], type, vtn_create_ssa_value(b, type->type));
-      val->ssa->def = result;
+      vtn_push_nir_ssa(b, w[2], result);
    } else {
       nir_builder_instr_insert(&b->nb, &intrin->instr);
    }
@@ -3318,10 +3329,7 @@ vtn_handle_atomics(struct vtn_builder *b, SpvOp opcode,
                         glsl_get_vector_elements(type->type),
                         glsl_get_bit_size(type->type), NULL);
 
-      struct vtn_ssa_value *ssa = rzalloc(b, struct vtn_ssa_value);
-      ssa->def = &atomic->dest.ssa;
-      ssa->type = type->type;
-      vtn_push_ssa(b, w[2], type, ssa);
+      vtn_push_nir_ssa(b, w[2], &atomic->dest.ssa);
    }
 
    nir_builder_instr_insert(&b->nb, &atomic->instr);
@@ -4777,9 +4785,7 @@ vtn_handle_ptr(struct vtn_builder *b, SpvOp opcode,
       unreachable("Invalid ptr operation");
    }
 
-   struct vtn_ssa_value *ssa_value = vtn_create_ssa_value(b, type);
-   ssa_value->def = def;
-   vtn_push_ssa(b, w[2], vtn_type, ssa_value);
+   vtn_push_nir_ssa(b, w[2], def);
 }
 
 static bool
@@ -5134,11 +5140,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
       nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 1, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-      struct vtn_type *res_type = vtn_get_type(b, w[1]);
-      struct vtn_ssa_value *val = vtn_create_ssa_value(b, res_type->type);
-      val->def = &intrin->dest.ssa;
-
-      vtn_push_ssa(b, w[2], res_type, val);
+      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
       break;
    }
 
@@ -5179,10 +5181,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
          result = nir_pack_64_2x32(&b->nb, &intrin->dest.ssa);
       }
 
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-      val->type = type;
-      val->ssa = vtn_create_ssa_value(b, dest_type);
-      val->ssa->def = result;
+      vtn_push_nir_ssa(b, w[2], result);
       break;
    }
 
index 60e88144ceb196175c50b5752fb247297b790c46..035bd857f1924ff8a54cf486d72674bf9c509dee 100644 (file)
@@ -699,7 +699,6 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
    struct vtn_type *type = vtn_get_type(b, w[1]);
    struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]);
    struct nir_ssa_def *src = vtn_src->def;
-   struct vtn_ssa_value *val = vtn_create_ssa_value(b, type->type);
 
    vtn_assert(glsl_type_is_vector_or_scalar(vtn_src->type));
 
@@ -707,6 +706,7 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
                glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
                "Source and destination of OpBitcast must have the same "
                "total number of bits");
-   val->def = nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
-   vtn_push_ssa(b, w[2], type, val);
+   nir_ssa_def *val =
+      nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
+   vtn_push_nir_ssa(b, w[2], val);
 }
index 061ffd092820e50a88da708e08c2434dc66713ea..933cf1d407ef937bf450a4be96d4c2e3902a48e2 100644 (file)
@@ -557,10 +557,6 @@ static void
 handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode,
                              const uint32_t *w, unsigned count)
 {
-   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
-   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-   val->ssa = vtn_create_ssa_value(b, dest_type);
-
    nir_intrinsic_op op;
    switch (opcode) {
    case GLSLstd450InterpolateAtCentroid:
@@ -615,13 +611,11 @@ handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode,
 
    nir_builder_instr_insert(&b->nb, &intrin->instr);
 
-   if (vec_array_deref) {
-      assert(vec_deref);
-      val->ssa->def = nir_vector_extract(&b->nb, &intrin->dest.ssa,
-                                         vec_deref->arr.index.ssa);
-   } else {
-      val->ssa->def = &intrin->dest.ssa;
-   }
+   nir_ssa_def *def = &intrin->dest.ssa;
+   if (vec_array_deref)
+      def = nir_vector_extract(&b->nb, def, vec_deref->arr.index.ssa);
+
+   vtn_push_nir_ssa(b, w[2], def);
 }
 
 bool
@@ -630,10 +624,7 @@ vtn_handle_glsl450_instruction(struct vtn_builder *b, SpvOp ext_opcode,
 {
    switch ((enum GLSLstd450)ext_opcode) {
    case GLSLstd450Determinant: {
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-      val->ssa = rzalloc(b, struct vtn_ssa_value);
-      val->ssa->type = vtn_get_type(b, w[1])->type;
-      val->ssa->def = build_mat_det(b, vtn_ssa_value(b, w[5]));
+      vtn_push_nir_ssa(b, w[2], build_mat_det(b, vtn_ssa_value(b, w[5])));
       break;
    }
 
index 57d39ee9e64999da91e907db6ccf24806fb237a9..ba3d00a7ce4ca0ed05ecc908421b5bce614ffbaa 100644 (file)
@@ -50,9 +50,7 @@ handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
 
    nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, dest_type);
    if (result) {
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-      val->ssa = vtn_create_ssa_value(b, dest_type);
-      val->ssa->def = result;
+      vtn_push_nir_ssa(b, w[2], result);
    } else {
       vtn_assert(dest_type == glsl_void_type());
    }
@@ -256,9 +254,7 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
       }
    }
    if (load) {
-      struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, dest_type);
-      ssa->def = nir_vec(&b->nb, ncomps, components);
-      vtn_push_ssa(b, w[2], type, ssa);
+      vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components));
    }
 }
 
index 9d9e7883481a3df4dacf7cc4a12f1cdf31aaf2cb..fd1ce33c6beda7fd73253743c67b0a613424cf05 100644 (file)
@@ -785,6 +785,9 @@ vtn_get_type(struct vtn_builder *b, uint32_t value_id)
 
 struct vtn_ssa_value *vtn_ssa_value(struct vtn_builder *b, uint32_t value_id);
 
+struct vtn_value *vtn_push_nir_ssa(struct vtn_builder *b, uint32_t value_id,
+                                   nir_ssa_def *def);
+
 struct vtn_value *vtn_push_pointer(struct vtn_builder *b,
                                    uint32_t value_id,
                                    struct vtn_pointer *ptr);
index fb3eef7de717de8dca8ece9d389e0eaded2992ef..6bd6bd5873e4fc69fce696f123b78cc64d03f247 100644 (file)
@@ -2735,9 +2735,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
                            nir_imm_int(&b->nb, 0u)),
                   nir_imm_int(&b->nb, stride));
 
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-      val->ssa = vtn_create_ssa_value(b, glsl_uint_type());
-      val->ssa->def = array_length;
+      vtn_push_nir_ssa(b, w[2], array_length);
       break;
    }