zink/spirv: store all values as uint.
authorDave Airlie <airlied@redhat.com>
Fri, 19 Oct 2018 01:02:26 +0000 (11:02 +1000)
committerErik Faye-Lund <erik.faye-lund@collabora.com>
Mon, 28 Oct 2019 08:51:43 +0000 (08:51 +0000)
This adds bitcasting to uint everywhere for now,
and stores all spir-v ssa values as uints.

It also casts bool to 0/0xffffffff for now
(nir 1-bit bools may be coming in the future).

This fixes a lot of piglit tests to pass now

Acked-by: Jordan Justen <jordan.l.justen@intel.com>
src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index 166bdf2994913ff92d3b299d075cca97b4b43b10..e074f39df8a279cb9fd86eb6601f5a09b3d62dea 100644 (file)
@@ -35,7 +35,9 @@ struct ntv_context {
 
    gl_shader_stage stage;
    SpvId inputs[PIPE_MAX_SHADER_INPUTS][4];
+   SpvId input_types[PIPE_MAX_SHADER_INPUTS][4];
    SpvId outputs[PIPE_MAX_SHADER_OUTPUTS][4];
+   SpvId output_types[PIPE_MAX_SHADER_OUTPUTS][4];
 
    SpvId ubos[128];
    size_t num_ubos;
@@ -48,6 +50,25 @@ struct ntv_context {
    size_t num_defs;
 };
 
+static SpvId
+get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
+                  const float values[]);
+
+static SpvId
+get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
+                  const uint32_t values[]);
+
+static SpvId
+emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
+
+static SpvId
+emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
+           SpvId src0, SpvId src1);
+
+static SpvId
+emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
+           SpvId src0, SpvId src1, SpvId src2);
+
 static SpvId
 get_bvec_type(struct ntv_context *ctx, int num_components)
 {
@@ -75,9 +96,37 @@ get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 }
 
 static SpvId
-get_dest_type(struct ntv_context *ctx, nir_dest *dest)
+get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
+{
+   assert(bit_size == 32); // only 32-bit ints supported so far
+
+   SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
+   if (num_components > 1)
+      return spirv_builder_type_vector(&ctx->builder, int_type,
+                                       num_components);
+
+   assert(num_components == 1);
+   return int_type;
+}
+
+static SpvId
+get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
+{
+   assert(bit_size == 32); // only 32-bit uints supported so far
+
+   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
+   if (num_components > 1)
+      return spirv_builder_type_vector(&ctx->builder, uint_type,
+                                       num_components);
+
+   assert(num_components == 1);
+   return uint_type;
+}
+
+static SpvId
+get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
 {
-   return get_fvec_type(ctx, nir_dest_bit_size(*dest),
+   return get_uvec_type(ctx, nir_dest_bit_size(*dest),
                              nir_dest_num_components(*dest));
 }
 
@@ -159,6 +208,7 @@ emit_input(struct ntv_context *ctx, struct nir_variable *var)
    assert(var->data.location_frac < 4);
    assert(ctx->inputs[var->data.driver_location][var->data.location_frac] == 0);
    ctx->inputs[var->data.driver_location][var->data.location_frac] = var_id;
+   ctx->input_types[var->data.driver_location][var->data.location_frac] = vec_type;
 
    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
@@ -211,6 +261,7 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var)
    assert(var->data.location_frac < 4);
    assert(ctx->outputs[var->data.driver_location][var->data.location_frac] == 0);
    ctx->outputs[var->data.driver_location][var->data.location_frac] = var_id;
+   ctx->output_types[var->data.driver_location][var->data.location_frac] = vec_type;
 
    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
@@ -274,7 +325,7 @@ static void
 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
 {
    uint32_t size = glsl_count_attribute_slots(var->type, false);
-   SpvId vec4_type = get_fvec_type(ctx, 32, 4);
+   SpvId vec4_type = get_uvec_type(ctx, 32, 4);
    SpvId array_length = spirv_builder_const_uint(&ctx->builder, 32, size);
    SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
                                                array_length);
@@ -320,21 +371,27 @@ emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
 }
 
 static SpvId
-get_src(struct ntv_context *ctx, nir_src *src)
+get_src_uint_ssa(struct ntv_context *ctx, nir_ssa_def *ssa)
+{
+   assert(ssa->index < ctx->num_defs);
+   assert(ctx->defs[ssa->index] != 0);
+   return ctx->defs[ssa->index];
+}
+
+static SpvId
+get_src_uint(struct ntv_context *ctx, nir_src *src)
 {
    assert(src->is_ssa);
-   assert(src->ssa->index < ctx->num_defs);
-   assert(ctx->defs[src->ssa->index] != 0);
-   return ctx->defs[src->ssa->index];
+   return get_src_uint_ssa(ctx, src->ssa);
 }
 
 static SpvId
-get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
+get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
 {
    assert(!alu->src[src].negate);
    assert(!alu->src[src].abs);
 
-   SpvId def = get_src(ctx, &alu->src[src].src);
+   SpvId def = get_src_uint(ctx, &alu->src[src].src);
 
    unsigned used_channels = 0;
    bool need_swizzle = false;
@@ -358,23 +415,27 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
 
    int bit_size = nir_src_bit_size(alu->src[src].src);
 
+   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
    if (used_channels == 1) {
-      SpvId result_type = spirv_builder_type_float(&ctx->builder, bit_size);
       uint32_t indices[] =  { alu->src[src].swizzle[0] };
-      return spirv_builder_emit_composite_extract(&ctx->builder, result_type,
+      return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
                                                   def, indices,
                                                   ARRAY_SIZE(indices));
    } else if (live_channels == 1) {
-      SpvId type = get_fvec_type(ctx, bit_size, used_channels);
+      SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
+                                                  used_channels);
 
       SpvId constituents[NIR_MAX_VEC_COMPONENTS];
       for (unsigned i = 0; i < used_channels; ++i)
         constituents[i] = def;
 
-      return spirv_builder_emit_composite_construct(&ctx->builder, type,
+      return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
                                                     constituents,
                                                     used_channels);
    } else {
+      SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
+                                                  used_channels);
+
       uint32_t components[NIR_MAX_VEC_COMPONENTS];
       size_t num_components = 0;
       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
@@ -384,32 +445,97 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
          components[num_components++] = alu->src[src].swizzle[i];
       }
 
-      SpvId vecType = get_fvec_type(ctx, bit_size, used_channels);
-      return spirv_builder_emit_vector_shuffle(&ctx->builder, vecType,
+      return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
                                         def, def, components, num_components);
    }
 }
 
 static void
-store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
+store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
 {
    assert(result != 0);
    assert(ssa->index < ctx->num_defs);
    ctx->defs[ssa->index] = result;
 }
 
+static SpvId
+bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
+{
+   SpvId otype = get_uvec_type(ctx, 32, num_components);
+   uint32_t zeros[4] = { 0, 0, 0, 0 };
+   uint32_t ones[4] = { 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff };
+   SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
+   SpvId one = get_uvec_constant(ctx, 32, num_components, ones);
+   return emit_triop(ctx, SpvOpSelect, otype, value, one, zero);
+}
+
+static SpvId
+uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
+{
+   SpvId type = get_bvec_type(ctx, num_components);
+
+   uint32_t zeros[NIR_MAX_VEC_COMPONENTS] = { 0 };
+   SpvId zero = get_uvec_constant(ctx, 32, num_components, zeros);
+
+   return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
+}
+
+static SpvId
+bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
+                unsigned num_components)
+{
+   SpvId type = get_uvec_type(ctx, bit_size, num_components);
+   return emit_unop(ctx, SpvOpBitcast, type, value);
+}
+
+static SpvId
+bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
+                unsigned num_components)
+{
+   SpvId type = get_ivec_type(ctx, bit_size, num_components);
+   return emit_unop(ctx, SpvOpBitcast, type, value);
+}
+
+static SpvId
+bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
+               unsigned num_components)
+{
+   SpvId type = get_fvec_type(ctx, bit_size, num_components);
+   return emit_unop(ctx, SpvOpBitcast, type, value);
+}
+
 static void
