zink: Use store_dest_raw instead of storing an uint
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
index 35c2fc4c0f39d71e423c09c5d17a792480c3c905..eecbca3e5dce559bcf923d9dbe1a2efc34732893 100644 (file)
@@ -38,8 +38,9 @@ struct ntv_context {
 
    SpvId ubos[128];
    size_t num_ubos;
+   SpvId image_types[PIPE_MAX_SAMPLERS];
    SpvId samplers[PIPE_MAX_SAMPLERS];
-   size_t num_samplers;
+   unsigned samplers_used : PIPE_MAX_SAMPLERS;
    SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
    size_t num_entry_ifaces;
 
@@ -56,7 +57,7 @@ struct ntv_context {
    bool block_started;
    SpvId loop_break, loop_cont;
 
-   SpvId front_face_var, vertex_id_var;
+   SpvId front_face_var, instance_id_var, vertex_id_var;
 };
 
 static SpvId
@@ -139,9 +140,9 @@ get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
+   assert(bit_size == 32); // only 32-bit ints supported so far
 
-   SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
+   SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
    if (num_components > 1)
       return spirv_builder_type_vector(&ctx->builder, int_type,
                                        num_components);
@@ -153,9 +154,9 @@ get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
+   assert(bit_size == 32); // only 32-bit uints supported so far
 
-   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
+   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
    if (num_components > 1)
       return spirv_builder_type_vector(&ctx->builder, uint_type,
                                        num_components);
@@ -167,8 +168,8 @@ get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
 {
-   return get_uvec_type(ctx, nir_dest_bit_size(*dest),
-                             nir_dest_num_components(*dest));
+   unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
+   return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
 }
 
 static SpvId
@@ -330,6 +331,7 @@ emit_output(struct ntv_context *ctx, struct nir_variable *var)
          switch (var->data.location) {
          case FRAG_RESULT_COLOR:
             spirv_builder_emit_location(&ctx->builder, var_id, 0);
+            spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
             break;
 
          case FRAG_RESULT_DEPTH:
@@ -367,7 +369,7 @@ type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
    case GLSL_SAMPLER_DIM_CUBE:
       return SpvDimCube;
    case GLSL_SAMPLER_DIM_RECT:
-      return SpvDimRect;
+      return SpvDim2D;
    case GLSL_SAMPLER_DIM_BUF:
       return SpvDimBuffer;
    case GLSL_SAMPLER_DIM_EXTERNAL:
@@ -382,33 +384,98 @@ type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
    return SpvDim2D;
 }
 
+uint32_t
+zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
+{
+   if (stage == MESA_SHADER_NONE ||
+       stage >= MESA_SHADER_COMPUTE) {
+      unreachable("not supported");
+   } else {
+      uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
+                                                 PIPE_MAX_SHADER_SAMPLER_VIEWS);
+
+      switch (type) {
+      case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
+         assert(index < PIPE_MAX_CONSTANT_BUFFERS);
+         return stage_offset + index;
+
+      case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
+         assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
+         return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
+
+      default:
+         unreachable("unexpected type");
+      }
+   }
+}
+
 static void
 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
 {
+   const struct glsl_type *type = glsl_without_array(var->type);
+
    bool is_ms;
-   SpvDim dimension = type_to_dim(glsl_get_sampler_dim(var->type), &is_ms);
-   SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
-   SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
-                            dimension, false, glsl_sampler_type_is_array(var->type), is_ms, 1,
-                            SpvImageFormatUnknown);
+   SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
+
+   SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
+   SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
+                                               dimension, false,
+                                               glsl_sampler_type_is_array(type),
+                                               is_ms, 1,
+                                               SpvImageFormatUnknown);
 
    SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
                                                          image_type);
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                    SpvStorageClassUniformConstant,
                                                    sampled_type);
-   SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
-                                         SpvStorageClassUniformConstant);
 
