radeonsi: initialize the per-context compiler on demand
[mesa.git] / src / gallium / drivers / radeonsi / si_compute_prim_discard.c
index 3bed818d5add30a4ebac7b4b8dcde083745ff288..31c18e098e6df2365e63e436af315e0d84265bc1 100644 (file)
@@ -267,7 +267,6 @@ static LLVMValueRef si_expand_32bit_pointer(struct si_shader_context *ctx, LLVMV
 
 struct si_thread0_section {
        struct si_shader_context *ctx;
-       struct lp_build_if_state if_thread0;
        LLVMValueRef vgpr_result; /* a VGPR for the value on thread 0. */
        LLVMValueRef saved_exec;
 };
@@ -288,9 +287,9 @@ static void si_enter_thread0_section(struct si_shader_context *ctx,
         *
         * It could just be s_and_saveexec_b64 s, 1.
         */
-       lp_build_if(&section->if_thread0, &ctx->gallivm,
-                   LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, thread_id,
-                                 ctx->i32_0, ""));
+       ac_build_ifcc(&ctx->ac,
+                     LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, thread_id,
+                                   ctx->i32_0, ""), 12601);
 }
 
 /* Exit a section that only executes on thread 0 and broadcast the result
@@ -302,7 +301,7 @@ static void si_exit_thread0_section(struct si_thread0_section *section,
 
        LLVMBuildStore(ctx->ac.builder, *result, section->vgpr_result);
 
-       lp_build_endif(&section->if_thread0);
+       ac_build_endif(&ctx->ac, 12601);
 
        /* Broadcast the result from thread 0 to all threads. */
        *result = ac_build_readlane(&ctx->ac,
@@ -319,50 +318,51 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        ac_add_function_attr(ctx->ac.context, vs, -1, AC_FUNC_ATTR_ALWAYSINLINE);
        LLVMSetLinkage(vs, LLVMPrivateLinkage);
 
-       LLVMTypeRef const_desc_type;
+       enum ac_arg_type const_desc_type;
        if (ctx->shader->selector->info.const_buffers_declared == 1 &&
            ctx->shader->selector->info.shader_buffers_declared == 0)
-               const_desc_type = ctx->f32;
+               const_desc_type = AC_ARG_CONST_FLOAT_PTR;
        else
-               const_desc_type = ctx->v4i32;
-
-       struct si_function_info fninfo;
-       si_init_function_info(&fninfo);
-
-       LLVMValueRef index_buffers_and_constants, vertex_counter, vb_desc, const_desc;
-       LLVMValueRef base_vertex, start_instance, block_id, local_id, ordered_wave_id;
-       LLVMValueRef restart_index, vp_scale[2], vp_translate[2], smallprim_precision;
-       LLVMValueRef num_prims_udiv_multiplier, num_prims_udiv_terms, sampler_desc;
-       LLVMValueRef last_wave_prim_id, vertex_count_addr;
-
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(ctx->v4i32),
-                      &index_buffers_and_constants);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &vertex_counter);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &last_wave_prim_id);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &vertex_count_addr);
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(ctx->v4i32),
-                      &vb_desc);
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(const_desc_type),
-                      &const_desc);
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(ctx->v8i32),
-                      &sampler_desc);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &base_vertex);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &start_instance);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &num_prims_udiv_multiplier);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &num_prims_udiv_terms);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &restart_index);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->f32, &smallprim_precision);
+               const_desc_type = AC_ARG_CONST_DESC_PTR;
+
+       memset(&ctx->args, 0, sizeof(ctx->args));
+
+       struct ac_arg param_index_buffers_and_constants, param_vertex_counter;
+       struct ac_arg param_vb_desc, param_const_desc;
+       struct ac_arg param_base_vertex, param_start_instance;
+       struct ac_arg param_block_id, param_local_id, param_ordered_wave_id;
+       struct ac_arg param_restart_index, param_smallprim_precision;
+       struct ac_arg param_num_prims_udiv_multiplier, param_num_prims_udiv_terms;
+       struct ac_arg param_sampler_desc, param_last_wave_prim_id, param_vertex_count_addr;
+
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                  &param_index_buffers_and_constants);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_vertex_counter);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_last_wave_prim_id);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_vertex_count_addr);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                  &param_vb_desc);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, const_desc_type,
+                  &param_const_desc);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_IMAGE_PTR,
+                  &param_sampler_desc);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_base_vertex);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_start_instance);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_num_prims_udiv_multiplier);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_num_prims_udiv_terms);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_restart_index);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, &param_smallprim_precision);
 
        /* Block ID and thread ID inputs. */
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &block_id);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_block_id);
        if (VERTEX_COUNTER_GDS_MODE == 2)
