zink: use helper function to handle uvec/bvec types
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Wed, 10 Jun 2020 15:00:16 +0000 (11:00 -0400)
committerMarge Bot <eric+marge@anholt.net>
Wed, 22 Jul 2020 14:01:29 +0000 (14:01 +0000)
bit_size of 1 means we use a bool type here, 32 means uint, so we can just
handle that automatically for all relevant cases

ref shaders@glsl-vs-continue-in-switch-in-do-while

Reviewed-by: Erik Faye-Lund <erik.faye-lund@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5911>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index 3d1e56d5f0e94765c31b33c27abe96335286b2d6..bd35fca81f29833c76baee0f3657a5fcf82d2432 100644 (file)
@@ -632,6 +632,17 @@ emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
    }
 }
 
+static SpvId
+get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_components)
+{
+   if (bit_size == 1)
+      return get_bvec_type(ctx, num_components);
+   if (bit_size == 32)
+      return get_uvec_type(ctx, bit_size, num_components);
+   unreachable("unhandled register bit size");
+   return 0;
+}
+
 static SpvId
 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
 {
@@ -656,7 +667,7 @@ get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
    assert(!reg->base_offset);
 
    SpvId var = get_var_from_reg(ctx, reg->reg);
-   SpvId type = get_uvec_type(ctx, reg->reg->bit_size, reg->reg->num_components);
+   SpvId type = get_vec_from_bit_size(ctx, reg->reg->bit_size, reg->reg->num_components);
    return spirv_builder_emit_load(&ctx->builder, type, var);
 }
 
@@ -1503,19 +1514,17 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
    SpvId constant;
    if (num_components > 1) {
       SpvId components[num_components];
-      SpvId type;
+      SpvId type = get_vec_from_bit_size(ctx, bit_size, num_components);
       if (bit_size == 1) {
          for (int i = 0; i < num_components; i++)
             components[i] = spirv_builder_const_bool(&ctx->builder,
                                                      load_const->value[i].b);
 
-         type = get_bvec_type(ctx, num_components);
       } else {
          for (int i = 0; i < num_components; i++)
             components[i] = emit_uint_const(ctx, bit_size,
                                             load_const->value[i].u32);
 
-         type = get_uvec_type(ctx, bit_size, num_components);
       }
       constant = spirv_builder_const_composite(&ctx->builder, type,
                                                components, num_components);
@@ -2302,7 +2311,7 @@ nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info
    /* emit a block only for the variable declarations */
    start_block(&ctx, spirv_builder_new_id(&ctx.builder));
    foreach_list_typed(nir_register, reg, node, &entry->registers) {
-      SpvId type = get_uvec_type(&ctx, reg->bit_size, reg->num_components);
+      SpvId type = get_vec_from_bit_size(&ctx, reg->bit_size, reg->num_components);
       SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
                                                       SpvStorageClassFunction,
                                                       type);