-   if (var->name)
-      spirv_builder_emit_name(&ctx->builder, var_id, var->name);
+   if (glsl_type_is_array(var->type)) {
+      for (int i = 0; i < glsl_get_length(var->type); ++i) {
+         SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
+                                               SpvStorageClassUniformConstant);
 
-   assert(ctx->num_samplers < ARRAY_SIZE(ctx->samplers));
-   ctx->samplers[ctx->num_samplers++] = var_id;
+         if (var->name) {
+            char element_name[100];
+            snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
+            spirv_builder_emit_name(&ctx->builder, var_id, var->name);
+         }
 
-   spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
-                                     var->data.descriptor_set);
-   spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
+         int index = var->data.binding + i;
+         assert(!(ctx->samplers_used & (1 << index)));
+         assert(!ctx->image_types[index]);
+         ctx->image_types[index] = image_type;
+         ctx->samplers[index] = var_id;
+         ctx->samplers_used |= 1 << index;
+
+         spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
+                                           var->data.descriptor_set);
+         int binding = zink_binding(ctx->stage,
+                                    VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+                                    var->data.binding + i);
+         spirv_builder_emit_binding(&ctx->builder, var_id, binding);
+      }
+   } else {
+      SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
+                                            SpvStorageClassUniformConstant);
+
+      if (var->name)
+         spirv_builder_emit_name(&ctx->builder, var_id, var->name);
+
+      int index = var->data.binding;
+      assert(!(ctx->samplers_used & (1 << index)));
+      assert(!ctx->image_types[index]);
+      ctx->image_types[index] = image_type;
+      ctx->samplers[index] = var_id;
+      ctx->samplers_used |= 1 << index;
+
+      spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
+                                        var->data.descriptor_set);
+      int binding = zink_binding(ctx->stage,
+                                 VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
+                                 var->data.binding);
+      spirv_builder_emit_binding(&ctx->builder, var_id, binding);
+   }
 }
 
 static void
@@ -448,7 +515,10 @@ emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
 
    spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
                                      var->data.descriptor_set);
-   spirv_builder_emit_binding(&ctx->builder, var_id, var->data.binding);
+   int binding = zink_binding(ctx->stage,
+                              VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
+                              var->data.binding);
+   spirv_builder_emit_binding(&ctx->builder, var_id, binding);
 }
 
 static void
@@ -458,13 +528,13 @@ emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
       emit_ubo(ctx, var);
    else {
       assert(var->data.mode == nir_var_uniform);
-      if (glsl_type_is_sampler(var->type))
+      if (glsl_type_is_sampler(glsl_without_array(var->type)))
          emit_sampler(ctx, var);
    }
 }
 
 static SpvId
-get_src_uint_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
+get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
 {
    assert(ssa->index < ctx->num_defs);
    assert(ctx->defs[ssa->index] != 0);
@@ -480,7 +550,7 @@ get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
 }
 
 static SpvId
-get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
+get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
 {
    assert(reg->reg);
    assert(!reg->indirect);
@@ -492,21 +562,21 @@ get_src_uint_reg(struct ntv_context *ctx, const nir_reg_src *reg)
 }
 
 static SpvId
-get_src_uint(struct ntv_context *ctx, nir_src *src)
+get_src(struct ntv_context *ctx, nir_src *src)
 {
    if (src->is_ssa)
-      return get_src_uint_ssa(ctx, src->ssa);
+      return get_src_ssa(ctx, src->ssa);
    else
-      return get_src_uint_reg(ctx, &src->reg);
+      return get_src_reg(ctx, &src->reg);
 }
 
 static SpvId
-get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
+get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
 {
    assert(!alu->src[src].negate);
    assert(!alu->src[src].abs);
 
-   SpvId def = get_src_uint(ctx, &alu->src[src].src);
+   SpvId def = get_src(ctx, &alu->src[src].src);
 
    unsigned used_channels = 0;
    bool need_swizzle = false;
@@ -531,28 +601,33 @@ get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
    int bit_size = nir_src_bit_size(alu->src[src].src);
    assert(bit_size == 1 || bit_size == 32);
 
-   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
+   SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
+                                    spirv_builder_type_uint(&ctx->builder, bit_size);
+
    if (used_channels == 1) {
       uint32_t indices[] =  { alu->src[src].swizzle[0] };
-      return spirv_builder_emit_composite_extract(&ctx->builder, uint_type,
+      return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
                                                   def, indices,
                                                   ARRAY_SIZE(indices));
    } else if (live_channels == 1) {
-      SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
-                                                  used_channels);
+      SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
+                                                     raw_type,
+                                                     used_channels);
 
