intel/compiler: Add support for variable workgroup size
[mesa.git] / src / intel / compiler / brw_nir_lower_cs_intrinsics.c
index 434ad005281172e3268359c12e9de411fc33d172..2393011312c92bbb36068ea95a419dbfd045be7d 100644 (file)
@@ -72,8 +72,16 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
             nir_ssa_def *channel = nir_load_subgroup_invocation(b);
             nir_ssa_def *linear = nir_iadd(b, channel, thread_local_id);
 
-            nir_ssa_def *size_x = nir_imm_int(b, nir->info.cs.local_size[0]);
-            nir_ssa_def *size_y = nir_imm_int(b, nir->info.cs.local_size[1]);
+            nir_ssa_def *size_x;
+            nir_ssa_def *size_y;
+            if (state->nir->info.cs.local_size_variable) {
+               nir_ssa_def *size_xyz = nir_load_local_group_size(b);
+               size_x = nir_channel(b, size_xyz, 0);
+               size_y = nir_channel(b, size_xyz, 1);
+            } else {
+               size_x = nir_imm_int(b, nir->info.cs.local_size[0]);
+               size_y = nir_imm_int(b, nir->info.cs.local_size[1]);
+            }
 
             /* The local invocation index and ID must respect the following
              *
@@ -152,12 +160,26 @@ lower_cs_intrinsics_convert_block(struct lower_intrinsics_state *state,
          break;
 
       case nir_intrinsic_load_num_subgroups: {
-         unsigned local_workgroup_size =
-            nir->info.cs.local_size[0] * nir->info.cs.local_size[1] *
-            nir->info.cs.local_size[2];
-         unsigned num_subgroups =
-            DIV_ROUND_UP(local_workgroup_size, state->dispatch_width);
-         sysval = nir_imm_int(b, num_subgroups);
+         if (state->nir->info.cs.local_size_variable) {
+            nir_ssa_def *size_xyz = nir_load_local_group_size(b);
+            nir_ssa_def *size_x = nir_channel(b, size_xyz, 0);
+            nir_ssa_def *size_y = nir_channel(b, size_xyz, 1);
+            nir_ssa_def *size_z = nir_channel(b, size_xyz, 2);
+            nir_ssa_def *size = nir_imul(b, nir_imul(b, size_x, size_y), size_z);
+
+            /* Calculate the equivalent of DIV_ROUND_UP. */
+            sysval = nir_idiv(b,
+                              nir_iadd_imm(b,
+                                 nir_iadd_imm(b, size, state->dispatch_width), -1),
+                              nir_imm_int(b, state->dispatch_width));
+         } else {
+            unsigned local_workgroup_size =
+               nir->info.cs.local_size[0] * nir->info.cs.local_size[1] *
+               nir->info.cs.local_size[2];
+            unsigned num_subgroups =
+               DIV_ROUND_UP(local_workgroup_size, state->dispatch_width);
+            sysval = nir_imm_int(b, num_subgroups);
+         }
          break;
       }
 
@@ -198,16 +220,21 @@ brw_nir_lower_cs_intrinsics(nir_shader *nir,
       .dispatch_width = dispatch_width,
    };
 
-   assert(!nir->info.cs.local_size_variable);
-   state.local_workgroup_size = nir->info.cs.local_size[0] *
-                                nir->info.cs.local_size[1] *
-                                nir->info.cs.local_size[2];
+   if (!nir->info.cs.local_size_variable) {
+      state.local_workgroup_size = nir->info.cs.local_size[0] *
+                                   nir->info.cs.local_size[1] *
+                                   nir->info.cs.local_size[2];
+   } else {
+      state.local_workgroup_size = nir->info.cs.max_variable_local_size;
+   }
 
    /* Constraints from NV_compute_shader_derivatives. */
-   if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_QUADS) {
+   if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_QUADS &&
+       !nir->info.cs.local_size_variable) {
       assert(nir->info.cs.local_size[0] % 2 == 0);
       assert(nir->info.cs.local_size[1] % 2 == 0);
-   } else if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_LINEAR) {
+   } else if (nir->info.cs.derivative_group == DERIVATIVE_GROUP_LINEAR &&
+              !nir->info.cs.local_size_variable) {
       assert(state.local_workgroup_size % 4 == 0);
    }