nir: Add options to nir_lower_compute_system_values to control compute ID base lowering
[mesa.git] / src / compiler / nir / nir_lower_system_values.c
index bc80d184f721348e85dbaf3c01b02ec060011656..b99f655c2e0d75863185f0c7bc4e2c6b4403ddae 100644 (file)
@@ -227,6 +227,7 @@ lower_compute_system_value_instr(nir_builder *b,
                                  nir_instr *instr, void *_options)
 {
    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   const nir_lower_compute_system_values_options *options = _options;
 
    /* All the intrinsics we care about are loads */
    if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
@@ -276,7 +277,7 @@ lower_compute_system_value_instr(nir_builder *b,
                                         nir_channel(b, local_size, 1)));
          return nir_u2u(b, nir_vec3(b, id_x, id_y, id_z), bit_size);
       } else {
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       }
 
    case nir_intrinsic_load_local_invocation_index:
@@ -310,7 +311,7 @@ lower_compute_system_value_instr(nir_builder *b,
          index = nir_iadd(b, index, nir_channel(b, local_id, 0));
          return nir_u2u(b, index, bit_size);
       } else {
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       }
 
    case nir_intrinsic_load_local_group_size:
@@ -319,7 +320,7 @@ lower_compute_system_value_instr(nir_builder *b,
           * this point.  We do, however, have to make sure that the intrinsic
           * is only 32-bit.
           */
-         return sanitize_32bit_sysval(b, intrin);
+         return NULL;
       } else {
          /* using a 32 bit constant is safe here as no device/driver needs more
           * than 32 bits for the local size */
@@ -331,8 +332,9 @@ lower_compute_system_value_instr(nir_builder *b,
          return nir_u2u(b, nir_build_imm(b, 3, 32, local_size_const), bit_size);
       }
 
-   case nir_intrinsic_load_global_invocation_id: {
-      if (!b->shader->options->has_cs_global_id) {
+   case nir_intrinsic_load_global_invocation_id_zero_base: {
+      if ((options && options->has_base_work_group_id) ||
+          !b->shader->options->has_cs_global_id) {
          nir_ssa_def *group_size = nir_load_local_group_size(b);
          nir_ssa_def *group_id = nir_load_work_group_id(b, bit_size);
          nir_ssa_def *local_id = nir_load_local_invocation_id(b);
@@ -345,8 +347,21 @@ lower_compute_system_value_instr(nir_builder *b,
       }
    }
 
+   case nir_intrinsic_load_global_invocation_id: {
+      if (options && options->has_base_global_invocation_id)
+         return nir_iadd(b, nir_load_global_invocation_id_zero_base(b, bit_size),
+                            nir_load_base_global_invocation_id(b, bit_size));
+      else if (!b->shader->options->has_cs_global_id)
+         return nir_load_global_invocation_id_zero_base(b, bit_size);
+      else
+         return NULL;
+   }
+
    case nir_intrinsic_load_global_invocation_index: {
-      nir_ssa_def *global_id = nir_load_global_invocation_id(b, bit_size);
+      /* OpenCL's global_linear_id explicitly removes the global offset before computing this */
+      assert(b->shader->info.stage == MESA_SHADER_KERNEL);
+      nir_ssa_def *global_base_id = nir_load_base_global_invocation_id(b, bit_size);
+      nir_ssa_def *global_id = nir_isub(b, nir_load_global_invocation_id(b, bit_size), global_base_id);
       nir_ssa_def *global_size = build_global_group_size(b, bit_size);
 
       /* index = id.x + ((id.y + (id.z * size.y)) * size.x) */
@@ -359,13 +374,22 @@ lower_compute_system_value_instr(nir_builder *b,
       return index;
    }
 
+   case nir_intrinsic_load_work_group_id: {
+      if (options && options->has_base_work_group_id)
+         return nir_iadd(b, nir_u2u(b, nir_load_work_group_id_zero_base(b), bit_size),
+                            nir_load_base_work_group_id(b, bit_size));
+      else
+         return NULL;
+   }
+
    default:
       return NULL;
    }
 }
 
 bool
-nir_lower_compute_system_values(nir_shader *shader)
+nir_lower_compute_system_values(nir_shader *shader,
+                                const nir_lower_compute_system_values_options *options)
 {
    if (shader->info.stage != MESA_SHADER_COMPUTE &&
        shader->info.stage != MESA_SHADER_KERNEL)
@@ -374,5 +398,5 @@ nir_lower_compute_system_values(nir_shader *shader)
    return nir_shader_lower_instructions(shader,
                                         lower_compute_system_value_filter,
                                         lower_compute_system_value_instr,
-                                        NULL);
+                                        (void*)options);
 }