zink: Use store_dest_raw instead of storing an uint
[mesa.git] / src / gallium / drivers / zink / nir_to_spirv / nir_to_spirv.c
index e609a08caf61dacc45b2733e67dd3da2e940581f..eecbca3e5dce559bcf923d9dbe1a2efc34732893 100644 (file)
@@ -614,7 +614,7 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
                                                      raw_type,
                                                      used_channels);
 
-      SpvId constituents[NIR_MAX_VEC_COMPONENTS];
+      SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
       for (unsigned i = 0; i < used_channels; ++i)
         constituents[i] = def;
 
@@ -627,7 +627,7 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
                                                      raw_type,
                                                      used_channels);
 
-      uint32_t components[NIR_MAX_VEC_COMPONENTS];
+      uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
       size_t num_components = 0;
       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
          if (!nir_alu_instr_channel_used(alu, src, i))
@@ -712,7 +712,7 @@ store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
       store_reg_def(ctx, &dest->reg, result);
 }
 
-static void
+static SpvId
 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
 {
    unsigned num_components = nir_dest_num_components(*dest);
@@ -737,6 +737,7 @@ store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type t
    }
 
    store_dest_raw(ctx, dest, result);
+   return result;
 }
 
 static SpvId
@@ -889,7 +890,7 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
    }
 }
 
-static void
+static SpvId
 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
 {
    assert(!alu->dest.saturate);
@@ -928,8 +929,11 @@ static void
 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
 {
    SpvId src[nir_op_infos[alu->op].num_inputs];
-   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
+   unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
+   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
       src[i] = get_alu_src(ctx, alu, i);
+      in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
+   }
 
    SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
                                    nir_op_infos[alu->op].output_type);
@@ -1153,51 +1157,67 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
 
    case nir_op_bany_fnequal2:
    case nir_op_bany_fnequal3:
-   case nir_op_bany_fnequal4:
+   case nir_op_bany_fnequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpFOrdNotEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAny, dest_type, result);
       break;
+   }
 
    case nir_op_ball_fequal2:
    case nir_op_ball_fequal3:
-   case nir_op_ball_fequal4:
+   case nir_op_ball_fequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpFOrdEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAll, dest_type, result);
       break;
+   }
 
    case nir_op_bany_inequal2:
    case nir_op_bany_inequal3:
-   case nir_op_bany_inequal4:
+   case nir_op_bany_inequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpINotEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAny, dest_type, result);
       break;
+   }
 
    case nir_op_ball_iequal2:
    case nir_op_ball_iequal3:
-   case nir_op_ball_iequal4:
+   case nir_op_ball_iequal4: {
       assert(nir_op_infos[alu->op].num_inputs == 2);
       assert(alu_instr_src_components(alu, 0) ==
              alu_instr_src_components(alu, 1));
-      result = emit_binop(ctx, SpvOpIEqual,
+      assert(in_bit_sizes[0] == in_bit_sizes[1]);
+      /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
+      SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
+      result = emit_binop(ctx, op,
                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
                           src[0], src[1]);
       result = emit_unop(ctx, SpvOpAll, dest_type, result);
       break;
+   }
 
    case nir_op_vec2:
    case nir_op_vec3:
@@ -1489,16 +1509,18 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
           tex->op == nir_texop_txl ||
           tex->op == nir_texop_txd ||
           tex->op == nir_texop_txf ||
+          tex->op == nir_texop_txf_ms ||
           tex->op == nir_texop_txs);
    assert(tex->texture_index == tex->sampler_index);
 
    SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
-         offset = 0;
+         offset = 0, sample = 0;
    unsigned coord_components = 0;
    for (unsigned i = 0; i < tex->num_srcs; i++) {
       switch (tex->src[i].src_type) {
       case nir_tex_src_coord:
-         if (tex->op == nir_texop_txf)
+         if (tex->op == nir_texop_txf ||
+             tex->op == nir_texop_txf_ms)
             coord = get_src_int(ctx, &tex->src[i].src);
          else
             coord = get_src_float(ctx, &tex->src[i].src);
@@ -1524,6 +1546,7 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
       case nir_tex_src_lod:
          assert(nir_src_num_components(tex->src[i].src) == 1);
          if (tex->op == nir_texop_txf ||
+             tex->op == nir_texop_txf_ms ||
              tex->op == nir_texop_txs)
             lod = get_src_int(ctx, &tex->src[i].src);
          else
@@ -1531,6 +1554,11 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
          assert(lod != 0);
          break;
 
+      case nir_tex_src_ms_index:
+         assert(nir_src_num_components(tex->src[i].src) == 1);
+         sample = get_src_int(ctx, &tex->src[i].src);
+         break;
+
       case nir_tex_src_comparator:
          assert(nir_src_num_components(tex->src[i].src) == 1);
          dref = get_src_float(ctx, &tex->src[i].src);
@@ -1605,10 +1633,11 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
       actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
 
    SpvId result;
-   if (tex->op == nir_texop_txf) {
+   if (tex->op == nir_texop_txf ||
+       tex->op == nir_texop_txf_ms) {
       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
       result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
-                                              image, coord, lod);
+                                              image, coord, lod, sample);
    } else {
       result = spirv_builder_emit_image_sample(&ctx->builder,
                                                actual_dest_type, load,
@@ -1689,8 +1718,7 @@ emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
    struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
    assert(he);
    SpvId result = (SpvId)(intptr_t)he->data;
-   /* uint is a bit of a lie here, it's really just an opaque type */
-   store_dest(ctx, &deref->dest, result, nir_type_uint);
+   store_dest_raw(ctx, &deref->dest, result);
 }
 
 static void