-               add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &ordered_wave_id);
-       add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &local_id);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_ordered_wave_id);
+       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &param_local_id);
 
        /* Create the compute shader function. */
        unsigned old_type = ctx->type;
        ctx->type = PIPE_SHADER_COMPUTE;
-       si_create_function(ctx, "prim_discard_cs", NULL, 0, &fninfo, THREADGROUP_SIZE);
+       si_create_function(ctx, "prim_discard_cs", NULL, 0, THREADGROUP_SIZE);
        ctx->type = old_type;
 
        if (VERTEX_COUNTER_GDS_MODE == 1) {
@@ -377,14 +377,14 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
        vs_params[num_vs_params++] = LLVMGetUndef(LLVMTypeOf(LLVMGetParam(vs, 0))); /* RW_BUFFERS */
        vs_params[num_vs_params++] = LLVMGetUndef(LLVMTypeOf(LLVMGetParam(vs, 1))); /* BINDLESS */
-       vs_params[num_vs_params++] = const_desc;
-       vs_params[num_vs_params++] = sampler_desc;
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_const_desc);
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_sampler_desc);
        vs_params[num_vs_params++] = LLVMConstInt(ctx->i32,
                                        S_VS_STATE_INDEXED(key->opt.cs_indexed), 0);
-       vs_params[num_vs_params++] = base_vertex;
-       vs_params[num_vs_params++] = start_instance;
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_base_vertex);
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_start_instance);
        vs_params[num_vs_params++] = ctx->i32_0; /* DrawID */
-       vs_params[num_vs_params++] = vb_desc;
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_vb_desc);
 
        vs_params[(param_vertex_id = num_vs_params++)] = NULL; /* VertexID */
        vs_params[(param_instance_id = num_vs_params++)] = NULL; /* InstanceID */
@@ -397,6 +397,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        /* Load descriptors. (load 8 dwords at once) */
        LLVMValueRef input_indexbuf, output_indexbuf, tmp, desc[8];
 
+       LLVMValueRef index_buffers_and_constants = ac_get_arg(&ctx->ac, param_index_buffers_and_constants);
        tmp = LLVMBuildPointerCast(builder, index_buffers_and_constants,
                                   ac_array_in_const32_addr_space(ctx->v8i32), "");
        tmp = ac_build_load_to_sgpr(&ctx->ac, tmp, ctx->i32_0);
@@ -409,12 +410,17 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
        /* Compute PrimID and InstanceID. */
        LLVMValueRef global_thread_id =
-               ac_build_imad(&ctx->ac, block_id,
-                             LLVMConstInt(ctx->i32, THREADGROUP_SIZE, 0), local_id);
+               ac_build_imad(&ctx->ac, ac_get_arg(&ctx->ac, param_block_id),
+                             LLVMConstInt(ctx->i32, THREADGROUP_SIZE, 0),
+                             ac_get_arg(&ctx->ac, param_local_id));
        LLVMValueRef prim_id = global_thread_id; /* PrimID within an instance */
        LLVMValueRef instance_id = ctx->i32_0;
 
        if (key->opt.cs_instancing) {
+               LLVMValueRef num_prims_udiv_terms =
+                       ac_get_arg(&ctx->ac, param_num_prims_udiv_terms);
+               LLVMValueRef num_prims_udiv_multiplier =
+                       ac_get_arg(&ctx->ac, param_num_prims_udiv_multiplier);
                /* Unpack num_prims_udiv_terms. */
                LLVMValueRef post_shift = LLVMBuildAnd(builder, num_prims_udiv_terms,
                                                       LLVMConstInt(ctx->i32, 0x1f, 0), "");
@@ -473,11 +479,13 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                for (unsigned i = 0; i < 3; i++) {
                        index[i] = ac_build_buffer_load_format(&ctx->ac, input_indexbuf,
                                                               index[i], ctx->i32_0, 1,
-                                                              false, true);
+                                                              0, true);
                        index[i] = ac_to_integer(&ctx->ac, index[i]);
                }
        }
 