-      SpvId constituents[NIR_MAX_VEC_COMPONENTS];
+      SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
       for (unsigned i = 0; i < used_channels; ++i)
         constituents[i] = def;
 
-      return spirv_builder_emit_composite_construct(&ctx->builder, uvec_type,
+      return spirv_builder_emit_composite_construct(&ctx->builder,
+                                                    raw_vec_type,
                                                     constituents,
                                                     used_channels);
    } else {
-      SpvId uvec_type = spirv_builder_type_vector(&ctx->builder, uint_type,
-                                                  used_channels);
+      SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
+                                                     raw_type,
+                                                     used_channels);
 
-      uint32_t components[NIR_MAX_VEC_COMPONENTS];
+      uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
       size_t num_components = 0;
       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
          if (!nir_alu_instr_channel_used(alu, src, i))
@@ -561,13 +636,14 @@ get_alu_src_uint(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
          components[num_components++] = alu->src[src].swizzle[i];
       }
 
-      return spirv_builder_emit_vector_shuffle(&ctx->builder, uvec_type,
-                                        def, def, components, num_components);
+      return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
+                                               def, def, components,
+                                               num_components);
    }
 }
 
 static void
-store_ssa_def_uint(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
+store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
 {
    assert(result != 0);
    assert(ssa->index < ctx->num_defs);
@@ -581,15 +657,6 @@ emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
    return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
 }
 
-static SpvId
-bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
-{
-   SpvId otype = get_uvec_type(ctx, 32, num_components);
-   SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
-   SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
-   return emit_select(ctx, otype, value, one, zero);
-}
-
 static SpvId
 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
 {
@@ -637,39 +704,40 @@ store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
 }
 
 static void
-store_dest_uint(struct ntv_context *ctx, nir_dest *dest, SpvId result)
+store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
 {
    if (dest->is_ssa)
-      store_ssa_def_uint(ctx, &dest->ssa, result);
+      store_ssa_def(ctx, &dest->ssa, result);
    else
       store_reg_def(ctx, &dest->reg, result);
 }
 
-static void
+static SpvId
 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
 {
    unsigned num_components = nir_dest_num_components(*dest);
    unsigned bit_size = nir_dest_bit_size(*dest);
 
-   switch (nir_alu_type_get_base_type(type)) {
-   case nir_type_bool:
-      assert(bit_size == 1);
-      result = bvec_to_uvec(ctx, result, num_components);
-      break;
+   if (bit_size != 1) {
+      switch (nir_alu_type_get_base_type(type)) {
+      case nir_type_bool:
+         assert("bool should have bit-size 1");
 
-   case nir_type_uint:
-      break; /* nothing to do! */
+      case nir_type_uint:
+         break; /* nothing to do! */
 
-   case nir_type_int:
-   case nir_type_float:
-      result = bitcast_to_uvec(ctx, result, bit_size, num_components);
-      break;
+      case nir_type_int:
+      case nir_type_float:
+         result = bitcast_to_uvec(ctx, result, bit_size, num_components);
+         break;
 
-   default:
-      unreachable("unsupported nir_alu_type");
+      default:
+         unreachable("unsupported nir_alu_type");
+      }
    }
 
-   store_dest_uint(ctx, dest, result);
+   store_dest_raw(ctx, dest, result);
+   return result;
 }
 
 static SpvId
@@ -794,36 +862,40 @@ alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
 static SpvId
 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
 {
-   SpvId uint_value = get_alu_src_uint(ctx, alu, src);
+   SpvId raw_value = get_alu_src_raw(ctx, alu, src);
 
    unsigned num_components = alu_instr_src_components(alu, src);
    unsigned bit_size = nir_src_bit_size(alu->src[src].src);
    nir_alu_type type = nir_op_infos[alu->op].input_types[src];
 
-   switch (nir_alu_type_get_base_type(type)) {
-   case nir_type_bool:
-      assert(bit_size == 1);
-      return uvec_to_bvec(ctx, uint_value, num_components);
+   if (bit_size == 1)
+      return raw_value;
+   else {
+      switch (nir_alu_type_get_base_type(type)) {
+      case nir_type_bool:
+         unreachable("bool should have bit-size 1");
 
-   case nir_type_int:
-      return bitcast_to_ivec(ctx, uint_value, bit_size, num_components);
+      case nir_type_int:
+         return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
 
-   case nir_type_uint:
-      return uint_value;
+      case nir_type_uint:
+         return raw_value;
 
-   case nir_type_float:
-      return bitcast_to_fvec(ctx, uint_value, bit_size, num_components);
+      case nir_type_float:
+         return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
 
-   default:
-      unreachable("unknown nir_alu_type");
+      default:
+         unreachable("unknown nir_alu_type");
+      }
    }
 }
 
-static void
+static SpvId
 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
 {
    assert(!alu->dest.saturate);
-   return store_dest(ctx, &alu->dest.dest, result, nir_op_infos[alu->op].output_type);
+   return store_dest(ctx, &alu->dest.dest, result,
+                     nir_op_infos[alu->op].output_type);
 }
 
 static SpvId
@@ -832,9 +904,12 @@ get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
    unsigned num_components = nir_dest_num_components(*dest);
    unsigned bit_size = nir_dest_bit_size(*dest);
 
+   if (bit_size == 1)
+      return get_bvec_type(ctx, num_components);
+
    switch (nir_alu_type_get_base_type(type)) {
    case nir_type_bool:
-      return get_bvec_type(ctx, num_components);
+      unreachable("bool should have bit-size 1");
 
    case nir_type_int:
       return get_ivec_type(ctx, bit_size, num_components);
@@ -854,8 +929,11 @@ static void
 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
 {
    SpvId src[nir_op_infos[alu->op].num_inputs];
-   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
+   unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
+   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
       src[i] = get_alu_src(ctx, alu, i);
+      in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
+   }
 
    SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
                                    nir_op_infos[alu->op].output_type);
@@ -878,14 +956,24 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    UNOP(nir_op_ineg, SpvOpSNegate)
    UNOP(nir_op_fneg, SpvOpFNegate)
    UNOP(nir_op_fddx, SpvOpDPdx)
+   UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
+   UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
    UNOP(nir_op_fddy, SpvOpDPdy)
+   UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
+   UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
    UNOP(nir_op_f2i32, SpvOpConvertFToS)
    UNOP(nir_op_f2u32, SpvOpConvertFToU)
    UNOP(nir_op_i2f32, SpvOpConvertSToF)
    UNOP(nir_op_u2f32, SpvOpConvertUToF)
-   UNOP(nir_op_inot, SpvOpNot)
 #undef UNOP
 
+   case nir_op_inot:
+      if (bit_size == 1)
+         result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
+      else
+         result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
+      break;
+
    case nir_op_b2i32:
       assert(nir_op_infos[alu->op].num_inputs == 1);
       result = emit_select(ctx, dest_type, src[0],
@@ -936,6 +1024,13 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
                                             nir_src_bit_size(alu->src[0].src),
                                             num_components, 0));
       break;
+   case nir_op_i2b1:
+      assert(nir_op_infos[alu->op].num_inputs == 1);
+      result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
+                          get_ivec_constant(ctx,
+                                            nir_src_bit_size(alu->src[0].src),
+                                            num_components, 0));
+      break;
 
 
 #define BINOP(nir_op, spirv_op) \
@@ -957,8 +1052,6 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    BINOP(nir_op_fmod, SpvOpFMod)
    BINOP(nir_op_ilt, SpvOpSLessThan)
    BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
-   BINOP(nir_op_ieq, SpvOpIEqual)
-   BINOP(nir_op_ine, SpvOpINotEqual)
    BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
    BINOP(nir_op_flt, SpvOpFOrdLessThan)
    BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
@@ -967,10 +1060,23 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
    BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
    BINOP(nir_op_ushr, SpvOpShiftRightLogical)
-   BINOP(nir_op_iand, SpvOpBitwiseAnd)
-   BINOP(nir_op_ior, SpvOpBitwiseOr)
 #undef BINOP
 
+#define BINOP_LOG(nir_op, spv_op, spv_log_op) \
+   case nir_op: \
+      assert(nir_op_infos[alu->op].num_inputs == 2); \
+      if (nir_src_bit_size(alu->src[0].src) == 1) \
+         result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
+      else \
+         result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
+      break;
+
+   BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
+   BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
+   BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
+   BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
+#undef BINOP_LOG
+
 #define BUILTIN_BINOP(nir_op, spirv_op) \
    case nir_op: \
       assert(nir_op_infos[alu->op].num_inputs == 2); \
@@ -988,6 +1094,9 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
       result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
       break;
 
+   case nir_op_fdph:
+      unreachable("should already be lowered away");
+
    case nir_op_seq:
    case nir_op_sne:
    case nir_op_slt:
@@ -1048,51 +1157,67 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
 
    case nir_op_bany_fnequal2:
    case nir_op_bany_fnequal3:
-   case nir_op_bany_fnequal4:
+   case nir_op_bany_fnequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpFOrdNotEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAny, dest_type, result);
       break;
