Revert "Revert "i965/fs: Use align1 mode on ternary instructions on Gen10+""
[mesa.git] / src / intel / compiler / brw_nir_lower_cs_intrinsics.c
index 01718eb5dd1b0f02366a41671220da844a30cea8..66eef6be0a609d0544f5c31a8b59c6a240cd178b 100644 (file)
 
 struct lower_intrinsics_state {
    nir_shader *nir;
-   struct brw_cs_prog_data *prog_data;
+   unsigned dispatch_width;
    nir_function_impl *impl;
    bool progress;
    nir_builder builder;
+   unsigned local_workgroup_size;
 };
 
-static nir_ssa_def *
-read_thread_local_id(struct lower_intrinsics_state *state)
-{
-   struct brw_cs_prog_data *prog_data = state->prog_data;
-   nir_builder *b = &state->builder;
-   nir_shader *nir = state->nir;
-   const unsigned *sizes = nir->info.cs.local_size;
-   const unsigned group_size = sizes[0] * sizes[1] * sizes[2];
-
-   /* Some programs have local_size dimensions so small that the thread local
-    * ID will always be 0.
-    */
-   if (group_size <= 8)
-      return nir_imm_int(b, 0);
-
-   if (prog_data->thread_local_id_index == -1) {
-      prog_data->thread_local_id_index = prog_data->base.nr_params;
-      brw_stage_prog_data_add_params(&prog_data->base, 1);
-      nir->num_uniforms += 4;
-   }
-   unsigned id_index = prog_data->thread_local_id_index;
-
-   nir_intrinsic_instr *load =
-      nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
-   load->num_components = 1;
-   load->src[0] = nir_src_for_ssa(nir_imm_int(b, 0));
-   nir_ssa_dest_init(&load->instr, &load->dest, 1, 32, NULL);
-   nir_intrinsic_set_base(load, id_index * sizeof(uint32_t));
-   nir_intrinsic_set_range(load, sizeof(uint32_t));
-   nir_builder_instr_insert(b, &load->instr);
-   return &load->dest.ssa;
-}
-
 static bool
 lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
                                   nir_block *block)
@@ -89,7 +57,14 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
           *    gl_LocalInvocationIndex =
           *       cs_thread_local_id + subgroup_invocation;
           */
-         nir_ssa_def *thread_local_id = read_thread_local_id(state);
+         nir_ssa_def *subgroup_id;
+         if (state->local_workgroup_size <= state->dispatch_width)
+            subgroup_id = nir_imm_int(b, 0);
+         else
+            subgroup_id = nir_load_subgroup_id(b);
+
+         nir_ssa_def *thread_local_id =
+            nir_imul(b, subgroup_id, nir_imm_int(b, state->dispatch_width));
          nir_ssa_def *channel = nir_load_subgroup_invocation(b);
          sysval = nir_iadd(b, channel, thread_local_id);
          break;
@@ -114,6 +89,7 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
          nir_ssa_def *local_index = nir_load_local_invocation_index(b);
 
          nir_const_value uvec3;
+         memset(&uvec3, 0, sizeof(uvec3));
          uvec3.u32[0] = 1;
          uvec3.u32[1] = size[0];
          uvec3.u32[2] = size[0] * size[1];
@@ -155,17 +131,18 @@ lower_cs_intrinsics_convert_impl(struct lower_intrinsics_state *state)
 
 bool
 brw_nir_lower_cs_intrinsics(nir_shader *nir,
-                            struct brw_cs_prog_data *prog_data)
+                            unsigned dispatch_width)
 {
-   assert(nir->stage == MESA_SHADER_COMPUTE);
+   assert(nir->info.stage == MESA_SHADER_COMPUTE);
 
    bool progress = false;
    struct lower_intrinsics_state state;
    memset(&state, 0, sizeof(state));
    state.nir = nir;
-   state.prog_data = prog_data;
-
-   state.prog_data->thread_local_id_index = -1;
+   state.dispatch_width = dispatch_width;
+   state.local_workgroup_size = nir->info.cs.local_size[0] *
+                                nir->info.cs.local_size[1] *
+                                nir->info.cs.local_size[2];
 
    do {
       state.progress = false;