nir/algebraic: Make algebraic_parser_test.sh executable.
[mesa.git] / src / compiler / nir / nir_lower_subgroups.c
index f18ad00c370afcc70722cb9a5c7b5fcf243d7871..70d736b040f0dd44a7393f7615c0865ed5ec3772 100644 (file)
  * \file nir_opt_intrinsics.c
  */
 
+static nir_intrinsic_instr *
+lower_subgroups_64bit_split_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
+                                      unsigned int component)
+{
+   nir_ssa_def *comp;
+   if (component == 0)
+      comp = nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa);
+   else
+      comp = nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa);
+
+   nir_intrinsic_instr *intr = nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
+   nir_ssa_dest_init(&intr->instr, &intr->dest, 1, 32, NULL);
+   intr->const_index[0] = intrin->const_index[0];
+   intr->const_index[1] = intrin->const_index[1];
+   intr->src[0] = nir_src_for_ssa(comp);
+   if (nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2)
+      nir_src_copy(&intr->src[1], &intrin->src[1], intr);
+
+   intr->num_components = 1;
+   nir_builder_instr_insert(b, &intr->instr);
+   return intr;
+}
+
+static nir_ssa_def *
+lower_subgroup_op_to_32bit(nir_builder *b, nir_intrinsic_instr *intrin)
+{
+   assert(intrin->src[0].ssa->bit_size == 64);
+   nir_intrinsic_instr *intr_x = lower_subgroups_64bit_split_intrinsic(b, intrin, 0);
+   nir_intrinsic_instr *intr_y = lower_subgroups_64bit_split_intrinsic(b, intrin, 1);
+   return nir_pack_64_2x32_split(b, &intr_x->dest.ssa, &intr_y->dest.ssa);
+}
+
 static nir_ssa_def *
 ballot_type_to_uint(nir_builder *b, nir_ssa_def *value, unsigned bit_size)
 {
@@ -80,7 +112,8 @@ uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
 }
 
 static nir_ssa_def *
-lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
+lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin,
+                            bool lower_to_32bit)
 {
    /* This is safe to call on scalar things but it would be silly */
    assert(intrin->dest.ssa.num_components > 1);
@@ -107,9 +140,12 @@ lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
       chan_intrin->const_index[0] = intrin->const_index[0];
       chan_intrin->const_index[1] = intrin->const_index[1];
 
-      nir_builder_instr_insert(b, &chan_intrin->instr);
-
-      reads[i] = &chan_intrin->dest.ssa;
+      if (lower_to_32bit && chan_intrin->src[0].ssa->bit_size == 64) {
+         reads[i] = lower_subgroup_op_to_32bit(b, chan_intrin);
+      } else {
+         nir_builder_instr_insert(b, &chan_intrin->instr);
+         reads[i] = &chan_intrin->dest.ssa;
+      }
    }
 
    return nir_vec(b, reads, intrin->num_components);
@@ -141,9 +177,54 @@ lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
    return result;
 }
 
+static nir_ssa_def *
+lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
+                        const nir_lower_subgroups_options *options)
+{
+   assert(intrin->src[0].is_ssa);
+   nir_ssa_def *value = intrin->src[0].ssa;
+
+   /* We have to implicitly lower to scalar */
+   nir_ssa_def *all_eq = NULL;
+   for (unsigned i = 0; i < intrin->num_components; i++) {
+      nir_intrinsic_instr *rfi =
+         nir_intrinsic_instr_create(b->shader,
+                                    nir_intrinsic_read_first_invocation);
+      nir_ssa_dest_init(&rfi->instr, &rfi->dest,
+                        1, value->bit_size, NULL);
+      rfi->num_components = 1;
+      rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
+      nir_builder_instr_insert(b, &rfi->instr);
+
+      nir_ssa_def *is_eq;
+      if (intrin->intrinsic == nir_intrinsic_vote_feq) {
+         is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+      } else {
+         is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+      }
+
+      if (all_eq == NULL) {
+         all_eq = is_eq;
+      } else {
+         all_eq = nir_iand(b, all_eq, is_eq);
+      }
+   }
+
+   nir_intrinsic_instr *ballot =
+      nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+   nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+                     1, options->ballot_bit_size, NULL);
+   ballot->num_components = 1;
+   ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq));
+   nir_builder_instr_insert(b, &ballot->instr);
+
+   return nir_ieq(b, &ballot->dest.ssa,
+                  nir_imm_intN_t(b, 0, options->ballot_bit_size));
+}
+
 static nir_ssa_def *
 lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
-              bool lower_to_scalar)
+              bool lower_to_scalar, bool lower_to_32bit)
 {
    nir_ssa_def *index = nir_load_subgroup_invocation(b);
    switch (intrin->intrinsic) {
@@ -196,7 +277,9 @@ lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
                      intrin->dest.ssa.bit_size, NULL);
 
    if (lower_to_scalar && shuffle->num_components > 1) {
-      return lower_subgroup_op_to_scalar(b, shuffle);
+      return lower_subgroup_op_to_scalar(b, shuffle, lower_to_32bit);
+   } else if (lower_to_32bit && shuffle->src[0].ssa->bit_size == 64) {
+      return lower_subgroup_op_to_32bit(b, shuffle);
    } else {
       nir_builder_instr_insert(b, &shuffle->instr);
       return &shuffle->dest.ssa;
@@ -217,7 +300,10 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
    case nir_intrinsic_vote_feq:
    case nir_intrinsic_vote_ieq:
       if (options->lower_vote_trivial)
-         return nir_imm_int(b, NIR_TRUE);
+         return nir_imm_true(b);
+
+      if (options->lower_vote_eq_to_ballot)
+         return lower_vote_eq_to_ballot(b, intrin, options);
 
       if (options->lower_to_scalar && intrin->num_components > 1)
          return lower_vote_eq_to_scalar(b, intrin);
@@ -231,7 +317,7 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
    case nir_intrinsic_read_invocation:
    case nir_intrinsic_read_first_invocation:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin);
+         return lower_subgroup_op_to_scalar(b, intrin, false);
       break;
 
    case nir_intrinsic_load_subgroup_eq_mask:
@@ -311,7 +397,7 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
          assert(intrin->src[1].is_ssa);
          return nir_i2b(b, nir_iand(b, nir_ushr(b, int_val,
                                                    intrin->src[1].ssa),
-                                       nir_imm_int(b, 1)));
+                                       nir_imm_intN_t(b, 1, options->ballot_bit_size)));
       case nir_intrinsic_ballot_bit_count_reduce:
          return nir_bit_count(b, int_val);
       case nir_intrinsic_ballot_find_lsb:
@@ -353,16 +439,20 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
 
    case nir_intrinsic_shuffle:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin);
+         return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+      else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
+         return lower_subgroup_op_to_32bit(b, intrin);
       break;
 
    case nir_intrinsic_shuffle_xor:
    case nir_intrinsic_shuffle_up:
    case nir_intrinsic_shuffle_down:
       if (options->lower_shuffle)
-         return lower_shuffle(b, intrin, options->lower_to_scalar);
+         return lower_shuffle(b, intrin, options->lower_to_scalar, options->lower_shuffle_to_32bit);
       else if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin);
+         return lower_subgroup_op_to_scalar(b, intrin, options->lower_shuffle_to_32bit);
+      else if (options->lower_shuffle_to_32bit && intrin->src[0].ssa->bit_size == 64)
+         return lower_subgroup_op_to_32bit(b, intrin);
       break;
 
    case nir_intrinsic_quad_broadcast:
@@ -370,16 +460,16 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
    case nir_intrinsic_quad_swap_vertical:
    case nir_intrinsic_quad_swap_diagonal:
       if (options->lower_quad)
-         return lower_shuffle(b, intrin, options->lower_to_scalar);
+         return lower_shuffle(b, intrin, options->lower_to_scalar, false);
       else if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin);
+         return lower_subgroup_op_to_scalar(b, intrin, false);
       break;
 
    case nir_intrinsic_reduce:
    case nir_intrinsic_inclusive_scan:
    case nir_intrinsic_exclusive_scan:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_subgroup_op_to_scalar(b, intrin);
+         return lower_subgroup_op_to_scalar(b, intrin, false);
       break;
 
    default: