#include "vtn_private.h"
+static void
+vtn_build_subgroup_instr(struct vtn_builder *b,
+ nir_intrinsic_op nir_op,
+ struct vtn_ssa_value *dst,
+ struct vtn_ssa_value *src0,
+ nir_ssa_def *index,
+ unsigned const_idx0,
+ unsigned const_idx1)
+{
+ /* Some of the subgroup operations take an index. SPIR-V allows this to be
+ * any integer type. To make things simpler for drivers, we only support
+ * 32-bit indices.
+ */
+ if (index && index->bit_size != 32)
+ index = nir_u2u32(&b->nb, index);
+
+ vtn_assert(dst->type == src0->type);
+ if (!glsl_type_is_vector_or_scalar(dst->type)) {
+ for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
+ vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
+ src0->elems[i], index,
+ const_idx0, const_idx1);
+ }
+ return;
+ }
+
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader, nir_op);
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ dst->type, NULL);
+ intrin->num_components = intrin->dest.ssa.num_components;
+
+ intrin->src[0] = nir_src_for_ssa(src0->def);
+ if (index)
+ intrin->src[1] = nir_src_for_ssa(index);
+
+ intrin->const_index[0] = const_idx0;
+ intrin->const_index[1] = const_idx1;
+
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ dst->def = &intrin->dest.ssa;
+}
+
void
vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
"OpGroupNonUniformElect must return a Bool");
nir_intrinsic_instr *elect =
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
- nir_ssa_dest_init(&elect->instr, &elect->dest, 1, 32, NULL);
+ nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
+ val->type->type, NULL);
nir_builder_instr_insert(&b->nb, &elect->instr);
val->ssa->def = &elect->dest.ssa;
break;
}
- case SpvOpGroupNonUniformAll:
- case SpvOpGroupNonUniformAny:
- case SpvOpGroupNonUniformAllEqual:
- case SpvOpGroupNonUniformBroadcast:
- case SpvOpGroupNonUniformBroadcastFirst:
- case SpvOpGroupNonUniformBallot:
- case SpvOpGroupNonUniformInverseBallot:
+ case SpvOpGroupNonUniformBallot: ++w;
+ case SpvOpSubgroupBallotKHR: {
+ vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
+ "OpGroupNonUniformBallot must return a uvec4");
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
+ ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def);
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
+ ballot->num_components = 4;
+ nir_builder_instr_insert(&b->nb, &ballot->instr);
+ val->ssa->def = &ballot->dest.ssa;
+ break;
+ }
+
+ case SpvOpGroupNonUniformInverseBallot: {
+ /* This one is just a BallotBitfieldExtract with subgroup invocation.
+ * We could add a NIR intrinsic but it's easier to just lower it on the
+ * spot.
+ */
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader,
+ nir_intrinsic_ballot_bitfield_extract);
+
+ intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+ intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
+
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ val->type->type, NULL);
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ val->ssa->def = &intrin->dest.ssa;
+ break;
+ }
+
case SpvOpGroupNonUniformBallotBitExtract:
case SpvOpGroupNonUniformBallotBitCount:
case SpvOpGroupNonUniformBallotFindLSB:
- case SpvOpGroupNonUniformBallotFindMSB:
+ case SpvOpGroupNonUniformBallotFindMSB: {
+ nir_ssa_def *src0, *src1 = NULL;
+ nir_intrinsic_op op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformBallotBitExtract:
+ op = nir_intrinsic_ballot_bitfield_extract;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ src1 = vtn_ssa_value(b, w[5])->def;
+ break;
+ case SpvOpGroupNonUniformBallotBitCount:
+ switch ((SpvGroupOperation)w[4]) {
+ case SpvGroupOperationReduce:
+ op = nir_intrinsic_ballot_bit_count_reduce;
+ break;
+ case SpvGroupOperationInclusiveScan:
+ op = nir_intrinsic_ballot_bit_count_inclusive;
+ break;
+ case SpvGroupOperationExclusiveScan:
+ op = nir_intrinsic_ballot_bit_count_exclusive;
+ break;
+ default:
+ unreachable("Invalid group operation");
+ }
+ src0 = vtn_ssa_value(b, w[5])->def;
+ break;
+ case SpvOpGroupNonUniformBallotFindLSB:
+ op = nir_intrinsic_ballot_find_lsb;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ break;
+ case SpvOpGroupNonUniformBallotFindMSB:
+ op = nir_intrinsic_ballot_find_msb;
+ src0 = vtn_ssa_value(b, w[4])->def;
+ break;
+ default:
+ unreachable("Unhandled opcode");
+ }
+
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader, op);
+
+ intrin->src[0] = nir_src_for_ssa(src0);
+ if (src1)
+ intrin->src[1] = nir_src_for_ssa(src1);
+
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ val->type->type, NULL);
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ val->ssa->def = &intrin->dest.ssa;
+ break;
+ }
+
+ case SpvOpGroupNonUniformBroadcastFirst: ++w;
+ case SpvOpSubgroupFirstInvocationKHR:
+ vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
+ val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
+ break;
+
+ case SpvOpGroupNonUniformBroadcast:
+ case SpvOpGroupBroadcast: ++w;
+ case SpvOpSubgroupReadInvocationKHR:
+ vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
+ val->ssa, vtn_ssa_value(b, w[3]),
+ vtn_ssa_value(b, w[4])->def, 0, 0);
+ break;
+
+ case SpvOpGroupNonUniformAll:
+ case SpvOpGroupNonUniformAny:
+ case SpvOpGroupNonUniformAllEqual:
+ case SpvOpGroupAll:
+ case SpvOpGroupAny:
+ case SpvOpSubgroupAllKHR:
+ case SpvOpSubgroupAnyKHR:
+ case SpvOpSubgroupAllEqualKHR: {
+ vtn_fail_if(val->type->type != glsl_bool_type(),
+ "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
+ nir_intrinsic_op op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformAll:
+ case SpvOpGroupAll:
+ case SpvOpSubgroupAllKHR:
+ op = nir_intrinsic_vote_all;
+ break;
+ case SpvOpGroupNonUniformAny:
+ case SpvOpGroupAny:
+ case SpvOpSubgroupAnyKHR:
+ op = nir_intrinsic_vote_any;
+ break;
+ case SpvOpSubgroupAllEqualKHR:
+ op = nir_intrinsic_vote_ieq;
+ break;
+ case SpvOpGroupNonUniformAllEqual:
+ switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
+ case GLSL_TYPE_FLOAT:
+ case GLSL_TYPE_FLOAT16:
+ case GLSL_TYPE_DOUBLE:
+ op = nir_intrinsic_vote_feq;
+ break;
+ case GLSL_TYPE_UINT:
+ case GLSL_TYPE_INT:
+ case GLSL_TYPE_UINT8:
+ case GLSL_TYPE_INT8:
+ case GLSL_TYPE_UINT16:
+ case GLSL_TYPE_INT16:
+ case GLSL_TYPE_UINT64:
+ case GLSL_TYPE_INT64:
+ case GLSL_TYPE_BOOL:
+ op = nir_intrinsic_vote_ieq;
+ break;
+ default:
+ unreachable("Unhandled type");
+ }
+ break;
+ default:
+ unreachable("Unhandled opcode");
+ }
+
+ nir_ssa_def *src0;
+ if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
+ opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
+ opcode == SpvOpGroupNonUniformAllEqual) {
+ src0 = vtn_ssa_value(b, w[4])->def;
+ } else {
+ src0 = vtn_ssa_value(b, w[3])->def;
+ }
+ nir_intrinsic_instr *intrin =
+ nir_intrinsic_instr_create(b->nb.shader, op);
+ if (nir_intrinsic_infos[op].src_components[0] == 0)
+ intrin->num_components = src0->num_components;
+ intrin->src[0] = nir_src_for_ssa(src0);
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ val->type->type, NULL);
+ nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+ val->ssa->def = &intrin->dest.ssa;
+ break;
+ }
+
case SpvOpGroupNonUniformShuffle:
case SpvOpGroupNonUniformShuffleXor:
case SpvOpGroupNonUniformShuffleUp:
- case SpvOpGroupNonUniformShuffleDown:
+ case SpvOpGroupNonUniformShuffleDown: {
+ nir_intrinsic_op op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformShuffle:
+ op = nir_intrinsic_shuffle;
+ break;
+ case SpvOpGroupNonUniformShuffleXor:
+ op = nir_intrinsic_shuffle_xor;
+ break;
+ case SpvOpGroupNonUniformShuffleUp:
+ op = nir_intrinsic_shuffle_up;
+ break;
+ case SpvOpGroupNonUniformShuffleDown:
+ op = nir_intrinsic_shuffle_down;
+ break;
+ default:
+ unreachable("Invalid opcode");
+ }
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+ vtn_ssa_value(b, w[5])->def, 0, 0);
+ break;
+ }
+
+ case SpvOpGroupNonUniformQuadBroadcast:
+ vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
+ val->ssa, vtn_ssa_value(b, w[4]),
+ vtn_ssa_value(b, w[5])->def, 0, 0);
+ break;
+
+ case SpvOpGroupNonUniformQuadSwap: {
+ unsigned direction = vtn_constant_uint(b, w[5]);
+ nir_intrinsic_op op;
+ switch (direction) {
+ case 0:
+ op = nir_intrinsic_quad_swap_horizontal;
+ break;
+ case 1:
+ op = nir_intrinsic_quad_swap_vertical;
+ break;
+ case 2:
+ op = nir_intrinsic_quad_swap_diagonal;
+ break;
+ default:
+ vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
+ }
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+ NULL, 0, 0);
+ break;
+ }
+
case SpvOpGroupNonUniformIAdd:
case SpvOpGroupNonUniformFAdd:
case SpvOpGroupNonUniformIMul:
case SpvOpGroupNonUniformLogicalAnd:
case SpvOpGroupNonUniformLogicalOr:
case SpvOpGroupNonUniformLogicalXor:
- case SpvOpGroupNonUniformQuadBroadcast:
- case SpvOpGroupNonUniformQuadSwap:
+ case SpvOpGroupIAdd:
+ case SpvOpGroupFAdd:
+ case SpvOpGroupFMin:
+ case SpvOpGroupUMin:
+ case SpvOpGroupSMin:
+ case SpvOpGroupFMax:
+ case SpvOpGroupUMax:
+ case SpvOpGroupSMax:
+ case SpvOpGroupIAddNonUniformAMD:
+ case SpvOpGroupFAddNonUniformAMD:
+ case SpvOpGroupFMinNonUniformAMD:
+ case SpvOpGroupUMinNonUniformAMD:
+ case SpvOpGroupSMinNonUniformAMD:
+ case SpvOpGroupFMaxNonUniformAMD:
+ case SpvOpGroupUMaxNonUniformAMD:
+ case SpvOpGroupSMaxNonUniformAMD: {
+ nir_op reduction_op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformIAdd:
+ case SpvOpGroupIAdd:
+ case SpvOpGroupIAddNonUniformAMD:
+ reduction_op = nir_op_iadd;
+ break;
+ case SpvOpGroupNonUniformFAdd:
+ case SpvOpGroupFAdd:
+ case SpvOpGroupFAddNonUniformAMD:
+ reduction_op = nir_op_fadd;
+ break;
+ case SpvOpGroupNonUniformIMul:
+ reduction_op = nir_op_imul;
+ break;
+ case SpvOpGroupNonUniformFMul:
+ reduction_op = nir_op_fmul;
+ break;
+ case SpvOpGroupNonUniformSMin:
+ case SpvOpGroupSMin:
+ case SpvOpGroupSMinNonUniformAMD:
+ reduction_op = nir_op_imin;
+ break;
+ case SpvOpGroupNonUniformUMin:
+ case SpvOpGroupUMin:
+ case SpvOpGroupUMinNonUniformAMD:
+ reduction_op = nir_op_umin;
+ break;
+ case SpvOpGroupNonUniformFMin:
+ case SpvOpGroupFMin:
+ case SpvOpGroupFMinNonUniformAMD:
+ reduction_op = nir_op_fmin;
+ break;
+ case SpvOpGroupNonUniformSMax:
+ case SpvOpGroupSMax:
+ case SpvOpGroupSMaxNonUniformAMD:
+ reduction_op = nir_op_imax;
+ break;
+ case SpvOpGroupNonUniformUMax:
+ case SpvOpGroupUMax:
+ case SpvOpGroupUMaxNonUniformAMD:
+ reduction_op = nir_op_umax;
+ break;
+ case SpvOpGroupNonUniformFMax:
+ case SpvOpGroupFMax:
+ case SpvOpGroupFMaxNonUniformAMD:
+ reduction_op = nir_op_fmax;
+ break;
+ case SpvOpGroupNonUniformBitwiseAnd:
+ case SpvOpGroupNonUniformLogicalAnd:
+ reduction_op = nir_op_iand;
+ break;
+ case SpvOpGroupNonUniformBitwiseOr:
+ case SpvOpGroupNonUniformLogicalOr:
+ reduction_op = nir_op_ior;
+ break;
+ case SpvOpGroupNonUniformBitwiseXor:
+ case SpvOpGroupNonUniformLogicalXor:
+ reduction_op = nir_op_ixor;
+ break;
+ default:
+ unreachable("Invalid reduction operation");
+ }
+
+ nir_intrinsic_op op;
+ unsigned cluster_size = 0;
+ switch ((SpvGroupOperation)w[4]) {
+ case SpvGroupOperationReduce:
+ op = nir_intrinsic_reduce;
+ break;
+ case SpvGroupOperationInclusiveScan:
+ op = nir_intrinsic_inclusive_scan;
+ break;
+ case SpvGroupOperationExclusiveScan:
+ op = nir_intrinsic_exclusive_scan;
+ break;
+ case SpvGroupOperationClusteredReduce:
+ op = nir_intrinsic_reduce;
+ assert(count == 7);
+ cluster_size = vtn_constant_uint(b, w[6]);
+ break;
+ default:
+ unreachable("Invalid group operation");
+ }
+
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
+ NULL, reduction_op, cluster_size);
+ break;
+ }
+
default:
unreachable("Invalid SPIR-V opcode");
}