nir: Add a lowering pass to split 64bit phis
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
index 4462c708ec8a1f91dc6128dcc3492db28c1b9b51..541544e2474cdb0295c6181ae6d2b410b492c234 100644 (file)
@@ -222,23 +222,64 @@ lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
                   nir_imm_intN_t(b, 0, options->ballot_bit_size));
 }
 
+static nir_ssa_def *
+lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin,
+                         const nir_lower_subgroups_options *options)
+{
+   unsigned mask = nir_src_as_uint(intrin->src[1]);
+
+   if (mask >= 32)
+      return NULL;
+
+   nir_intrinsic_instr *swizzle = nir_intrinsic_instr_create(
+      b->shader, nir_intrinsic_masked_swizzle_amd);
+   swizzle->num_components = intrin->num_components;
+   nir_src_copy(&swizzle->src[0], &intrin->src[0], swizzle);
+   nir_intrinsic_set_swizzle_mask(swizzle, (mask << 10) | 0x1f);
+   nir_ssa_dest_init(&swizzle->instr, &swizzle->dest,
+                     intrin->dest.ssa.num_components,
+                     intrin->dest.ssa.bit_size, NULL);
+
+   if (options->lower_to_scalar && swizzle->num_components > 1) {
+      return lower_subgroup_op_to_scalar(b, swizzle, options->lower_shuffle_to_32bit);
+   } else if (options->lower_shuffle_to_32bit && swizzle->src[0].ssa->bit_size == 64) {
+      return lower_subgroup_op_to_32bit(b, swizzle);
+   } else {
+      nir_builder_instr_insert(b, &swizzle->instr);
+      return &swizzle->dest.ssa;
+   }
+}
+
 static nir_ssa_def *
 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
-              bool lower_to_scalar, bool lower_to_32bit)
+              const nir_lower_subgroups_options *options)
 {
+   if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
+       options->lower_shuffle_to_swizzle_amd &&
+       nir_src_is_const(intrin->src[1])) {
+      nir_ssa_def *result =
+         lower_shuffle_to_swizzle(b, intrin, options);
+      if (result)
+         return result;
+   }
+
    nir_ssa_def *index = nir_load_subgroup_invocation(b);
+   bool is_shuffle = false;
    switch (intrin->intrinsic) {
    case nir_intrinsic_shuffle_xor:
       assert(intrin->src[1].is_ssa);
       index = nir_ixor(b, index, intrin->src[1].ssa);
+      is_shuffle = true;
       break;
    case nir_intrinsic_shuffle_up:
       assert(intrin->src[1].is_ssa);
       index = nir_isub(b, index, intrin->src[1].ssa);
+      is_shuffle = true;
       break;
    case nir_intrinsic_shuffle_down:
       assert(intrin->src[1].is_ssa);
       index = nir_iadd(b, index, intrin->src[1].ssa);
+      is_shuffle = true;
       break;
    case nir_intrinsic_quad_broadcast:
       assert(intrin->src[1].is_ssa);
@@ -276,7 +317,8 @@ lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
                      intrin->dest.ssa.num_components,
                      intrin->dest.ssa.bit_size, NULL);
 
-   if (lower_to_scalar && shuffle->num_components > 1) {
+   bool lower_to_32bit = options->lower_shuffle_to_32bit && is_shuffle;
+   if (options->lower_to_scalar && shuffle->num_components > 1) {
       return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
    } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
       return lower_subgroup_op_to_32bit(b, shuffle);
@@ -301,6 +343,46 @@ build_subgroup_mask(nir_builder *b, unsigned bit_size,
                                   nir_load_subgroup_size(b)));
 }
 
+static nir_ssa_def *
+lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
+                             const nir_lower_subgroups_options *options)
+{
+   if (!options->lower_quad_broadcast_dynamic_to_const)
+      return lower_shuffle(b, intrin, options);
+
+   nir_ssa_def *dst = NULL;
+
+   for (unsigned i = 0; i < 4; ++i) {
+      nir_intrinsic_instr *qbcst =
+         nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast);
+
+      qbcst->num_components = intrin->num_components;
+      qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i));
+      nir_src_copy(&qbcst->src[0], &intrin->src[0], qbcst);
+      nir_ssa_dest_init(&qbcst->instr, &qbcst->dest,
+                        intrin->dest.ssa.num_components,
+                        intrin->dest.ssa.bit_size, NULL);
+
+      nir_ssa_def *qbcst_dst = NULL;
+
+      if (options->lower_to_scalar && qbcst->num_components > 1) {
+         qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
+      } else {
+         nir_builder_instr_insert(b, &qbcst->instr);
+         qbcst_dst = &qbcst->dest.ssa;
+      }
+
+      if (i)
+         dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa,
+                                    nir_src_for_ssa(nir_imm_int(b, i)).ssa),
+                         qbcst_dst, dst);
+      else
+         dst = qbcst_dst;
+   }
+
+   return dst;
+}
+
 static nir_ssa_def *
 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
 {
@@ -406,6 +488,32 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
       assert(intrin->src[0].is_ssa);
       nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
                                                  options->ballot_bit_size);
+
+      if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
+          intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
+         /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
+          *
+          *    "Find the most significant bit set to 1 in Value, considering
+          *    only the bits in Value required to represent all bits of the
+          *    group’s invocations.  If none of the considered bits is set to
+          *    1, the result is undefined."
+          *
+          * It has similar text for the other three.  This means that, in case
+          * the subgroup size is less than 32, we have to mask off the unused
+          * bits.  If the subgroup size is fixed and greater than or equal to
+          * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
+          * the iand.
+          *
+          * We only have to worry about this for BitCount and FindMSB because
+          * FindLSB counts from the bottom and BitfieldExtract selects
+          * individual bits.  In either case, if run outside the range of
+          * valid bits, we hit the undefined results case and we can return
+          * anything we want.
+          */
+         int_val = nir_iand(b, int_val,
+            build_subgroup_mask(b, options->ballot_bit_size, options));
+      }
+
       switch (intrin->intrinsic) {
       case nir_intrinsic_ballot_bitfield_extract:
          assert(intrin->src[1].is_ssa);
@@ -457,12 +565,11 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
          return lower_subgroup_op_to_32bit(b, intrin);
       break;
-
    case nir_intrinsic_shuffle_xor:
    case nir_intrinsic_shuffle_up:
    case nir_intrinsic_shuffle_down:
       if (options->lower_shuffle)
-         return lower_shuffle(b, intrin, options->lower_to_scalar, options->lower_shuffle_to_32bit);
+         return lower_shuffle(b, intrin, options);
       else if (options->lower_to_scalar && intrin->num_components > 1)
          return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
       else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
@@ -477,7 +584,7 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
           (options->lower_quad_broadcast_dynamic &&
            intrin->intrinsic == nir_intrinsic_quad_broadcast &&
            !nir_src_is_const(intrin->src[1])))
-         return lower_shuffle(b, intrin, options->lower_to_scalar, false);
+         return lower_dynamic_quad_broadcast(b, intrin, options);
       else if (options->lower_to_scalar && intrin->num_components > 1)
          return lower_subgroup_op_to_scalar(b, intrin, false);
       break;