From: Rob Clark Date: Sat, 9 Mar 2019 16:17:55 +0000 (+0100) Subject: nir+vtn: vec8+vec16 support X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=a8ec4082;p=mesa.git nir+vtn: vec8+vec16 support This introduces new vec8 and vec16 instructions (which are the only instructions taking more than 4 sources), in order to construct 8 and 16 component vectors. In order to avoid fixing up the non-autogenerated nir_build_alu() sites and making them pass 16 src args for the benefit of the two instructions that take more than 4 srcs (ie vec8 and vec16), nir_build_alu() is has nir_build_alu_tail() split out and re-used by nir_build_alu2() (which is used for the > 4 src args case). v2 (Karol Herbst): use nir_build_alu2 for vec8 and vec16 use python's array multiplication syntax add nir_op_vec helper simplify nir_vec nir_build_alu_tail -> nir_builder_alu_instr_finish_and_insert use nir_build_alu for opcodes with <= 4 sources v3 (Karol Herbst): fix nir_serialize v4 (Dave Airlie): fix serialization of glsl_type handle vec8/16 in lowering of bools v5 (Karol Herbst): fix load store vectorizer Signed-off-by: Karol Herbst Reviewed-by: Dave Airlie --- diff --git a/src/compiler/glsl_types.cpp b/src/compiler/glsl_types.cpp index 4450b23a8c4..c958cc4b90d 100644 --- a/src/compiler/glsl_types.cpp +++ b/src/compiler/glsl_types.cpp @@ -2630,9 +2630,13 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type) case GLSL_TYPE_INT64: case GLSL_TYPE_BOOL: encoded.basic.interface_row_major = type->interface_row_major; - assert(type->vector_elements < 8); assert(type->matrix_columns < 8); - encoded.basic.vector_elements = type->vector_elements; + if (type->vector_elements <= 4) + encoded.basic.vector_elements = type->vector_elements; + else if (type->vector_elements == 8) + encoded.basic.vector_elements = 5; + else if (type->vector_elements == 16) + encoded.basic.vector_elements = 6; encoded.basic.matrix_columns = type->matrix_columns; encoded.basic.explicit_stride = MIN2(type->explicit_stride, 0xfffff); blob_write_uint32(blob, encoded.u32); @@ -2741,6 +2745,11 @@ decode_type_from_blob(struct blob_reader *blob) unsigned explicit_stride = encoded.basic.explicit_stride; if (explicit_stride == 0xfffff) explicit_stride = blob_read_uint32(blob); + uint32_t vector_elements = encoded.basic.vector_elements; + if (vector_elements == 5) + vector_elements = 8; + else if (vector_elements == 6) + vector_elements = 16; return glsl_type::get_instance(base_type, encoded.basic.vector_elements, encoded.basic.matrix_columns, explicit_stride, diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index fbef503f0fa..53305ae7148 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -58,10 +58,19 @@ extern "C" { #define NIR_FALSE 0u #define NIR_TRUE (~0u) -#define NIR_MAX_VEC_COMPONENTS 4 +#define NIR_MAX_VEC_COMPONENTS 16 #define NIR_MAX_MATRIX_COLUMNS 4 #define NIR_STREAM_PACKED (1 << 8) -typedef uint8_t nir_component_mask_t; +typedef uint16_t nir_component_mask_t; + +static inline bool +nir_num_components_valid(unsigned num_components) +{ + return (num_components >= 1 && + num_components <= 4) || + num_components == 8 || + num_components == 16; +} /** Defines a cast function * @@ -1030,6 +1039,8 @@ nir_op_vec(unsigned components) case 2: return nir_op_vec2; case 3: return nir_op_vec3; case 4: return nir_op_vec4; + case 8: return nir_op_vec8; + case 16: return nir_op_vec16; default: unreachable("bad component count"); } } diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index aed47593826..8b5923211db 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -874,7 +874,7 @@ nir_ssa_for_src(nir_builder *build, nir_src src, int num_components) static inline nir_ssa_def * nir_ssa_for_alu_src(nir_builder *build, nir_alu_instr *instr, unsigned srcn) { - static uint8_t trivial_swizzle[] = { 0, 1, 2, 3 }; + static uint8_t trivial_swizzle[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }; STATIC_ASSERT(ARRAY_SIZE(trivial_swizzle) == NIR_MAX_VEC_COMPONENTS); nir_alu_src *src = &instr->src[srcn]; diff --git a/src/compiler/nir/nir_builder_opcodes_h.py b/src/compiler/nir/nir_builder_opcodes_h.py index 53fb23ca2b3..f0d8cf1db68 100644 --- a/src/compiler/nir/nir_builder_opcodes_h.py +++ b/src/compiler/nir/nir_builder_opcodes_h.py @@ -31,14 +31,22 @@ def src_decl_list(num_srcs): return ', '.join('nir_ssa_def *src' + str(i) for i in range(num_srcs)) def src_list(num_srcs): - return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4)) + if num_srcs <= 4: + return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4)) + else: + return ', '.join('src' + str(i) for i in range(num_srcs)) %> % for name, opcode in sorted(opcodes.items()): static inline nir_ssa_def * nir_${name}(nir_builder *build, ${src_decl_list(opcode.num_inputs)}) { +% if opcode.num_inputs <= 4: return nir_build_alu(build, nir_op_${name}, ${src_list(opcode.num_inputs)}); +% else: + nir_ssa_def *srcs[${opcode.num_inputs}] = {${src_list(opcode.num_inputs)}}; + return nir_build_alu_src_arr(build, nir_op_${name}, srcs); +% endif } % endfor diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index 267a2615964..8b8cd5ffa69 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -292,6 +292,18 @@ struct ${type}${width}_vec { ${type}${width}_t y; ${type}${width}_t z; ${type}${width}_t w; + ${type}${width}_t e; + ${type}${width}_t f; + ${type}${width}_t g; + ${type}${width}_t h; + ${type}${width}_t i; + ${type}${width}_t j; + ${type}${width}_t k; + ${type}${width}_t l; + ${type}${width}_t m; + ${type}${width}_t n; + ${type}${width}_t o; + ${type}${width}_t p; }; % endfor % endfor @@ -324,7 +336,7 @@ struct ${type}${width}_vec { _src[${j}][${k}].${get_const_field(input_types[j])}, % endif % endfor - % for k in range(op.input_sizes[j], 4): + % for k in range(op.input_sizes[j], 16): 0, % endfor }; @@ -418,18 +430,18 @@ struct ${type}${width}_vec { % for k in range(op.output_size): % if output_type == "int1" or output_type == "uint1": /* 1-bit integers get truncated */ - _dst_val[${k}].b = dst.${"xyzw"[k]} & 1; + _dst_val[${k}].b = dst.${"xyzwefghijklmnop"[k]} & 1; % elif output_type.startswith("bool"): ## Sanitize the C value to a proper NIR 0/-1 bool - _dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzw"[k]}; + _dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzwefghijklmnop"[k]}; % elif output_type == "float16": if (nir_is_rounding_mode_rtz(execution_mode, 16)) { - _dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzw"[k]}); + _dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzwefghijklmnop"[k]}); } else { - _dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzw"[k]}); + _dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzwefghijklmnop"[k]}); } % else: - _dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzw"[k]}; + _dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzwefghijklmnop"[k]}; % endif % if op.name != "fquantize2f16" and type_base_type(output_type) == "float": diff --git a/src/compiler/nir/nir_lower_alu_to_scalar.c b/src/compiler/nir/nir_lower_alu_to_scalar.c index f672699b5a4..e919370ec1a 100644 --- a/src/compiler/nir/nir_lower_alu_to_scalar.c +++ b/src/compiler/nir/nir_lower_alu_to_scalar.c @@ -117,6 +117,8 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data) return lower_reduction(alu, chan, merge, b); \ switch (alu->op) { + case nir_op_vec16: + case nir_op_vec8: case nir_op_vec4: case nir_op_vec3: case nir_op_vec2: diff --git a/src/compiler/nir/nir_lower_bool_to_float.c b/src/compiler/nir/nir_lower_bool_to_float.c index c07121f6d88..d552b057590 100644 --- a/src/compiler/nir/nir_lower_bool_to_float.c +++ b/src/compiler/nir/nir_lower_bool_to_float.c @@ -56,6 +56,8 @@ lower_alu_instr(nir_builder *b, nir_alu_instr *alu) case nir_op_vec2: case nir_op_vec3: case nir_op_vec4: + case nir_op_vec8: + case nir_op_vec16: /* These we expect to have booleans but the opcode doesn't change */ break; diff --git a/src/compiler/nir/nir_lower_bool_to_int32.c b/src/compiler/nir/nir_lower_bool_to_int32.c index e331de488a3..0978207e72d 100644 --- a/src/compiler/nir/nir_lower_bool_to_int32.c +++ b/src/compiler/nir/nir_lower_bool_to_int32.c @@ -53,6 +53,8 @@ lower_alu_instr(nir_alu_instr *alu) case nir_op_vec2: case nir_op_vec3: case nir_op_vec4: + case nir_op_vec8: + case nir_op_vec16: case nir_op_inot: case nir_op_iand: case nir_op_ior: diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index 2ab04ed9b1d..86485e39508 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -75,7 +75,7 @@ class Opcode(object): assert isinstance(algebraic_properties, str) assert isinstance(const_expr, str) assert len(input_sizes) == len(input_types) - assert 0 <= output_size <= 4 + assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16) for size in input_sizes: assert 0 <= size <= 4 if output_size != 0: @@ -1057,6 +1057,40 @@ dst.z = src2.x; dst.w = src3.x; """) +opcode("vec8", 8, tuint, + [1] * 8, [tuint] * 8, + False, "", """ +dst.x = src0.x; +dst.y = src1.x; +dst.z = src2.x; +dst.w = src3.x; +dst.e = src4.x; +dst.f = src5.x; +dst.g = src6.x; +dst.h = src7.x; +""") + +opcode("vec16", 16, tuint, + [1] * 16, [tuint] * 16, + False, "", """ +dst.x = src0.x; +dst.y = src1.x; +dst.z = src2.x; +dst.w = src3.x; +dst.e = src4.x; +dst.f = src5.x; +dst.g = src6.x; +dst.h = src7.x; +dst.i = src8.x; +dst.j = src9.x; +dst.k = src10.x; +dst.l = src11.x; +dst.m = src12.x; +dst.n = src13.x; +dst.o = src14.x; +dst.p = src15.x; +""") + # An integer multiply instruction for address calculation. This is # similar to imul, except that the results are undefined in case of # overflow. Overflow is defined according to the size of the variable diff --git a/src/compiler/nir/nir_opt_load_store_vectorize.c b/src/compiler/nir/nir_opt_load_store_vectorize.c index c40d4ada4ee..6587251c7e7 100644 --- a/src/compiler/nir/nir_opt_load_store_vectorize.c +++ b/src/compiler/nir/nir_opt_load_store_vectorize.c @@ -643,7 +643,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size, return false; unsigned new_num_components = size / new_bit_size; - if (new_num_components > NIR_MAX_VEC_COMPONENTS) + if (!nir_num_components_valid(new_num_components)) return false; unsigned high_offset = high->offset_signed - low->offset_signed; diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c index ab82f4fbb50..9ec6dcfc947 100644 --- a/src/compiler/nir/nir_print.c +++ b/src/compiler/nir/nir_print.c @@ -171,6 +171,12 @@ print_dest(nir_dest *dest, print_state *state) print_reg_dest(&dest->reg, state); } +static const char * +comp_mask_string(unsigned num_components) +{ + return (num_components > 4) ? "abcdefghijklmnop" : "xyzw"; +} + static void print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state) { @@ -206,7 +212,7 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state) if (!nir_alu_instr_channel_used(instr, src, i)) continue; - fprintf(fp, "%c", "xyzw"[instr->src[src].swizzle[i]]); + fprintf(fp, "%c", comp_mask_string(live_channels)[instr->src[src].swizzle[i]]); } } @@ -224,10 +230,11 @@ print_alu_dest(nir_alu_dest *dest, print_state *state) if (!dest->dest.is_ssa && dest->write_mask != (1 << dest->dest.reg.reg->num_components) - 1) { + unsigned live_channels = dest->dest.reg.reg->num_components; fprintf(fp, "."); for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) if ((dest->write_mask >> i) & 1) - fprintf(fp, "%c", "xyzw"[i]); + fprintf(fp, "%c", comp_mask_string(live_channels)[i]); } } @@ -569,8 +576,8 @@ print_var_decl(nir_variable *var, print_state *state) switch (var->data.mode) { case nir_var_shader_in: case nir_var_shader_out: - if (num_components < 4 && num_components != 0) { - const char *xyzw = "xyzw"; + if (num_components < 16 && num_components != 0) { + const char *xyzw = comp_mask_string(num_components); for (int i = 0; i < num_components; i++) components_local[i + 1] = xyzw[i + var->data.location_frac]; @@ -816,9 +823,9 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state) /* special case wrmask to show it as a writemask.. */ unsigned wrmask = nir_intrinsic_write_mask(instr); fprintf(fp, " wrmask="); - for (unsigned i = 0; i < 4; i++) + for (unsigned i = 0; i < instr->num_components; i++) if ((wrmask >> i) & 1) - fprintf(fp, "%c", "xyzw"[i]); + fprintf(fp, "%c", comp_mask_string(instr->num_components)[i]); break; } diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index c1b179525ab..458a4eeb1ce 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -56,7 +56,13 @@ static bool nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states, const struct per_op_table *pass_op_table); -static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 }; +static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = +{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, +}; /** * Check if a source produces a value of the given type. diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c index 2a0839768b8..38e32e4ce3d 100644 --- a/src/compiler/nir/nir_validate.c +++ b/src/compiler/nir/nir_validate.c @@ -128,8 +128,7 @@ static void validate_src(nir_src *src, validate_state *state, static void validate_num_components(validate_state *state, unsigned num_components) { - validate_assert(state, num_components >= 1 && - num_components <= 4); + validate_assert(state, nir_num_components_valid(num_components)); } static void diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 29fef9fad66..63a165d5b84 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -3819,10 +3819,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, case SpvCapabilityInputAttachment: case SpvCapabilityImageGatherExtended: case SpvCapabilityStorageImageExtendedFormats: + case SpvCapabilityVector16: break; case SpvCapabilityLinkage: - case SpvCapabilityVector16: case SpvCapabilityFloat16Buffer: case SpvCapabilitySparseResidency: vtn_warn("Unsupported SPIR-V capability: %s",