spirv: Add subgroup ballot support
authorJason Ekstrand <jason.ekstrand@intel.com>
Tue, 22 Aug 2017 23:53:05 +0000 (16:53 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Wed, 7 Mar 2018 20:13:47 +0000 (12:13 -0800)
Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
src/compiler/shader_info.h
src/compiler/spirv/spirv_to_nir.c
src/compiler/spirv/vtn_subgroup.c
src/compiler/spirv/vtn_variables.c

index 0ba32349d0329a1ab98b72d481049d629f6164e4..f876111b0b182e37348dc0dba4757f0b68594bc6 100644 (file)
@@ -45,6 +45,7 @@ struct spirv_supported_capabilities {
    bool variable_pointers;
    bool storage_16bit;
    bool shader_viewport_index_layer;
+   bool subgroup_ballot;
    bool subgroup_basic;
 };
 
index 4d2c1533d24b8960e3a4f4f30ec376bc01f59abb..38a1df9fd2106a07d001bb50fe778801327a82b1 100644 (file)
@@ -3296,6 +3296,11 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
          spv_check_supported(subgroup_basic, cap);
          break;
 
+      case SpvCapabilitySubgroupBallotKHR:
+      case SpvCapabilityGroupNonUniformBallot:
+         spv_check_supported(subgroup_ballot, cap);
+         break;
+
       case SpvCapabilityVariablePointersStorageBuffer:
       case SpvCapabilityVariablePointers:
          spv_check_supported(variable_pointers, cap);
index 033c43e601c1f4dae8eafbdb76b19eadfb057ed1..a86f0cb2832a0f3791f10f12a67a417623f45405 100644 (file)
 
 #include "vtn_private.h"
 
+static void
+vtn_build_subgroup_instr(struct vtn_builder *b,
+                         nir_intrinsic_op nir_op,
+                         struct vtn_ssa_value *dst,
+                         struct vtn_ssa_value *src0,
+                         nir_ssa_def *index)
+{
+   /* Some of the subgroup operations take an index.  SPIR-V allows this to be
+    * any integer type.  To make things simpler for drivers, we only support
+    * 32-bit indices.
+    */
+   if (index && index->bit_size != 32)
+      index = nir_u2u32(&b->nb, index);
+
+   vtn_assert(dst->type == src0->type);
+   if (!glsl_type_is_vector_or_scalar(dst->type)) {
+      for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
+         vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
+                                  src0->elems[i], index);
+      }
+      return;
+   }
+
+   nir_intrinsic_instr *intrin =
+      nir_intrinsic_instr_create(b->nb.shader, nir_op);
+   nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                              dst->type, NULL);
+   intrin->num_components = intrin->dest.ssa.num_components;
+
+   intrin->src[0] = nir_src_for_ssa(src0->def);
+   if (index)
+      intrin->src[1] = nir_src_for_ssa(index);
+
+   nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+   dst->def = &intrin->dest.ssa;
+}
+
 void
 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                     const uint32_t *w, unsigned count)
@@ -43,17 +81,106 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
-   case SpvOpGroupNonUniformAll:
-   case SpvOpGroupNonUniformAny:
-   case SpvOpGroupNonUniformAllEqual:
-   case SpvOpGroupNonUniformBroadcast:
-   case SpvOpGroupNonUniformBroadcastFirst:
-   case SpvOpGroupNonUniformBallot:
-   case SpvOpGroupNonUniformInverseBallot:
+   case SpvOpGroupNonUniformBallot: {
+      vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
+                  "OpGroupNonUniformBallot must return a uvec4");
+      nir_intrinsic_instr *ballot =
+         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
+      ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+      nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
+      ballot->num_components = 4;
+      nir_builder_instr_insert(&b->nb, &ballot->instr);
+      val->ssa->def = &ballot->dest.ssa;
+      break;
+   }
+
+   case SpvOpGroupNonUniformInverseBallot: {
+      /* This one is just a BallotBitfieldExtract with subgroup invocation.
+       * We could add a NIR intrinsic but it's easier to just lower it on the
+       * spot.
+       */
+      nir_intrinsic_instr *intrin =
+         nir_intrinsic_instr_create(b->nb.shader,
+                                    nir_intrinsic_ballot_bitfield_extract);
+
+      intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+      intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
+
+      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+      val->ssa->def = &intrin->dest.ssa;
+      break;
+   }
+
    case SpvOpGroupNonUniformBallotBitExtract:
    case SpvOpGroupNonUniformBallotBitCount:
    case SpvOpGroupNonUniformBallotFindLSB:
