From 57bff0a546c8ebe9a09335200719cb9e13d6aea9 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Tue, 29 Aug 2017 20:10:35 -0700 Subject: [PATCH] spirv: Add support for subgroup arithmetic Reviewed-by: Lionel Landwerlin Reviewed-by: Iago Toral Quiroga --- src/compiler/shader_info.h | 1 + src/compiler/spirv/spirv_to_nir.c | 4 ++ src/compiler/spirv/vtn_subgroup.c | 97 ++++++++++++++++++++++++++++--- 3 files changed, 94 insertions(+), 8 deletions(-) diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h index fd740e0d489..4cbe461d381 100644 --- a/src/compiler/shader_info.h +++ b/src/compiler/shader_info.h @@ -45,6 +45,7 @@ struct spirv_supported_capabilities { bool variable_pointers; bool storage_16bit; bool shader_viewport_index_layer; + bool subgroup_arithmetic; bool subgroup_ballot; bool subgroup_basic; bool subgroup_quad; diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index a8b545ec866..19862ab612f 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -3313,6 +3313,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, case SpvCapabilityGroupNonUniformQuad: spv_check_supported(subgroup_quad, cap); + case SpvCapabilityGroupNonUniformArithmetic: + case SpvCapabilityGroupNonUniformClustered: + spv_check_supported(subgroup_arithmetic, cap); + case SpvCapabilityVariablePointersStorageBuffer: case SpvCapabilityVariablePointers: spv_check_supported(variable_pointers, cap); diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index 1204c5945c8..bd3143962be 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -28,7 +28,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b, nir_intrinsic_op nir_op, struct vtn_ssa_value *dst, struct vtn_ssa_value *src0, - nir_ssa_def *index) + nir_ssa_def *index, + unsigned const_idx0, + unsigned const_idx1) { /* Some of the subgroup operations take an index. SPIR-V allows this to be * any integer type. To make things simpler for drivers, we only support @@ -41,7 +43,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b, if (!glsl_type_is_vector_or_scalar(dst->type)) { for (unsigned i = 0; i < glsl_get_length(dst->type); i++) { vtn_build_subgroup_instr(b, nir_op, dst->elems[i], - src0->elems[i], index); + src0->elems[i], index, + const_idx0, const_idx1); } return; } @@ -56,6 +59,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b, if (index) intrin->src[1] = nir_src_for_ssa(index); + intrin->const_index[0] = const_idx0; + intrin->const_index[1] = const_idx1; + nir_builder_instr_insert(&b->nb, &intrin->instr); dst->def = &intrin->dest.ssa; @@ -169,13 +175,13 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, case SpvOpGroupNonUniformBroadcastFirst: vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, - val->ssa, vtn_ssa_value(b, w[4]), NULL); + val->ssa, vtn_ssa_value(b, w[4]), NULL, 0, 0); break; case SpvOpGroupNonUniformBroadcast: vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, val->ssa, vtn_ssa_value(b, w[4]), - vtn_ssa_value(b, w[5])->def); + vtn_ssa_value(b, w[5])->def, 0, 0); break; case SpvOpGroupNonUniformAll: @@ -248,14 +254,14 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, unreachable("Invalid opcode"); } vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), - vtn_ssa_value(b, w[5])->def); + vtn_ssa_value(b, w[5])->def, 0, 0); break; } case SpvOpGroupNonUniformQuadBroadcast: vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, val->ssa, vtn_ssa_value(b, w[4]), - vtn_ssa_value(b, w[5])->def); + vtn_ssa_value(b, w[5])->def, 0, 0); break; case SpvOpGroupNonUniformQuadSwap: { @@ -272,7 +278,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, op = nir_intrinsic_quad_swap_diagonal; break; } - vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), NULL); + vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]), + NULL, 0, 0); break; } @@ -291,7 +298,81 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, case SpvOpGroupNonUniformBitwiseXor: case SpvOpGroupNonUniformLogicalAnd: case SpvOpGroupNonUniformLogicalOr: - case SpvOpGroupNonUniformLogicalXor: + case SpvOpGroupNonUniformLogicalXor: { + nir_op reduction_op; + switch (opcode) { + case SpvOpGroupNonUniformIAdd: + reduction_op = nir_op_iadd; + break; + case SpvOpGroupNonUniformFAdd: + reduction_op = nir_op_fadd; + break; + case SpvOpGroupNonUniformIMul: + reduction_op = nir_op_imul; + break; + case SpvOpGroupNonUniformFMul: + reduction_op = nir_op_fmul; + break; + case SpvOpGroupNonUniformSMin: + reduction_op = nir_op_imin; + break; + case SpvOpGroupNonUniformUMin: + reduction_op = nir_op_umin; + break; + case SpvOpGroupNonUniformFMin: + reduction_op = nir_op_fmin; + break; + case SpvOpGroupNonUniformSMax: + reduction_op = nir_op_imax; + break; + case SpvOpGroupNonUniformUMax: + reduction_op = nir_op_umax; + break; + case SpvOpGroupNonUniformFMax: + reduction_op = nir_op_fmax; + break; + case SpvOpGroupNonUniformBitwiseAnd: + case SpvOpGroupNonUniformLogicalAnd: + reduction_op = nir_op_iand; + break; + case SpvOpGroupNonUniformBitwiseOr: + case SpvOpGroupNonUniformLogicalOr: + reduction_op = nir_op_ior; + break; + case SpvOpGroupNonUniformBitwiseXor: + case SpvOpGroupNonUniformLogicalXor: + reduction_op = nir_op_ixor; + break; + default: + unreachable("Invalid reduction operation"); + } + + nir_intrinsic_op op; + unsigned cluster_size = 0; + switch ((SpvGroupOperation)w[4]) { + case SpvGroupOperationReduce: + op = nir_intrinsic_reduce; + break; + case SpvGroupOperationInclusiveScan: + op = nir_intrinsic_inclusive_scan; + break; + case SpvGroupOperationExclusiveScan: + op = nir_intrinsic_exclusive_scan; + break; + case SpvGroupOperationClusteredReduce: + op = nir_intrinsic_reduce; + assert(count == 7); + cluster_size = vtn_constant_value(b, w[6])->values[0].u32[0]; + break; + default: + unreachable("Invalid group operation"); + } + + vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]), + NULL, reduction_op, cluster_size); + break; + } + default: unreachable("Invalid SPIR-V opcode"); } -- 2.30.2