nir_intrinsic_op nir_op,
struct vtn_ssa_value *dst,
struct vtn_ssa_value *src0,
- nir_ssa_def *index)
+ nir_ssa_def *index,
+ unsigned const_idx0,
+ unsigned const_idx1)
{
/* Some of the subgroup operations take an index. SPIR-V allows this to be
* any integer type. To make things simpler for drivers, we only support
if (!glsl_type_is_vector_or_scalar(dst->type)) {
for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
- src0->elems[i], index);
+ src0->elems[i], index,
+ const_idx0, const_idx1);
}
return;
}
if (index)
intrin->src[1] = nir_src_for_ssa(index);
+ intrin->const_index[0] = const_idx0;
+ intrin->const_index[1] = const_idx1;
+
nir_builder_instr_insert(&b->nb, &intrin->instr);
dst->def = &intrin->dest.ssa;
"OpGroupNonUniformElect must return a Bool");
nir_intrinsic_instr *elect =
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
- nir_ssa_dest_init(&elect->instr, &elect->dest, 1, 32, NULL);
+ nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
+ val->type->type, NULL);
nir_builder_instr_insert(&b->nb, &elect->instr);
val->ssa->def = &elect->dest.ssa;
break;
}
- case SpvOpGroupNonUniformBallot: {
+ case SpvOpGroupNonUniformBallot: ++w;
+ case SpvOpSubgroupBallotKHR: {
vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
"OpGroupNonUniformBallot must return a uvec4");
nir_intrinsic_instr *ballot =
nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
- ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+ ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def);
nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
ballot->num_components = 4;
nir_builder_instr_insert(&b->nb, &ballot->instr);
intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
- nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ val->type->type, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
if (src1)
intrin->src[1] = nir_src_for_ssa(src1);
- nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ val->type->type, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
break;
}
- case SpvOpGroupNonUniformBroadcastFirst:
+ case SpvOpGroupNonUniformBroadcastFirst: ++w;
+ case SpvOpSubgroupFirstInvocationKHR:
vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
- val->ssa, vtn_ssa_value(b, w[4]), NULL);
+ val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
break;
case SpvOpGroupNonUniformBroadcast:
+ case SpvOpGroupBroadcast: ++w;
+ case SpvOpSubgroupReadInvocationKHR:
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
- val->ssa, vtn_ssa_value(b, w[4]),
- vtn_ssa_value(b, w[5])->def);
+ val->ssa, vtn_ssa_value(b, w[3]),
+ vtn_ssa_value(b, w[4])->def, 0, 0);
break;
case SpvOpGroupNonUniformAll:
case SpvOpGroupNonUniformAny:
- case SpvOpGroupNonUniformAllEqual: {
+ case SpvOpGroupNonUniformAllEqual:
+ case SpvOpGroupAll:
+ case SpvOpGroupAny:
+ case SpvOpSubgroupAllKHR:
+ case SpvOpSubgroupAnyKHR:
+ case SpvOpSubgroupAllEqualKHR: {
vtn_fail_if(val->type->type != glsl_bool_type(),
"OpGroupNonUniform(All|Any|AllEqual) must return a bool");
nir_intrinsic_op op;
switch (opcode) {
case SpvOpGroupNonUniformAll:
+ case SpvOpGroupAll:
+ case SpvOpSubgroupAllKHR:
op = nir_intrinsic_vote_all;
break;
case SpvOpGroupNonUniformAny:
+ case SpvOpGroupAny:
+ case SpvOpSubgroupAnyKHR:
op = nir_intrinsic_vote_any;
break;
- case SpvOpGroupNonUniformAllEqual: {
- switch (glsl_get_base_type(val->type->type)) {
+ case SpvOpSubgroupAllEqualKHR:
+ op = nir_intrinsic_vote_ieq;
+ break;
+ case SpvOpGroupNonUniformAllEqual:
+ switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
case GLSL_TYPE_FLOAT:
+ case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_DOUBLE:
op = nir_intrinsic_vote_feq;
break;
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
+ case GLSL_TYPE_UINT8:
+ case GLSL_TYPE_INT8:
+ case GLSL_TYPE_UINT16:
+ case GLSL_TYPE_INT16:
case GLSL_TYPE_UINT64:
case GLSL_TYPE_INT64:
case GLSL_TYPE_BOOL:
unreachable("Unhandled type");
}
break;
- }
default:
unreachable("Unhandled opcode");
}
- nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def;
-
+ nir_ssa_def *src0;
+ if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
+ opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
+ opcode == SpvOpGroupNonUniformAllEqual) {
+ src0 = vtn_ssa_value(b, w[4])->def;
+ } else {
+ src0 = vtn_ssa_value(b, w[3])->def;
+ }
nir_intrinsic_instr *intrin =
nir_intrinsic_instr_create(b->nb.shader, op);
- intrin->num_components = src0->num_components;
+ if (nir_intrinsic_infos[op].src_components[0] == 0)
+ intrin->num_components = src0->num_components;
intrin->src[0] = nir_src_for_ssa(src0);
- nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+ nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+ val->type->type, NULL);
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
case SpvOpGroupNonUniformShuffle:
case SpvOpGroupNonUniformShuffleXor:
case SpvOpGroupNonUniformShuffleUp:
- case SpvOpGroupNonUniformShuffleDown:
+ case SpvOpGroupNonUniformShuffleDown: {
+ nir_intrinsic_op op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformShuffle:
+ op = nir_intrinsic_shuffle;
+ break;
+ case SpvOpGroupNonUniformShuffleXor:
+ op = nir_intrinsic_shuffle_xor;
+ break;
+ case SpvOpGroupNonUniformShuffleUp:
+ op = nir_intrinsic_shuffle_up;
+ break;
+ case SpvOpGroupNonUniformShuffleDown:
+ op = nir_intrinsic_shuffle_down;
+ break;
+ default:
+ unreachable("Invalid opcode");
+ }
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+ vtn_ssa_value(b, w[5])->def, 0, 0);
+ break;
+ }
+
+ case SpvOpGroupNonUniformQuadBroadcast:
+ vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
+ val->ssa, vtn_ssa_value(b, w[4]),
+ vtn_ssa_value(b, w[5])->def, 0, 0);
+ break;
+
+ case SpvOpGroupNonUniformQuadSwap: {
+ unsigned direction = vtn_constant_uint(b, w[5]);
+ nir_intrinsic_op op;
+ switch (direction) {
+ case 0:
+ op = nir_intrinsic_quad_swap_horizontal;
+ break;
+ case 1:
+ op = nir_intrinsic_quad_swap_vertical;
+ break;
+ case 2:
+ op = nir_intrinsic_quad_swap_diagonal;
+ break;
+ default:
+ vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
+ }
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+ NULL, 0, 0);
+ break;
+ }
+
case SpvOpGroupNonUniformIAdd:
case SpvOpGroupNonUniformFAdd:
case SpvOpGroupNonUniformIMul:
case SpvOpGroupNonUniformLogicalAnd:
case SpvOpGroupNonUniformLogicalOr:
case SpvOpGroupNonUniformLogicalXor:
- case SpvOpGroupNonUniformQuadBroadcast:
- case SpvOpGroupNonUniformQuadSwap:
+ case SpvOpGroupIAdd:
+ case SpvOpGroupFAdd:
+ case SpvOpGroupFMin:
+ case SpvOpGroupUMin:
+ case SpvOpGroupSMin:
+ case SpvOpGroupFMax:
+ case SpvOpGroupUMax:
+ case SpvOpGroupSMax:
+ case SpvOpGroupIAddNonUniformAMD:
+ case SpvOpGroupFAddNonUniformAMD:
+ case SpvOpGroupFMinNonUniformAMD:
+ case SpvOpGroupUMinNonUniformAMD:
+ case SpvOpGroupSMinNonUniformAMD:
+ case SpvOpGroupFMaxNonUniformAMD:
+ case SpvOpGroupUMaxNonUniformAMD:
+ case SpvOpGroupSMaxNonUniformAMD: {
+ nir_op reduction_op;
+ switch (opcode) {
+ case SpvOpGroupNonUniformIAdd:
+ case SpvOpGroupIAdd:
+ case SpvOpGroupIAddNonUniformAMD:
+ reduction_op = nir_op_iadd;
+ break;
+ case SpvOpGroupNonUniformFAdd:
+ case SpvOpGroupFAdd:
+ case SpvOpGroupFAddNonUniformAMD:
+ reduction_op = nir_op_fadd;
+ break;
+ case SpvOpGroupNonUniformIMul:
+ reduction_op = nir_op_imul;
+ break;
+ case SpvOpGroupNonUniformFMul:
+ reduction_op = nir_op_fmul;
+ break;
+ case SpvOpGroupNonUniformSMin:
+ case SpvOpGroupSMin:
+ case SpvOpGroupSMinNonUniformAMD:
+ reduction_op = nir_op_imin;
+ break;
+ case SpvOpGroupNonUniformUMin:
+ case SpvOpGroupUMin:
+ case SpvOpGroupUMinNonUniformAMD:
+ reduction_op = nir_op_umin;
+ break;
+ case SpvOpGroupNonUniformFMin:
+ case SpvOpGroupFMin:
+ case SpvOpGroupFMinNonUniformAMD:
+ reduction_op = nir_op_fmin;
+ break;
+ case SpvOpGroupNonUniformSMax:
+ case SpvOpGroupSMax:
+ case SpvOpGroupSMaxNonUniformAMD:
+ reduction_op = nir_op_imax;
+ break;
+ case SpvOpGroupNonUniformUMax:
+ case SpvOpGroupUMax:
+ case SpvOpGroupUMaxNonUniformAMD:
+ reduction_op = nir_op_umax;
+ break;
+ case SpvOpGroupNonUniformFMax:
+ case SpvOpGroupFMax:
+ case SpvOpGroupFMaxNonUniformAMD:
+ reduction_op = nir_op_fmax;
+ break;
+ case SpvOpGroupNonUniformBitwiseAnd:
+ case SpvOpGroupNonUniformLogicalAnd:
+ reduction_op = nir_op_iand;
+ break;
+ case SpvOpGroupNonUniformBitwiseOr:
+ case SpvOpGroupNonUniformLogicalOr:
+ reduction_op = nir_op_ior;
+ break;
+ case SpvOpGroupNonUniformBitwiseXor:
+ case SpvOpGroupNonUniformLogicalXor:
+ reduction_op = nir_op_ixor;
+ break;
+ default:
+ unreachable("Invalid reduction operation");
+ }
+
+ nir_intrinsic_op op;
+ unsigned cluster_size = 0;
+ switch ((SpvGroupOperation)w[4]) {
+ case SpvGroupOperationReduce:
+ op = nir_intrinsic_reduce;
+ break;
+ case SpvGroupOperationInclusiveScan:
+ op = nir_intrinsic_inclusive_scan;
+ break;
+ case SpvGroupOperationExclusiveScan:
+ op = nir_intrinsic_exclusive_scan;
+ break;
+ case SpvGroupOperationClusteredReduce:
+ op = nir_intrinsic_reduce;
+ assert(count == 7);
+ cluster_size = vtn_constant_uint(b, w[6]);
+ break;
+ default:
+ unreachable("Invalid group operation");
+ }
+
+ vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
+ NULL, reduction_op, cluster_size);
+ break;
+ }
+
default:
unreachable("Invalid SPIR-V opcode");
}