+   }
 
    case nir_op_ball_fequal2:
    case nir_op_ball_fequal3:
-   case nir_op_ball_fequal4:
+   case nir_op_ball_fequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpFOrdEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAll, dest_type, result);
       break;
+   }
 
    case nir_op_bany_inequal2:
    case nir_op_bany_inequal3:
-   case nir_op_bany_inequal4:
+   case nir_op_bany_inequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpINotEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAny, dest_type, result);
       break;
+   }
 
    case nir_op_ball_iequal2:
    case nir_op_ball_iequal3:
-   case nir_op_ball_iequal4:
+   case nir_op_ball_iequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpIEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAll, dest_type, result);
       break;
+   }
 
    case nir_op_vec2:
    case nir_op_vec3:
@@ -1149,10 +1274,7 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
          constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
    }
 
-   if (bit_size == 1)
-      constant = bvec_to_uvec(ctx, constant, num_components);
-
-   store_ssa_def_uint(ctx, &load_const->def, constant);
+   store_ssa_def(ctx, &load_const->def, constant);
 }
 
 static void
@@ -1201,7 +1323,10 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
                                                          num_components);
       }
 
-      store_dest_uint(ctx, &intr->dest, result);
+      if (nir_dest_bit_size(intr->dest) == 1)
+         result = uvec_to_bvec(ctx, result, num_components);
+
+      store_dest(ctx, &intr->dest, result, nir_type_uint);
    } else
       unreachable("uniform-addressing not yet supported");
 }
