From a026458020e947cc5d864cfb5b19660836b2d613 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Tue, 22 Aug 2017 18:44:51 -0700 Subject: [PATCH] nir/lower_subgroups: Lower ballot intrinsics to the specified bit size Ballot intrinsics return a bitfield of subgroups. In GLSL and some SPIR-V extensions, they return a uint64_t. In SPV_KHR_shader_ballot, they return a uvec4. Also, some back-ends would rather pass around 32-bit values because it's easier than messing with 64-bit all the time. To solve this mess, we make nir_lower_subgroups take a new parameter called ballot_bit_size and it lowers whichever thing it gets in from the source language (uint64_t or uvec4) to a scalar with the specified number of bits. This replaces a chunk of the old lowering code. Reviewed-by: Lionel Landwerlin Reviewed-by: Iago Toral Quiroga --- src/compiler/nir/nir.h | 3 +- src/compiler/nir/nir_lower_subgroups.c | 92 +++++++++++++++++++++++--- src/compiler/nir/nir_opt_intrinsics.c | 18 ----- src/intel/compiler/brw_compiler.c | 1 - src/intel/compiler/brw_nir.c | 1 + 5 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 7ee4bfcc557..c047ab7512b 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1854,8 +1854,6 @@ typedef struct nir_shader_compiler_options { */ bool use_interpolated_input_intrinsics; - unsigned max_subgroup_size; - unsigned max_unroll_iterations; } nir_shader_compiler_options; @@ -2486,6 +2484,7 @@ bool nir_lower_samplers_as_deref(nir_shader *shader, const struct gl_shader_program *shader_program); typedef struct nir_lower_subgroups_options { + uint8_t ballot_bit_size; bool lower_to_scalar:1; bool lower_vote_trivial:1; bool lower_subgroup_masks:1; diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 0d11dc9c23a..76e831691ee 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -28,6 +28,42 @@ * \file nir_opt_intrinsics.c */ +/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */ +static nir_ssa_def * +uint_to_ballot_type(nir_builder *b, nir_ssa_def *value, + unsigned num_components, unsigned bit_size) +{ + assert(value->num_components == 1); + assert(value->bit_size == 32 || value->bit_size == 64); + + nir_ssa_def *zero = nir_imm_int(b, 0); + if (num_components > 1) { + /* SPIR-V uses a uvec4 for ballot values */ + assert(num_components == 4); + assert(bit_size == 32); + + if (value->bit_size == 32) { + return nir_vec4(b, value, zero, zero, zero); + } else { + assert(value->bit_size == 64); + return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value), + nir_unpack_64_2x32_split_y(b, value), + zero, zero); + } + } else { + /* GLSL uses a uint64_t for ballot values */ + assert(num_components == 1); + assert(bit_size == 64); + + if (value->bit_size == 32) { + return nir_pack_64_2x32_split(b, value, zero); + } else { + assert(value->bit_size == 64); + return value; + } + } +} + static nir_ssa_def * lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) { @@ -62,7 +98,8 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin) static nir_ssa_def * high_subgroup_mask(nir_builder *b, nir_ssa_def *count, - uint64_t base_mask) + uint64_t base_mask, + unsigned bit_size) { /* group_mask could probably be calculated more efficiently but we want to * be sure not to shift by 64 if the subgroup size is 64 because the GLSL @@ -71,10 +108,11 @@ high_subgroup_mask(nir_builder *b, * subgroup size is likely to be known at compile time. */ nir_ssa_def *subgroup_size = nir_load_subgroup_size(b); - nir_ssa_def *all_bits = nir_imm_int64(b, ~0ull); + nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size); nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size); nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift); - nir_ssa_def *higher_bits = nir_ishl(b, nir_imm_int64(b, base_mask), count); + nir_ssa_def *higher_bits = + nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count); return nir_iand(b, higher_bits, group_mask); } @@ -109,24 +147,58 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin, if (!options->lower_subgroup_masks) return NULL; - nir_ssa_def *count = nir_load_subgroup_invocation(b); + /* If either the result or the requested bit size is 64-bits then we + * know that we have 64-bit types and using them will probably be more + * efficient than messing around with 32-bit shifts and packing. + */ + const unsigned bit_size = MAX2(options->ballot_bit_size, + intrin->dest.ssa.bit_size); + nir_ssa_def *count = nir_load_subgroup_invocation(b); + nir_ssa_def *val; switch (intrin->intrinsic) { case nir_intrinsic_load_subgroup_eq_mask: - return nir_ishl(b, nir_imm_int64(b, 1ull), count); + val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count); + break; case nir_intrinsic_load_subgroup_ge_mask: - return high_subgroup_mask(b, count, ~0ull); + val = high_subgroup_mask(b, count, ~0ull, bit_size); + break; case nir_intrinsic_load_subgroup_gt_mask: - return high_subgroup_mask(b, count, ~1ull); + val = high_subgroup_mask(b, count, ~1ull, bit_size); + break; case nir_intrinsic_load_subgroup_le_mask: - return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~1ull), count)); + val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count)); + break; case nir_intrinsic_load_subgroup_lt_mask: - return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~0ull), count)); + val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count)); + break; default: unreachable("you seriously can't tell this is unreachable?"); } - break; + + return uint_to_ballot_type(b, val, + intrin->dest.ssa.num_components, + intrin->dest.ssa.bit_size); + } + + case nir_intrinsic_ballot: { + if (intrin->dest.ssa.num_components == 1 && + intrin->dest.ssa.bit_size == options->ballot_bit_size) + return NULL; + + nir_intrinsic_instr *ballot = + nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot); + ballot->num_components = 1; + nir_ssa_dest_init(&ballot->instr, &ballot->dest, + 1, options->ballot_bit_size, NULL); + nir_src_copy(&ballot->src[0], &intrin->src[0], ballot); + nir_builder_instr_insert(b, &ballot->instr); + + return uint_to_ballot_type(b, &ballot->dest.ssa, + intrin->dest.ssa.num_components, + intrin->dest.ssa.bit_size); } + default: break; } diff --git a/src/compiler/nir/nir_opt_intrinsics.c b/src/compiler/nir/nir_opt_intrinsics.c index 98c8b1a01e8..eb394af0c10 100644 --- a/src/compiler/nir/nir_opt_intrinsics.c +++ b/src/compiler/nir/nir_opt_intrinsics.c @@ -54,24 +54,6 @@ opt_intrinsics_impl(nir_function_impl *impl) if (nir_src_as_const_value(intrin->src[0])) replacement = nir_imm_int(&b, NIR_TRUE); break; - case nir_intrinsic_ballot: { - assert(b.shader->options->max_subgroup_size != 0); - if (b.shader->options->max_subgroup_size > 32 || - intrin->dest.ssa.bit_size <= 32) - continue; - - nir_intrinsic_instr *ballot = - nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot); - nir_ssa_dest_init(&ballot->instr, &ballot->dest, 1, 32, NULL); - nir_src_copy(&ballot->src[0], &intrin->src[0], ballot); - - nir_builder_instr_insert(&b, &ballot->instr); - - replacement = nir_pack_64_2x32_split(&b, - &ballot->dest.ssa, - nir_imm_int(&b, 0)); - break; - } default: break; } diff --git a/src/intel/compiler/brw_compiler.c b/src/intel/compiler/brw_compiler.c index 8c709b55a10..e89aeacc7d2 100644 --- a/src/intel/compiler/brw_compiler.c +++ b/src/intel/compiler/brw_compiler.c @@ -57,7 +57,6 @@ static const struct nir_shader_compiler_options scalar_nir_options = { .lower_unpack_snorm_4x8 = true, .lower_unpack_unorm_2x16 = true, .lower_unpack_unorm_4x8 = true, - .max_subgroup_size = 32, .max_unroll_iterations = 32, }; diff --git a/src/intel/compiler/brw_nir.c b/src/intel/compiler/brw_nir.c index f599f748a8a..0d59d36ca63 100644 --- a/src/intel/compiler/brw_nir.c +++ b/src/intel/compiler/brw_nir.c @@ -637,6 +637,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir) OPT(nir_lower_system_values); const nir_lower_subgroups_options subgroups_options = { + .ballot_bit_size = 32, .lower_to_scalar = true, .lower_subgroup_masks = true, .lower_vote_trivial = !is_scalar, -- 2.30.2