Added few more stubs so that control reaches to DestroyDevice().
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
index c97849bf8ba0dc865521fae372ddf0bb2fcc26a8..541544e2474cdb0295c6181ae6d2b410b492c234 100644 (file)
  * \file nir_opt_intrinsics.c
  */
 
+static nir_intrinsic_instr *
+lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
+                                      unsigned int component)
+{
+   nir_ssa_def *comp;
+   if (component == 0)
+      comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
+   else
+      comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
+
+   nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
+   nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL);
+   intr->const_index[0] = intrin->const_index[0];
+   intr->const_index[1] = intrin->const_index[1];
+   intr->src[0] = nir_src_for_ssa(comp);
+   if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
+      nir_src_copy(&intr->src[1], &intrin->src[1], intr);
+
+   intr->num_components = 1;
+   nir_builder_instr_insert(b, &intr->instr);
+   return intr;
+}
+
+static nir_ssa_def *
+lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
+{
+   assert(intrin->src[0].ssa->bit_size == 64);
+   nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
+   nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
+   return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa);
+}
+
 static nir_ssa_def *
 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
 {
@@ -80,7 +112,8 @@ uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
 }
 
 static nir_ssa_def *
-lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
+lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
+                            bool lower_to_32bit)
 {
    /* This is safe to call on scalar things but it would be silly */
    assert(intrin->dest.ssa.num_components > 1);
@@ -99,12 +132,20 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
       /* value */
       chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
       /* invocation */
-      if (intrin->intrinsic == nir_intrinsic_read_invocation)
+      if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
+         assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
          nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
+      }
 
-      nir_builder_instr_insert(b, &chan_intrin->instr);
+      chan_intrin->const_index[0] = intrin->const_index[0];
+      chan_intrin->const_index[1] = intrin->const_index[1];
 
-      reads[i] = &chan_intrin->dest.ssa;
+      if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) {
+         reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin);
+      } else {
+         nir_builder_instr_insert(b, &chan_intrin->instr);
+         reads[i] = &chan_intrin->dest.ssa;
+      }
    }
 
    return nir_vec(b, reads, intrin->num_components);
@@ -137,9 +178,217 @@ lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
 }
 
 static nir_ssa_def *
-lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
-                       const nir_lower_subgroups_options *options)
+lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
+                        const nir_lower_subgroups_options *options)
+{
+   assert(intrin->src[0].is_ssa);
+   nir_ssa_def *value = intrin->src[0].ssa;
+
+   /* We have to implicitly lower to scalar */
+   nir_ssa_def *all_eq = NULL;
+   for (unsigned i = 0; i < intrin->num_components; i++) {
+      nir_intrinsic_instr *rfi =
+         nir_intrinsic_instr_create(b->shader,
+                                    nir_intrinsic_read_first_invocation);
+      nir_ssa_dest_init(&rfi->instr, &rfi->dest,
+                        1, value->bit_size, NULL);
+      rfi->num_components = 1;
+      rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
+      nir_builder_instr_insert(b, &rfi->instr);
+
+      nir_ssa_def *is_eq;
+      if (intrin->intrinsic == nir_intrinsic_vote_feq) {
+         is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+      } else {
+         is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+      }
+
+      if (all_eq == NULL) {
+         all_eq = is_eq;
+      } else {
+         all_eq = nir_iand(b, all_eq, is_eq);
+      }
+   }
+
+   nir_intrinsic_instr *ballot =
+      nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+   nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+                     1, options->ballot_bit_size, NULL);
+   ballot->num_components = 1;
+   ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq));
+   nir_builder_instr_insert(b, &ballot->instr);
+
+   return nir_ieq(b, &ballot->dest.ssa,
+                  nir_imm_intN_t(b, 0, options->ballot_bit_size));
+}
+
+static nir_ssa_def *
+lower_shuffle_to_swizzle(nir_builder *b, nir_intrinsic_instr *intrin,
+                         const nir_lower_subgroups_options *options)
 {
+   unsigned mask = nir_src_as_uint(intrin->src[1]);
+
+   if (mask >= 32)
+      return NULL;
+
+   nir_intrinsic_instr *swizzle = nir_intrinsic_instr_create(
+      b->shader, nir_intrinsic_masked_swizzle_amd);
+   swizzle->num_components = intrin->num_components;
+   nir_src_copy(&swizzle->src[0], &intrin->src[0], swizzle);
+   nir_intrinsic_set_swizzle_mask(swizzle, (mask << 10) | 0x1f);
+   nir_ssa_dest_init(&swizzle->instr, &swizzle->dest,
+                     intrin->dest.ssa.num_components,
+                     intrin->dest.ssa.bit_size, NULL);
+
+   if (options->lower_to_scalar && swizzle->num_components > 1) {
+      return lower_subgroup_op_to_scalar(b, swizzle, options->lower_shuffle_to_32bit);
+   } else if (options->lower_shuffle_to_32bit && swizzle->src[0].ssa->bit_size == 64) {
+      return lower_subgroup_op_to_32bit(b, swizzle);
+   } else {
+      nir_builder_instr_insert(b, &swizzle->instr);
+      return &swizzle->dest.ssa;
+   }
+}
+
+static nir_ssa_def *
+lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
+              const nir_lower_subgroups_options *options)
+{
+   if (intrin->intrinsic == nir_intrinsic_shuffle_xor &&
+       options->lower_shuffle_to_swizzle_amd &&
+       nir_src_is_const(intrin->src[1])) {
+      nir_ssa_def *result =
+         lower_shuffle_to_swizzle(b, intrin, options);
+      if (result)
+         return result;
+   }
+
+   nir_ssa_def *index = nir_load_subgroup_invocation(b);
+   bool is_shuffle = false;
+   switch (intrin->intrinsic) {
+   case nir_intrinsic_shuffle_xor:
+      assert(intrin->src[1].is_ssa);
+      index = nir_ixor(b, index, intrin->src[1].ssa);
+      is_shuffle = true;
+      break;
+   case nir_intrinsic_shuffle_up:
+      assert(intrin->src[1].is_ssa);
+      index = nir_isub(b, index, intrin->src[1].ssa);
+      is_shuffle = true;
+      break;
+   case nir_intrinsic_shuffle_down:
+      assert(intrin->src[1].is_ssa);
+      index = nir_iadd(b, index, intrin->src[1].ssa);
+      is_shuffle = true;
+      break;
+   case nir_intrinsic_quad_broadcast:
+      assert(intrin->src[1].is_ssa);
+      index = nir_ior(b, nir_iand(b, index, nir_imm_int(b, ~0x3)),
+                         intrin->src[1].ssa);
+      break;
+   case nir_intrinsic_quad_swap_horizontal:
+      /* For Quad operations, subgroups are divided into quads where
+       * (invocation % 4) is the index to a square arranged as follows:
+       *
+       *    +---+---+
+       *    | 0 | 1 |
+       *    +---+---+
+       *    | 2 | 3 |
+       *    +---+---+
+       */
+      index = nir_ixor(b, index, nir_imm_int(b, 0x1));
+      break;
+   case nir_intrinsic_quad_swap_vertical:
+      index = nir_ixor(b, index, nir_imm_int(b, 0x2));
+      break;
+   case nir_intrinsic_quad_swap_diagonal:
+      index = nir_ixor(b, index, nir_imm_int(b, 0x3));
+      break;
+   default:
+      unreachable("Invalid intrinsic");
+   }
+
+   nir_intrinsic_instr *shuffle =
+      nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle);
+   shuffle->num_components = intrin->num_components;
+   nir_src_copy(&shuffle->src[0], &intrin->src[0], shuffle);
+   shuffle->src[1] = nir_src_for_ssa(index);
+   nir_ssa_dest_init(&shuffle->instr, &shuffle->dest,
+                     intrin->dest.ssa.num_components,
+                     intrin->dest.ssa.bit_size, NULL);
+
+   bool lower_to_32bit = options->lower_shuffle_to_32bit && is_shuffle;
+   if (options->lower_to_scalar && shuffle->num_components > 1) {
+      return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
+   } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
+      return lower_subgroup_op_to_32bit(b, shuffle);
+   } else {
+      nir_builder_instr_insert(b, &shuffle->instr);
+      return &shuffle->dest.ssa;
+   }
+}
+
+static bool
+lower_subgroups_filter(const nir_instr *instr, const void *_options)
+{
+   return instr->type == nir_instr_type_intrinsic;
+}
+
+static nir_ssa_def *
+build_subgroup_mask(nir_builder *b, unsigned bit_size,
+                    const nir_lower_subgroups_options *options)
+{
+   return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
+                      nir_isub(b, nir_imm_int(b, bit_size),
+                                  nir_load_subgroup_size(b)));
+}
+
+static nir_ssa_def *
+lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
+                             const nir_lower_subgroups_options *options)
+{
+   if (!options->lower_quad_broadcast_dynamic_to_const)
+      return lower_shuffle(b, intrin, options);
+
+   nir_ssa_def *dst = NULL;
+
+   for (unsigned i = 0; i < 4; ++i) {
+      nir_intrinsic_instr *qbcst =
+         nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast);
+
+      qbcst->num_components = intrin->num_components;
+      qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i));
+      nir_src_copy(&qbcst->src[0], &intrin->src[0], qbcst);
+      nir_ssa_dest_init(&qbcst->instr, &qbcst->dest,
+                        intrin->dest.ssa.num_components,
+                        intrin->dest.ssa.bit_size, NULL);
+
+      nir_ssa_def *qbcst_dst = NULL;
+
+      if (options->lower_to_scalar && qbcst->num_components > 1) {
+         qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
+      } else {
+         nir_builder_instr_insert(b, &qbcst->instr);
+         qbcst_dst = &qbcst->dest.ssa;
+      }
+
+      if (i)
+         dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa,
+                                    nir_src_for_ssa(nir_imm_int(b, i)).ssa),
+                         qbcst_dst, dst);
+      else
+         dst = qbcst_dst;
+   }
+
+   return dst;
+}
+
+static nir_ssa_def *
+lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
+{
+   const nir_lower_subgroups_options *options = _options;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
    switch (intrin->intrinsic) {
    case nir_intrinsic_vote_any:
    case nir_intrinsic_vote_all:
@@ -150,7 +399,10 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
    case nir_intrinsic_vote_feq:
    case nir_intrinsic_vote_ieq:
       if (options->lower_vote_trivial)
-         return nir_imm_int(b, NIR_TRUE);
+         return nir_imm_true(b);
+
+      if (options->lower_vote_eq_to_ballot)
+         return lower_vote_eq_to_ballot(b, intrin, options);
 
       if (options->lower_to_scalar && intrin->num_components > 1)
          return lower_vote_eq_to_scalar(b, intrin);
@@ -164,7 +416,7 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
    case nir_intrinsic_read_invocation:
    case nir_intrinsic_read_first_invocation:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_read_invocation_to_scalar(b, intrin);
+         return lower_subgroup_op_to_scalar(b, intrin, false);
       break;
 
    case nir_intrinsic_load_subgroup_eq_mask:
@@ -182,9 +434,6 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
       const unsigned bit_size = MAX2(options->ballot_bit_size,
                                      intrin->dest.ssa.bit_size);
 
-      assert(options->subgroup_size <= 64);
-      uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
-
       nir_ssa_def *count = nir_load_subgroup_invocation(b);
       nir_ssa_def *val;
       switch (intrin->intrinsic) {
@@ -193,11 +442,11 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
          break;
       case nir_intrinsic_load_subgroup_ge_mask:
          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
-                           nir_imm_intN_t(b, group_mask, bit_size));
+                           build_subgroup_mask(b, bit_size, options));
          break;
       case nir_intrinsic_load_subgroup_gt_mask:
          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
-                           nir_imm_intN_t(b, group_mask, bit_size));
+                           build_subgroup_mask(b, bit_size, options));
          break;
       case nir_intrinsic_load_subgroup_le_mask:
          val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
