nir/vtn: add support for SpvBuiltInGlobalLinearId
[mesa.git] / src / compiler / nir / nir_lower_system_values.c
index de5ccab0f38823aaff3c0fa1730a3bdf4c93076a..4436799fa527893c5b963b59b4c226517ef73b24 100644 (file)
@@ -99,6 +99,29 @@ build_local_invocation_id(nir_builder *b, unsigned bit_size)
    }
 }
 
+static nir_ssa_def*
+build_global_group_size(nir_builder *b, unsigned bit_size)
+{
+   nir_ssa_def *group_size = build_local_group_size(b, bit_size);
+   nir_ssa_def *num_work_groups = nir_u2u(b, nir_load_num_work_groups(b), bit_size);
+   return nir_imul(b, group_size, num_work_groups);
+}
+
+static nir_ssa_def*
+build_global_invocation_id(nir_builder *b, unsigned bit_size)
+{
+   /* From the GLSL man page for gl_GlobalInvocationID:
+    *
+    *    "The value of gl_GlobalInvocationID is equal to
+    *    gl_WorkGroupID * gl_WorkGroupSize + gl_LocalInvocationID"
+    */
+   nir_ssa_def *group_size = build_local_group_size(b, bit_size);
+   nir_ssa_def *group_id = nir_u2u(b, nir_load_work_group_id(b), bit_size);
+   nir_ssa_def *local_id = build_local_invocation_id(b, bit_size);
+
+   return nir_iadd(b, nir_imul(b, group_id, group_size), local_id);
+}
+
 static bool
 convert_block(nir_block *block, nir_builder *b)
 {
@@ -133,16 +156,20 @@ convert_block(nir_block *block, nir_builder *b)
       nir_ssa_def *sysval = NULL;
       switch (var->data.location) {
       case SYSTEM_VALUE_GLOBAL_INVOCATION_ID: {
-         /* From the GLSL man page for gl_GlobalInvocationID:
-          *
-          *    "The value of gl_GlobalInvocationID is equal to
-          *    gl_WorkGroupID * gl_WorkGroupSize + gl_LocalInvocationID"
-          */
-         nir_ssa_def *group_size = build_local_group_size(b, bit_size);
-         nir_ssa_def *group_id = nir_u2u(b, nir_load_work_group_id(b), bit_size);
-         nir_ssa_def *local_id = build_local_invocation_id(b, bit_size);
+         sysval = build_global_invocation_id(b, bit_size);
+         break;
+      }
+
+      case SYSTEM_VALUE_GLOBAL_INVOCATION_INDEX: {
+         nir_ssa_def *global_id = build_global_invocation_id(b, bit_size);
+         nir_ssa_def *global_size = build_global_group_size(b, bit_size);
 
-         sysval = nir_iadd(b, nir_imul(b, group_id, group_size), local_id);
+         /* index = id.x + ((id.y + (id.z * size.y)) * size.x) */
+         sysval = nir_imul(b, nir_channel(b, global_id, 2),
+                              nir_channel(b, global_size, 1));
+         sysval = nir_iadd(b, nir_channel(b, global_id, 1), sysval);
+         sysval = nir_imul(b, nir_channel(b, global_size, 0), sysval);
+         sysval = nir_iadd(b, nir_channel(b, global_id, 0), sysval);
          break;
       }
 
@@ -259,9 +286,7 @@ convert_block(nir_block *block, nir_builder *b)
          break;
 
       case SYSTEM_VALUE_GLOBAL_GROUP_SIZE: {
-         nir_ssa_def *group_size = build_local_group_size(b, bit_size);
-         nir_ssa_def *num_work_groups = nir_u2u(b, nir_load_num_work_groups(b), bit_size);
-         sysval = nir_imul(b, group_size, num_work_groups);
+         sysval = build_global_group_size(b, bit_size);
          break;
       }