From: Jason Ekstrand Date: Fri, 28 Apr 2017 11:45:50 +0000 (-0700) Subject: spirv: Add initial subgroup support X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=adc077797aa8abb74a7aa539b1601e405f150f5f;p=mesa.git spirv: Add initial subgroup support Reviewed-by: Samuel Iglesias Gonsálvez Reviewed-by: Iago Toral Quiroga --- diff --git a/src/compiler/Makefile.sources b/src/compiler/Makefile.sources index 84dbc26d198..841bc8fec91 100644 --- a/src/compiler/Makefile.sources +++ b/src/compiler/Makefile.sources @@ -304,4 +304,5 @@ SPIRV_FILES = \ spirv/vtn_cfg.c \ spirv/vtn_glsl450.c \ spirv/vtn_private.h \ + spirv/vtn_subgroup.c \ spirv/vtn_variables.c diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build index 859a0c1e62d..7a7d6fd07aa 100644 --- a/src/compiler/nir/meson.build +++ b/src/compiler/nir/meson.build @@ -192,6 +192,7 @@ files_libnir = files( '../spirv/vtn_cfg.c', '../spirv/vtn_glsl450.c', '../spirv/vtn_private.h', + '../spirv/vtn_subgroup.c', '../spirv/vtn_variables.c', ) diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h index 00aee77c002..0ba32349d03 100644 --- a/src/compiler/shader_info.h +++ b/src/compiler/shader_info.h @@ -45,6 +45,7 @@ struct spirv_supported_capabilities { bool variable_pointers; bool storage_16bit; bool shader_viewport_index_layer; + bool subgroup_basic; }; typedef struct shader_info { diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 98fdd700b0b..4d2c1533d24 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -3292,6 +3292,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, spv_check_supported(multiview, cap); break; + case SpvCapabilityGroupNonUniform: + spv_check_supported(subgroup_basic, cap); + break; + case SpvCapabilityVariablePointersStorageBuffer: case SpvCapabilityVariablePointers: spv_check_supported(variable_pointers, cap); @@ -3873,6 +3877,43 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, vtn_handle_barrier(b, opcode, w, count); break; + case SpvOpGroupNonUniformElect: + case SpvOpGroupNonUniformAll: + case SpvOpGroupNonUniformAny: + case SpvOpGroupNonUniformAllEqual: + case SpvOpGroupNonUniformBroadcast: + case SpvOpGroupNonUniformBroadcastFirst: + case SpvOpGroupNonUniformBallot: + case SpvOpGroupNonUniformInverseBallot: + case SpvOpGroupNonUniformBallotBitExtract: + case SpvOpGroupNonUniformBallotBitCount: + case SpvOpGroupNonUniformBallotFindLSB: + case SpvOpGroupNonUniformBallotFindMSB: + case SpvOpGroupNonUniformShuffle: + case SpvOpGroupNonUniformShuffleXor: + case SpvOpGroupNonUniformShuffleUp: + case SpvOpGroupNonUniformShuffleDown: + case SpvOpGroupNonUniformIAdd: + case SpvOpGroupNonUniformFAdd: + case SpvOpGroupNonUniformIMul: + case SpvOpGroupNonUniformFMul: + case SpvOpGroupNonUniformSMin: + case SpvOpGroupNonUniformUMin: + case SpvOpGroupNonUniformFMin: + case SpvOpGroupNonUniformSMax: + case SpvOpGroupNonUniformUMax: + case SpvOpGroupNonUniformFMax: + case SpvOpGroupNonUniformBitwiseAnd: + case SpvOpGroupNonUniformBitwiseOr: + case SpvOpGroupNonUniformBitwiseXor: + case SpvOpGroupNonUniformLogicalAnd: + case SpvOpGroupNonUniformLogicalOr: + case SpvOpGroupNonUniformLogicalXor: + case SpvOpGroupNonUniformQuadBroadcast: + case SpvOpGroupNonUniformQuadSwap: + vtn_handle_subgroup(b, opcode, w, count); + break; + default: vtn_fail("Unhandled opcode"); } diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index 817f78ce42d..316def08027 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -713,6 +713,9 @@ nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count); +void vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count); + bool vtn_handle_glsl450_instruction(struct vtn_builder *b, uint32_t ext_opcode, const uint32_t *words, unsigned count); diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c new file mode 100644 index 00000000000..033c43e601c --- /dev/null +++ b/src/compiler/spirv/vtn_subgroup.c @@ -0,0 +1,82 @@ +/* + * Copyright © 2016 Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#include "vtn_private.h" + +void +vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); + + val->ssa = vtn_create_ssa_value(b, val->type->type); + + switch (opcode) { + case SpvOpGroupNonUniformElect: { + vtn_fail_if(val->type->type != glsl_bool_type(), + "OpGroupNonUniformElect must return a Bool"); + nir_intrinsic_instr *elect = + nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect); + nir_ssa_dest_init(&elect->instr, &elect->dest, 1, 32, NULL); + nir_builder_instr_insert(&b->nb, &elect->instr); + val->ssa->def = &elect->dest.ssa; + break; + } + + case SpvOpGroupNonUniformAll: + case SpvOpGroupNonUniformAny: + case SpvOpGroupNonUniformAllEqual: + case SpvOpGroupNonUniformBroadcast: + case SpvOpGroupNonUniformBroadcastFirst: + case SpvOpGroupNonUniformBallot: + case SpvOpGroupNonUniformInverseBallot: + case SpvOpGroupNonUniformBallotBitExtract: + case SpvOpGroupNonUniformBallotBitCount: + case SpvOpGroupNonUniformBallotFindLSB: + case SpvOpGroupNonUniformBallotFindMSB: + case SpvOpGroupNonUniformShuffle: + case SpvOpGroupNonUniformShuffleXor: + case SpvOpGroupNonUniformShuffleUp: + case SpvOpGroupNonUniformShuffleDown: + case SpvOpGroupNonUniformIAdd: + case SpvOpGroupNonUniformFAdd: + case SpvOpGroupNonUniformIMul: + case SpvOpGroupNonUniformFMul: + case SpvOpGroupNonUniformSMin: + case SpvOpGroupNonUniformUMin: + case SpvOpGroupNonUniformFMin: + case SpvOpGroupNonUniformSMax: + case SpvOpGroupNonUniformUMax: + case SpvOpGroupNonUniformFMax: + case SpvOpGroupNonUniformBitwiseAnd: + case SpvOpGroupNonUniformBitwiseOr: + case SpvOpGroupNonUniformBitwiseXor: + case SpvOpGroupNonUniformLogicalAnd: + case SpvOpGroupNonUniformLogicalOr: + case SpvOpGroupNonUniformLogicalXor: + case SpvOpGroupNonUniformQuadBroadcast: + case SpvOpGroupNonUniformQuadSwap: + default: + unreachable("Invalid SPIR-V opcode"); + } +} diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index da511a44b08..68e1adf8152 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -1293,6 +1293,22 @@ vtn_get_builtin_location(struct vtn_builder *b, *location = SYSTEM_VALUE_DRAW_ID; set_mode_system_value(b, mode); break; + case SpvBuiltInSubgroupSize: + *location = SYSTEM_VALUE_SUBGROUP_SIZE; + set_mode_system_value(b, mode); + break; + case SpvBuiltInSubgroupId: + *location = SYSTEM_VALUE_SUBGROUP_ID; + set_mode_system_value(b, mode); + break; + case SpvBuiltInSubgroupLocalInvocationId: + *location = SYSTEM_VALUE_SUBGROUP_INVOCATION; + set_mode_system_value(b, mode); + break; + case SpvBuiltInNumSubgroups: + *location = SYSTEM_VALUE_NUM_SUBGROUPS; + set_mode_system_value(b, mode); + break; case SpvBuiltInDeviceIndex: *location = SYSTEM_VALUE_DEVICE_INDEX; set_mode_system_value(b, mode);