X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fcompiler%2Fnir%2Fnir_lower_subgroups.c;h=f5eebb851446dce901085dd9f5c110e5688e5bd5;hb=d2dfcee7f7ebf87dae9570f1c7476eacb6240f83;hp=eca441fcf3954b2c4150037367ad3e6467a9257d;hpb=ce3af830cb6b1b6225a85aeade927db6c736412f;p=mesa.git diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index eca441fcf39..f5eebb85144 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -292,6 +292,55 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options) return instr->type == nir_instr_type_intrinsic; } +static nir_ssa_def * +build_subgroup_mask(nir_builder *b, unsigned bit_size, + const nir_lower_subgroups_options *options) +{ + return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size), + nir_isub(b, nir_imm_int(b, bit_size), + nir_load_subgroup_size(b))); +} + +static nir_ssa_def * +lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin, + const nir_lower_subgroups_options *options) +{ + if (!options->lower_quad_broadcast_dynamic_to_const) + return lower_shuffle(b, intrin, options->lower_to_scalar, false); + + nir_ssa_def *dst = NULL; + + for (unsigned i = 0; i < 4; ++i) { + nir_intrinsic_instr *qbcst = + nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast); + + qbcst->num_components = intrin->num_components; + qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i)); + nir_src_copy(&qbcst->src[0], &intrin->src[0], qbcst); + nir_ssa_dest_init(&qbcst->instr, &qbcst->dest, + intrin->dest.ssa.num_components, + intrin->dest.ssa.bit_size, NULL); + + nir_ssa_def *qbcst_dst = NULL; + + if (options->lower_to_scalar && qbcst->num_components > 1) { + qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false); + } else { + nir_builder_instr_insert(b, &qbcst->instr); + qbcst_dst = &qbcst->dest.ssa; + } + + if (i) + dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa, + nir_src_for_ssa(nir_imm_int(b, i)).ssa), + qbcst_dst, dst); + else + dst = qbcst_dst; + } + + return dst; +} + static nir_ssa_def * lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) { @@ -343,9 +392,6 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) const unsigned bit_size = MAX2(options->ballot_bit_size, intrin->dest.ssa.bit_size); - assert(options->subgroup_size <= 64); - uint64_t group_mask = ~0ull >> (64 - options->subgroup_size); - nir_ssa_def *count = nir_load_subgroup_invocation(b); nir_ssa_def *val; switch (intrin->intrinsic) { @@ -354,11 +400,11 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) break; case nir_intrinsic_load_subgroup_ge_mask: val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count), - nir_imm_intN_t(b, group_mask, bit_size)); + build_subgroup_mask(b, bit_size, options)); break; case nir_intrinsic_load_subgroup_gt_mask: val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count), - nir_imm_intN_t(b, group_mask, bit_size)); + build_subgroup_mask(b, bit_size, options)); break; case nir_intrinsic_load_subgroup_le_mask: val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count)); @@ -467,13 +513,27 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) case nir_intrinsic_quad_swap_horizontal: case nir_intrinsic_quad_swap_vertical: case nir_intrinsic_quad_swap_diagonal: - if (options->lower_quad) - return lower_shuffle(b, intrin, options->lower_to_scalar, false); + if (options->lower_quad || + (options->lower_quad_broadcast_dynamic && + intrin->intrinsic == nir_intrinsic_quad_broadcast && + !nir_src_is_const(intrin->src[1]))) + return lower_dynamic_quad_broadcast(b, intrin, options); else if (options->lower_to_scalar && intrin->num_components > 1) return lower_subgroup_op_to_scalar(b, intrin, false); break; - case nir_intrinsic_reduce: + case nir_intrinsic_reduce: { + nir_ssa_def *ret = NULL; + /* A cluster size greater than the subgroup size is implemention defined */ + if (options->subgroup_size && + nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) { + nir_intrinsic_set_cluster_size(intrin, 0); + ret = NIR_LOWER_INSTR_PROGRESS; + } + if (options->lower_to_scalar && intrin->num_components > 1) + ret = lower_subgroup_op_to_scalar(b, intrin, false); + return ret; + } case nir_intrinsic_inclusive_scan: case nir_intrinsic_exclusive_scan: if (options->lower_to_scalar && intrin->num_components > 1)