+       LLVMValueRef ordered_wave_id = ac_get_arg(&ctx->ac, param_ordered_wave_id);
+
        /* Extract the ordered wave ID. */
        if (VERTEX_COUNTER_GDS_MODE == 2) {
                ordered_wave_id = LLVMBuildLShr(builder, ordered_wave_id,
@@ -486,7 +494,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                                               LLVMConstInt(ctx->i32, 0xfff, 0), "");
        }
        LLVMValueRef thread_id =
-               LLVMBuildAnd(builder, local_id, LLVMConstInt(ctx->i32, 63, 0), "");
+               LLVMBuildAnd(builder, ac_get_arg(&ctx->ac, param_local_id),
+                            LLVMConstInt(ctx->i32, 63, 0), "");
 
        /* Every other triangle in a strip has a reversed vertex order, so we
         * need to swap vertices of odd primitives to get the correct primitive
@@ -494,6 +503,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
         * restart complicates it, because a strip can start anywhere.
         */
        LLVMValueRef prim_restart_accepted = ctx->i1true;
+       LLVMValueRef vertex_counter = ac_get_arg(&ctx->ac, param_vertex_counter);
 
        if (key->opt.cs_prim_type == PIPE_PRIM_TRIANGLE_STRIP) {
                /* Without primitive restart, odd primitives have reversed orientation.
@@ -521,7 +531,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
                        for (unsigned i = 0; i < 3; i++) {
                                LLVMValueRef not_reset = LLVMBuildICmp(builder, LLVMIntNE, index[i],
-                                                                      restart_index, "");
+                                                                      ac_get_arg(&ctx->ac, param_restart_index),
+                                                                      "");
                                if (i == 0)
                                        index0_is_reset = LLVMBuildNot(builder, not_reset, "");
                                prim_restart_accepted = LLVMBuildAnd(builder, prim_restart_accepted,
@@ -597,10 +608,9 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                                 *    // Just read the value that previous waves stored.
                                 *    first_is_odd = ds.ordered.add(0);
                                 */
-                               struct lp_build_if_state if_overwrite_counter;
-                               lp_build_if(&if_overwrite_counter, &ctx->gallivm,
-                                           LLVMBuildOr(builder, is_first_wave,
-                                                       current_wave_resets_index, ""));
+                               ac_build_ifcc(&ctx->ac,
+                                             LLVMBuildOr(builder, is_first_wave,
+                                                         current_wave_resets_index, ""), 12602);
                                {
                                        /* The GDS address is always 0 with ordered append. */
                                        tmp = si_build_ds_ordered_op(ctx, "swap",
@@ -608,7 +618,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                                                                     1, true, false);
                                        LLVMBuildStore(builder, tmp, ret);
                                }
-                               lp_build_else(&if_overwrite_counter);
+                               ac_build_else(&ctx->ac, 12603);
                                {
                                        /* Just read the value from GDS. */
                                        tmp = si_build_ds_ordered_op(ctx, "add",
@@ -616,7 +626,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                                                                     1, true, false);
                                        LLVMBuildStore(builder, tmp, ret);
                                }
-                               lp_build_endif(&if_overwrite_counter);
+                               ac_build_endif(&ctx->ac, 12602);
 
                                prev_wave_state = LLVMBuildLoad(builder, ret, "");
                                /* Ignore the return value if this is the first wave. */
@@ -667,7 +677,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                vs_params[param_vertex_id] = index[i];
                vs_params[param_instance_id] = instance_id;
 
-               LLVMValueRef ret = LLVMBuildCall(builder, vs, vs_params, num_vs_params, "");
+               LLVMValueRef ret = ac_build_call(&ctx->ac, vs, vs_params, num_vs_params);
                for (unsigned chan = 0; chan < 4; chan++)
                        pos[i][chan] = LLVMBuildExtractValue(builder, ret, chan, "");
        }
@@ -682,6 +692,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        LLVMValueRef vp = ac_build_load_invariant(&ctx->ac, index_buffers_and_constants,
                                                  LLVMConstInt(ctx->i32, 2, 0));
        vp = LLVMBuildBitCast(builder, vp, ctx->v4f32, "");
+       LLVMValueRef vp_scale[2], vp_translate[2];
        vp_scale[0] = ac_llvm_extract_elem(&ctx->ac, vp, 0);
        vp_scale[1] = ac_llvm_extract_elem(&ctx->ac, vp, 1);
        vp_translate[0] = ac_llvm_extract_elem(&ctx->ac, vp, 2);
@@ -701,7 +712,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
        LLVMValueRef accepted =
                ac_cull_triangle(&ctx->ac, pos, prim_restart_accepted,
-                                vp_scale, vp_translate, smallprim_precision,
+                                vp_scale, vp_translate,
+                                ac_get_arg(&ctx->ac, param_smallprim_precision),
                                 &options);
 
        LLVMValueRef accepted_threadmask = ac_get_i1_sgpr_mask(&ctx->ac, accepted);
@@ -760,15 +772,14 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                         *    previous = ds.ordered.add(num_prims_accepted) // add the primitive count
                         * }
                         */
-                       struct lp_build_if_state if_first_wave;
-                       lp_build_if(&if_first_wave, &ctx->gallivm, is_first_wave);
+                       ac_build_ifcc(&ctx->ac, is_first_wave, 12604);
                        {
                                /* The GDS address is always 0 with ordered append. */
                                si_build_ds_ordered_op(ctx, "swap", ordered_wave_id,
                                                       num_prims_accepted, 0, true, true);
                                LLVMBuildStore(builder, ctx->i32_0, tmp_store);
                        }
-                       lp_build_else(&if_first_wave);
+                       ac_build_else(&ctx->ac, 12605);
                        {
                                LLVMBuildStore(builder,
                                               si_build_ds_ordered_op(ctx, "add", ordered_wave_id,
@@ -776,7 +787,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                                                                      true, true),
                                               tmp_store);
                        }
-                       lp_build_endif(&if_first_wave);
+                       ac_build_endif(&ctx->ac, 12604);
 
                        start = LLVMBuildLoad(builder, tmp_store, "");
                }
@@ -789,20 +800,20 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
         * event like this.
         */
        if (VERTEX_COUNTER_GDS_MODE == 2) {
-               struct lp_build_if_state if_last_wave;
-               lp_build_if(&if_last_wave, &ctx->gallivm,
-                           LLVMBuildICmp(builder, LLVMIntEQ, global_thread_id,
-                                         last_wave_prim_id, ""));
+               ac_build_ifcc(&ctx->ac,
+                             LLVMBuildICmp(builder, LLVMIntEQ, global_thread_id,
+                                           ac_get_arg(&ctx->ac, param_last_wave_prim_id), ""),
+                             12606);
                LLVMValueRef count = LLVMBuildAdd(builder, start, num_prims_accepted, "");
                count = LLVMBuildMul(builder, count,
                                     LLVMConstInt(ctx->i32, vertices_per_prim, 0), "");
 
-               /* VI needs to disable caching, so that the CP can see the stored value.
+               /* GFX8 needs to disable caching, so that the CP can see the stored value.
                 * MTYPE=3 bypasses TC L2.
                 */
                if (ctx->screen->info.chip_class <= GFX8) {
                        LLVMValueRef desc[] = {
-                               vertex_count_addr,
+                               ac_get_arg(&ctx->ac, param_vertex_count_addr),
                                LLVMConstInt(ctx->i32,
                                        S_008F04_BASE_ADDRESS_HI(ctx->screen->info.address32_hi), 0),
                                LLVMConstInt(ctx->i32, 4, 0),
@@ -811,12 +822,14 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                        };
                        LLVMValueRef rsrc = ac_build_gather_values(&ctx->ac, desc, 4);
                        ac_build_buffer_store_dword(&ctx->ac, rsrc, count, 1, ctx->i32_0,
-                                                   ctx->i32_0, 0, true, true, true, false);
+                                                   ctx->i32_0, 0, ac_glc | ac_slc);
                } else {
                        LLVMBuildStore(builder, count,
-                                      si_expand_32bit_pointer(ctx, vertex_count_addr));
+                                      si_expand_32bit_pointer(ctx,
+                                                              ac_get_arg(&ctx->ac,
+                                                                         param_vertex_count_addr)));
                }
