nir: Embed the shader_info in the nir_shader again
[mesa.git] / src / compiler / nir / nir_lower_system_values.c
index 79f6bedc990bc3e5c3c804c0b17a378eb94a9e3e..810100a08160303b8a39da5c3787bf411767331d 100644 (file)
 #include "nir.h"
 #include "nir_builder.h"
 
-struct lower_system_values_state {
-   nir_builder builder;
-   bool progress;
-};
-
 static bool
-convert_block(nir_block *block, void *void_state)
+convert_block(nir_block *block, nir_builder *b)
 {
-   struct lower_system_values_state *state = void_state;
-
-   nir_builder *b = &state->builder;
+   bool progress = false;
 
-   nir_foreach_instr_safe(block, instr) {
+   nir_foreach_instr_safe(instr, block) {
       if (instr->type != nir_instr_type_intrinsic)
          continue;
 
@@ -55,7 +48,7 @@ convert_block(nir_block *block, void *void_state)
 
       b->cursor = nir_after_instr(&load_var->instr);
 
-      nir_ssa_def *sysval;
+      nir_ssa_def *sysval = NULL;
       switch (var->data.location) {
       case SYSTEM_VALUE_GLOBAL_INVOCATION_ID: {
          /* From the GLSL man page for gl_GlobalInvocationID:
@@ -65,89 +58,97 @@ convert_block(nir_block *block, void *void_state)
           */
 
          nir_const_value local_size;
-         local_size.u[0] = b->shader->info.cs.local_size[0];
-         local_size.u[1] = b->shader->info.cs.local_size[1];
-         local_size.u[2] = b->shader->info.cs.local_size[2];
+         local_size.u32[0] = b->shader->info.cs.local_size[0];
+         local_size.u32[1] = b->shader->info.cs.local_size[1];
+         local_size.u32[2] = b->shader->info.cs.local_size[2];
 
-         nir_ssa_def *group_id =
-            nir_load_system_value(b, nir_intrinsic_load_work_group_id, 0);
-         nir_ssa_def *local_id =
-            nir_load_system_value(b, nir_intrinsic_load_local_invocation_id, 0);
+         nir_ssa_def *group_id = nir_load_work_group_id(b);
+         nir_ssa_def *local_id = nir_load_local_invocation_id(b);
 
          sysval = nir_iadd(b, nir_imul(b, group_id,
-                                          nir_build_imm(b, 3, local_size)),
+                                       nir_build_imm(b, 3, 32, local_size)),
                               local_id);
          break;
       }
 
       case SYSTEM_VALUE_LOCAL_INVOCATION_INDEX: {
+         /* If lower_cs_local_index_from_id is true, then we derive the local
+          * index from the local id.
+          */
+         if (!b->shader->options->lower_cs_local_index_from_id)
+            break;
+
          /* From the GLSL man page for gl_LocalInvocationIndex:
           *
-          *    ?The value of gl_LocalInvocationIndex is equal to
+          *    "The value of gl_LocalInvocationIndex is equal to
           *    gl_LocalInvocationID.z * gl_WorkGroupSize.x *
           *    gl_WorkGroupSize.y + gl_LocalInvocationID.y *
           *    gl_WorkGroupSize.x + gl_LocalInvocationID.x"
           */
-         nir_ssa_def *local_id =
-            nir_load_system_value(b, nir_intrinsic_load_local_invocation_id, 0);
-
-         unsigned stride_y = b->shader->info.cs.local_size[0];
-         unsigned stride_z = b->shader->info.cs.local_size[0] *
-                             b->shader->info.cs.local_size[1];
-
-         sysval = nir_iadd(b, nir_imul(b, nir_channel(b, local_id, 2),
-                                          nir_imm_int(b, stride_z)),
-                              nir_iadd(b, nir_imul(b, nir_channel(b, local_id, 1),
-                                                      nir_imm_int(b, stride_y)),
-                                          nir_channel(b, local_id, 0)));
+         nir_ssa_def *local_id = nir_load_local_invocation_id(b);
+
+         nir_ssa_def *size_x =
+            nir_imm_int(b, b->shader->info.cs.local_size[0]);
+         nir_ssa_def *size_y =
+            nir_imm_int(b, b->shader->info.cs.local_size[1]);
+
+         sysval = nir_imul(b, nir_channel(b, local_id, 2),
+                              nir_imul(b, size_x, size_y));
+         sysval = nir_iadd(b, sysval,
+                              nir_imul(b, nir_channel(b, local_id, 1), size_x));
+         sysval = nir_iadd(b, sysval, nir_channel(b, local_id, 0));
          break;
       }
 
       case SYSTEM_VALUE_VERTEX_ID:
          if (b->shader->options->vertex_id_zero_based) {
             sysval = nir_iadd(b,
-               nir_load_system_value(b, nir_intrinsic_load_vertex_id_zero_base, 0),
-               nir_load_system_value(b, nir_intrinsic_load_base_vertex, 0));
+                              nir_load_vertex_id_zero_base(b),
+                              nir_load_base_vertex(b));
          } else {
-            sysval = nir_load_system_value(b, nir_intrinsic_load_vertex_id, 0);
+            sysval = nir_load_vertex_id(b);
          }
          break;
 
       case SYSTEM_VALUE_INSTANCE_INDEX:
          sysval = nir_iadd(b,
-            nir_load_system_value(b, nir_intrinsic_load_instance_id, 0),
-            nir_load_system_value(b, nir_intrinsic_load_base_instance, 0));
+                           nir_load_instance_id(b),
+                           nir_load_base_instance(b));
          break;
 
-      default: {
+      default:
+         break;
+      }
+
+      if (sysval == NULL) {
          nir_intrinsic_op sysval_op =
             nir_intrinsic_from_system_value(var->data.location);
          sysval = nir_load_system_value(b, sysval_op, 0);
-         break;
-      } /* default */
       }
 
       nir_ssa_def_rewrite_uses(&load_var->dest.ssa, nir_src_for_ssa(sysval));
       nir_instr_remove(&load_var->instr);
 
-      state->progress = true;
+      progress = true;
    }
 
-   return true;
+   return progress;
 }
 
 static bool
 convert_impl(nir_function_impl *impl)
 {
-   struct lower_system_values_state state;
+   bool progress = false;
+   nir_builder builder;
+   nir_builder_init(&builder, impl);
 
-   state.progress = false;
-   nir_builder_init(&state.builder, impl);
+   nir_foreach_block(block, impl) {
+      progress |= convert_block(block, &builder);
+   }
 
-   nir_foreach_block(impl, convert_block, &state);
    nir_metadata_preserve(impl, nir_metadata_block_index |
                                nir_metadata_dominance);
-   return state.progress;
+   return progress;
 }
 
 bool
@@ -155,7 +156,7 @@ nir_lower_system_values(nir_shader *shader)
 {
    bool progress = false;
 
-   nir_foreach_function(shader, function) {
+   nir_foreach_function(function, shader) {
       if (function->impl)
          progress = convert_impl(function->impl) || progress;
    }