* \file nir_opt_intrinsics.c
*/
+static nir_intrinsic_instr *
+lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
+ unsigned int component)
+{
+ nir_ssa_def *comp;
+ if (component == 0)
+ comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
+ else
+ comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
+
+ nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
+ nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL);
+ intr->const_index[0] = intrin->const_index[0];
+ intr->const_index[1] = intrin->const_index[1];
+ intr->src[0] = nir_src_for_ssa(comp);
+ if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
+ nir_src_copy(&intr->src[1], &intrin->src[1], intr);
+
+ intr->num_components = 1;
+ nir_builder_instr_insert(b, &intr->instr);
+ return intr;
+}
+
+static nir_ssa_def *
+lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
+{
+ assert(intrin->src[0].ssa->bit_size == 64);
+ nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
+ nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
+ return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa);
+}
+
static nir_ssa_def *
ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
{
}
static nir_ssa_def *
-lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
+lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
+ bool lower_to_32bit)
{
/* This is safe to call on scalar things but it would be silly */
assert(intrin->dest.ssa.num_components > 1);
nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
}
- nir_builder_instr_insert(b, &chan_intrin->instr);
+ chan_intrin->const_index[0] = intrin->const_index[0];
+ chan_intrin->const_index[1] = intrin->const_index[1];
- reads[i] = &chan_intrin->dest.ssa;
+ if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) {
+ reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin);
+ } else {
+ nir_builder_instr_insert(b, &chan_intrin->instr);
+ reads[i] = &chan_intrin->dest.ssa;
+ }
}
return nir_vec(b, reads, intrin->num_components);
return result;
}
+static nir_ssa_def *
+lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
+ const nir_lower_subgroups_options *options)
+{
+ assert(intrin->src[0].is_ssa);
+ nir_ssa_def *value = intrin->src[0].ssa;
+
+ /* We have to implicitly lower to scalar */
+ nir_ssa_def *all_eq = NULL;
+ for (unsigned i = 0; i < intrin->num_components; i++) {
+ nir_intrinsic_instr *rfi =
+ nir_intrinsic_instr_create(b->shader,
+ nir_intrinsic_read_first_invocation);
+ nir_ssa_dest_init(&rfi->instr, &rfi->dest,
+ 1, value->bit_size, NULL);
+ rfi->num_components = 1;
+ rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
+ nir_builder_instr_insert(b, &rfi->instr);
+
+ nir_ssa_def *is_eq;
+ if (intrin->intrinsic == nir_intrinsic_vote_feq) {
+ is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+ } else {
+ is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+ }
+
+ if (all_eq == NULL) {
+ all_eq = is_eq;
+ } else {
+ all_eq = nir_iand(b, all_eq, is_eq);
+ }
+ }
+
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+ 1, options->ballot_bit_size, NULL);
+ ballot->num_components = 1;
+ ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq));
+ nir_builder_instr_insert(b, &ballot->instr);
+
+ return nir_ieq(b, &ballot->dest.ssa,
+ nir_imm_intN_t(b, 0, options->ballot_bit_size));
+}
+
+static nir_ssa_def *
+lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin,
+ const nir_lower_subgroups_options *options)
+{
+ unsigned mask = nir_src_as_uint(intrin->src[1]);
+
+ if (mask >= 32)
+ return NULL;
+
+ nir_intrinsic_instr *swizzle = nir_intrinsic_instr_create(
+ b->shader, nir_intrinsic_masked_swizzle_amd);
+ swizzle->num_components = intrin->num_components;
+ nir_src_copy(&swizzle->src[0], &intrin->src[0], swizzle);
+ nir_intrinsic_set_swizzle_mask(swizzle, (mask << 10) | 0x1f);
+ nir_ssa_dest_init(&swizzle->instr, &swizzle->dest,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size, NULL);
+
+ if (options->lower_to_scalar && swizzle->num_components > 1) {
+ return lower_subgroup_op_to_scalar(b, swizzle, options->lower_shuffle_to_32bit);
+ } else if (options->lower_shuffle_to_32bit && swizzle->src[0].ssa->bit_size == 64) {
+ return lower_subgroup_op_to_32bit(b, swizzle);
+ } else {
+ nir_builder_instr_insert(b, &swizzle->instr);
+ return &swizzle->dest.ssa;
+ }
+}
+
static nir_ssa_def *
lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
- bool lower_to_scalar)
+ const nir_lower_subgroups_options *options)
{
+ if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
+ options->lower_shuffle_to_swizzle_amd &&
+ nir_src_is_const(intrin->src[1])) {
+ nir_ssa_def *result =
+ lower_shuffle_to_swizzle(b, intrin, options);
+ if (result)
+ return result;
+ }
+
nir_ssa_def *index = nir_load_subgroup_invocation(b);
+ bool is_shuffle = false;
switch (intrin->intrinsic) {
case nir_intrinsic_shuffle_xor:
assert(intrin->src[1].is_ssa);
index = nir_ixor(b, index, intrin->src[1].ssa);
+ is_shuffle = true;
break;
case nir_intrinsic_shuffle_up:
assert(intrin->src[1].is_ssa);
index = nir_isub(b, index, intrin->src[1].ssa);
+ is_shuffle = true;
break;
case nir_intrinsic_shuffle_down:
assert(intrin->src[1].is_ssa);
index = nir_iadd(b, index, intrin->src[1].ssa);
+ is_shuffle = true;
break;
case nir_intrinsic_quad_broadcast:
assert(intrin->src[1].is_ssa);
intrin->dest.ssa.num_components,
intrin->dest.ssa.bit_size, NULL);
- if (lower_to_scalar && shuffle->num_components > 1) {
- return lower_subgroup_op_to_scalar(b, shuffle);
+ bool lower_to_32bit = options->lower_shuffle_to_32bit && is_shuffle;
+ if (options->lower_to_scalar && shuffle->num_components > 1) {
+ return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
+ } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
+ return lower_subgroup_op_to_32bit(b, shuffle);
} else {
nir_builder_instr_insert(b, &shuffle->instr);
return &shuffle->dest.ssa;
}
}
+static bool
+lower_subgroups_filter(const nir_instr *instr, const void *_options)
+{
+ return instr->type == nir_instr_type_intrinsic;
+}
+
+static nir_ssa_def *
+build_subgroup_mask(nir_builder *b, unsigned bit_size,
+ const nir_lower_subgroups_options *options)
+{
+ return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
+ nir_isub(b, nir_imm_int(b, bit_size),
+ nir_load_subgroup_size(b)));
+}
+
+static nir_ssa_def *
+lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
+ const nir_lower_subgroups_options *options)
+{
+ if (!options->lower_quad_broadcast_dynamic_to_const)
+ return lower_shuffle(b, intrin, options);
+
+ nir_ssa_def *dst = NULL;
+
+ for (unsigned i = 0; i < 4; ++i) {
+ nir_intrinsic_instr *qbcst =
+ nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast);
+
+ qbcst->num_components = intrin->num_components;
+ qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i));
+ nir_src_copy(&qbcst->src[0], &intrin->src[0], qbcst);
+ nir_ssa_dest_init(&qbcst->instr, &qbcst->dest,
+ intrin->dest.ssa.num_components,
+ intrin->dest.ssa.bit_size, NULL);
+
+ nir_ssa_def *qbcst_dst = NULL;
+
+ if (options->lower_to_scalar && qbcst->num_components > 1) {
+ qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
+ } else {
+ nir_builder_instr_insert(b, &qbcst->instr);
+ qbcst_dst = &qbcst->dest.ssa;
+ }
+
+ if (i)
+ dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa,
+ nir_src_for_ssa(nir_imm_int(b, i)).ssa),
+ qbcst_dst, dst);
+ else
+ dst = qbcst_dst;
+ }
+
+ return dst;
+}
+
static nir_ssa_def *
-lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
- const nir_lower_subgroups_options *options)
+lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
{
+ const nir_lower_subgroups_options *options = _options;
+
+ nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
switch (intrin->intrinsic) {
case nir_intrinsic_vote_any:
case nir_intrinsic_vote_all:
case nir_intrinsic_vote_feq:
case nir_intrinsic_vote_ieq:
if (options->lower_vote_trivial)
- return nir_imm_int(b, NIR_TRUE);
+ return nir_imm_true(b);
+
+ if (options->lower_vote_eq_to_ballot)
+ return lower_vote_eq_to_ballot(b, intrin, options);
if (options->lower_to_scalar && intrin->num_components > 1)
return lower_vote_eq_to_scalar(b, intrin);
case nir_intrinsic_read_invocation:
case nir_intrinsic_read_first_invocation:
if (options->lower_to_scalar && intrin->num_components > 1)
- return lower_subgroup_op_to_scalar(b, intrin);
+ return lower_subgroup_op_to_scalar(b, intrin, false);
break;
case nir_intrinsic_load_subgroup_eq_mask:
const unsigned bit_size = MAX2(options->ballot_bit_size,
intrin->dest.ssa.bit_size);
- assert(options->subgroup_size <= 64);
- uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
-
nir_ssa_def *count = nir_load_subgroup_invocation(b);
nir_ssa_def *val;
switch (intrin->intrinsic) {
break;
case nir_intrinsic_load_subgroup_ge_mask:
val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
- nir_imm_intN_t(b, group_mask, bit_size));
+ build_subgroup_mask(b, bit_size, options));
break;
case nir_intrinsic_load_subgroup_gt_mask:
val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
- nir_imm_intN_t(b, group_mask, bit_size));
+ build_subgroup_mask(b, bit_size, options));
break;
case nir_intrinsic_load_subgroup_le_mask:
val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
assert(intrin->src[0].is_ssa);
nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
options->ballot_bit_size);
+
+ if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
+ intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
+ /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
+ *
+ * "Find the most significant bit set to 1 in Value, considering
+ * only the bits in Value required to represent all bits of the
+ * group’s invocations. If none of the considered bits is set to
+ * 1, the result is undefined."
+ *
+ * It has similar text for the other three. This means that, in case
+ * the subgroup size is less than 32, we have to mask off the unused
+ * bits. If the subgroup size is fixed and greater than or equal to
+ * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
+ * the iand.
+ *
+ * We only have to worry about this for BitCount and FindMSB because
+ * FindLSB counts from the bottom and BitfieldExtract selects
+ * individual bits. In either case, if run outside the range of
+ * valid bits, we hit the undefined results case and we can return
+ * anything we want.
+ */
+ int_val = nir_iand(b, int_val,
+ build_subgroup_mask(b, options->ballot_bit_size, options));
+ }
+
switch (intrin->intrinsic) {
case nir_intrinsic_ballot_bitfield_extract:
assert(intrin->src[1].is_ssa);
return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
intrin->src[1].ssa),
- nir_imm_int(b, 1)));
+ nir_imm_intN_t(b, 1, options->ballot_bit_size)));
case nir_intrinsic_ballot_bit_count_reduce:
return nir_bit_count(b, int_val);
case nir_intrinsic_ballot_find_lsb:
case nir_intrinsic_shuffle:
if (options->lower_to_scalar && intrin->num_components > 1)
- return lower_subgroup_op_to_scalar(b, intrin);
+ return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+ else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
+ return lower_subgroup_op_to_32bit(b, intrin);
break;
-
case nir_intrinsic_shuffle_xor:
case nir_intrinsic_shuffle_up:
case nir_intrinsic_shuffle_down:
if (options->lower_shuffle)
- return lower_shuffle(b, intrin, options->lower_to_scalar);
+ return lower_shuffle(b, intrin, options);
else if (options->lower_to_scalar && intrin->num_components > 1)
- return lower_subgroup_op_to_scalar(b, intrin);
+ return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+ else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
+ return lower_subgroup_op_to_32bit(b, intrin);
break;
case nir_intrinsic_quad_broadcast:
case nir_intrinsic_quad_swap_horizontal:
case nir_intrinsic_quad_swap_vertical:
case nir_intrinsic_quad_swap_diagonal:
- if (options->lower_quad)
- return lower_shuffle(b, intrin, options->lower_to_scalar);
+ if (options->lower_quad ||
+ (options->lower_quad_broadcast_dynamic &&
+ intrin->intrinsic == nir_intrinsic_quad_broadcast &&
+ !nir_src_is_const(intrin->src[1])))
+ return lower_dynamic_quad_broadcast(b, intrin, options);
else if (options->lower_to_scalar && intrin->num_components > 1)
- return lower_subgroup_op_to_scalar(b, intrin);
+ return lower_subgroup_op_to_scalar(b, intrin, false);
+ break;
+
+ case nir_intrinsic_reduce: {
+ nir_ssa_def *ret = NULL;
+ /* A cluster size greater than the subgroup size is implemention defined */
+ if (options->subgroup_size &&
+ nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
+ nir_intrinsic_set_cluster_size(intrin, 0);
+ ret = NIR_LOWER_INSTR_PROGRESS;
+ }
+ if (options->lower_to_scalar && intrin->num_components > 1)
+ ret = lower_subgroup_op_to_scalar(b, intrin, false);
+ return ret;
+ }
+ case nir_intrinsic_inclusive_scan:
+ case nir_intrinsic_exclusive_scan:
+ if (options->lower_to_scalar && intrin->num_components > 1)
+ return lower_subgroup_op_to_scalar(b, intrin, false);
break;
default:
return NULL;
}
-static bool
-lower_subgroups_impl(nir_function_impl *impl,
- const nir_lower_subgroups_options *options)
-{
- nir_builder b;
- nir_builder_init(&b, impl);
- bool progress = false;
-
- nir_foreach_block(block, impl) {
- nir_foreach_instr_safe(instr, block) {
- if (instr->type != nir_instr_type_intrinsic)
- continue;
-
- nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
- b.cursor = nir_before_instr(instr);
-
- nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
- if (!lower)
- continue;
-
- nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
- nir_instr_remove(instr);
- progress = true;
- }
- }
-
- return progress;
-}
-
bool
nir_lower_subgroups(nir_shader *shader,
const nir_lower_subgroups_options *options)
{
- bool progress = false;
-
- nir_foreach_function(function, shader) {
- if (!function->impl)
- continue;
-
- if (lower_subgroups_impl(function->impl, options)) {
- progress = true;
- nir_metadata_preserve(function->impl, nir_metadata_block_index |
- nir_metadata_dominance);
- }
- }
-
- return progress;
+ return nir_shader_lower_instructions(shader,
+ lower_subgroups_filter,
+ lower_subgroups_instr,
+ (void *)options);
}