-   case SpvOpGroupNonUniformBallotFindMSB:
+   case SpvOpGroupNonUniformBallotFindMSB: {
+      nir_ssa_def *src0, *src1 = NULL;
+      nir_intrinsic_op op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformBallotBitExtract:
+         op = nir_intrinsic_ballot_bitfield_extract;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         src1 = vtn_ssa_value(b, w[5])->def;
+         break;
+      case SpvOpGroupNonUniformBallotBitCount:
+         switch ((SpvGroupOperation)w[4]) {
+         case SpvGroupOperationReduce:
+            op = nir_intrinsic_ballot_bit_count_reduce;
+            break;
+         case SpvGroupOperationInclusiveScan:
+            op = nir_intrinsic_ballot_bit_count_inclusive;
+            break;
+         case SpvGroupOperationExclusiveScan:
+            op = nir_intrinsic_ballot_bit_count_exclusive;
+            break;
+         default:
+            unreachable("Invalid group operation");
+         }
+         src0 = vtn_ssa_value(b, w[5])->def;
+         break;
+      case SpvOpGroupNonUniformBallotFindLSB:
+         op = nir_intrinsic_ballot_find_lsb;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         break;
+      case SpvOpGroupNonUniformBallotFindMSB:
+         op = nir_intrinsic_ballot_find_msb;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         break;
+      default:
+         unreachable("Unhandled opcode");
+      }
+
+      nir_intrinsic_instr *intrin =
+         nir_intrinsic_instr_create(b->nb.shader, op);
+
+      intrin->src[0] = nir_src_for_ssa(src0);
+      if (src1)
+         intrin->src[1] = nir_src_for_ssa(src1);
+
+      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+      val->ssa->def = &intrin->dest.ssa;
+      break;
+   }
+
+   case SpvOpGroupNonUniformBroadcastFirst:
+      vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
+                               val->ssa, vtn_ssa_value(b, w[4]), NULL);
+      break;
+
+   case SpvOpGroupNonUniformBroadcast:
+      vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
+                               val->ssa, vtn_ssa_value(b, w[4]),
+                               vtn_ssa_value(b, w[5])->def);
+      break;
+
+   case SpvOpGroupNonUniformAll:
+   case SpvOpGroupNonUniformAny:
+   case SpvOpGroupNonUniformAllEqual:
    case SpvOpGroupNonUniformShuffle:
    case SpvOpGroupNonUniformShuffleXor:
    case SpvOpGroupNonUniformShuffleUp:
index 68e1adf8152b0f59da999148d86a82f71c092bb2..61caaafa31179297a66b26de78516b432416c676 100644 (file)
@@ -1317,6 +1317,26 @@ vtn_get_builtin_location(struct vtn_builder *b,
       *location = SYSTEM_VALUE_VIEW_INDEX;
       set_mode_system_value(b, mode);
       break;
+   case SpvBuiltInSubgroupEqMask:
+      *location = SYSTEM_VALUE_SUBGROUP_EQ_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupGeMask:
+      *location = SYSTEM_VALUE_SUBGROUP_GE_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupGtMask:
+      *location = SYSTEM_VALUE_SUBGROUP_GT_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupLeMask:
+      *location = SYSTEM_VALUE_SUBGROUP_LE_MASK,
+      set_mode_system_value(b, mode);
+      break;
+   case SpvBuiltInSubgroupLtMask:
+      *location = SYSTEM_VALUE_SUBGROUP_LT_MASK,
+      set_mode_system_value(b, mode);
+      break;
    default:
       vtn_fail("unsupported builtin");
    }