v3d: Add support for CS workgroup/invocation id intrinsics.
[mesa.git] / src / broadcom / compiler / nir_to_vir.c
index f10ed5975c1683688271f3976cfe03588a051483..c1889a7d645bab34650c811ffc2fa4a0f11db3f2 100644 (file)
@@ -1899,6 +1899,32 @@ ntq_emit_intrinsic(struct v3d_compile *c, nir_intrinsic_instr *instr)
                  */
                 break;
 
+        case nir_intrinsic_load_num_work_groups:
+                for (int i = 0; i < 3; i++) {
+                        ntq_store_dest(c, &instr->dest, i,
+                                       vir_uniform(c, QUNIFORM_NUM_WORK_GROUPS,
+                                                   i));
+                }
+                break;
+
+        case nir_intrinsic_load_local_invocation_index:
+                ntq_store_dest(c, &instr->dest, 0,
+                               vir_SHR(c, c->cs_payload[1],
+                                       vir_uniform_ui(c, 32 - c->local_invocation_index_bits)));
+                break;
+
+        case nir_intrinsic_load_work_group_id:
+                ntq_store_dest(c, &instr->dest, 0,
+                               vir_AND(c, c->cs_payload[0],
+                                       vir_uniform_ui(c, 0xffff)));
+                ntq_store_dest(c, &instr->dest, 1,
+                               vir_SHR(c, c->cs_payload[0],
+                                       vir_uniform_ui(c, 16)));
+                ntq_store_dest(c, &instr->dest, 2,
+                               vir_AND(c, c->cs_payload[1],
+                                       vir_uniform_ui(c, 0xffff)));
+                break;
+
         default:
                 fprintf(stderr, "Unknown intrinsic: ");
                 nir_print_instr(&instr->instr, stderr);
@@ -2255,7 +2281,8 @@ ntq_emit_impl(struct v3d_compile *c, nir_function_impl *impl)
 static void
 nir_to_vir(struct v3d_compile *c)
 {
-        if (c->s->info.stage == MESA_SHADER_FRAGMENT) {
+        switch (c->s->info.stage) {
+        case MESA_SHADER_FRAGMENT:
                 c->payload_w = vir_MOV(c, vir_reg(QFILE_REG, 0));
                 c->payload_w_centroid = vir_MOV(c, vir_reg(QFILE_REG, 1));
                 c->payload_z = vir_MOV(c, vir_reg(QFILE_REG, 2));
@@ -2270,6 +2297,30 @@ nir_to_vir(struct v3d_compile *c)
                 } else if (c->fs_key->is_lines) {
                         c->line_x = emit_fragment_varying(c, NULL, 0, 0);
                 }
+                break;
+        case MESA_SHADER_COMPUTE:
+                if (c->s->info.system_values_read &
+                    ((1ull << SYSTEM_VALUE_LOCAL_INVOCATION_INDEX) |
+                     (1ull << SYSTEM_VALUE_WORK_GROUP_ID))) {
+                        c->cs_payload[0] = vir_MOV(c, vir_reg(QFILE_REG, 0));
+                }
+                if (c->s->info.system_values_read &
+                    ((1ull << SYSTEM_VALUE_WORK_GROUP_ID))) {
+                        c->cs_payload[1] = vir_MOV(c, vir_reg(QFILE_REG, 2));
+                }
+
+                /* Set up the division between gl_LocalInvocationIndex and
+                 * wg_in_mem in the payload reg.
+                 */
+                int wg_size = (c->s->info.cs.local_size[0] *
+                               c->s->info.cs.local_size[1] *
+                               c->s->info.cs.local_size[2]);
+                c->local_invocation_index_bits =
+                        ffs(util_next_power_of_two(MAX2(wg_size, 64))) - 1;
+                assert(c->local_invocation_index_bits <= 8);
+                break;
+        default:
+                break;
         }
 
         if (c->s->info.stage == MESA_SHADER_FRAGMENT)
@@ -2298,6 +2349,7 @@ const nir_shader_compiler_options v3d_nir_options = {
         .lower_bitfield_extract_to_shifts = true,
         .lower_bitfield_reverse = true,
         .lower_bit_count = true,
+        .lower_cs_local_id_from_index = true,
         .lower_pack_unorm_2x16 = true,
         .lower_pack_snorm_2x16 = true,
         .lower_pack_unorm_4x8 = true,