From: Jason Ekstrand Date: Thu, 18 Oct 2018 16:59:40 +0000 (-0500) Subject: nir: Add support for 1-bit data types X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=3191a823726053242696a7d96f37abd5475fadcc;p=mesa.git nir: Add support for 1-bit data types This commit adds support for 1-bit Booleans and integers. Booleans obviously take a value of true or false. Because we have to define the semantics of 1-bit signed and unsigned integers, we define uint1_t to take values of 0 and 1 and int1_t to take values of 0 and -1. 1-bit arithmetic is then well-defined in the usual way, just with fewer bits. The definition of int1_t and uint1_t doesn't usually matter but we do need something for purposes of constant folding. Reviewed-by: Eric Anholt Reviewed-by: Bas Nieuwenhuizen Tested-by: Bas Nieuwenhuizen --- diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index 249b9357c3f..3c80e03a091 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -638,6 +638,7 @@ const_value_int(int64_t i, unsigned bit_size) { nir_const_value v; switch (bit_size) { + case 1: v.b[0] = i & 1; break; case 8: v.i8[0] = i; break; case 16: v.i16[0] = i; break; case 32: v.i32[0] = i; break; @@ -1206,6 +1207,8 @@ nir_src_comp_as_int(nir_src src, unsigned comp) assert(comp < load->def.num_components); switch (load->def.bit_size) { + /* int1_t uses 0/-1 convention */ + case 1: return -(int)load->value.b[comp]; case 8: return load->value.i8[comp]; case 16: return load->value.i16[comp]; case 32: return load->value.i32[comp]; @@ -1223,6 +1226,7 @@ nir_src_comp_as_uint(nir_src src, unsigned comp) assert(comp < load->def.num_components); switch (load->def.bit_size) { + case 1: return load->value.b[comp]; case 8: return load->value.u8[comp]; case 16: return load->value.u16[comp]; case 32: return load->value.u32[comp]; @@ -1235,15 +1239,12 @@ nir_src_comp_as_uint(nir_src src, unsigned comp) bool nir_src_comp_as_bool(nir_src src, unsigned comp) { - assert(nir_src_is_const(src)); - nir_load_const_instr *load = nir_instr_as_load_const(src.ssa->parent_instr); + int64_t i = nir_src_comp_as_int(src, comp); - assert(comp < load->def.num_components); - assert(load->def.bit_size == 32); - assert(load->value.u32[comp] == NIR_TRUE || - load->value.u32[comp] == NIR_FALSE); + /* Booleans of any size use 0/-1 convention */ + assert(i == 0 || i == -1); - return load->value.u32[comp]; + return i; } double diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index d16eabdec53..dca1c80d6e2 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -118,6 +118,7 @@ typedef enum { } nir_rounding_mode; typedef union { + bool b[NIR_MAX_VEC_COMPONENTS]; float f32[NIR_MAX_VEC_COMPONENTS]; double f64[NIR_MAX_VEC_COMPONENTS]; int8_t i8[NIR_MAX_VEC_COMPONENTS]; @@ -779,17 +780,25 @@ typedef struct { unsigned write_mask : NIR_MAX_VEC_COMPONENTS; /* ignored if dest.is_ssa is true */ } nir_alu_dest; +/** NIR sized and unsized types + * + * The values in this enum are carefully chosen so that the sized type is + * just the unsized type OR the number of bits. + */ typedef enum { nir_type_invalid = 0, /* Not a valid type */ - nir_type_float, - nir_type_int, - nir_type_uint, - nir_type_bool, + nir_type_int = 2, + nir_type_uint = 4, + nir_type_bool = 6, + nir_type_float = 128, + nir_type_bool1 = 1 | nir_type_bool, nir_type_bool32 = 32 | nir_type_bool, + nir_type_int1 = 1 | nir_type_int, nir_type_int8 = 8 | nir_type_int, nir_type_int16 = 16 | nir_type_int, nir_type_int32 = 32 | nir_type_int, nir_type_int64 = 64 | nir_type_int, + nir_type_uint1 = 1 | nir_type_uint, nir_type_uint8 = 8 | nir_type_uint, nir_type_uint16 = 16 | nir_type_uint, nir_type_uint32 = 32 | nir_type_uint, @@ -799,8 +808,8 @@ typedef enum { nir_type_float64 = 64 | nir_type_float, } nir_alu_type; -#define NIR_ALU_TYPE_SIZE_MASK 0xfffffff8 -#define NIR_ALU_TYPE_BASE_TYPE_MASK 0x00000007 +#define NIR_ALU_TYPE_SIZE_MASK 0x79 +#define NIR_ALU_TYPE_BASE_TYPE_MASK 0x86 static inline unsigned nir_alu_type_get_type_size(nir_alu_type type) diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index e0cdcd4ba23..d8abb7fd027 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -332,7 +332,10 @@ nir_imm_intN_t(nir_builder *build, uint64_t x, unsigned bit_size) memset(&v, 0, sizeof(v)); assert(bit_size <= 64); - v.i64[0] = x & (~0ull >> (64 - bit_size)); + if (bit_size == 1) + v.b[0] = x & 1; + else + v.i64[0] = x & (~0ull >> (64 - bit_size)); return nir_build_imm(build, 1, bit_size, v); } @@ -351,6 +354,13 @@ nir_imm_ivec4(nir_builder *build, int x, int y, int z, int w) return nir_build_imm(build, 4, 32, v); } +static inline nir_ssa_def * +nir_imm_boolN_t(nir_builder *build, bool x, unsigned bit_size) +{ + /* We use a 0/-1 convention for all booleans regardless of size */ + return nir_imm_intN_t(build, -(int)x, bit_size); +} + static inline nir_ssa_def * nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0, nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3) diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index 5091270b197..505cdd8baae 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -24,7 +24,9 @@ def op_bit_sizes(op): return sorted(list(sizes)) if sizes is not None else None def get_const_field(type_): - if type_base_type(type_) == 'bool': + if type_size(type_) == 1: + return 'b' + elif type_base_type(type_) == 'bool': return 'i' + str(type_size(type_)) elif type_ == "float16": return "u16" @@ -237,9 +239,12 @@ unpack_half_1x16(uint16_t u) } /* Some typed vector structures to make things like src0.y work */ +typedef int8_t int1_t; +typedef uint8_t uint1_t; typedef float float16_t; typedef float float32_t; typedef double float64_t; +typedef bool bool1_t; typedef bool bool8_t; typedef bool bool16_t; typedef bool bool32_t; @@ -274,7 +279,10 @@ struct ${type}${width}_vec { const struct ${input_types[j]}_vec src${j} = { % for k in range(op.input_sizes[j]): - % if input_types[j] == "float16": + % if input_types[j] == "int1": + /* 1-bit integers use a 0/-1 convention */ + -(int1_t)_src[${j}].b[${k}], + % elif input_types[j] == "float16": _mesa_half_to_float(_src[${j}].u16[${k}]), % else: _src[${j}].${get_const_field(input_types[j])}[${k}], @@ -299,6 +307,9 @@ struct ${type}${width}_vec { % elif "src" + str(j) not in op.const_expr: ## Avoid unused variable warnings <% continue %> + % elif input_types[j] == "int1": + /* 1-bit integers use a 0/-1 convention */ + const int1_t src${j} = -(int1_t)_src[${j}].b[_i]; % elif input_types[j] == "float16": const float src${j} = _mesa_half_to_float(_src[${j}].u16[_i]); @@ -321,7 +332,10 @@ struct ${type}${width}_vec { ## Store the current component of the actual destination to the ## value of dst. - % if output_type.startswith("bool"): + % if output_type == "int1" or output_type == "uint1": + /* 1-bit integers get truncated */ + _dst_val.b[_i] = dst & 1; + % elif output_type.startswith("bool"): ## Sanitize the C value to a proper NIR 0/-1 bool _dst_val.${get_const_field(output_type)}[_i] = -(int)dst; % elif output_type == "float16": @@ -350,7 +364,10 @@ struct ${type}${width}_vec { ## For each component in the destination, copy the value of dst to ## the actual destination. % for k in range(op.output_size): - % if output_type == "bool32": + % if output_type == "int1" or output_type == "uint1": + /* 1-bit integers get truncated */ + _dst_val.b[${k}] = dst.${"xyzw"[k]} & 1; + % elif output_type.startswith("bool"): ## Sanitize the C value to a proper NIR 0/-1 bool _dst_val.${get_const_field(output_type)}[${k}] = -(int)dst.${"xyzw"[k]}; % elif output_type == "float16": diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c index 19771fcd9dd..2a9e3396985 100644 --- a/src/compiler/nir/nir_instr_set.c +++ b/src/compiler/nir/nir_instr_set.c @@ -117,8 +117,15 @@ hash_load_const(uint32_t hash, const nir_load_const_instr *instr) { hash = HASH(hash, instr->def.num_components); - unsigned size = instr->def.num_components * (instr->def.bit_size / 8); - hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value.f32, size); + if (instr->def.bit_size == 1) { + for (unsigned i = 0; i < instr->def.num_components; i++) { + uint8_t b = instr->value.b[i]; + hash = HASH(hash, b); + } + } else { + unsigned size = instr->def.num_components * (instr->def.bit_size / 8); + hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value.f32, size); + } return hash; } @@ -399,8 +406,13 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2) if (load1->def.bit_size != load2->def.bit_size) return false; - return memcmp(load1->value.f32, load2->value.f32, - load1->def.num_components * (load1->def.bit_size / 8u)) == 0; + if (load1->def.bit_size == 1) { + unsigned size = load1->def.num_components * sizeof(bool); + return memcmp(load1->value.b, load2->value.b, size) == 0; + } else { + unsigned size = load1->def.num_components * (load1->def.bit_size / 8); + return memcmp(load1->value.f32, load2->value.f32, size) == 0; + } } case nir_instr_type_phi: { nir_phi_instr *phi1 = nir_instr_as_phi(instr1); diff --git a/src/compiler/nir/nir_lower_load_const_to_scalar.c b/src/compiler/nir/nir_lower_load_const_to_scalar.c index b2e055f7dea..b62d32e483e 100644 --- a/src/compiler/nir/nir_lower_load_const_to_scalar.c +++ b/src/compiler/nir/nir_lower_load_const_to_scalar.c @@ -63,6 +63,9 @@ lower_load_const_instr_scalar(nir_load_const_instr *lower) case 8: load_comp->value.u8[0] = lower->value.u8[i]; break; + case 1: + load_comp->value.b[0] = lower->value.b[i]; + break; default: assert(!"invalid bit size"); } diff --git a/src/compiler/nir/nir_opt_constant_folding.c b/src/compiler/nir/nir_opt_constant_folding.c index be91a2a8fd6..5097a3bcc36 100644 --- a/src/compiler/nir/nir_opt_constant_folding.c +++ b/src/compiler/nir/nir_opt_constant_folding.c @@ -88,6 +88,9 @@ constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx) case 8: src[i].u8[j] = load_const->value.u8[instr->src[i].swizzle[j]]; break; + case 1: + src[i].b[j] = load_const->value.b[instr->src[i].swizzle[j]]; + break; default: unreachable("Invalid bit size"); } diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c index 7485fe06c3f..493f44136f4 100644 --- a/src/compiler/nir/nir_print.c +++ b/src/compiler/nir/nir_print.c @@ -996,6 +996,9 @@ print_load_const_instr(nir_load_const_instr *instr, print_state *state) case 8: fprintf(fp, "0x%02x", instr->value.u8[i]); break; + case 1: + fprintf(fp, "%s", instr->value.b[i] ? "true" : "false"); + break; } } diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index c7954b39415..50f5464cef8 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -476,8 +476,9 @@ construct_value(nir_builder *build, break; case nir_type_bool: - cval = nir_imm_bool(build, c->data.u); + cval = nir_imm_boolN_t(build, c->data.u, bit_size); break; + default: unreachable("Invalid alu source type"); } diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c index 62893cad87e..c896b9a8037 100644 --- a/src/compiler/nir/nir_validate.c +++ b/src/compiler/nir/nir_validate.c @@ -818,7 +818,7 @@ validate_if(nir_if *if_stmt, validate_state *state) nir_cf_node *next_node = nir_cf_node_next(&if_stmt->cf_node); validate_assert(state, next_node->type == nir_cf_node_block); - validate_src(&if_stmt->condition, state, 32, 1); + validate_src(&if_stmt->condition, state, 0, 1); validate_assert(state, !exec_list_is_empty(&if_stmt->then_list)); validate_assert(state, !exec_list_is_empty(&if_stmt->else_list)); diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 4679fd323e6..8c9dc7d19eb 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -1561,6 +1561,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, case 8: val->constant->values[0].u8[i] = elems[i]->values[0].u8[0]; break; + case 1: + val->constant->values[0].b[i] = elems[i]->values[0].b[0]; + break; default: vtn_fail("Invalid SpvOpConstantComposite bit size"); } @@ -1734,6 +1737,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, case 8: val->constant->values[0].u8[i] = (*c)->values[col].u8[elem + i]; break; + case 1: + val->constant->values[0].b[i] = (*c)->values[col].b[elem + i]; + break; default: vtn_fail("Invalid SpvOpCompositeExtract bit size"); } @@ -1761,6 +1767,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, case 8: (*c)->values[col].u8[elem + i] = insert->constant->values[0].u8[i]; break; + case 1: + (*c)->values[col].b[elem + i] = insert->constant->values[0].b[i]; + break; default: vtn_fail("Invalid SpvOpCompositeInsert bit size"); }