nir: Add options to nir_lower_compute_system_values to control compute ID base lowering
authorJesse Natalie <jenatali@microsoft.com>
Fri, 21 Aug 2020 17:40:45 +0000 (10:40 -0700)
committerMarge Bot <eric+marge@anholt.net>
Fri, 21 Aug 2020 22:07:05 +0000 (22:07 +0000)
If no options are provided, existing intrinsics are used.
If the lowering pass indicates there should be offsets used for global
invocation ID or work group ID, then those instructions are lowered to
include the offset.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5891>

13 files changed:
src/amd/vulkan/radv_shader.c
src/broadcom/compiler/vir.c
src/compiler/nir/nir.h
src/compiler/nir/nir_lower_system_values.c
src/freedreno/vulkan/tu_shader.c
src/gallium/auxiliary/nir/tgsi_to_nir.c
src/gallium/drivers/freedreno/ir3/ir3_cmdline.c
src/gallium/frontends/clover/nir/invocation.cpp
src/gallium/frontends/vallium/val_pipeline.c
src/intel/compiler/brw_nir.c
src/mesa/state_tracker/st_glsl_to_nir.cpp
src/mesa/state_tracker/st_nir_builtins.c
src/mesa/state_tracker/st_program.c

index 01a22b7a7579319b5673a38617f9d750b25e87b9..1d227efe4afeea6fe66c0ca05a12b5e325a33df0 100644 (file)
@@ -540,7 +540,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
                NIR_PASS_V(nir, nir_propagate_invariant);
 
                NIR_PASS_V(nir, nir_lower_system_values);
-               NIR_PASS_V(nir, nir_lower_compute_system_values);
+               NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
                NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
 
index 83eb4d14b384dd131383ff52775a543b48145cb1..a0b08918f3216bdb2aaf250b1765d97767c53f5d 100644 (file)
@@ -586,7 +586,7 @@ v3d_lower_nir(struct v3d_compile *c)
 
         NIR_PASS_V(c->s, nir_lower_tex, &tex_options);
         NIR_PASS_V(c->s, nir_lower_system_values);
