spirv/nir: add support for AMD_shader_ballot and Groups capability
authorDaniel Schürmann <daniel.schuermann@campus.tu-berlin.de>
Wed, 9 May 2018 18:41:23 +0000 (20:41 +0200)
committerConnor Abbott <cwabbott0@gmail.com>
Thu, 13 Jun 2019 12:44:23 +0000 (12:44 +0000)
This commit also renames existing AMD capabilities:
 - gcn_shader -> amd_gcn_shader
 - trinary_minmax -> amd_trinary_minmax

Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
src/amd/vulkan/radv_shader.c
src/compiler/shader_info.h
src/compiler/spirv/spirv_to_nir.c
src/compiler/spirv/vtn_amd.c
src/compiler/spirv/vtn_private.h
src/compiler/spirv/vtn_subgroup.c

index 37fdbfe33f53ef317a900e8c75ea51028066a3fc..3e1098de79f1272b3d0c6dff78197e4c680d2f64 100644 (file)
@@ -245,6 +245,9 @@ radv_shader_compile_to_nir(struct radv_device *device,
                const struct spirv_to_nir_options spirv_options = {
                        .lower_ubo_ssbo_access_to_offsets = true,
                        .caps = {
+                               .amd_gcn_shader = true,
+                               .amd_shader_ballot = false,
+                               .amd_trinary_minmax = true,
                                .derivative_group = true,
                                .descriptor_array_dynamic_indexing = true,
                                .descriptor_array_non_uniform_indexing = true,
@@ -253,7 +256,6 @@ radv_shader_compile_to_nir(struct radv_device *device,
                                .draw_parameters = true,
                                .float16 = true,
                                .float64 = true,
-                               .gcn_shader = true,
                                .geometry_streams = true,
                                .image_read_without_format = true,
                                .image_write_without_format = true,
@@ -277,7 +279,6 @@ radv_shader_compile_to_nir(struct radv_device *device,
                                .subgroup_vote = true,
                                .tessellation = true,
                                .transform_feedback = true,
-                               .trinary_minmax = true,
                                .variable_pointers = true,
                        },
                        .ubo_addr_format = nir_address_format_32bit_index_offset,
index 32d87b234ecdf34b584e38a0276b17e773c2b8a0..46588c327f9e90a9626ff02a03b8d8349e509b8b 100644 (file)
@@ -45,7 +45,6 @@ struct spirv_supported_capabilities {
    bool fragment_shader_sample_interlock;
    bool fragment_shader_pixel_interlock;
    bool geometry_streams;
-   bool gcn_shader;
    bool image_ms_array;
    bool image_read_without_format;
    bool image_write_without_format;
@@ -72,9 +71,11 @@ struct spirv_supported_capabilities {
    bool subgroup_vote;
    bool tessellation;
    bool transform_feedback;
-   bool trinary_minmax;
    bool variable_pointers;
    bool float16;
+   bool amd_gcn_shader;
+   bool amd_shader_ballot;
+   bool amd_trinary_minmax;
 };
 
 typedef struct shader_info {
index 326f4b0d4110b3e5eac2aef507b4c9cffdef5a6e..ccd11fa63294a04b344696b9b3d6d5cb59620b11 100644 (file)
@@ -394,10 +394,13 @@ vtn_handle_extension(struct vtn_builder *b, SpvOp opcode,
       if (strcmp(ext, "GLSL.std.450") == 0) {
          val->ext_handler = vtn_handle_glsl450_instruction;
       } else if ((strcmp(ext, "SPV_AMD_gcn_shader") == 0)
-                && (b->options && b->options->caps.gcn_shader)) {
+                && (b->options && b->options->caps.amd_gcn_shader)) {
          val->ext_handler = vtn_handle_amd_gcn_shader_instruction;
+      } else if ((strcmp(ext, "SPV_AMD_shader_ballot") == 0)
+                && (b->options && b->options->caps.amd_shader_ballot)) {
+         val->ext_handler = vtn_handle_amd_shader_ballot_instruction;
       } else if ((strcmp(ext, "SPV_AMD_shader_trinary_minmax") == 0)
-                && (b->options && b->options->caps.trinary_minmax)) {
+                && (b->options && b->options->caps.amd_trinary_minmax)) {
          val->ext_handler = vtn_handle_amd_shader_trinary_minmax_instruction;
       } else if (strcmp(ext, "OpenCL.std") == 0) {
          val->ext_handler = vtn_handle_opencl_instruction;
@@ -3612,7 +3615,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityImageReadWrite:
       case SpvCapabilityImageMipmap:
       case SpvCapabilityPipes:
-      case SpvCapabilityGroups:
       case SpvCapabilityDeviceEnqueue:
       case SpvCapabilityLiteralSampler:
       case SpvCapabilityGenericPointer:
@@ -3677,6 +3679,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
          spv_check_supported(subgroup_arithmetic, cap);
          break;
 
+      case SpvCapabilityGroups:
+         spv_check_supported(amd_shader_ballot, cap);
+         break;
+
       case SpvCapabilityVariablePointersStorageBuffer:
       case SpvCapabilityVariablePointers:
          spv_check_supported(variable_pointers, cap);
@@ -4525,12 +4531,31 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformLogicalXor:
    case SpvOpGroupNonUniformQuadBroadcast:
    case SpvOpGroupNonUniformQuadSwap:
+   case SpvOpGroupAll:
+   case SpvOpGroupAny:
+   case SpvOpGroupBroadcast:
+   case SpvOpGroupIAdd:
+   case SpvOpGroupFAdd:
+   case SpvOpGroupFMin:
+   case SpvOpGroupUMin:
+   case SpvOpGroupSMin:
+   case SpvOpGroupFMax:
+   case SpvOpGroupUMax:
+   case SpvOpGroupSMax:
    case SpvOpSubgroupBallotKHR:
    case SpvOpSubgroupFirstInvocationKHR:
    case SpvOpSubgroupReadInvocationKHR:
    case SpvOpSubgroupAllKHR:
    case SpvOpSubgroupAnyKHR:
    case SpvOpSubgroupAllEqualKHR:
+   case SpvOpGroupIAddNonUniformAMD:
+   case SpvOpGroupFAddNonUniformAMD:
+   case SpvOpGroupFMinNonUniformAMD:
+   case SpvOpGroupUMinNonUniformAMD:
+   case SpvOpGroupSMinNonUniformAMD:
+   case SpvOpGroupFMaxNonUniformAMD:
+   case SpvOpGroupUMaxNonUniformAMD:
+   case SpvOpGroupSMaxNonUniformAMD:
       vtn_handle_subgroup(b, opcode, w, count);
       break;
 
index 0d5b429783bf76ca69999e587a61370ae1280099..23f8930faa26444fcd9edd76ee568199092fc7d1 100644 (file)
@@ -56,6 +56,67 @@ vtn_handle_amd_gcn_shader_instruction(struct vtn_builder *b, SpvOp ext_opcode,
    return true;
 }
 
+bool
+vtn_handle_amd_shader_ballot_instruction(struct vtn_builder *b, SpvOp ext_opcode,
+                                         const uint32_t *w, unsigned count)
+{
+   const struct glsl_type *dest_type =
+                           vtn_value(b, w[1], vtn_value_type_type)->type->type;
+   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+   val->ssa = vtn_create_ssa_value(b, dest_type);
+
+   unsigned num_args;
+   nir_intrinsic_op op;
+   switch ((enum ShaderBallotAMD)ext_opcode) {
+   case SwizzleInvocationsAMD:
+      num_args = 1;
+      op = nir_intrinsic_quad_swizzle_amd;
+      break;
+   case SwizzleInvocationsMaskedAMD:
+      num_args = 1;
+      op = nir_intrinsic_masked_swizzle_amd;
+      break;
+   case WriteInvocationAMD:
+      num_args = 3;
+      op = nir_intrinsic_write_invocation_amd;
+      break;
+   case MbcntAMD:
+      num_args = 1;
+      op = nir_intrinsic_mbcnt_amd;
+      break;
+   default:
+      unreachable("Invalid opcode");
+   }
+
+   nir_intrinsic_instr *intrin = nir_intrinsic_instr_create(b->nb.shader, op);
+   nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, dest_type, NULL);
+   intrin->num_components = intrin->dest.ssa.num_components;
+
+   for (unsigned i = 0; i < num_args; i++)
+      intrin->src[i] = nir_src_for_ssa(vtn_ssa_value(b, w[i + 5])->def);
+
+   if (intrin->intrinsic == nir_intrinsic_quad_swizzle_amd) {
+      struct vtn_value *val = vtn_value(b, w[6], vtn_value_type_constant);
+      unsigned mask = val->constant->values[0][0].u32 |
+                      val->constant->values[0][1].u32 << 2 |
+                      val->constant->values[0][2].u32 << 4 |
+                      val->constant->values[0][3].u32 << 6;
+      nir_intrinsic_set_swizzle_mask(intrin, mask);
+
+   } else if (intrin->intrinsic == nir_intrinsic_masked_swizzle_amd) {
+      struct vtn_value *val = vtn_value(b, w[6], vtn_value_type_constant);
+      unsigned mask = val->constant->values[0][0].u32 |
+                      val->constant->values[0][1].u32 << 5 |
+                      val->constant->values[0][2].u32 << 10;
+      nir_intrinsic_set_swizzle_mask(intrin, mask);
+   }
+
+   nir_builder_instr_insert(&b->nb, &intrin->instr);
+   val->ssa->def = &intrin->dest.ssa;
+
+   return true;
+}
+
 bool
 vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ext_opcode,
                                                  const uint32_t *w, unsigned count)
index c0595709852156df03cc160dd49e3cec7dde8393..1b2be93621f4e2432c0c3137d07efe422e0f9fab 100644 (file)
@@ -833,6 +833,9 @@ vtn_u64_literal(const uint32_t *w)
 bool vtn_handle_amd_gcn_shader_instruction(struct vtn_builder *b, SpvOp ext_opcode,
                                            const uint32_t *words, unsigned count);
 
+bool vtn_handle_amd_shader_ballot_instruction(struct vtn_builder *b, SpvOp ext_opcode,
+                                              const uint32_t *w, unsigned count);
+
 bool vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ext_opcode,
                                                      const uint32_t *words, unsigned count);
 #endif /* _VTN_PRIVATE_H_ */
index ce795ec2cb51c8c61641e7621cf4bfe32841bf29..8339b1a4862c131d9855a39b57e8253e3a709aa5 100644 (file)
@@ -183,7 +183,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                                val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
       break;
 
-   case SpvOpGroupNonUniformBroadcast: ++w;
+   case SpvOpGroupNonUniformBroadcast:
+   case SpvOpGroupBroadcast: ++w;
    case SpvOpSubgroupReadInvocationKHR:
       vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
                                val->ssa, vtn_ssa_value(b, w[3]),
@@ -193,6 +194,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformAll:
    case SpvOpGroupNonUniformAny:
    case SpvOpGroupNonUniformAllEqual:
+   case SpvOpGroupAll:
+   case SpvOpGroupAny:
    case SpvOpSubgroupAllKHR:
    case SpvOpSubgroupAnyKHR:
    case SpvOpSubgroupAllEqualKHR: {
@@ -201,10 +204,12 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       nir_intrinsic_op op;
       switch (opcode) {
       case SpvOpGroupNonUniformAll:
+      case SpvOpGroupAll:
       case SpvOpSubgroupAllKHR:
          op = nir_intrinsic_vote_all;
          break;
       case SpvOpGroupNonUniformAny:
+      case SpvOpGroupAny:
       case SpvOpSubgroupAnyKHR:
          op = nir_intrinsic_vote_any;
          break;
@@ -232,8 +237,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       }
 
       nir_ssa_def *src0;
-      if (opcode == SpvOpGroupNonUniformAll ||
-          opcode == SpvOpGroupNonUniformAny ||
+      if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
+          opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
           opcode == SpvOpGroupNonUniformAllEqual) {
          src0 = vtn_ssa_value(b, w[4])->def;
       } else {
@@ -319,13 +324,33 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformBitwiseXor:
    case SpvOpGroupNonUniformLogicalAnd:
    case SpvOpGroupNonUniformLogicalOr:
-   case SpvOpGroupNonUniformLogicalXor: {
+   case SpvOpGroupNonUniformLogicalXor:
+   case SpvOpGroupIAdd:
+   case SpvOpGroupFAdd:
+   case SpvOpGroupFMin:
+   case SpvOpGroupUMin:
+   case SpvOpGroupSMin:
+   case SpvOpGroupFMax:
+   case SpvOpGroupUMax:
+   case SpvOpGroupSMax:
+   case SpvOpGroupIAddNonUniformAMD:
+   case SpvOpGroupFAddNonUniformAMD:
+   case SpvOpGroupFMinNonUniformAMD:
+   case SpvOpGroupUMinNonUniformAMD:
+   case SpvOpGroupSMinNonUniformAMD:
+   case SpvOpGroupFMaxNonUniformAMD:
+   case SpvOpGroupUMaxNonUniformAMD:
+   case SpvOpGroupSMaxNonUniformAMD: {
       nir_op reduction_op;
       switch (opcode) {
       case SpvOpGroupNonUniformIAdd:
+      case SpvOpGroupIAdd:
+      case SpvOpGroupIAddNonUniformAMD:
          reduction_op = nir_op_iadd;
          break;
       case SpvOpGroupNonUniformFAdd:
+      case SpvOpGroupFAdd:
+      case SpvOpGroupFAddNonUniformAMD:
          reduction_op = nir_op_fadd;
          break;
       case SpvOpGroupNonUniformIMul:
@@ -335,21 +360,33 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
          reduction_op = nir_op_fmul;
          break;
       case SpvOpGroupNonUniformSMin:
+      case SpvOpGroupSMin:
+      case SpvOpGroupSMinNonUniformAMD:
          reduction_op = nir_op_imin;
          break;
       case SpvOpGroupNonUniformUMin:
+      case SpvOpGroupUMin:
+      case SpvOpGroupUMinNonUniformAMD:
          reduction_op = nir_op_umin;
          break;
       case SpvOpGroupNonUniformFMin:
+      case SpvOpGroupFMin:
+      case SpvOpGroupFMinNonUniformAMD:
          reduction_op = nir_op_fmin;
          break;
       case SpvOpGroupNonUniformSMax:
+      case SpvOpGroupSMax:
+      case SpvOpGroupSMaxNonUniformAMD:
          reduction_op = nir_op_imax;
          break;
       case SpvOpGroupNonUniformUMax:
+      case SpvOpGroupUMax:
+      case SpvOpGroupUMaxNonUniformAMD:
          reduction_op = nir_op_umax;
          break;
       case SpvOpGroupNonUniformFMax:
+      case SpvOpGroupFMax:
+      case SpvOpGroupFMaxNonUniformAMD:
          reduction_op = nir_op_fmax;
          break;
       case SpvOpGroupNonUniformBitwiseAnd: