nir: Add subgroup shuffle intrinsics and lowering
authorJason Ekstrand <jason.ekstrand@intel.com>
Thu, 7 Dec 2017 05:41:47 +0000 (21:41 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Wed, 7 Mar 2018 20:13:47 +0000 (12:13 -0800)
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
src/compiler/nir/nir.h
src/compiler/nir/nir_intrinsics.h
src/compiler/nir/nir_lower_subgroups.c

index 804a6561bfff8c0cf506a2a1b66e2a9ce0050523..4480fff58a19a4d7a256dcf8b7639180a1fb185c 100644 (file)
@@ -2544,6 +2544,7 @@ typedef struct nir_lower_subgroups_options {
    bool lower_to_scalar:1;
    bool lower_vote_trivial:1;
    bool lower_subgroup_masks:1;
+   bool lower_shuffle:1;
 } nir_lower_subgroups_options;
 
 bool nir_lower_subgroups(nir_shader *shader,
index 7543d83e786e2e39e9c854193cd2691fc8c1e9bd..8c155a609cbcdc8dc9ca352bd1670849536a4934 100644 (file)
@@ -157,6 +157,16 @@ INTRINSIC(ballot_find_lsb, 1, ARR(4), true, 1, 0,
 INTRINSIC(ballot_find_msb, 1, ARR(4), true, 1, 0,
           0, xx, xx, xx, NIR_INTRINSIC_CAN_ELIMINATE)
 
+/** Shuffle operations from SPIR-V. */
+INTRINSIC(shuffle, 2, ARR(0, 1), true, 0, 0,
+          0, xx, xx, xx, NIR_INTRINSIC_CAN_ELIMINATE)
+INTRINSIC(shuffle_xor, 2, ARR(0, 1), true, 0, 0,
+          0, xx, xx, xx, NIR_INTRINSIC_CAN_ELIMINATE)
+INTRINSIC(shuffle_up, 2, ARR(0, 1), true, 0, 0,
+          0, xx, xx, xx, NIR_INTRINSIC_CAN_ELIMINATE)
+INTRINSIC(shuffle_down, 2, ARR(0, 1), true, 0, 0,
+          0, xx, xx, xx, NIR_INTRINSIC_CAN_ELIMINATE)
+
 /**
  * Basic Geometry Shader intrinsics.
  *
index c97849bf8ba0dc865521fae372ddf0bb2fcc26a8..03c5a1c7b3645faedd47561bdc5509d0cfe551f6 100644 (file)
@@ -80,7 +80,7 @@ uint_to_ballot_type(nir_builder *b, nir_ssa_def *value,
 }
 
 static nir_ssa_def *
-lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
+lower_subgroup_op_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
 {
    /* This is safe to call on scalar things but it would be silly */
    assert(intrin->dest.ssa.num_components > 1);
@@ -99,8 +99,10 @@ lower_read_invocation_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
       /* value */
       chan_intrin->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
       /* invocation */
-      if (intrin->intrinsic == nir_intrinsic_read_invocation)
+      if (nir_intrinsic_infos[intrin->intrinsic].num_srcs > 1) {
+         assert(nir_intrinsic_infos[intrin->intrinsic].num_srcs == 2);
          nir_src_copy(&chan_intrin->src[1], &intrin->src[1], chan_intrin);
+      }
 
       nir_builder_instr_insert(b, &chan_intrin->instr);
 
@@ -136,6 +138,45 @@ lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
    return result;
 }
 
+static nir_ssa_def *
+lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
+              bool lower_to_scalar)
+{
+   nir_ssa_def *index = nir_load_subgroup_invocation(b);
+   switch (intrin->intrinsic) {
+   case nir_intrinsic_shuffle_xor:
+      assert(intrin->src[1].is_ssa);
+      index = nir_ixor(b, index, intrin->src[1].ssa);
+      break;
+   case nir_intrinsic_shuffle_up:
+      assert(intrin->src[1].is_ssa);
+      index = nir_isub(b, index, intrin->src[1].ssa);
+      break;
+   case nir_intrinsic_shuffle_down:
+      assert(intrin->src[1].is_ssa);
+      index = nir_iadd(b, index, intrin->src[1].ssa);
+      break;
+   default:
+      unreachable("Invalid intrinsic");
+   }
+
+   nir_intrinsic_instr *shuffle =
+      nir_intrinsic_instr_create(b->shader, nir_intrinsic_shuffle);
+   shuffle->num_components = intrin->num_components;
+   nir_src_copy(&shuffle->src[0], &intrin->src[0], shuffle);
+   shuffle->src[1] = nir_src_for_ssa(index);
+   nir_ssa_dest_init(&shuffle->instr, &shuffle->dest,
+                     intrin->dest.ssa.num_components,
+                     intrin->dest.ssa.bit_size, NULL);
+
+   if (lower_to_scalar && shuffle->num_components > 1) {
+      return lower_subgroup_op_to_scalar(b, shuffle);
+   } else {
+      nir_builder_instr_insert(b, &shuffle->instr);
+      return &shuffle->dest.ssa;
+   }
+}
+
 static nir_ssa_def *
 lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
                        const nir_lower_subgroups_options *options)
@@ -164,7 +205,7 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
    case nir_intrinsic_read_invocation:
    case nir_intrinsic_read_first_invocation:
       if (options->lower_to_scalar && intrin->num_components > 1)
-         return lower_read_invocation_to_scalar(b, intrin);
+         return lower_subgroup_op_to_scalar(b, intrin);
       break;
 
    case nir_intrinsic_load_subgroup_eq_mask:
@@ -284,6 +325,20 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
       return nir_ieq(b, nir_load_subgroup_invocation(b), &first->dest.ssa);
    }
 
+   case nir_intrinsic_shuffle:
+      if (options->lower_to_scalar && intrin->num_components > 1)
+         return lower_subgroup_op_to_scalar(b, intrin);
+      break;
+
+   case nir_intrinsic_shuffle_xor:
+   case nir_intrinsic_shuffle_up:
+   case nir_intrinsic_shuffle_down:
+      if (options->lower_shuffle)
+         return lower_shuffle(b, intrin, options->lower_to_scalar);
+      else if (options->lower_to_scalar && intrin->num_components > 1)
+         return lower_subgroup_op_to_scalar(b, intrin);
+      break;
+
    default:
       break;
    }