intel/fs: Add and use a new load_simd_width_intel intrinsic
[mesa.git] / src / intel / compiler / brw_nir_lower_cs_intrinsics.c
index 434ad005281172e3268359c12e9de411fc33d172..883fc4699247fd74f9be6ab5c7b2d665c56c12f2 100644 (file)
@@ -26,7 +26,6 @@
 
 struct lower_intrinsics_state {
    nir_shader *nir;
-   unsigned dispatch_width;
    nir_function_impl *impl;
    bool progress;
    nir_builder builder;
@@ -61,19 +60,23 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
          if (!local_index) {
             assert(!local_id);
 
-            nir_ssa_def *subgroup_id;
-            if (state->local_workgroup_size <= state->dispatch_width)
-               subgroup_id = nir_imm_int(b, 0);
-            else
-               subgroup_id = nir_load_subgroup_id(b);
+            nir_ssa_def *subgroup_id = nir_load_subgroup_id(b);
 
             nir_ssa_def *thread_local_id =
-               nir_imul_imm(b, subgroup_id, state->dispatch_width);
+               nir_imul(b, subgroup_id, nir_load_simd_width_intel(b));
             nir_ssa_def *channel = nir_load_subgroup_invocation(b);
             nir_ssa_def *linear = nir_iadd(b, channel, thread_local_id);
 
-            nir_ssa_def *size_x = nir_imm_int(b, nir->info.cs.local_size[0]);
-            nir_ssa_def *size_y = nir_imm_int(b, nir->info.cs.local_size[1]);
+            nir_ssa_def *size_x;
+            nir_ssa_def *size_y;
+            if (state->nir->info.cs.local_size_variable) {
+               nir_ssa_def *size_xyz = nir_load_local_group_size(b);
+               size_x = nir_channel(b, size_xyz, 0);
+               size_y = nir_channel(b, size_xyz, 1);
+            } else {
+               size_x = nir_imm_int(b, nir->info.cs.local_size[0]);
+               size_y = nir_imm_int(b, nir->info.cs.local_size[1]);
+            }
 
             /* The local invocation index and ID must respect the following
              *
@@ -143,21 +146,25 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
          break;
       }
 
-      case nir_intrinsic_load_subgroup_id:
-         if (state->local_workgroup_size > 8)
-            continue;
-
-         /* For small workgroup sizes, we know subgroup_id will be zero */
-         sysval = nir_imm_int(b, 0);
-         break;
-
       case nir_intrinsic_load_num_subgroups: {
-         unsigned local_workgroup_size =
-            nir->info.cs.local_size[0] * nir->info.cs.local_size[1] *
-            nir->info.cs.local_size[2];
-         unsigned num_subgroups =
-            DIV_ROUND_UP(local_workgroup_size, state->dispatch_width);
-         sysval = nir_imm_int(b, num_subgroups);
+         nir_ssa_def *size;
+         if (state->nir->info.cs.local_size_variable) {
+            nir_ssa_def *size_xyz = nir_load_local_group_size(b);
+            nir_ssa_def *size_x = nir_channel(b, size_xyz, 0);
+            nir_ssa_def *size_y = nir_channel(b, size_xyz, 1);
+            nir_ssa_def *size_z = nir_channel(b, size_xyz, 2);
+            size = nir_imul(b, nir_imul(b, size_x, size_y), size_z);
+         } else {
+            size = nir_imm_int(b, nir->info.cs.local_size[0] *
+                                  nir->info.cs.local_size[1] *
+                                  nir->info.cs.local_size[2]);
+         }
+
+         /* Calculate the equivalent of DIV_ROUND_UP. */
+         nir_ssa_def *simd_width = nir_load_simd_width_intel(b);
+         sysval =
+            nir_udiv(b, nir_iadd_imm(b, nir_iadd(b, size, simd_width), -1),
+                        simd_width);
          break;
       }
 
@@ -188,26 +195,29 @@ lower_cs_intrinsics_convert_impl(struct lower_intrinsics_state *state)
 }
 
 bool
-brw_nir_lower_cs_intrinsics(nir_shader *nir,
-                            unsigned dispatch_width)
+brw_nir_lower_cs_intrinsics(nir_shader *nir)
 {
    assert(nir->info.stage == MESA_SHADER_COMPUTE);
 
    struct lower_intrinsics_state state = {
       .nir = nir,
-      .dispatch_width = dispatch_width,
    };
 
-   assert(!nir->info.cs.local_size_variable);
-   state.local_workgroup_size = nir->info.cs.local_size[0] *
-                                nir->info.cs.local_size[1] *
-                                nir->info.cs.local_size[2];
+   if (!nir->info.cs.local_size_variable) {
+      state.local_workgroup_size = nir->info.cs.local_size[0] *
+                                   nir->info.cs.local_size[1] *
+                                   nir->info.cs.local_size[2];
+   } else {
+      state.local_workgroup_size = nir->info.cs.max_variable_local_size;
+   }
 
    /* Constraints from NV_compute_shader_derivatives. */
-   if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_QUADS) {
+   if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_QUADS &&
+       !nir->info.cs.local_size_variable) {
       assert(nir->info.cs.local_size[0] % 2 == 0);
       assert(nir->info.cs.local_size[1] % 2 == 0);
-   } else if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_LINEAR) {
+   } else if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_LINEAR &&
+              !nir->info.cs.local_size_variable) {
       assert(state.local_workgroup_size % 4 == 0);
    }