@@ -1219,8 +1344,7 @@ emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 static void
 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
-   /* uint is a bit of a lie here; it's really just a pointer */
-   SpvId ptr = get_src_uint(ctx, intr->src);
+   SpvId ptr = get_src(ctx, intr->src);
 
    nir_variable *var = nir_intrinsic_get_var(intr, 0);
    SpvId result = spirv_builder_emit_load(&ctx->builder,
@@ -1229,15 +1353,14 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    unsigned num_components = nir_dest_num_components(intr->dest);
    unsigned bit_size = nir_dest_bit_size(intr->dest);
    result = bitcast_to_uvec(ctx, result, bit_size, num_components);
-   store_dest_uint(ctx, &intr->dest, result);
+   store_dest(ctx, &intr->dest, result, nir_type_uint);
 }
 
 static void
 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
-   /* uint is a bit of a lie here; it's really just a pointer */
-   SpvId ptr = get_src_uint(ctx, &intr->src[0]);
-   SpvId src = get_src_uint(ctx, &intr->src[1]);
+   SpvId ptr = get_src(ctx, &intr->src[0]);
+   SpvId src = get_src(ctx, &intr->src[1]);
 
    nir_variable *var = nir_intrinsic_get_var(intr, 0);
    SpvId type = get_glsl_type(ctx, glsl_without_array(var->type));
@@ -1276,8 +1399,23 @@ emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
                                           ctx->front_face_var);
    assert(1 == nir_dest_num_components(intr->dest));
-   result = bvec_to_uvec(ctx, result, 1);
-   store_dest_uint(ctx, &intr->dest, result);
+   store_dest(ctx, &intr->dest, result, nir_type_bool);
+}
+
+static void
+emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
+{
+   SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
+   if (!ctx->instance_id_var)
+      ctx->instance_id_var = create_builtin_var(ctx, var_type,
+                                               SpvStorageClassInput,
+                                               "gl_InstanceId",
+                                               SpvBuiltInInstanceIndex);
+
+   SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
+                                          ctx->instance_id_var);
+   assert(1 == nir_dest_num_components(intr->dest));
+   store_dest(ctx, &intr->dest, result, nir_type_uint);
 }
 
 static void
@@ -1293,7 +1431,7 @@ emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
                                           ctx->vertex_id_var);
    assert(1 == nir_dest_num_components(intr->dest));
-   store_dest_uint(ctx, &intr->dest, result);
+   store_dest(ctx, &intr->dest, result, nir_type_uint);
 }
 
 static void
@@ -1320,6 +1458,10 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       emit_load_front_face(ctx, intr);
       break;
 
+   case nir_intrinsic_load_instance_id:
+      emit_load_instance_id(ctx, intr);
+      break;
+
    case nir_intrinsic_load_vertex_id:
       emit_load_vertex_id(ctx, intr);
       break;
@@ -1337,14 +1479,14 @@ emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
    SpvId type = get_uvec_type(ctx, undef->def.bit_size,
                               undef->def.num_components);
 
-   store_ssa_def_uint(ctx, &undef->def,
-                      spirv_builder_emit_undef(&ctx->builder, type));
+   store_ssa_def(ctx, &undef->def,
+                 spirv_builder_emit_undef(&ctx->builder, type));
 }
 
 static SpvId
 get_src_float(struct ntv_context *ctx, nir_src *src)
 {
-   SpvId def = get_src_uint(ctx, src);
+   SpvId def = get_src(ctx, src);
    unsigned num_components = nir_src_num_components(*src);
    unsigned bit_size = nir_src_bit_size(*src);
    return bitcast_to_fvec(ctx, def, bit_size, num_components);
@@ -1353,7 +1495,7 @@ get_src_float(struct ntv_context *ctx, nir_src *src)
 static SpvId
 get_src_int(struct ntv_context *ctx, nir_src *src)
 {
-   SpvId def = get_src_uint(ctx, src);
+   SpvId def = get_src(ctx, src);
    unsigned num_components = nir_src_num_components(*src);
    unsigned bit_size = nir_src_bit_size(*src);
    return bitcast_to_ivec(ctx, def, bit_size, num_components);
@@ -1367,17 +1509,18 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
           tex->op == nir_texop_txl ||
           tex->op == nir_texop_txd ||
           tex->op == nir_texop_txf ||
+          tex->op == nir_texop_txf_ms ||
           tex->op == nir_texop_txs);
-   assert(tex->op == nir_texop_txs ||
-          nir_alu_type_get_base_type(tex->dest_type) == nir_type_float);
    assert(tex->texture_index == tex->sampler_index);
 
-   SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0;
+   SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
+         offset = 0, sample = 0;
    unsigned coord_components = 0;
    for (unsigned i = 0; i < tex->num_srcs; i++) {
       switch (tex->src[i].src_type) {
       case nir_tex_src_coord:
-         if (tex->op == nir_texop_txf)
+         if (tex->op == nir_texop_txf ||
+             tex->op == nir_texop_txf_ms)
             coord = get_src_int(ctx, &tex->src[i].src);
          else
             coord = get_src_float(ctx, &tex->src[i].src);
@@ -1390,6 +1533,10 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
          assert(proj != 0);
          break;
 
+      case nir_tex_src_offset:
+         offset = get_src_int(ctx, &tex->src[i].src);
+         break;
+
       case nir_tex_src_bias:
          assert(tex->op == nir_texop_txb);
          bias = get_src_float(ctx, &tex->src[i].src);
@@ -1399,6 +1546,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
       case nir_tex_src_lod:
          assert(nir_src_num_components(tex->src[i].src) == 1);
          if (tex->op == nir_texop_txf ||
+             tex->op == nir_texop_txf_ms ||
              tex->op == nir_texop_txs)
             lod = get_src_int(ctx, &tex->src[i].src);
          else
@@ -1406,6 +1554,11 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
          assert(lod != 0);
          break;
 
+      case nir_tex_src_ms_index:
+         assert(nir_src_num_components(tex->src[i].src) == 1);
+         sample = get_src_int(ctx, &tex->src[i].src);
+         break;
+
       case nir_tex_src_comparator:
          assert(nir_src_num_components(tex->src[i].src) == 1);
          dref = get_src_float(ctx, &tex->src[i].src);
@@ -1433,16 +1586,11 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
       assert(lod != 0);
    }
 
-   bool is_ms;
-   SpvDim dimension = type_to_dim(tex->sampler_dim, &is_ms);
-   SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
-   SpvId image_type = spirv_builder_type_image(&ctx->builder, float_type,
-                            dimension, false, tex->is_array, is_ms, 1,
-                            SpvImageFormatUnknown);
+   SpvId image_type = ctx->image_types[tex->texture_index];
    SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
                                                          image_type);
 
-   assert(tex->texture_index < ctx->num_samplers);
+   assert(ctx->samplers_used & (1u << tex->texture_index));
    SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
                                         ctx->samplers[tex->texture_index]);
 
