intel/cs: Push subgroup ID instead of base thread ID
authorJason Ekstrand <jason.ekstrand@intel.com>
Thu, 24 Aug 2017 18:40:31 +0000 (11:40 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Tue, 7 Nov 2017 18:37:52 +0000 (10:37 -0800)
We're going to want subgroup ID for SPIR-V subgroups eventually anyway.
We really only want to push one and calculate the other from it.  It
makes a bit more sense to push the subgroup ID because it's simpler to
calculate and because it's a real API thing.  The only advantage to
pushing the base thread ID is to avoid a single SHL in the shader.

Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
src/compiler/nir/nir_intrinsics.h
src/intel/compiler/brw_compiler.h
src/intel/compiler/brw_fs.cpp
src/intel/compiler/brw_fs.h
src/intel/compiler/brw_fs_nir.cpp
src/intel/compiler/brw_nir.h
src/intel/compiler/brw_nir_lower_cs_intrinsics.c
src/intel/vulkan/anv_cmd_buffer.c
src/mesa/drivers/dri/i965/gen6_constant_state.c

index 47022dd135bf8a02dcf7ea191ab6985d2e89e5ca..bb8cfac6620abc8bfa5a6ac3714545cec0ff6063 100644 (file)
@@ -355,6 +355,7 @@ SYSTEM_VALUE(subgroup_ge_mask, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(subgroup_gt_mask, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(subgroup_le_mask, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(subgroup_lt_mask, 1, 0, xx, xx, xx)
+SYSTEM_VALUE(subgroup_id, 1, 0, xx, xx, xx)
 
 /* Blend constant color values.  Float values are clamped. */
 SYSTEM_VALUE(blend_const_color_r_float, 1, 0, xx, xx, xx)
@@ -364,9 +365,6 @@ SYSTEM_VALUE(blend_const_color_a_float, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(blend_const_color_rgba8888_unorm, 1, 0, xx, xx, xx)
 SYSTEM_VALUE(blend_const_color_aaaa8888_unorm, 1, 0, xx, xx, xx)
 
-/* Intel specific system values */
-SYSTEM_VALUE(intel_thread_local_id, 1, 0, xx, xx, xx)
-
 /**
  * Barycentric coordinate intrinsics.
  *
index 662f645e183ce978ad853ba8a01ccbe9a97b2adc..df6ee0185462d979b3bfa1c782220fa192224112 100644 (file)
@@ -552,7 +552,7 @@ enum brw_param_builtin {
    BRW_PARAM_BUILTIN_TESS_LEVEL_INNER_X,
    BRW_PARAM_BUILTIN_TESS_LEVEL_INNER_Y,
 
-   BRW_PARAM_BUILTIN_THREAD_LOCAL_ID,
+   BRW_PARAM_BUILTIN_SUBGROUP_ID,
 };
 
 #define BRW_PARAM_BUILTIN_CLIP_PLANE(idx, comp) \
index 006b72b19e115e6fb686536421893494bc7ebff0..40e64a482018c5777d66d23799190ec2a43689ce 100644 (file)
@@ -996,7 +996,7 @@ fs_visitor::import_uniforms(fs_visitor *v)
    this->push_constant_loc = v->push_constant_loc;
    this->pull_constant_loc = v->pull_constant_loc;
    this->uniforms = v->uniforms;
-   this->thread_local_id = v->thread_local_id;
+   this->subgroup_id = v->subgroup_id;
 }
 
 void
@@ -1931,14 +1931,14 @@ set_push_pull_constant_loc(unsigned uniform, int *chunk_start,
 }
 
 static int
-get_thread_local_id_param_index(const brw_stage_prog_data *prog_data)
+get_subgroup_id_param_index(const brw_stage_prog_data *prog_data)
 {
    if (prog_data->nr_params == 0)
       return -1;
 
    /* The local thread id is always the last parameter in the list */
    uint32_t last_param = prog_data->param[prog_data->nr_params - 1];
-   if (last_param == BRW_PARAM_BUILTIN_THREAD_LOCAL_ID)
+   if (last_param == BRW_PARAM_BUILTIN_SUBGROUP_ID)
       return prog_data->nr_params - 1;
 
    return -1;
@@ -2019,7 +2019,7 @@ fs_visitor::assign_constant_locations()
       }
    }
 
-   int thread_local_id_index = get_thread_local_id_param_index(stage_prog_data);
+   int subgroup_id_index = get_subgroup_id_param_index(stage_prog_data);
 
    /* Only allow 16 registers (128 uniform components) as push constants.
     *
@@ -2030,7 +2030,7 @@ fs_visitor::assign_constant_locations()
     * brw_curbe.c.
     */
    unsigned int max_push_components = 16 * 8;
-   if (thread_local_id_index >= 0)
+   if (subgroup_id_index >= 0)
       max_push_components--; /* Save a slot for the thread ID */
 
    /* We push small arrays, but no bigger than 16 floats.  This is big enough
@@ -2075,8 +2075,8 @@ fs_visitor::assign_constant_locations()
       if (!is_live[u])
          continue;
 
-      /* Skip thread_local_id_index to put it in the last push register. */
-      if (thread_local_id_index == (int)u)
+      /* Skip subgroup_id_index to put it in the last push register. */
+      if (subgroup_id_index == (int)u)
          continue;
 
       set_push_pull_constant_loc(u, &chunk_start, &max_chunk_bitsize,
@@ -2090,8 +2090,8 @@ fs_visitor::assign_constant_locations()
    }
 
    /* Add the CS local thread ID uniform at the end of the push constants */
-   if (thread_local_id_index >= 0)
-      push_constant_loc[thread_local_id_index] = num_push_constants++;
+   if (subgroup_id_index >= 0)
+      push_constant_loc[subgroup_id_index] = num_push_constants++;
 
    /* As the uniforms are going to be reordered, stash the old array and
     * create two new arrays for push/pull params.
@@ -6778,20 +6778,20 @@ cs_fill_push_const_info(const struct gen_device_info *devinfo,
                         struct brw_cs_prog_data *cs_prog_data)
 {
    const struct brw_stage_prog_data *prog_data = &cs_prog_data->base;
-   int thread_local_id_index = get_thread_local_id_param_index(prog_data);
+   int subgroup_id_index = get_subgroup_id_param_index(prog_data);
    bool cross_thread_supported = devinfo->gen > 7 || devinfo->is_haswell;
 
    /* The thread ID should be stored in the last param dword */
-   assert(thread_local_id_index == -1 ||
-          thread_local_id_index == (int)prog_data->nr_params - 1);
+   assert(subgroup_id_index == -1 ||
+          subgroup_id_index == (int)prog_data->nr_params - 1);
 
    unsigned cross_thread_dwords, per_thread_dwords;
    if (!cross_thread_supported) {
       cross_thread_dwords = 0u;
       per_thread_dwords = prog_data->nr_params;
-   } else if (thread_local_id_index >= 0) {
+   } else if (subgroup_id_index >= 0) {
       /* Fill all but the last register with cross-thread payload */
-      cross_thread_dwords = 8 * (thread_local_id_index / 8);
+      cross_thread_dwords = 8 * (subgroup_id_index / 8);
       per_thread_dwords = prog_data->nr_params - cross_thread_dwords;
       assert(per_thread_dwords > 0 && per_thread_dwords <= 8);
    } else {
@@ -6834,7 +6834,7 @@ compile_cs_to_nir(const struct brw_compiler *compiler,
 {
    nir_shader *shader = nir_shader_clone(mem_ctx, src_shader);
    shader = brw_nir_apply_sampler_key(shader, compiler, &key->tex, true);
-   brw_nir_lower_cs_intrinsics(shader);
+   brw_nir_lower_cs_intrinsics(shader, dispatch_width);
    return brw_postprocess_nir(shader, compiler, true);
 }
 
index f51a4d8889bd651b1ed2583782b52caaed5b078b..40dd83f45e45793c2324a1a3e9da9912de4784cd 100644 (file)
@@ -315,7 +315,7 @@ public:
     */
    int *push_constant_loc;
 
-   fs_reg thread_local_id;
+   fs_reg subgroup_id;
    fs_reg frag_depth;
    fs_reg frag_stencil;
    fs_reg sample_mask;
index 77d8bae4db60755bc9a944d8d436a7c0960b852b..39e7e692874804755bb04c8a58c47922be94d823 100644 (file)
@@ -95,8 +95,8 @@ fs_visitor::nir_setup_uniforms()
        */
       assert(uniforms == prog_data->nr_params);
       uint32_t *param = brw_stage_prog_data_add_params(prog_data, 1);
-      *param = BRW_PARAM_BUILTIN_THREAD_LOCAL_ID;
-      thread_local_id = fs_reg(UNIFORM, uniforms++, BRW_REGISTER_TYPE_UD);
+      *param = BRW_PARAM_BUILTIN_SUBGROUP_ID;
+      subgroup_id = fs_reg(UNIFORM, uniforms++, BRW_REGISTER_TYPE_UD);
    }
 }
 
@@ -3422,8 +3422,8 @@ fs_visitor::nir_emit_cs_intrinsic(const fs_builder &bld,
       cs_prog_data->uses_barrier = true;
       break;
 
-   case nir_intrinsic_load_intel_thread_local_id:
-      bld.MOV(retype(dest, BRW_REGISTER_TYPE_UD), thread_local_id);
+   case nir_intrinsic_load_subgroup_id:
+      bld.MOV(retype(dest, BRW_REGISTER_TYPE_UD), subgroup_id);
       break;
 
    case nir_intrinsic_load_local_invocation_id:
index 3e407122681a69e8c8e0783528ec901d03370af4..0118cfadc1f34c287399bad01c8c336d82bd991e 100644 (file)
@@ -95,7 +95,8 @@ void brw_nir_analyze_boolean_resolves(nir_shader *nir);
 nir_shader *brw_preprocess_nir(const struct brw_compiler *compiler,
                                nir_shader *nir);
 
-bool brw_nir_lower_cs_intrinsics(nir_shader *nir);
+bool brw_nir_lower_cs_intrinsics(nir_shader *nir,
+                                 unsigned dispatch_width);
 void brw_nir_lower_vs_inputs(nir_shader *nir,
                              bool use_legacy_snorm_formula,
                              const uint8_t *vs_attrib_wa_flags);
index 07d2dccd0412e2a2184298055022fd068abf258c..66eef6be0a609d0544f5c31a8b59c6a240cd178b 100644 (file)
@@ -26,6 +26,7 @@
 
 struct lower_intrinsics_state {
    nir_shader *nir;
+   unsigned dispatch_width;
    nir_function_impl *impl;
    bool progress;
    nir_builder builder;
@@ -56,12 +57,14 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
           *    gl_LocalInvocationIndex =
           *       cs_thread_local_id + subgroup_invocation;
           */
-         nir_ssa_def *thread_local_id;
-         if (state->local_workgroup_size <= 8)
-            thread_local_id = nir_imm_int(b, 0);
+         nir_ssa_def *subgroup_id;
+         if (state->local_workgroup_size <= state->dispatch_width)
+            subgroup_id = nir_imm_int(b, 0);
          else
-            thread_local_id = nir_load_intel_thread_local_id(b);
+            subgroup_id = nir_load_subgroup_id(b);
 
+         nir_ssa_def *thread_local_id =
+            nir_imul(b, subgroup_id, nir_imm_int(b, state->dispatch_width));
          nir_ssa_def *channel = nir_load_subgroup_invocation(b);
          sysval = nir_iadd(b, channel, thread_local_id);
          break;
@@ -127,7 +130,8 @@ lower_cs_intrinsics_convert_impl(struct lower_intrinsics_state *state)
 }
 
 bool
-brw_nir_lower_cs_intrinsics(nir_shader *nir)
+brw_nir_lower_cs_intrinsics(nir_shader *nir,
+                            unsigned dispatch_width)
 {
    assert(nir->info.stage == MESA_SHADER_COMPUTE);
 
@@ -135,6 +139,7 @@ brw_nir_lower_cs_intrinsics(nir_shader *nir)
    struct lower_intrinsics_state state;
    memset(&state, 0, sizeof(state));
    state.nir = nir;
+   state.dispatch_width = dispatch_width;
    state.local_workgroup_size = nir->info.cs.local_size[0] *
                                 nir->info.cs.local_size[1] *
                                 nir->info.cs.local_size[2];
index b45f8f83757c64f6382e78143572493c6d0d92bd..69acafaae26960378d30eca8096dd4237941882e 100644 (file)
@@ -710,7 +710,7 @@ anv_cmd_buffer_cs_push_constants(struct anv_cmd_buffer *cmd_buffer)
       for (unsigned i = 0;
            i < cs_prog_data->push.cross_thread.dwords;
            i++) {
-         assert(prog_data->param[i] != BRW_PARAM_BUILTIN_THREAD_LOCAL_ID);
+         assert(prog_data->param[i] != BRW_PARAM_BUILTIN_SUBGROUP_ID);
          u32_map[i] = anv_push_constant_value(data, prog_data->param[i]);
       }
    }
@@ -722,8 +722,8 @@ anv_cmd_buffer_cs_push_constants(struct anv_cmd_buffer *cmd_buffer)
                  cs_prog_data->push.cross_thread.regs);
          unsigned src = cs_prog_data->push.cross_thread.dwords;
          for ( ; src < prog_data->nr_params; src++, dst++) {
-            if (prog_data->param[src] == BRW_PARAM_BUILTIN_THREAD_LOCAL_ID) {
-               u32_map[dst] = t * cs_prog_data->simd_size;
+            if (prog_data->param[src] == BRW_PARAM_BUILTIN_SUBGROUP_ID) {
+               u32_map[dst] = t;
             } else {
                u32_map[dst] =
                   anv_push_constant_value(data, prog_data->param[src]);
index acf7454cef580025795985937ea86deafac6ec8d..d89e7bde24b43c6d5aada7e8eeb8bfbaaa0dc713 100644 (file)
@@ -317,7 +317,7 @@ brw_upload_cs_push_constants(struct brw_context *brw,
       for (unsigned i = 0;
            i < cs_prog_data->push.cross_thread.dwords;
            i++) {
-         assert(prog_data->param[i] != BRW_PARAM_BUILTIN_THREAD_LOCAL_ID);
+         assert(prog_data->param[i] != BRW_PARAM_BUILTIN_SUBGROUP_ID);
          param_copy[i] = brw_param_value(brw, prog, stage_state,
                                          prog_data->param[i]);
       }
@@ -330,8 +330,8 @@ brw_upload_cs_push_constants(struct brw_context *brw,
                  cs_prog_data->push.cross_thread.regs);
          unsigned src = cs_prog_data->push.cross_thread.dwords;
          for ( ; src < prog_data->nr_params; src++, dst++) {
-            if (prog_data->param[src] == BRW_PARAM_BUILTIN_THREAD_LOCAL_ID) {
-               param[dst] = t * cs_prog_data->simd_size;
+            if (prog_data->param[src] == BRW_PARAM_BUILTIN_SUBGROUP_ID) {
+               param[dst] = t;
             } else {
                param[dst] = brw_param_value(brw, prog, stage_state,
                                             prog_data->param[src]);