nir/lower_subgroups: Properly lower masks when subgroup_size == 0
authorJason Ekstrand <jason@jlekstrand.net>
Thu, 11 Jul 2019 03:20:00 +0000 (22:20 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Wed, 24 Jul 2019 17:55:40 +0000 (12:55 -0500)
Instead of building a constant mask (which depends on knowing the
subgroup size), we build an expression.  Because the pass uses the
nir_shader_lower_instructions helper, subgroup lowering will be run on
any newly emitted instructions as well as the previously existing
instructions.  In particular, if the subgroup size is known, the newly
emitted subgroup_size intrinsic will get turned into a constant and a
later constant folding pass will clean it up.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
src/compiler/nir/nir_lower_subgroups.c

index eca441fcf3954b2c4150037367ad3e6467a9257d..1e2e3f0eebf2003d834f1fe76a1d9e5f27e757f4 100644 (file)
@@ -292,6 +292,15 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options)
    return instr->type == nir_instr_type_intrinsic;
 }
 
+static nir_ssa_def *
+build_subgroup_mask(nir_builder *b, unsigned bit_size,
+                    const nir_lower_subgroups_options *options)
+{
+   return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
+                      nir_isub(b, nir_imm_int(b, bit_size),
+                                  nir_load_subgroup_size(b)));
+}
+
 static nir_ssa_def *
 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
 {
@@ -343,9 +352,6 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
       const unsigned bit_size = MAX2(options->ballot_bit_size,
                                      intrin->dest.ssa.bit_size);
 
-      assert(options->subgroup_size <= 64);
-      uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
-
       nir_ssa_def *count = nir_load_subgroup_invocation(b);
       nir_ssa_def *val;
       switch (intrin->intrinsic) {
@@ -354,11 +360,11 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
          break;
       case nir_intrinsic_load_subgroup_ge_mask:
          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
-                           nir_imm_intN_t(b, group_mask, bit_size));
+                           build_subgroup_mask(b, bit_size, options));
          break;
       case nir_intrinsic_load_subgroup_gt_mask:
          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
-                           nir_imm_intN_t(b, group_mask, bit_size));
+                           build_subgroup_mask(b, bit_size, options));
          break;
       case nir_intrinsic_load_subgroup_le_mask:
          val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));