@@ -1457,7 +1605,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
       return;
    }
 
-   if (proj) {
+   if (proj && coord_components > 0) {
       SpvId constituents[coord_components + 1];
       if (coord_components == 1)
          constituents[0] = coord;
@@ -1482,25 +1630,27 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
 
    SpvId actual_dest_type = dest_type;
    if (dref)
-      actual_dest_type = float_type;
+      actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
 
    SpvId result;
-   if (tex->op == nir_texop_txf) {
+   if (tex->op == nir_texop_txf ||
+       tex->op == nir_texop_txf_ms) {
       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
       result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
-                                              image, coord, lod);
+                                              image, coord, lod, sample);
    } else {
       result = spirv_builder_emit_image_sample(&ctx->builder,
                                                actual_dest_type, load,
                                                coord,
                                                proj != 0,
-                                               lod, bias, dref, dx, dy);
+                                               lod, bias, dref, dx, dy,
+                                               offset);
    }
 
    spirv_builder_emit_decoration(&ctx->builder, result,
                                  SpvDecorationRelaxedPrecision);
 
-   if (dref) {
+   if (dref && nir_dest_num_components(tex->dest) > 1) {
       SpvId components[4] = { result, result, result, result };
       result = spirv_builder_emit_composite_construct(&ctx->builder,
                                                       dest_type,
@@ -1568,8 +1718,7 @@ emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
    struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
    assert(he);
    SpvId result = (SpvId)(intptr_t)he->data;
-   /* uint is a bit of a lie here, it's really just an opaque type */
-   store_dest_uint(ctx, &deref->dest, result);
+   store_dest_raw(ctx, &deref->dest, result);
 }
 
 static void
@@ -1592,7 +1741,7 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
       unreachable("Unsupported nir_variable_mode\n");
    }
 
-   SpvId index = get_src_uint(ctx, &deref->arr.index);
+   SpvId index = get_src(ctx, &deref->arr.index);
 
    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
                                                storage_class,
@@ -1600,10 +1749,10 @@ emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
 
    SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
                                                   ptr_type,
-                                                  get_src_uint(ctx, &deref->parent),
+                                                  get_src(ctx, &deref->parent),
                                                   &index, 1);
    /* uint is a bit of a lie here, it's really just an opaque type */
-   store_dest_uint(ctx, &deref->dest, result);
+   store_dest(ctx, &deref->dest, result, nir_type_uint);
 }
 
 static void
@@ -1669,10 +1818,8 @@ emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
 static SpvId
 get_src_bool(struct ntv_context *ctx, nir_src *src)
 {
-   SpvId def = get_src_uint(ctx, src);
    assert(nir_src_bit_size(*src) == 1);
-   unsigned num_components = nir_src_num_components(*src);
-   return uvec_to_bvec(ctx, def, num_components);
+   return get_src(ctx, src);
 }
 
 static void
@@ -1794,6 +1941,7 @@ nir_to_spirv(struct nir_shader *s)
    if (s->info.stage == MESA_SHADER_FRAGMENT) {
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
+      spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
    }
 
    ctx.stage = s->info.stage;