-        NIR_PASS_V(c->s, nir_lower_compute_system_values);
+        NIR_PASS_V(c->s, nir_lower_compute_system_values, NULL);
 
         NIR_PASS_V(c->s, nir_lower_vars_to_scratch,
                    nir_var_function_temp,
index 7432afd8d94051260a8d2eb78080bb3264ae01c9..005f7625a602047cc8c269f1a9855a6c42e3f189 100644 (file)
@@ -4276,7 +4276,13 @@ bool nir_lower_subgroups(nir_shader *shader,
 
 bool nir_lower_system_values(nir_shader *shader);
 
-bool nir_lower_compute_system_values(nir_shader *shader);
+typedef struct nir_lower_compute_system_values_options {
+   bool has_base_global_invocation_id:1;
+   bool has_base_work_group_id:1;
+} nir_lower_compute_system_values_options;
+
+bool nir_lower_compute_system_values(nir_shader *shader,
+                                     const nir_lower_compute_system_values_options *options);
 
 enum PACKED nir_lower_tex_packing {
    nir_lower_tex_packing_none = 0,
index bc80d184f721348e85dbaf3c01b02ec060011656..b99f655c2e0d75863185f0c7bc4e2c6b4403ddae 100644 (file)
@@ -227,6 +227,7 @@ lower_compute_system_value_instr(nir_builder *b,
                                  nir_instr *instr, void *_options)
 {
    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   const nir_lower_compute_system_values_options *options = _options;
 
    /* All the intrinsics we care about are loads */
    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
@@ -276,7 +277,7 @@ lower_compute_system_value_instr(nir_builder *b,
                                         nir_channel(b, local_size, 1)));
          return nir_u2u(b, nir_vec3(b, id_x, id_y, id_z), bit_size);
       } else {
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       }
 
    case nir_intrinsic_load_local_invocation_index:
@@ -310,7 +311,7 @@ lower_compute_system_value_instr(nir_builder *b,
          index = nir_iadd(b, index, nir_channel(b, local_id, 0));
          return nir_u2u(b, index, bit_size);
       } else {
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       }
 
    case nir_intrinsic_load_local_group_size:
@@ -319,7 +320,7 @@ lower_compute_system_value_instr(nir_builder *b,
           * this point.  We do, however, have to make sure that the intrinsic
           * is only 32-bit.
           */
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       } else {
          /* using a 32 bit constant is safe here as no device/driver needs more
           * than 32 bits for the local size */
@@ -331,8 +332,9 @@ lower_compute_system_value_instr(nir_builder *b,
          return nir_u2u(b, nir_build_imm(b, 3, 32, local_size_const), bit_size);
       }
 
-   case nir_intrinsic_load_global_invocation_id: {
-      if (!b->shader->options->has_cs_global_id) {
+   case nir_intrinsic_load_global_invocation_id_zero_base: {
+      if ((options && options->has_base_work_group_id) ||
+          !b->shader->options->has_cs_global_id) {
          nir_ssa_def *group_size = nir_load_local_group_size(b);
          nir_ssa_def *group_id = nir_load_work_group_id(b, bit_size);
          nir_ssa_def *local_id = nir_load_local_invocation_id(b);
@@ -345,8 +347,21 @@ lower_compute_system_value_instr(nir_builder *b,
       }
    }
 
+   case nir_intrinsic_load_global_invocation_id: {
+      if (options && options->has_base_global_invocation_id)
+         return nir_iadd(b, nir_load_global_invocation_id_zero_base(b, bit_size),
+                            nir_load_base_global_invocation_id(b, bit_size));
+      else if (!b->shader->options->has_cs_global_id)
+         return nir_load_global_invocation_id_zero_base(b, bit_size);
+      else
+         return NULL;
+   }
+
    case nir_intrinsic_load_global_invocation_index: {
-      nir_ssa_def *global_id = nir_load_global_invocation_id(b, bit_size);
+      /* OpenCL's global_linear_id explicitly removes the global offset before computing this */
+      assert(b->shader->info.stage == MESA_SHADER_KERNEL);
+      nir_ssa_def *global_base_id = nir_load_base_global_invocation_id(b, bit_size);
+      nir_ssa_def *global_id = nir_isub(b, nir_load_global_invocation_id(b, bit_size), global_base_id);
       nir_ssa_def *global_size = build_global_group_size(b, bit_size);
 
       /* index = id.x + ((id.y + (id.z * size.y)) * size.x) */
@@ -359,13 +374,22 @@ lower_compute_system_value_instr(nir_builder *b,
       return index;
    }
 
+   case nir_intrinsic_load_work_group_id: {
+      if (options && options->has_base_work_group_id)
+         return nir_iadd(b, nir_u2u(b, nir_load_work_group_id_zero_base(b), bit_size),
+                            nir_load_base_work_group_id(b, bit_size));
+      else
+         return NULL;
+   }
+
    default:
       return NULL;
    }
 }
 
 bool
-nir_lower_compute_system_values(nir_shader *shader)
+nir_lower_compute_system_values(nir_shader *shader,
+                                const nir_lower_compute_system_values_options *options)
 {
    if (shader->info.stage != MESA_SHADER_COMPUTE &&
        shader->info.stage != MESA_SHADER_KERNEL)
@@ -374,5 +398,5 @@ nir_lower_compute_system_values(nir_shader *shader)
    return nir_shader_lower_instructions(shader,
                                         lower_compute_system_value_filter,
                                         lower_compute_system_value_instr,
-                                        NULL);
+                                        (void*)options);
 }
index aca5ea02f71dfb1e9963d8425c7b73470a114fad..afdb7c7de748ea9b6b842ab7deca04d7f0e7adf6 100644 (file)
@@ -765,7 +765,7 @@ tu_shader_create(struct tu_device *dev,
    nir_assign_io_var_locations(nir, nir_var_shader_out, &nir->num_outputs, stage);
 
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    NIR_PASS_V(nir, nir_lower_frexp);
 
index 7b1b055af370799fd3c2b691d3d0b918b48e9c2b..1070f22537a821aeccaabcc92ebde569efadf4b4 100644 (file)
@@ -2559,7 +2559,7 @@ ttn_finalize_nir(struct ttn_compile *c, struct pipe_screen *screen)
    NIR_PASS_V(nir, nir_split_var_copies);
    NIR_PASS_V(nir, nir_lower_var_copies);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    if (c->cap_packed_uniforms)
       NIR_PASS_V(nir, nir_lower_uniforms_to_ubo, 16);
index 54227e4886d792f85ef92d25b5c3ad99ed924d60..e80312d25716c565452a53b73086c20b86b0fb87 100644 (file)
@@ -185,7 +185,7 @@ load_glsl(unsigned num_files, char* const* files, gl_shader_stage stage)
                        ir3_glsl_type_size);
 
        NIR_PASS_V(nir, nir_lower_system_values);
-       NIR_PASS_V(nir, nir_lower_compute_system_values);
+       NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
        NIR_PASS_V(nir, nir_lower_frexp);
        NIR_PASS_V(nir, nir_lower_io,
@@ -403,7 +403,7 @@ main(int argc, char **argv)
                /* TODO do this somewhere else */
                nir_lower_int64(nir);
                nir_lower_system_values(nir);
-               nir_lower_compute_system_values(nir);
+               nir_lower_compute_system_values(nir, NULL);
        } else if (num_files > 0) {
                nir = load_glsl(num_files, filenames, stage);
        } else {
index f57472052222bed771774668b5833b31ac7ebf1f..3656a3c1ef9d43d05f091f3e8479edf03eff3a9c 100644 (file)
@@ -169,7 +169,7 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
                  spirv_options.global_addr_format);
 
       NIR_PASS_V(nir, nir_lower_system_values);
-      NIR_PASS_V(nir, nir_lower_compute_system_values);
+      NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
       if (compiler_options->lower_int64_options)
          NIR_PASS_V(nir, nir_lower_int64);
index 779e2ecdc5f2f6cfea93a5daea135bc249600e26..fb0a88a051d49f38c70bb4a91a84407510acb31f 100644 (file)
@@ -562,7 +562,7 @@ val_shader_compile_to_ir(struct val_pipeline *pipeline,
    if (stage == MESA_SHADER_FRAGMENT)
       val_lower_input_attachments(nir, false);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
    nir_remove_dead_variables(nir, nir_var_uniform, NULL);
index 024cea7df5bfd0f1c6da35304f46e00db46ddcf7..4f1b56289c7cd2b73b44f02508e2ae5016b65931 100644 (file)
@@ -709,7 +709,7 @@ brw_preprocess_nir(const struct brw_compiler *compiler, nir_shader *nir,
    }
 
    OPT(nir_lower_system_values);
-   OPT(nir_lower_compute_system_values);
+   OPT(nir_lower_compute_system_values, NULL);
 
    const nir_lower_subgroups_options subgroups_options = {
       .ballot_bit_size = 32,
index b5b85ae46c05c9eb3325f3522a8f3bfcd7300960..b4d78c86d6ca84f5b5395d0530158293b4591cc9 100644 (file)
@@ -771,7 +771,7 @@ st_link_nir(struct gl_context *ctx,
                  st->pipe->screen);
 
       NIR_PASS_V(nir, nir_lower_system_values);
-      NIR_PASS_V(nir, nir_lower_compute_system_values);
+      NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
       NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
 
index cd078f3a5617d56055cd1416d8ce0b6906a2cb8a..1e295264e0bad4e904b0366d4eae1ee2303c05c8 100644 (file)
@@ -43,7 +43,7 @@ st_nir_finish_builtin_shader(struct st_context *st,
    NIR_PASS_V(nir, nir_split_var_copies);
    NIR_PASS_V(nir, nir_lower_var_copies);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    if (nir->options->lower_to_scalar) {
       nir_variable_mode mask =
index 7ba6344051c0edb3819baa962effc6dd2b5045ef..bfc6d90922a4fec629ec70e20e8238fd4fd36ea1 100644 (file)
@@ -388,7 +388,7 @@ st_translate_prog_to_nir(struct st_context *st, struct gl_program *prog,
 
    NIR_PASS_V(nir, st_nir_lower_wpos_ytransform, prog, screen);
    NIR_PASS_V(nir, nir_lower_system_values);
-   NIR_PASS_V(nir, nir_lower_compute_system_values);
+   NIR_PASS_V(nir, nir_lower_compute_system_values, NULL);
 
    /* Optimise NIR */
    NIR_PASS_V(nir, nir_opt_constant_folding);