-store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result)
+store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
 {
    assert(dest->is_ssa);
-   store_ssa_def(ctx, &dest->ssa, result);
+   store_ssa_def_uint(ctx, &dest->ssa, result);
 }
 
 static void
-store_alu_result(struct ntv_context *ctx, nir_alu_dest *dest, SpvId result)
+store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
 {
-   assert(!dest->saturate);
-   return store_dest(ctx, &dest->dest, result);
+   unsigned num_components = nir_dest_num_components(*dest);
+   unsigned bit_size = nir_dest_bit_size(*dest);
+
+   switch (nir_alu_type_get_base_type(type)) {
+   case nir_type_bool:
+      assert(bit_size == 1);
+      result = bvec_to_uvec(ctx, result, num_components);
+      break;
+
+   case nir_type_uint:
+      break; /* nothing to do! */
+
+   case nir_type_int:
+   case nir_type_float:
+      result = bitcast_to_uvec(ctx, result, bit_size, num_components);
+      break;
+
+   default:
+      unreachable("unsupported nir_alu_type");
+   }
+
+   store_dest_uint(ctx, dest, result);
 }
 
 static SpvId
@@ -452,7 +578,7 @@ emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
 
 static SpvId
 get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
-                 float values[])
+                  const float values[])
 {
    assert(bit_size == 32);
 
@@ -471,6 +597,98 @@ get_fvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
    return spirv_builder_const_float(&ctx->builder, bit_size, values[0]);
 }
 
+static SpvId
+get_uvec_constant(struct ntv_context *ctx, int bit_size, int num_components,
+                  const uint32_t values[])
+{
+   assert(bit_size == 32);
+
+   if (num_components > 1) {
+      SpvId components[num_components];
+      for (int i = 0; i < num_components; i++)
+         components[i] = spirv_builder_const_uint(&ctx->builder, bit_size,
+                                                  values[i]);
+
+      SpvId type = get_uvec_type(ctx, bit_size, num_components);
+      return spirv_builder_const_composite(&ctx->builder, type, components,
+                                           num_components);
+   }
+
+   assert(num_components == 1);
+   return spirv_builder_const_uint(&ctx->builder, bit_size, values[0]);
+}
+
+static inline unsigned
+alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
+{
+   if (nir_op_infos[instr->op].input_sizes[src] > 0)
+      return nir_op_infos[instr->op].input_sizes[src];
+
+   if (instr->dest.dest.is_ssa)
+      return instr->dest.dest.ssa.num_components;
+   else
+      return instr->dest.dest.reg.reg->num_components;
+}
+
+static SpvId
+get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
+{
+   SpvId uint_value = get_alu_src_uint(ctx, alu, src);
+
+   unsigned num_components = alu_instr_src_components(alu, src);
+   unsigned bit_size = nir_src_bit_size(alu->src[src].src);
+   nir_alu_type type = nir_op_infos[alu->op].input_types[src];
+
+   switch (nir_alu_type_get_base_type(type)) {
+   case nir_type_bool:
+      assert(bit_size == 1);
+      return uvec_to_bvec(ctx, uint_value, num_components);
+
+   case nir_type_int:
+      return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
+
+   case nir_type_uint:
+      return uint_value;
+
+   case nir_type_float:
+      return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
+
+   default:
+      unreachable("unknown nir_alu_type");
+   }
+}
+
+static void
+store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
+{
+   assert(!alu->dest.saturate);
+   return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
+}
+
+static SpvId
+get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
+{
+   unsigned num_components = nir_dest_num_components(*dest);
+   unsigned bit_size = nir_dest_bit_size(*dest);
+
+   switch (nir_alu_type_get_base_type(type)) {
+   case nir_type_bool:
+      return get_bvec_type(ctx, num_components);
+
+   case nir_type_int:
+      return get_ivec_type(ctx, bit_size, num_components);
+
+   case nir_type_uint:
+      return get_uvec_type(ctx, bit_size, num_components);
+
+   case nir_type_float:
+      return get_fvec_type(ctx, bit_size, num_components);
+
+   default:
+      unreachable("unsupported nir_alu_type");
+   }
+}
+
 static void
 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
 {
@@ -478,7 +696,10 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
       src[i] = get_alu_src(ctx, alu, i);
 
-   SpvId dest_type = get_dest_type(ctx, &alu->dest.dest);
+   SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
+                                   nir_op_infos[alu->op].output_type);
+   unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
+   unsigned num_components = nir_dest_num_components(alu->dest.dest);
 
    SpvId result = 0;
    switch (alu->op) {
@@ -521,9 +742,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
       assert(nir_op_infos[alu->op].num_inputs == 1);
       float one[4] = { 1, 1, 1, 1 };
       src[1] = src[0];
-      src[0] = get_fvec_constant(ctx, nir_dest_bit_size(alu->dest.dest),
-                                     nir_dest_num_components(alu->dest.dest),
-                                     one);
+      src[0] = get_fvec_constant(ctx, bit_size, num_components, one);
       result = emit_binop(ctx, SpvOpFDiv, dest_type, src[0], src[1]);
       }
       break;
@@ -631,20 +850,20 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
       return;
    }
 
-   store_alu_result(ctx, &alu->dest, result);
+   store_alu_result(ctx, alu, result);
 }
 
 static void
 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
 {
-   float values[NIR_MAX_VEC_COMPONENTS];
+   uint32_t values[NIR_MAX_VEC_COMPONENTS];
    for (int i = 0; i < load_const->def.num_components; ++i)
-      values[i] = load_const->value[i].f32;
+      values[i] = load_const->value[i].u32;
 
-   SpvId constant = get_fvec_constant(ctx, load_const->def.bit_size,
-                                            load_const->def.num_components,
-                                            values);
-   store_ssa_def(ctx, &load_const->def, constant);
+   SpvId constant = get_uvec_constant(ctx, load_const->def.bit_size,
+                                           load_const->def.num_components,
+                                           values);
+   store_ssa_def_uint(ctx, &load_const->def, constant);
 }
 
 static void
@@ -652,17 +871,22 @@ emit_load_input(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
    nir_const_value *const_offset = nir_src_as_const_value(intr->src[0]);
    if (const_offset) {
-      SpvId type = get_dest_type(ctx, &intr->dest);
-
       int driver_location = (int)nir_intrinsic_base(intr) + const_offset->u32;
       assert(driver_location < PIPE_MAX_SHADER_INPUTS);
       int location_frac = nir_intrinsic_component(intr);
       assert(location_frac < 4);
 
       SpvId ptr = ctx->inputs[driver_location][location_frac];
-      assert(ptr > 0);
+      SpvId type = ctx->input_types[driver_location][location_frac];
+      assert(ptr && type);
 
-      store_dest(ctx, &intr->dest, spirv_builder_emit_load(&ctx->builder, type, ptr));
+      SpvId result = spirv_builder_emit_load(&ctx->builder, type, ptr);
+
+      unsigned num_components = nir_dest_num_components(intr->dest);
+      unsigned bit_size = nir_dest_bit_size(intr->dest);
+      result = bitcast_to_uvec(ctx, result, bit_size, num_components);
+
+      store_dest_uint(ctx, &intr->dest, result);
    } else
       unreachable("input-addressing not yet supported");
 }
@@ -676,10 +900,10 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 
    nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
    if (const_offset) {
-      SpvId vec4_type = get_fvec_type(ctx, 32, 4);
+      SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
       SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                       SpvStorageClassUniform,
-                                                      vec4_type);
+                                                      uvec4_type);
 
       unsigned idx = const_offset->u32;
       SpvId member = spirv_builder_const_uint(&ctx->builder, 32, 0);
@@ -688,9 +912,9 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
                                                   ctx->ubos[0], offsets,
                                                   ARRAY_SIZE(offsets));
-      SpvId result = spirv_builder_emit_load(&ctx->builder, vec4_type, ptr);
+      SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
 
-      SpvId type = get_dest_type(ctx, &intr->dest);
+      SpvId type = get_dest_uvec_type(ctx, &intr->dest);
       unsigned num_components = nir_dest_num_components(intr->dest);
       if (num_components == 1) {
          uint32_t components[] = { 0 };
@@ -700,10 +924,10 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
                                                        1);
       } else if (num_components < 4) {
          SpvId constituents[num_components];
-         SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
+         SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
          for (uint32_t i = 0; i < num_components; ++i)
             constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
-                                                                   float_type,
+                                                                   uint_type,
                                                                    result, &i,
                                                                    1);
 
@@ -713,7 +937,7 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
                                                          num_components);
       }
 
-      store_dest(ctx, &intr->dest, result);
+      store_dest_uint(ctx, &intr->dest, result);
    } else
       unreachable("uniform-addressing not yet supported");
 }
@@ -731,8 +955,10 @@ emit_store_output(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       SpvId ptr = ctx->outputs[driver_location][location_frac];
       assert(ptr > 0);
 
-      SpvId src = get_src(ctx, &intr->src[0]);
-      spirv_builder_emit_store(&ctx->builder, ptr, src);
+      SpvId src = get_src_uint(ctx, &intr->src[0]);
+      SpvId spirv_type = ctx->output_types[driver_location][location_frac];
+      SpvId result = emit_unop(ctx, SpvOpBitcast, spirv_type, src);
+      spirv_builder_emit_store(&ctx->builder, ptr, result);
    } else
       unreachable("output-addressing not yet supported");
 }
@@ -766,8 +992,17 @@ emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
    SpvId type = get_fvec_type(ctx, undef->def.bit_size,
                               undef->def.num_components);
 
-   store_ssa_def(ctx, &undef->def,
-                 spirv_builder_emit_undef(&ctx->builder, type));
+   store_ssa_def_uint(ctx, &undef->def,
+                      spirv_builder_emit_undef(&ctx->builder, type));
+}
+
+static SpvId
+get_src_float(struct ntv_context *ctx, nir_src *src)
+{
+   SpvId def = get_src_uint(ctx, src);
+   unsigned num_components = nir_src_num_components(*src);
+   unsigned bit_size = nir_src_bit_size(*src);
+   return bitcast_to_fvec(ctx, def, bit_size, num_components);
 }
 
 static void
@@ -779,17 +1014,18 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 
    bool has_proj = false;
    SpvId coord = 0, proj;
-   unsigned coord_size;
+   unsigned coord_components;
    for (unsigned i = 0; i < tex->num_srcs; i++) {
       switch (tex->src[i].src_type) {
       case nir_tex_src_coord:
-         coord = get_src(ctx, &tex->src[i].src);
-         coord_size = nir_src_num_components(tex->src[i].src);
+         coord = get_src_float(ctx, &tex->src[i].src);
+         coord_components = nir_src_num_components(tex->src[i].src);
          break;
 
       case nir_tex_src_projector:
          has_proj = true;
-         proj = get_src(ctx, &tex->src[i].src);
+         proj = get_src_float(ctx, &tex->src[i].src);
+         assert(nir_src_num_components(tex->src[i].src) == 1);
          break;
 
       default:
@@ -811,25 +1047,25 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
    SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
                                         ctx->samplers[tex->texture_index]);
 
-   SpvId dest_type = get_dest_type(ctx, &tex->dest);
+   SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
 
    SpvId result;
    if (has_proj) {
-      SpvId constituents[coord_size + 1];
+      SpvId constituents[coord_components + 1];
       SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
-      for (uint32_t i = 0; i < coord_size; ++i)
+      for (uint32_t i = 0; i < coord_components; ++i)
          constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
                                               float_type,
                                               coord,
                                               &i, 1);
 
-      constituents[coord_size++] = proj;
+      constituents[coord_components++] = proj;
 
-      SpvId vec_type = get_fvec_type(ctx, 32, coord_size);
+      SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
       SpvId merged = spirv_builder_emit_composite_construct(&ctx->builder,
                                                             vec_type,
                                                             constituents,
-                                                            coord_size);
+                                                            coord_components);
 
       result = spirv_builder_emit_image_sample_proj_implicit_lod(&ctx->builder,
                                                                  dest_type,
@@ -842,7 +1078,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
    spirv_builder_emit_decoration(&ctx->builder, result,
                                  SpvDecorationRelaxedPrecision);
 
-   store_dest(ctx, &tex->dest, result);
+   store_dest(ctx, &tex->dest, result, tex->dest_type);
 }
 
 static void