@@ -239,12 +488,38 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
       assert(intrin->src[0].is_ssa);
       nir_ssa_def *int_val = ballot_type_to_uint(b, intrin->src[0].ssa,
                                                  options->ballot_bit_size);
+
+      if (intrin->intrinsic != nir_intrinsic_ballot_bitfield_extract &&
+          intrin->intrinsic != nir_intrinsic_ballot_find_lsb) {
+         /* For OpGroupNonUniformBallotFindMSB, the SPIR-V Spec says:
+          *
+          *    "Find the most significant bit set to 1 in Value, considering
+          *    only the bits in Value required to represent all bits of the
+          *    group’s invocations.  If none of the considered bits is set to
+          *    1, the result is undefined."
+          *
+          * It has similar text for the other three.  This means that, in case
+          * the subgroup size is less than 32, we have to mask off the unused
+          * bits.  If the subgroup size is fixed and greater than or equal to
+          * 32, the mask will be 0xffffffff and nir_opt_algebraic will delete
+          * the iand.
+          *
+          * We only have to worry about this for BitCount and FindMSB because
+          * FindLSB counts from the bottom and BitfieldExtract selects
+          * individual bits.  In either case, if run outside the range of
+          * valid bits, we hit the undefined results case and we can return
+          * anything we want.
+          */
+         int_val = nir_iand(b, int_val,
+            build_subgroup_mask(b, options->ballot_bit_size, options));
+      }
+
       switch (intrin->intrinsic) {
       case nir_intrinsic_ballot_bitfield_extract:
          assert(intrin->src[1].is_ssa);
          return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
                                                    intrin->src[1].ssa),
