nir/lower_subgroups: Lower ballot intrinsics to the specified bit size
authorJason Ekstrand <jason.ekstrand@intel.com>
Wed, 23 Aug 2017 01:44:51 +0000 (18:44 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Tue, 7 Nov 2017 18:37:52 +0000 (10:37 -0800)
Ballot intrinsics return a bitfield of subgroups.  In GLSL and some
SPIR-V extensions, they return a uint64_t.  In SPV_KHR_shader_ballot,
they return a uvec4.  Also, some back-ends would rather pass around
32-bit values because it's easier than messing with 64-bit all the time.
To solve this mess, we make nir_lower_subgroups take a new parameter
called ballot_bit_size and it lowers whichever thing it gets in from the
source language (uint64_t or uvec4) to a scalar with the specified
number of bits.  This replaces a chunk of the old lowering code.

Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
src/compiler/nir/nir.h
src/compiler/nir/nir_lower_subgroups.c
src/compiler/nir/nir_opt_intrinsics.c
src/intel/compiler/brw_compiler.c
src/intel/compiler/brw_nir.c

index 7ee4bfcc5571887481546390801965f19683a9d1..c047ab7512b4b7bd44a2402a7fb7a9b58550b615 100644 (file)
@@ -1854,8 +1854,6 @@ typedef struct nir_shader_compiler_options {
     */
    bool use_interpolated_input_intrinsics;
 
-   unsigned max_subgroup_size;
-
    unsigned max_unroll_iterations;
 } nir_shader_compiler_options;
 
@@ -2486,6 +2484,7 @@ bool nir_lower_samplers_as_deref(nir_shader *shader,
                                  const struct gl_shader_program *shader_program);
 
 typedef struct nir_lower_subgroups_options {
+   uint8_t ballot_bit_size;
    bool lower_to_scalar:1;
    bool lower_vote_trivial:1;
    bool lower_subgroup_masks:1;
index 0d11dc9c23a5ef165eab2a008e595b9dd845beb6..76e831691ee2b002811d2a02b46a70328ce09708 100644 (file)
  * \file nir_opt_intrinsics.c
  */
 
+/* Converts a uint32_t or uint64_t value to uint64_t or uvec4 */
+static nir_ssa_def *
+uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
+                    unsigned num_components, unsigned bit_size)
+{
+   assert(value->num_components == 1);
+   assert(value->bit_size == 32 || value->bit_size == 64);
+
+   nir_ssa_def *zero = nir_imm_int(b, 0);
+   if (num_components > 1) {
+      /* SPIR-V uses a uvec4 for ballot values */
+      assert(num_components == 4);
+      assert(bit_size == 32);
+
+      if (value->bit_size == 32) {
+         return nir_vec4(b, value, zero, zero, zero);
+      } else {
+         assert(value->bit_size == 64);
+         return nir_vec4(b, nir_unpack_64_2x32_split_x(b, value),
+                            nir_unpack_64_2x32_split_y(b, value),
+                            zero, zero);
+      }
+   } else {
+      /* GLSL uses a uint64_t for ballot values */
+      assert(num_components == 1);
+      assert(bit_size == 64);
+
+      if (value->bit_size == 32) {
+         return nir_pack_64_2x32_split(b, value, zero);
+      } else {
+         assert(value->bit_size == 64);
+         return value;
+      }
+   }
+}
+
 static nir_ssa_def *
 lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
 {
@@ -62,7 +98,8 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
 static nir_ssa_def *
 high_subgroup_mask(nir_builder *b,
                    nir_ssa_def *count,
-                   uint64_t base_mask)
+                   uint64_t base_mask,
+                   unsigned bit_size)
 {
    /* group_mask could probably be calculated more efficiently but we want to
     * be sure not to shift by 64 if the subgroup size is 64 because the GLSL
@@ -71,10 +108,11 @@ high_subgroup_mask(nir_builder *b,
     * subgroup size is likely to be known at compile time.
     */
    nir_ssa_def *subgroup_size = nir_load_subgroup_size(b);
-   nir_ssa_def *all_bits = nir_imm_int64(b, ~0ull);
+   nir_ssa_def *all_bits = nir_imm_intN_t(b, ~0ull, bit_size);
    nir_ssa_def *shift = nir_isub(b, nir_imm_int(b, 64), subgroup_size);
    nir_ssa_def *group_mask = nir_ushr(b, all_bits, shift);
-   nir_ssa_def *higher_bits = nir_ishl(b, nir_imm_int64(b, base_mask), count);
+   nir_ssa_def *higher_bits =
+      nir_ishl(b, nir_imm_intN_t(b, base_mask, bit_size), count);
 
    return nir_iand(b, higher_bits, group_mask);
 }
@@ -109,24 +147,58 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
       if (!options->lower_subgroup_masks)
          return NULL;
 
-      nir_ssa_def *count = nir_load_subgroup_invocation(b);
+      /* If either the result or the requested bit size is 64-bits then we
+       * know that we have 64-bit types and using them will probably be more
+       * efficient than messing around with 32-bit shifts and packing.
+       */
+      const unsigned bit_size = MAX2(options->ballot_bit_size,
+                                     intrin->dest.ssa.bit_size);
 
+      nir_ssa_def *count = nir_load_subgroup_invocation(b);
+      nir_ssa_def *val;
       switch (intrin->intrinsic) {
       case nir_intrinsic_load_subgroup_eq_mask:
-         return nir_ishl(b, nir_imm_int64(b, 1ull), count);
+         val = nir_ishl(b, nir_imm_intN_t(b, 1ull, bit_size), count);
+         break;
       case nir_intrinsic_load_subgroup_ge_mask:
-         return high_subgroup_mask(b, count, ~0ull);
+         val = high_subgroup_mask(b, count, ~0ull, bit_size);
+         break;
       case nir_intrinsic_load_subgroup_gt_mask:
-         return high_subgroup_mask(b, count, ~1ull);
+         val = high_subgroup_mask(b, count, ~1ull, bit_size);
+         break;
       case nir_intrinsic_load_subgroup_le_mask:
-         return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~1ull), count));
+         val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
+         break;
       case nir_intrinsic_load_subgroup_lt_mask:
-         return nir_inot(b, nir_ishl(b, nir_imm_int64(b, ~0ull), count));
+         val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count));
+         break;
       default:
          unreachable("you seriously can't tell this is unreachable?");
       }
-      break;
+
+      return uint_to_ballot_type(b, val,
+                                 intrin->dest.ssa.num_components,
+                                 intrin->dest.ssa.bit_size);
+   }
+
+   case nir_intrinsic_ballot: {
+      if (intrin->dest.ssa.num_components == 1 &&
+          intrin->dest.ssa.bit_size == options->ballot_bit_size)
+         return NULL;
+
+      nir_intrinsic_instr *ballot =
+         nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+      ballot->num_components = 1;
+      nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+                        1, options->ballot_bit_size, NULL);
+      nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
+      nir_builder_instr_insert(b, &ballot->instr);
+
+      return uint_to_ballot_type(b, &ballot->dest.ssa,
+                                 intrin->dest.ssa.num_components,
+                                 intrin->dest.ssa.bit_size);
    }
+
    default:
       break;
    }
index 98c8b1a01e85c5319632acdbadff99e3d81bb938..eb394af0c10830b678c057859d7654b57e317bca 100644 (file)
@@ -54,24 +54,6 @@ opt_intrinsics_impl(nir_function_impl *impl)
             if (nir_src_as_const_value(intrin->src[0]))
                replacement = nir_imm_int(&b, NIR_TRUE);
             break;
-         case nir_intrinsic_ballot: {
-            assert(b.shader->options->max_subgroup_size != 0);
-            if (b.shader->options->max_subgroup_size > 32 ||
-                intrin->dest.ssa.bit_size <= 32)
-               continue;
-
-            nir_intrinsic_instr *ballot =
-               nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot);
-            nir_ssa_dest_init(&ballot->instr, &ballot->dest, 1, 32, NULL);
-            nir_src_copy(&ballot->src[0], &intrin->src[0], ballot);
-
-            nir_builder_instr_insert(&b, &ballot->instr);
-
-            replacement = nir_pack_64_2x32_split(&b,
-                                                 &ballot->dest.ssa,
-                                                 nir_imm_int(&b, 0));
-            break;
-         }
          default:
             break;
          }
index 8c709b55a10bd6169c3b9e8d33960114102cb356..e89aeacc7d2f3a829f3d105235b3135a91f2e8a5 100644 (file)
@@ -57,7 +57,6 @@ static const struct nir_shader_compiler_options scalar_nir_options = {
    .lower_unpack_snorm_4x8 = true,
    .lower_unpack_unorm_2x16 = true,
    .lower_unpack_unorm_4x8 = true,
-   .max_subgroup_size = 32,
    .max_unroll_iterations = 32,
 };
 
index f599f748a8a29fa0a2833f099d256256be3619ab..0d59d36ca6388e77299a6627f879bb657792dbc5 100644 (file)
@@ -637,6 +637,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir)
    OPT(nir_lower_system_values);
 
    const nir_lower_subgroups_options subgroups_options = {
+      .ballot_bit_size = 32,
       .lower_to_scalar = true,
       .lower_subgroup_masks = true,
       .lower_vote_trivial = !is_scalar,