zink: do not convert bools to/from uint
authorErik Faye-Lund <erik.faye-lund@collabora.com>
Mon, 10 Feb 2020 14:45:22 +0000 (15:45 +0100)
committerMarge Bot <eric+marge@anholt.net>
Mon, 17 Feb 2020 12:46:54 +0000 (12:46 +0000)
Since bools are the only 1-bit type, we always know if an SSA-def is a
bool or not. So we don't need to marshal it to uint.

So let's simplify the code a bit here.

Tested-by: Marge Bot <https://gitlab.freedesktop.org/mesa/mesa/merge_requests/3763>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/merge_requests/3763>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index 46e623a0c6de4fa2f4b8bf3864598713f1d74ae7..66fcbdcb9cef84819624ab6d975f42f7a6af8689 100644 (file)
@@ -140,9 +140,9 @@ get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
+   assert(bit_size == 32); // only 32-bit ints supported so far
 
-   SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
+   SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
    if (num_components > 1)
       return spirv_builder_type_vector(&ctx->builder, int_type,
                                        num_components);
@@ -154,9 +154,9 @@ get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
+   assert(bit_size == 32); // only 32-bit uints supported so far
 
-   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
+   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
    if (num_components > 1)
       return spirv_builder_type_vector(&ctx->builder, uint_type,
                                        num_components);
@@ -168,8 +168,8 @@ get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
 {
-   return get_uvec_type(ctx, nir_dest_bit_size(*dest),
-                             nir_dest_num_components(*dest));
+   unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
+   return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
 }
 
 static SpvId
@@ -601,7 +601,9 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
    int bit_size = nir_src_bit_size(alu->src[src].src);
    assert(bit_size == 1 || bit_size == 32);
 
-   SpvId raw_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
+   SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
+                                    spirv_builder_type_uint(&ctx->builder, bit_size);
+
    if (used_channels == 1) {
       uint32_t indices[] =  { alu->src[src].swizzle[0] };
       return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
@@ -655,15 +657,6 @@ emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
    return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
 }
 
-static SpvId
-bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
-{
-   SpvId otype = get_uvec_type(ctx, 32, num_components);
-   SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
-   SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
-   return emit_select(ctx, otype, value, one, zero);
-}
-
 static SpvId
 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
 {
@@ -725,22 +718,22 @@ store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type t
    unsigned num_components = nir_dest_num_components(*dest);
    unsigned bit_size = nir_dest_bit_size(*dest);
 
-   switch (nir_alu_type_get_base_type(type)) {
-   case nir_type_bool:
-      assert(bit_size == 1);
-      result = bvec_to_uvec(ctx, result, num_components);
-      break;
+   if (bit_size != 1) {
+      switch (nir_alu_type_get_base_type(type)) {
+      case nir_type_bool:
+         assert("bool should have bit-size 1");
 
-   case nir_type_uint:
-      break; /* nothing to do! */
+      case nir_type_uint:
+         break; /* nothing to do! */
 
-   case nir_type_int:
-   case nir_type_float:
-      result = bitcast_to_uvec(ctx, result, bit_size, num_components);
-      break;
+      case nir_type_int:
+      case nir_type_float:
+         result = bitcast_to_uvec(ctx, result, bit_size, num_components);
+         break;
 
-   default:
-      unreachable("unsupported nir_alu_type");
+      default:
+         unreachable("unsupported nir_alu_type");
+      }
    }
 
    store_dest_raw(ctx, dest, result);
@@ -874,22 +867,25 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
    unsigned bit_size = nir_src_bit_size(alu->src[src].src);
    nir_alu_type type = nir_op_infos[alu->op].input_types[src];
 
-   switch (nir_alu_type_get_base_type(type)) {
-   case nir_type_bool:
-      assert(bit_size == 1);
-      return uvec_to_bvec(ctx, raw_value, num_components);
+   if (bit_size == 1)
+      return raw_value;
+   else {
+      switch (nir_alu_type_get_base_type(type)) {
+      case nir_type_bool:
+         unreachable("bool should have bit-size 1");
 
-   case nir_type_int:
-      return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
+      case nir_type_int:
+         return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
 
-   case nir_type_uint:
-      return raw_value;
+      case nir_type_uint:
+         return raw_value;
 
-   case nir_type_float:
-      return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
+      case nir_type_float:
+         return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
 
-   default:
-      unreachable("unknown nir_alu_type");
+      default:
+         unreachable("unknown nir_alu_type");
+      }
    }
 }
 
@@ -907,9 +903,12 @@ get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
    unsigned num_components = nir_dest_num_components(*dest);
    unsigned bit_size = nir_dest_bit_size(*dest);
 
+   if (bit_size == 1)
+      return get_bvec_type(ctx, num_components);
+
    switch (nir_alu_type_get_base_type(type)) {
    case nir_type_bool:
-      return get_bvec_type(ctx, num_components);
+      unreachable("bool should have bit-size 1");
 
    case nir_type_int:
       return get_ivec_type(ctx, bit_size, num_components);
@@ -1231,9 +1230,6 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
          constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
    }
 
-   if (bit_size == 1)
-      constant = bvec_to_uvec(ctx, constant, num_components);
-
    store_ssa_def(ctx, &load_const->def, constant);
 }
 
@@ -1283,6 +1279,9 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
                                                          num_components);
       }
 
+      if (nir_dest_bit_size(intr->dest) == 1)
+         result = uvec_to_bvec(ctx, result, num_components);
+
       store_dest(ctx, &intr->dest, result, nir_type_uint);
    } else
       unreachable("uniform-addressing not yet supported");
@@ -1767,10 +1766,8 @@ emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
 static SpvId
 get_src_bool(struct ntv_context *ctx, nir_src *src)
 {
-   SpvId def = get_src(ctx, src);
    assert(nir_src_bit_size(*src) == 1);
-   unsigned num_components = nir_src_num_components(*src);
-   return uvec_to_bvec(ctx, def, num_components);
+   return get_src(ctx, src);
 }
 
 static void