-                                       nir_imm_int(b, 1)));
+                                       nir_imm_intN_t(b, 1, options->ballot_bit_size)));
       case nir_intrinsic_ballot_bit_count_reduce:
          return nir_bit_count(b, int_val);
       case nir_intrinsic_ballot_find_lsb:
@@ -284,58 +559,67 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
       return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa);
    }
 
-   default:
+   case nir_intrinsic_shuffle:
+      if (options->lower_to_scalar && intrin->num_components > 1)
+         return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+      else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
+         return lower_subgroup_op_to_32bit(b, intrin);
+      break;
+   case nir_intrinsic_shuffle_xor:
+   case nir_intrinsic_shuffle_up:
+   case nir_intrinsic_shuffle_down:
+      if (options->lower_shuffle)
+         return lower_shuffle(b, intrin, options);
+      else if (options->lower_to_scalar && intrin->num_components > 1)
+         return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+      else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
+         return lower_subgroup_op_to_32bit(b, intrin);
       break;
-   }
-
-   return NULL;
-}
-
-static bool
-lower_subgroups_impl(nir_function_impl *impl,
-                     const nir_lower_subgroups_options *options)
-{
-   nir_builder b;
-   nir_builder_init(&b, impl);
-   bool progress = false;
-
-   nir_foreach_block(block, impl) {
-      nir_foreach_instr_safe(instr, block) {
-         if (instr->type != nir_instr_type_intrinsic)
-            continue;
-
-         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-         b.cursor = nir_before_instr(instr);
 
-         nir_ssa_def *lower = lower_subgroups_intrin(&b, intrin, options);
-         if (!lower)
-            continue;
+   case nir_intrinsic_quad_broadcast:
+   case nir_intrinsic_quad_swap_horizontal:
+   case nir_intrinsic_quad_swap_vertical:
+   case nir_intrinsic_quad_swap_diagonal:
+      if (options->lower_quad ||
+          (options->lower_quad_broadcast_dynamic &&
+           intrin->intrinsic == nir_intrinsic_quad_broadcast &&
+           !nir_src_is_const(intrin->src[1])))
+         return lower_dynamic_quad_broadcast(b, intrin, options);
+      else if (options->lower_to_scalar && intrin->num_components > 1)
+         return lower_subgroup_op_to_scalar(b, intrin, false);
+      break;
 
-         nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_src_for_ssa(lower));
-         nir_instr_remove(instr);
-         progress = true;
+   case nir_intrinsic_reduce: {
+      nir_ssa_def *ret = NULL;
+      /* A cluster size greater than the subgroup size is implemention defined */
+      if (options->subgroup_size &&
+          nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
+         nir_intrinsic_set_cluster_size(intrin, 0);
+         ret = NIR_LOWER_INSTR_PROGRESS;
       }
+      if (options->lower_to_scalar && intrin->num_components > 1)
+         ret = lower_subgroup_op_to_scalar(b, intrin, false);
+      return ret;
+   }
+   case nir_intrinsic_inclusive_scan:
+   case nir_intrinsic_exclusive_scan:
+      if (options->lower_to_scalar && intrin->num_components > 1)
+         return lower_subgroup_op_to_scalar(b, intrin, false);
+      break;
+
+   default:
+      break;
    }
 
-   return progress;
+   return NULL;
 }
 
 bool
 nir_lower_subgroups(nir_shader *shader,
                     const nir_lower_subgroups_options *options)
 {
-   bool progress = false;
-
-   nir_foreach_function(function, shader) {
-      if (!function->impl)
-         continue;
-
-      if (lower_subgroups_impl(function->impl, options)) {
-         progress = true;
-         nir_metadata_preserve(function->impl, nir_metadata_block_index |
-                                               nir_metadata_dominance);
-      }
-   }
-
-   return progress;
+   return nir_shader_lower_instructions(shader,
+                                        lower_subgroups_filter,
+                                        lower_subgroups_instr,
+                                        (void *)options);
 }