From 9903f10636566834a7563b6828c52fe40c5b0d71 Mon Sep 17 00:00:00 2001 From: Erik Faye-Lund Date: Mon, 10 Feb 2020 15:45:22 +0100 Subject: [PATCH] zink: do not convert bools to/from uint Since bools are the only 1-bit type, we always know if an SSA-def is a bool or not. So we don't need to marshal it to uint. So let's simplify the code a bit here. Tested-by: Marge Bot Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 93 +++++++++---------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 46e623a0c6d..66fcbdcb9ce 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -140,9 +140,9 @@ get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component static SpvId get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components) { - assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far + assert(bit_size == 32); // only 32-bit ints supported so far - SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32)); + SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size); if (num_components > 1) return spirv_builder_type_vector(&ctx->builder, int_type, num_components); @@ -154,9 +154,9 @@ get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component static SpvId get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components) { - assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far + assert(bit_size == 32); // only 32-bit uints supported so far - SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32)); + SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size); if (num_components > 1) return spirv_builder_type_vector(&ctx->builder, uint_type, num_components); @@ -168,8 +168,8 @@ get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component static SpvId get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest) { - return get_uvec_type(ctx, nir_dest_bit_size(*dest), - nir_dest_num_components(*dest)); + unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32); + return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest)); } static SpvId @@ -601,7 +601,9 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) int bit_size = nir_src_bit_size(alu->src[src].src); assert(bit_size == 1 || bit_size == 32); - SpvId raw_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32)); + SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) : + spirv_builder_type_uint(&ctx->builder, bit_size); + if (used_channels == 1) { uint32_t indices[] = { alu->src[src].swizzle[0] }; return spirv_builder_emit_composite_extract(&ctx->builder, raw_type, @@ -655,15 +657,6 @@ emit_select(struct ntv_context *ctx, SpvId type, SpvId cond, return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false); } -static SpvId -bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components) -{ - SpvId otype = get_uvec_type(ctx, 32, num_components); - SpvId zero = get_uvec_constant(ctx, 32, num_components, 0); - SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX); - return emit_select(ctx, otype, value, one, zero); -} - static SpvId uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components) { @@ -725,22 +718,22 @@ store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type t unsigned num_components = nir_dest_num_components(*dest); unsigned bit_size = nir_dest_bit_size(*dest); - switch (nir_alu_type_get_base_type(type)) { - case nir_type_bool: - assert(bit_size == 1); - result = bvec_to_uvec(ctx, result, num_components); - break; + if (bit_size != 1) { + switch (nir_alu_type_get_base_type(type)) { + case nir_type_bool: + assert("bool should have bit-size 1"); - case nir_type_uint: - break; /* nothing to do! */ + case nir_type_uint: + break; /* nothing to do! */ - case nir_type_int: - case nir_type_float: - result = bitcast_to_uvec(ctx, result, bit_size, num_components); - break; + case nir_type_int: + case nir_type_float: + result = bitcast_to_uvec(ctx, result, bit_size, num_components); + break; - default: - unreachable("unsupported nir_alu_type"); + default: + unreachable("unsupported nir_alu_type"); + } } store_dest_raw(ctx, dest, result); @@ -874,22 +867,25 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src) unsigned bit_size = nir_src_bit_size(alu->src[src].src); nir_alu_type type = nir_op_infos[alu->op].input_types[src]; - switch (nir_alu_type_get_base_type(type)) { - case nir_type_bool: - assert(bit_size == 1); - return uvec_to_bvec(ctx, raw_value, num_components); + if (bit_size == 1) + return raw_value; + else { + switch (nir_alu_type_get_base_type(type)) { + case nir_type_bool: + unreachable("bool should have bit-size 1"); - case nir_type_int: - return bitcast_to_ivec(ctx, raw_value, bit_size, num_components); + case nir_type_int: + return bitcast_to_ivec(ctx, raw_value, bit_size, num_components); - case nir_type_uint: - return raw_value; + case nir_type_uint: + return raw_value; - case nir_type_float: - return bitcast_to_fvec(ctx, raw_value, bit_size, num_components); + case nir_type_float: + return bitcast_to_fvec(ctx, raw_value, bit_size, num_components); - default: - unreachable("unknown nir_alu_type"); + default: + unreachable("unknown nir_alu_type"); + } } } @@ -907,9 +903,12 @@ get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type) unsigned num_components = nir_dest_num_components(*dest); unsigned bit_size = nir_dest_bit_size(*dest); + if (bit_size == 1) + return get_bvec_type(ctx, num_components); + switch (nir_alu_type_get_base_type(type)) { case nir_type_bool: - return get_bvec_type(ctx, num_components); + unreachable("bool should have bit-size 1"); case nir_type_int: return get_ivec_type(ctx, bit_size, num_components); @@ -1231,9 +1230,6 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const) constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32); } - if (bit_size == 1) - constant = bvec_to_uvec(ctx, constant, num_components); - store_ssa_def(ctx, &load_const->def, constant); } @@ -1283,6 +1279,9 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr) num_components); } + if (nir_dest_bit_size(intr->dest) == 1) + result = uvec_to_bvec(ctx, result, num_components); + store_dest(ctx, &intr->dest, result, nir_type_uint); } else unreachable("uniform-addressing not yet supported"); @@ -1767,10 +1766,8 @@ emit_cf_list(struct ntv_context *ctx, struct exec_list *list); static SpvId get_src_bool(struct ntv_context *ctx, nir_src *src) { - SpvId def = get_src(ctx, src); assert(nir_src_bit_size(*src) == 1); - unsigned num_components = nir_src_num_components(*src); - return uvec_to_bvec(ctx, def, num_components); + return get_src(ctx, src); } static void -- 2.30.2