nir: Use b2b opcodes for shared and constant memory
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
index eca441fcf3954b2c4150037367ad3e6467a9257d..f5eebb851446dce901085dd9f5c110e5688e5bd5 100644 (file)
@@ -292,6 +292,55 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options)
    return instr->type == nir_instr_type_intrinsic;
 }
 
+static nir_ssa_def *
+build_subgroup_mask(nir_builder *b, unsigned bit_size,
+                    const nir_lower_subgroups_options *options)
+{
+   return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
+                      nir_isub(b, nir_imm_int(b, bit_size),
+                                  nir_load_subgroup_size(b)));
+}
+
+static nir_ssa_def *
+lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin,
+                             const nir_lower_subgroups_options *options)
+{
+   if (!options->lower_quad_broadcast_dynamic_to_const)
+      return lower_shuffle(b, intrin, options->lower_to_scalar, false);
+
+   nir_ssa_def *dst = NULL;
+
+   for (unsigned i = 0; i < 4; ++i) {
+      nir_intrinsic_instr *qbcst =
+         nir_intrinsic_instr_create(b->shader, nir_intrinsic_quad_broadcast);
+
+      qbcst->num_components = intrin->num_components;
+      qbcst->src[1] = nir_src_for_ssa(nir_imm_int(b, i));
+      nir_src_copy(&qbcst->src[0], &intrin->src[0], qbcst);
+      nir_ssa_dest_init(&qbcst->instr, &qbcst->dest,
+                        intrin->dest.ssa.num_components,
+                        intrin->dest.ssa.bit_size, NULL);
+
+      nir_ssa_def *qbcst_dst = NULL;
+
+      if (options->lower_to_scalar && qbcst->num_components > 1) {
+         qbcst_dst = lower_subgroup_op_to_scalar(b, qbcst, false);
+      } else {
+         nir_builder_instr_insert(b, &qbcst->instr);
+         qbcst_dst = &qbcst->dest.ssa;
+      }
+
+      if (i)
+         dst = nir_bcsel(b, nir_ieq(b, intrin->src[1].ssa,
+                                    nir_src_for_ssa(nir_imm_int(b, i)).ssa),
+                         qbcst_dst, dst);
+      else
+         dst = qbcst_dst;
+   }
+
+   return dst;
+}
+
 static nir_ssa_def *
 lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
 {
@@ -343,9 +392,6 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
       const unsigned bit_size = MAX2(options->ballot_bit_size,
                                      intrin->dest.ssa.bit_size);
 
-      assert(options->subgroup_size <= 64);
-      uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
-
       nir_ssa_def *count = nir_load_subgroup_invocation(b);
       nir_ssa_def *val;
       switch (intrin->intrinsic) {
@@ -354,11 +400,11 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
          break;
       case nir_intrinsic_load_subgroup_ge_mask:
          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
-                           nir_imm_intN_t(b, group_mask, bit_size));
+                           build_subgroup_mask(b, bit_size, options));
          break;
       case nir_intrinsic_load_subgroup_gt_mask:
          val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
-                           nir_imm_intN_t(b, group_mask, bit_size));
+                           build_subgroup_mask(b, bit_size, options));
          break;
       case nir_intrinsic_load_subgroup_le_mask:
          val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));
@@ -467,13 +513,27 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
    case nir_intrinsic_quad_swap_horizontal:
    case nir_intrinsic_quad_swap_vertical:
    case nir_intrinsic_quad_swap_diagonal:
-      if (options->lower_quad)
-         return lower_shuffle(b, intrin, options->lower_to_scalar, false);
+      if (options->lower_quad ||
+          (options->lower_quad_broadcast_dynamic &&
+           intrin->intrinsic == nir_intrinsic_quad_broadcast &&
+           !nir_src_is_const(intrin->src[1])))
+         return lower_dynamic_quad_broadcast(b, intrin, options);
       else if (options->lower_to_scalar && intrin->num_components > 1)
          return lower_subgroup_op_to_scalar(b, intrin, false);
       break;
 
-   case nir_intrinsic_reduce:
+   case nir_intrinsic_reduce: {
+      nir_ssa_def *ret = NULL;
+      /* A cluster size greater than the subgroup size is implemention defined */
+      if (options->subgroup_size &&
+          nir_intrinsic_cluster_size(intrin) >= options->subgroup_size) {
+         nir_intrinsic_set_cluster_size(intrin, 0);
+         ret = NIR_LOWER_INSTR_PROGRESS;
+      }
+      if (options->lower_to_scalar && intrin->num_components > 1)
+         ret = lower_subgroup_op_to_scalar(b, intrin, false);
+      return ret;
+   }
    case nir_intrinsic_inclusive_scan:
    case nir_intrinsic_exclusive_scan:
       if (options->lower_to_scalar && intrin->num_components > 1)