-               lp_build_endif(&if_last_wave);
+               ac_build_endif(&ctx->ac, 12606);
        } else {
                /* For unordered modes that increment a vertex count instead of
                 * primitive count, convert it into the primitive index.
@@ -828,8 +841,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        /* Now we need to store the indices of accepted primitives into
         * the output index buffer.
         */
-       struct lp_build_if_state if_accepted;
-       lp_build_if(&if_accepted, &ctx->gallivm, accepted);
+       ac_build_ifcc(&ctx->ac, accepted, 16607);
        {
                /* Get the number of bits set before the index of this thread. */
                LLVMValueRef prim_index = ac_build_mbcnt(&ctx->ac, accepted_threadmask);
@@ -863,10 +875,10 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                        vdata = ac_build_expand_to_vec4(&ctx->ac, vdata, 3);
 
                ac_build_buffer_store_format(&ctx->ac, output_indexbuf, vdata,
-                                            vindex, ctx->i32_0, 3, true,
-                                            INDEX_STORES_USE_SLC, true);
+                                            vindex, ctx->i32_0, 3,
+                                            ac_glc | (INDEX_STORES_USE_SLC ? ac_slc : 0));
        }
-       lp_build_endif(&if_accepted);
+       ac_build_endif(&ctx->ac, 16607);
 
        LLVMBuildRetVoid(builder);
 }
@@ -927,6 +939,9 @@ static bool si_shader_select_prim_discard_cs(struct si_context *sctx,
        sctx->cs_prim_discard_state.cso = sctx->vs_shader.cso;
        sctx->cs_prim_discard_state.current = NULL;
 
+       if (!sctx->compiler.passes)
+               si_init_compiler(sctx->screen, &sctx->compiler);
+
        struct si_compiler_ctx_state compiler_state;
        compiler_state.compiler = &sctx->compiler;
        compiler_state.debug = sctx->debug;
@@ -982,7 +997,7 @@ static bool si_initialize_prim_discard_cmdbuf(struct si_context *sctx)
                                                 SI_RESOURCE_FLAG_UNMAPPABLE,
                                                 PIPE_USAGE_DEFAULT,
                                                 sctx->index_ring_size_per_ib * 2,
-                                                2 * 1024 * 1024);
+                                                sctx->screen->info.pte_fragment_size);
                if (!sctx->index_ring)
                        return false;
        }
@@ -1103,7 +1118,7 @@ si_prepare_prim_discard_or_split_draw(struct si_context *sctx,
 
        /* The compute IB is always chained, but we need to call cs_check_space to add more space. */
        struct radeon_cmdbuf *cs = sctx->prim_discard_compute_cs;
-       bool compute_has_space = sctx->ws->cs_check_space(cs, need_compute_dw, false);
+       ASSERTED bool compute_has_space = sctx->ws->cs_check_space(cs, need_compute_dw, false);
        assert(compute_has_space);
        assert(si_check_ring_space(sctx, out_indexbuf_size));
        return SI_PRIM_DISCARD_ENABLED;
@@ -1196,6 +1211,8 @@ void si_dispatch_prim_discard_cs_and_draw(struct si_context *sctx,
 
                /* This needs to be done at the beginning of IBs due to possible
                 * TTM buffer moves in the kernel.
+                *
+                * TODO: update for GFX10
                 */
                si_emit_surface_sync(sctx, cs,
                                     S_0085F0_TC_ACTION_ENA(1) |
@@ -1424,8 +1441,10 @@ void si_dispatch_prim_discard_cs_and_draw(struct si_context *sctx,
                                S_00B84C_LDS_SIZE(shader->config.lds_size));
 
                radeon_set_sh_reg(cs, R_00B854_COMPUTE_RESOURCE_LIMITS,
-                       si_get_compute_resource_limits(sctx->screen, WAVES_PER_TG,
-                                                      MAX_WAVES_PER_SH, THREADGROUPS_PER_CU));
+                       ac_get_compute_resource_limits(&sctx->screen->info,
+                                                      WAVES_PER_TG,
+                                                      MAX_WAVES_PER_SH,
+                                                      THREADGROUPS_PER_CU));
                sctx->compute_ib_last_shader = shader;
        }