ac/nir, radv, radeonsi: Switch to using ac_shader_args
authorConnor Abbott <cwabbott0@gmail.com>
Mon, 11 Nov 2019 11:50:12 +0000 (12:50 +0100)
committerConnor Abbott <cwabbott0@gmail.com>
Mon, 25 Nov 2019 13:17:10 +0000 (14:17 +0100)
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Acked-by: Marek Olšák <marek.olsak@amd.com>
13 files changed:
src/amd/llvm/ac_nir_to_llvm.c
src/amd/llvm/ac_nir_to_llvm.h
src/amd/llvm/ac_shader_abi.h
src/amd/vulkan/meson.build
src/amd/vulkan/radv_nir_to_llvm.c
src/amd/vulkan/radv_shader_args.h [new file with mode: 0644]
src/gallium/drivers/radeonsi/gfx10_shader_ngg.c
src/gallium/drivers/radeonsi/si_compute_prim_discard.c
src/gallium/drivers/radeonsi/si_shader.c
src/gallium/drivers/radeonsi/si_shader_internal.h
src/gallium/drivers/radeonsi/si_shader_nir.c
src/gallium/drivers/radeonsi/si_shader_tgsi_mem.c
src/gallium/drivers/radeonsi/si_shader_tgsi_setup.c

index 8fae7bb5b77a105b082b59b64b8682ab1aa2e03e..9e9ddf62555dd63b9673aef998a99c08a0a2cf57 100644 (file)
@@ -38,6 +38,7 @@
 struct ac_nir_context {
        struct ac_llvm_context ac;
        struct ac_shader_abi *abi;
+       const struct ac_shader_args *args;
 
        gl_shader_stage stage;
        shader_info *info;
@@ -1435,16 +1436,22 @@ static LLVMValueRef visit_load_push_constant(struct ac_nir_context *ctx,
                offset += LLVMConstIntGetZExtValue(src0);
                offset /= 4;
 
-               offset -= ctx->abi->base_inline_push_consts;
+               offset -= ctx->args->base_inline_push_consts;
 
-               if (offset + count <= ctx->abi->num_inline_push_consts) {
+               unsigned num_inline_push_consts = ctx->args->num_inline_push_consts;
+               if (offset + count <= num_inline_push_consts) {
+                       LLVMValueRef push_constants[num_inline_push_consts];
+                       for (unsigned i = 0; i < num_inline_push_consts; i++)
+                               push_constants[i] = ac_get_arg(&ctx->ac,
+                                                              ctx->args->inline_push_consts[i]);
                        return ac_build_gather_values(&ctx->ac,
-                                                     ctx->abi->inline_push_consts + offset,
+                                                     push_constants + offset,
                                                      count);
                }
        }
 
-       ptr = LLVMBuildGEP(ctx->ac.builder, ctx->abi->push_constants, &addr, 1, "");
+       ptr = LLVMBuildGEP(ctx->ac.builder,
+                          ac_get_arg(&ctx->ac, ctx->args->push_constants), &addr, 1, "");
 
        if (instr->dest.ssa.bit_size == 8) {
                unsigned load_dwords = instr->dest.ssa.num_components > 1 ? 2 : 1;
@@ -2902,7 +2909,8 @@ visit_load_local_invocation_index(struct ac_nir_context *ctx)
 {
        LLVMValueRef result;
        LLVMValueRef thread_id = ac_get_thread_id(&ctx->ac);
-       result = LLVMBuildAnd(ctx->ac.builder, ctx->abi->tg_size,
+       result = LLVMBuildAnd(ctx->ac.builder,
+                             ac_get_arg(&ctx->ac, ctx->args->tg_size),
                              LLVMConstInt(ctx->ac.i32, 0xfc0, false), "");
 
        if (ctx->ac.wave_size == 32)
@@ -2917,7 +2925,8 @@ visit_load_subgroup_id(struct ac_nir_context *ctx)
 {
        if (ctx->stage == MESA_SHADER_COMPUTE) {
                LLVMValueRef result;
-               result = LLVMBuildAnd(ctx->ac.builder, ctx->abi->tg_size,
+               result = LLVMBuildAnd(ctx->ac.builder,
+                                     ac_get_arg(&ctx->ac, ctx->args->tg_size),
                                LLVMConstInt(ctx->ac.i32, 0xfc0, false), "");
                return LLVMBuildLShr(ctx->ac.builder, result,  LLVMConstInt(ctx->ac.i32, 6, false), "");
        } else {
@@ -2929,7 +2938,8 @@ static LLVMValueRef
 visit_load_num_subgroups(struct ac_nir_context *ctx)
 {
        if (ctx->stage == MESA_SHADER_COMPUTE) {
-               return LLVMBuildAnd(ctx->ac.builder, ctx->abi->tg_size,
+               return LLVMBuildAnd(ctx->ac.builder,
+                                   ac_get_arg(&ctx->ac, ctx->args->tg_size),
                                    LLVMConstInt(ctx->ac.i32, 0x3f, false), "");
        } else {
                return LLVMConstInt(ctx->ac.i32, 1, false);
@@ -3059,8 +3069,10 @@ static LLVMValueRef load_sample_pos(struct ac_nir_context *ctx)
        LLVMValueRef values[2];
        LLVMValueRef pos[2];
 
-       pos[0] = ac_to_float(&ctx->ac, ctx->abi->frag_pos[0]);
-       pos[1] = ac_to_float(&ctx->ac, ctx->abi->frag_pos[1]);
+       pos[0] = ac_to_float(&ctx->ac,
+                            ac_get_arg(&ctx->ac, ctx->args->frag_pos[0]));
+       pos[1] = ac_to_float(&ctx->ac,
+                            ac_get_arg(&ctx->ac, ctx->args->frag_pos[1]));
 
        values[0] = ac_build_fract(&ctx->ac, pos[0], 32);
        values[1] = ac_build_fract(&ctx->ac, pos[1], 32);
@@ -3077,19 +3089,19 @@ static LLVMValueRef lookup_interp_param(struct ac_nir_context *ctx,
        case INTERP_MODE_SMOOTH:
        case INTERP_MODE_NONE:
                if (location == INTERP_CENTER)
-                       return ctx->abi->persp_center;
+                       return ac_get_arg(&ctx->ac, ctx->args->persp_center);
                else if (location == INTERP_CENTROID)
                        return ctx->abi->persp_centroid;
                else if (location == INTERP_SAMPLE)
-                       return ctx->abi->persp_sample;
+                       return ac_get_arg(&ctx->ac, ctx->args->persp_sample);
                break;
        case INTERP_MODE_NOPERSPECTIVE:
                if (location == INTERP_CENTER)
-                       return ctx->abi->linear_center;
+                       return ac_get_arg(&ctx->ac, ctx->args->linear_center);
                else if (location == INTERP_CENTROID)
                        return ctx->abi->linear_centroid;
                else if (location == INTERP_SAMPLE)
-                       return ctx->abi->linear_sample;
+                       return ac_get_arg(&ctx->ac, ctx->args->linear_sample);
                break;
        }
        return NULL;
@@ -3203,10 +3215,10 @@ static LLVMValueRef load_interpolated_input(struct ac_nir_context *ctx,
                LLVMValueRef llvm_chan = LLVMConstInt(ctx->ac.i32, comp_start + comp, false);
                if (bitsize == 16) {
                        values[comp] = ac_build_fs_interp_f16(&ctx->ac, llvm_chan, attr_number,
-                                                             ctx->abi->prim_mask, i, j);
+                                                             ac_get_arg(&ctx->ac, ctx->args->prim_mask), i, j);
                } else {
                        values[comp] = ac_build_fs_interp(&ctx->ac, llvm_chan, attr_number,
-                                                         ctx->abi->prim_mask, i, j);
+                                                         ac_get_arg(&ctx->ac, ctx->args->prim_mask), i, j);
                }
        }
 
@@ -3234,7 +3246,7 @@ static LLVMValueRef load_flat_input(struct ac_nir_context *ctx,
                                                      LLVMConstInt(ctx->ac.i32, 2, false),
                                                      llvm_chan,
                                                      attr_number,
-                                                     ctx->abi->prim_mask);
+                                                     ac_get_arg(&ctx->ac, ctx->args->prim_mask));
                values[chan] = LLVMBuildBitCast(ctx->ac.builder, values[chan], ctx->ac.i32, "");
                values[chan] = LLVMBuildTruncOrBitCast(ctx->ac.builder, values[chan],
                                                       bit_size == 16 ? ctx->ac.i16 : ctx->ac.i32, "");
@@ -3274,8 +3286,8 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                LLVMValueRef values[3];
 
                for (int i = 0; i < 3; i++) {
-                       values[i] = ctx->abi->workgroup_ids[i] ?
-                                   ctx->abi->workgroup_ids[i] : ctx->ac.i32_0;
+                       values[i] = ctx->args->workgroup_ids[i].used ?
+                                   ac_get_arg(&ctx->ac, ctx->args->workgroup_ids[i]) : ctx->ac.i32_0;
                }
 
                result = ac_build_gather_values(&ctx->ac, values, 3);
@@ -3289,51 +3301,56 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->load_local_group_size(ctx->abi);
                break;
        case nir_intrinsic_load_vertex_id:
-               result = LLVMBuildAdd(ctx->ac.builder, ctx->abi->vertex_id,
-                                     ctx->abi->base_vertex, "");
+               result = LLVMBuildAdd(ctx->ac.builder,
+                                     ac_get_arg(&ctx->ac, ctx->args->vertex_id),
+                                     ac_get_arg(&ctx->ac, ctx->args->base_vertex), "");
                break;
        case nir_intrinsic_load_vertex_id_zero_base: {
                result = ctx->abi->vertex_id;
                break;
        }
        case nir_intrinsic_load_local_invocation_id: {
-               result = ctx->abi->local_invocation_ids;
+               result = ac_get_arg(&ctx->ac, ctx->args->local_invocation_ids);
                break;
        }
        case nir_intrinsic_load_base_instance:
-               result = ctx->abi->start_instance;
+               result = ac_get_arg(&ctx->ac, ctx->args->start_instance);
                break;
        case nir_intrinsic_load_draw_id:
-               result = ctx->abi->draw_id;
+               result = ac_get_arg(&ctx->ac, ctx->args->draw_id);
                break;
        case nir_intrinsic_load_view_index:
-               result = ctx->abi->view_index;
+               result = ac_get_arg(&ctx->ac, ctx->args->view_index);
                break;
        case nir_intrinsic_load_invocation_id:
                if (ctx->stage == MESA_SHADER_TESS_CTRL) {
-                       result = ac_unpack_param(&ctx->ac, ctx->abi->tcs_rel_ids, 8, 5);
+                       result = ac_unpack_param(&ctx->ac,
+                                                ac_get_arg(&ctx->ac, ctx->args->tcs_rel_ids),
+                                                8, 5);
                } else {
                        if (ctx->ac.chip_class >= GFX10) {
                                result = LLVMBuildAnd(ctx->ac.builder,
-                                                     ctx->abi->gs_invocation_id,
+                                                     ac_get_arg(&ctx->ac, ctx->args->gs_invocation_id),
                                                      LLVMConstInt(ctx->ac.i32, 127, 0), "");
                        } else {
-                               result = ctx->abi->gs_invocation_id;
+                               result = ac_get_arg(&ctx->ac, ctx->args->gs_invocation_id);
                        }
                }
                break;
        case nir_intrinsic_load_primitive_id:
                if (ctx->stage == MESA_SHADER_GEOMETRY) {
-                       result = ctx->abi->gs_prim_id;
+                       result = ac_get_arg(&ctx->ac, ctx->args->gs_prim_id);
                } else if (ctx->stage == MESA_SHADER_TESS_CTRL) {
-                       result = ctx->abi->tcs_patch_id;
+                       result = ac_get_arg(&ctx->ac, ctx->args->tcs_patch_id);
                } else if (ctx->stage == MESA_SHADER_TESS_EVAL) {
-                       result = ctx->abi->tes_patch_id;
+                       result = ac_get_arg(&ctx->ac, ctx->args->tes_patch_id);
                } else
                        fprintf(stderr, "Unknown primitive id intrinsic: %d", ctx->stage);
                break;
        case nir_intrinsic_load_sample_id:
-               result = ac_unpack_param(&ctx->ac, ctx->abi->ancillary, 8, 4);
+               result = ac_unpack_param(&ctx->ac,
+                                        ac_get_arg(&ctx->ac, ctx->args->ancillary),
+                                        8, 4);
                break;
        case nir_intrinsic_load_sample_pos:
                result = load_sample_pos(ctx);
@@ -3343,10 +3360,11 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                break;
        case nir_intrinsic_load_frag_coord: {
                LLVMValueRef values[4] = {
-                       ctx->abi->frag_pos[0],
-                       ctx->abi->frag_pos[1],
-                       ctx->abi->frag_pos[2],
-                       ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, ctx->abi->frag_pos[3])
+                       ac_get_arg(&ctx->ac, ctx->args->frag_pos[0]),
+                       ac_get_arg(&ctx->ac, ctx->args->frag_pos[1]),
+                       ac_get_arg(&ctx->ac, ctx->args->frag_pos[2]),
+                       ac_build_fdiv(&ctx->ac, ctx->ac.f32_1,
+                                     ac_get_arg(&ctx->ac, ctx->args->frag_pos[3]))
                };
                result = ac_to_integer(&ctx->ac,
                                       ac_build_gather_values(&ctx->ac, values, 4));
@@ -3356,7 +3374,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->inputs[ac_llvm_reg_index_soa(VARYING_SLOT_LAYER, 0)];
                break;
        case nir_intrinsic_load_front_face:
-               result = ctx->abi->front_face;
+               result = ac_get_arg(&ctx->ac, ctx->args->front_face);
                break;
        case nir_intrinsic_load_helper_invocation:
                result = ac_build_load_helper_invocation(&ctx->ac);
@@ -3375,7 +3393,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->instance_id;
                break;
        case nir_intrinsic_load_num_work_groups:
-               result = ctx->abi->num_work_groups;
+               result = ac_get_arg(&ctx->ac, ctx->args->num_work_groups);
                break;
        case nir_intrinsic_load_local_invocation_index:
                result = visit_load_local_invocation_index(ctx);
@@ -4714,13 +4732,14 @@ setup_shared(struct ac_nir_context *ctx,
 }
 
 void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
-                     struct nir_shader *nir)
+                     const struct ac_shader_args *args, struct nir_shader *nir)
 {
        struct ac_nir_context ctx = {};
        struct nir_function *func;
 
        ctx.ac = *ac;
        ctx.abi = abi;
+       ctx.args = args;
 
        ctx.stage = nir->info.stage;
        ctx.info = &nir->info;
index 4782d9fc9d64d70b439d15965aa1845e063a8e6d..7c2d6b319553118f071c6682373382e028d3b819 100644 (file)
@@ -34,6 +34,7 @@ struct nir_shader;
 struct nir_variable;
 struct ac_llvm_context;
 struct ac_shader_abi;
+struct ac_shader_args;
 
 /* Interpolation locations */
 #define INTERP_CENTER 0
@@ -50,7 +51,7 @@ void ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class);
 bool ac_are_tessfactors_def_in_all_invocs(const struct nir_shader *nir);
 
 void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
-                     struct nir_shader *nir);
+                     const struct ac_shader_args *args, struct nir_shader *nir);
 
 void
 ac_handle_shader_output_decl(struct ac_llvm_context *ctx,
index 01c321506ef7c68698a52584325c3495aab50e4c..ea977fc0cfa6ebe7e3d29f0c01a1c508ba0ed915 100644 (file)
@@ -25,6 +25,7 @@
 #define AC_SHADER_ABI_H
 
 #include <llvm-c/Core.h>
+#include <assert.h>
 #include "ac_shader_args.h"
 
 #include "compiler/shader_enums.h"
@@ -49,47 +50,14 @@ enum ac_descriptor_type {
  * radv to share a compiler backend.
  */
 struct ac_shader_abi {
-       LLVMValueRef base_vertex;
-       LLVMValueRef start_instance;
-       LLVMValueRef draw_id;
+       LLVMValueRef outputs[AC_LLVM_MAX_OUTPUTS * 4];
+
+       /* These input registers sometimes need to be fixed up. */
        LLVMValueRef vertex_id;
        LLVMValueRef instance_id;
-       LLVMValueRef tcs_patch_id;
-       LLVMValueRef tcs_rel_ids;
-       LLVMValueRef tes_patch_id;
-       LLVMValueRef gs_prim_id;
-       LLVMValueRef gs_invocation_id;
-
-       /* PS */
-       LLVMValueRef frag_pos[4];
-       LLVMValueRef front_face;
-       LLVMValueRef ancillary;
-       LLVMValueRef sample_coverage;
-       LLVMValueRef prim_mask;
-       LLVMValueRef color0;
-       LLVMValueRef color1;
+       LLVMValueRef persp_centroid, linear_centroid;
+       LLVMValueRef color0, color1;
        LLVMValueRef user_data;
-       LLVMValueRef persp_sample;
-       LLVMValueRef persp_center;
-       LLVMValueRef persp_centroid;
-       LLVMValueRef linear_sample;
-       LLVMValueRef linear_center;
-       LLVMValueRef linear_centroid;
-
-       /* CS */
-       LLVMValueRef local_invocation_ids;
-       LLVMValueRef num_work_groups;
-       LLVMValueRef workgroup_ids[3];
-       LLVMValueRef tg_size;
-
-       /* Vulkan only */
-       LLVMValueRef push_constants;
-       LLVMValueRef inline_push_consts[AC_MAX_INLINE_PUSH_CONSTS];
-       unsigned num_inline_push_consts;
-       unsigned base_inline_push_consts;
-       LLVMValueRef view_index;
-
-       LLVMValueRef outputs[AC_LLVM_MAX_OUTPUTS * 4];
 
        /* For VS and PS: pre-loaded shader inputs.
         *
index 54d2f319f0bf7a3c6aad58b3c170c515ab993d3f..37e76cc41a9f866617c305c8b2181f6bd176e89f 100644 (file)
@@ -101,6 +101,7 @@ libradv_files = files(
   'radv_radeon_winsys.h',
   'radv_shader.c',
   'radv_shader.h',
+  'radv_shader_args.h',
   'radv_shader_helper.h',
   'radv_shader_info.c',
   'radv_query.c',
index 148a571fc79aa6e4e9e1a70e493bdf023232ee72..060dbcf2afb6608c0f1ee881f97fa260c6b46970 100644 (file)
@@ -28,6 +28,7 @@
 #include "radv_private.h"
 #include "radv_shader.h"
 #include "radv_shader_helper.h"
+#include "radv_shader_args.h"
 #include "nir/nir.h"
 
 #include <llvm-c/Core.h>
 
 struct radv_shader_context {
        struct ac_llvm_context ac;
-       const struct radv_nir_compiler_options *options;
-       struct radv_shader_info *shader_info;
        const struct nir_shader *shader;
        struct ac_shader_abi abi;
+       const struct radv_shader_args *args;
+
+       gl_shader_stage stage;
 
        unsigned max_workgroup_size;
        LLVMContextRef context;
        LLVMValueRef main_function;
 
        LLVMValueRef descriptor_sets[MAX_SETS];
+
        LLVMValueRef ring_offsets;
 
-       LLVMValueRef vertex_buffers;
        LLVMValueRef rel_auto_id;
-       LLVMValueRef vs_prim_id;
-       LLVMValueRef es2gs_offset;
-
-       LLVMValueRef oc_lds;
-       LLVMValueRef merged_wave_info;
-       LLVMValueRef tess_factor_offset;
-       LLVMValueRef tes_rel_patch_id;
-       LLVMValueRef tes_u;
-       LLVMValueRef tes_v;
-
-       /* HW GS */
-       /* On gfx10:
-        *  - bits 0..10: ordered_wave_id
-        *  - bits 12..20: number of vertices in group
-        *  - bits 22..30: number of primitives in group
-        */
-       LLVMValueRef gs_tg_info;
-       LLVMValueRef gs2vs_offset;
+
        LLVMValueRef gs_wave_id;
        LLVMValueRef gs_vtx_offset[6];
 
@@ -87,19 +72,10 @@ struct radv_shader_context {
        LLVMValueRef hs_ring_tess_offchip;
        LLVMValueRef hs_ring_tess_factor;
 
-       /* Streamout */
-       LLVMValueRef streamout_buffers;
-       LLVMValueRef streamout_write_idx;
-       LLVMValueRef streamout_config;
-       LLVMValueRef streamout_offset[4];
-
-       gl_shader_stage stage;
-
        LLVMValueRef inputs[RADEON_LLVM_MAX_INPUTS * 4];
 
        uint64_t output_mask;
 
-       bool is_gs_copy_shader;
        LLVMValueRef gs_next_vertex[4];
        LLVMValueRef gs_curprim_verts[4];
        LLVMValueRef gs_generated_prims[4];
@@ -119,14 +95,6 @@ struct radv_shader_output_values {
        unsigned usage_mask;
 };
 
-enum radeon_llvm_calling_convention {
-       RADEON_LLVM_AMDGPU_VS = 87,
-       RADEON_LLVM_AMDGPU_GS = 88,
-       RADEON_LLVM_AMDGPU_PS = 89,
-       RADEON_LLVM_AMDGPU_CS = 90,
-       RADEON_LLVM_AMDGPU_HS = 93,
-};
-
 static inline struct radv_shader_context *
 radv_shader_context_from_abi(struct ac_shader_abi *abi)
 {
@@ -138,9 +106,11 @@ static LLVMValueRef get_rel_patch_id(struct radv_shader_context *ctx)
 {
        switch (ctx->stage) {
        case MESA_SHADER_TESS_CTRL:
-               return ac_unpack_param(&ctx->ac, ctx->abi.tcs_rel_ids, 0, 8);
+               return ac_unpack_param(&ctx->ac,
+                                      ac_get_arg(&ctx->ac, ctx->args->ac.tcs_rel_ids),
+                                      0, 8);
        case MESA_SHADER_TESS_EVAL:
-               return ctx->tes_rel_patch_id;
+               return ac_get_arg(&ctx->ac, ctx->args->tes_rel_patch_id);
                break;
        default:
                unreachable("Illegal stage");
@@ -150,12 +120,12 @@ static LLVMValueRef get_rel_patch_id(struct radv_shader_context *ctx)
 static unsigned
 get_tcs_num_patches(struct radv_shader_context *ctx)
 {
-       unsigned num_tcs_input_cp = ctx->options->key.tcs.input_vertices;
+       unsigned num_tcs_input_cp = ctx->args->options->key.tcs.input_vertices;
        unsigned num_tcs_output_cp = ctx->shader->info.tess.tcs_vertices_out;
        uint32_t input_vertex_size = ctx->tcs_num_inputs * 16;
-       uint32_t input_patch_size = ctx->options->key.tcs.input_vertices * input_vertex_size;
-       uint32_t num_tcs_outputs = util_last_bit64(ctx->shader_info->tcs.outputs_written);
-       uint32_t num_tcs_patch_outputs = util_last_bit64(ctx->shader_info->tcs.patch_outputs_written);
+       uint32_t input_patch_size = ctx->args->options->key.tcs.input_vertices * input_vertex_size;
+       uint32_t num_tcs_outputs = util_last_bit64(ctx->args->shader_info->tcs.outputs_written);
+       uint32_t num_tcs_patch_outputs = util_last_bit64(ctx->args->shader_info->tcs.patch_outputs_written);
        uint32_t output_vertex_size = num_tcs_outputs * 16;
        uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
        uint32_t output_patch_size = pervertex_output_patch_size + num_tcs_patch_outputs * 16;
@@ -177,19 +147,19 @@ get_tcs_num_patches(struct radv_shader_context *ctx)
         *
         * Test: dEQP-VK.tessellation.shader_input_output.barrier
         */
-       if (ctx->options->chip_class >= GFX7 && ctx->options->family != CHIP_STONEY)
+       if (ctx->args->options->chip_class >= GFX7 && ctx->args->options->family != CHIP_STONEY)
                hardware_lds_size = 65536;
 
        num_patches = MIN2(num_patches, hardware_lds_size / (input_patch_size + output_patch_size));
        /* Make sure the output data fits in the offchip buffer */
-       num_patches = MIN2(num_patches, (ctx->options->tess_offchip_block_dw_size * 4) / output_patch_size);
+       num_patches = MIN2(num_patches, (ctx->args->options->tess_offchip_block_dw_size * 4) / output_patch_size);
        /* Not necessary for correctness, but improves performance. The
         * specific value is taken from the proprietary driver.
         */
        num_patches = MIN2(num_patches, 40);
 
        /* GFX6 bug workaround - limit LS-HS threadgroups to only one wave. */
-       if (ctx->options->chip_class == GFX6) {
+       if (ctx->args->options->chip_class == GFX6) {
                unsigned one_wave = 64 / MAX2(num_tcs_input_cp, num_tcs_output_cp);
                num_patches = MIN2(num_patches, one_wave);
        }
@@ -199,7 +169,7 @@ get_tcs_num_patches(struct radv_shader_context *ctx)
 static unsigned
 calculate_tess_lds_size(struct radv_shader_context *ctx)
 {
-       unsigned num_tcs_input_cp = ctx->options->key.tcs.input_vertices;
+       unsigned num_tcs_input_cp = ctx->args->options->key.tcs.input_vertices;
        unsigned num_tcs_output_cp;
        unsigned num_tcs_outputs, num_tcs_patch_outputs;
        unsigned input_vertex_size, output_vertex_size;
@@ -210,8 +180,8 @@ calculate_tess_lds_size(struct radv_shader_context *ctx)
        unsigned lds_size;
 
        num_tcs_output_cp = ctx->shader->info.tess.tcs_vertices_out;
-       num_tcs_outputs = util_last_bit64(ctx->shader_info->tcs.outputs_written);
-       num_tcs_patch_outputs = util_last_bit64(ctx->shader_info->tcs.patch_outputs_written);
+       num_tcs_outputs = util_last_bit64(ctx->args->shader_info->tcs.outputs_written);
+       num_tcs_patch_outputs = util_last_bit64(ctx->args->shader_info->tcs.patch_outputs_written);
 
        input_vertex_size = ctx->tcs_num_inputs * 16;
        output_vertex_size = num_tcs_outputs * 16;
@@ -251,9 +221,9 @@ calculate_tess_lds_size(struct radv_shader_context *ctx)
 static LLVMValueRef
 get_tcs_in_patch_stride(struct radv_shader_context *ctx)
 {
-       assert (ctx->stage == MESA_SHADER_TESS_CTRL);
+       assert(ctx->stage == MESA_SHADER_TESS_CTRL);
        uint32_t input_vertex_size = ctx->tcs_num_inputs * 16;
-       uint32_t input_patch_size = ctx->options->key.tcs.input_vertices * input_vertex_size;
+       uint32_t input_patch_size = ctx->args->options->key.tcs.input_vertices * input_vertex_size;
 
        input_patch_size /= 4;
        return LLVMConstInt(ctx->ac.i32, input_patch_size, false);
@@ -262,8 +232,8 @@ get_tcs_in_patch_stride(struct radv_shader_context *ctx)
 static LLVMValueRef
 get_tcs_out_patch_stride(struct radv_shader_context *ctx)
 {
-       uint32_t num_tcs_outputs = util_last_bit64(ctx->shader_info->tcs.outputs_written);
-       uint32_t num_tcs_patch_outputs = util_last_bit64(ctx->shader_info->tcs.patch_outputs_written);
+       uint32_t num_tcs_outputs = util_last_bit64(ctx->args->shader_info->tcs.outputs_written);
+       uint32_t num_tcs_patch_outputs = util_last_bit64(ctx->args->shader_info->tcs.patch_outputs_written);
        uint32_t output_vertex_size = num_tcs_outputs * 16;
        uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
        uint32_t output_patch_size = pervertex_output_patch_size + num_tcs_patch_outputs * 16;
@@ -274,7 +244,7 @@ get_tcs_out_patch_stride(struct radv_shader_context *ctx)
 static LLVMValueRef
 get_tcs_out_vertex_stride(struct radv_shader_context *ctx)
 {
-       uint32_t num_tcs_outputs = util_last_bit64(ctx->shader_info->tcs.outputs_written);
+       uint32_t num_tcs_outputs = util_last_bit64(ctx->args->shader_info->tcs.outputs_written);
        uint32_t output_vertex_size = num_tcs_outputs * 16;
        output_vertex_size /= 4;
        return LLVMConstInt(ctx->ac.i32, output_vertex_size, false);
@@ -285,7 +255,7 @@ get_tcs_out_patch0_offset(struct radv_shader_context *ctx)
 {
        assert (ctx->stage == MESA_SHADER_TESS_CTRL);
        uint32_t input_vertex_size = ctx->tcs_num_inputs * 16;
-       uint32_t input_patch_size = ctx->options->key.tcs.input_vertices * input_vertex_size;
+       uint32_t input_patch_size = ctx->args->options->key.tcs.input_vertices * input_vertex_size;
        uint32_t output_patch0_offset = input_patch_size;
        unsigned num_patches = ctx->tcs_num_patches;
 
@@ -299,10 +269,10 @@ get_tcs_out_patch0_patch_data_offset(struct radv_shader_context *ctx)
 {
        assert (ctx->stage == MESA_SHADER_TESS_CTRL);
        uint32_t input_vertex_size = ctx->tcs_num_inputs * 16;
-       uint32_t input_patch_size = ctx->options->key.tcs.input_vertices * input_vertex_size;
+       uint32_t input_patch_size = ctx->args->options->key.tcs.input_vertices * input_vertex_size;
        uint32_t output_patch0_offset = input_patch_size;
 
-       uint32_t num_tcs_outputs = util_last_bit64(ctx->shader_info->tcs.outputs_written);
+       uint32_t num_tcs_outputs = util_last_bit64(ctx->args->shader_info->tcs.outputs_written);
        uint32_t output_vertex_size = num_tcs_outputs * 16;
        uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
        unsigned num_patches = ctx->tcs_num_patches;
@@ -345,87 +315,16 @@ get_tcs_out_current_patch_data_offset(struct radv_shader_context *ctx)
                             patch0_patch_data_offset);
 }
 
-#define MAX_ARGS 64
-struct arg_info {
-       LLVMTypeRef types[MAX_ARGS];
-       LLVMValueRef *assign[MAX_ARGS];
-       uint8_t count;
-       uint8_t sgpr_count;
-       uint8_t num_sgprs_used;
-       uint8_t num_vgprs_used;
-};
-
-enum radv_arg_regfile {
-       ARG_SGPR,
-       ARG_VGPR,
-};
-
-static void
-add_arg(struct arg_info *info, enum radv_arg_regfile regfile, LLVMTypeRef type,
-       LLVMValueRef *param_ptr)
-{
-       assert(info->count < MAX_ARGS);
-
-       info->assign[info->count] = param_ptr;
-       info->types[info->count] = type;
-       info->count++;
-
-       if (regfile == ARG_SGPR) {
-               info->num_sgprs_used += ac_get_type_size(type) / 4;
-               info->sgpr_count++;
-       } else {
-               assert(regfile == ARG_VGPR);
-               info->num_vgprs_used += ac_get_type_size(type) / 4;
-       }
-}
-
-static void assign_arguments(LLVMValueRef main_function,
-                            struct arg_info *info)
-{
-       unsigned i;
-       for (i = 0; i < info->count; i++) {
-               if (info->assign[i])
-                       *info->assign[i] = LLVMGetParam(main_function, i);
-       }
-}
-
 static LLVMValueRef
-create_llvm_function(LLVMContextRef ctx, LLVMModuleRef module,
-                     LLVMBuilderRef builder, LLVMTypeRef *return_types,
-                     unsigned num_return_elems,
-                    struct arg_info *args,
+create_llvm_function(struct ac_llvm_context *ctx, LLVMModuleRef module,
+                     LLVMBuilderRef builder,
+                    struct ac_shader_args *args,
+                    enum ac_llvm_calling_convention convention,
                     unsigned max_workgroup_size,
                     const struct radv_nir_compiler_options *options)
 {
-       LLVMTypeRef main_function_type, ret_type;
-       LLVMBasicBlockRef main_function_body;
-
-       if (num_return_elems)
-               ret_type = LLVMStructTypeInContext(ctx, return_types,
-                                                  num_return_elems, true);
-       else
-               ret_type = LLVMVoidTypeInContext(ctx);
-
-       /* Setup the function */
-       main_function_type =
-           LLVMFunctionType(ret_type, args->types, args->count, 0);
        LLVMValueRef main_function =
-           LLVMAddFunction(module, "main", main_function_type);
-       main_function_body =
-           LLVMAppendBasicBlockInContext(ctx, main_function, "main_body");
-       LLVMPositionBuilderAtEnd(builder, main_function_body);
-
-       LLVMSetFunctionCallConv(main_function, RADEON_LLVM_AMDGPU_CS);
-       for (unsigned i = 0; i < args->sgpr_count; ++i) {
-               LLVMValueRef P = LLVMGetParam(main_function, i);
-
-               ac_add_function_attr(ctx, main_function, i + 1, AC_FUNC_ATTR_INREG);
-
-               if (LLVMGetTypeKind(LLVMTypeOf(P)) == LLVMPointerTypeKind) {
-                       ac_add_function_attr(ctx, main_function, i + 1, AC_FUNC_ATTR_NOALIAS);
-                       ac_add_attr_dereferenceable(P, UINT64_MAX);
-               }
-       }
+               ac_build_main(args, ctx, convention, "main", ctx->voidt, module);
 
        if (options->address32_hi) {
                ac_llvm_add_target_dep_function_attr(main_function,
@@ -449,29 +348,29 @@ set_loc(struct radv_userdata_info *ud_info, uint8_t *sgpr_idx,
 }
 
 static void
-set_loc_shader(struct radv_shader_context *ctx, int idx, uint8_t *sgpr_idx,
+set_loc_shader(struct radv_shader_args *args, int idx, uint8_t *sgpr_idx,
               uint8_t num_sgprs)
 {
        struct radv_userdata_info *ud_info =
-               &ctx->shader_info->user_sgprs_locs.shader_data[idx];
+               &args->shader_info->user_sgprs_locs.shader_data[idx];
        assert(ud_info);
 
        set_loc(ud_info, sgpr_idx, num_sgprs);
 }
 
 static void
-set_loc_shader_ptr(struct radv_shader_context *ctx, int idx, uint8_t *sgpr_idx)
+set_loc_shader_ptr(struct radv_shader_args *args, int idx, uint8_t *sgpr_idx)
 {
        bool use_32bit_pointers = idx != AC_UD_SCRATCH_RING_OFFSETS;
 
-       set_loc_shader(ctx, idx, sgpr_idx, use_32bit_pointers ? 1 : 2);
+       set_loc_shader(args, idx, sgpr_idx, use_32bit_pointers ? 1 : 2);
 }
 
 static void
-set_loc_desc(struct radv_shader_context *ctx, int idx, uint8_t *sgpr_idx)
+set_loc_desc(struct radv_shader_args *args, int idx, uint8_t *sgpr_idx)
 {
        struct radv_userdata_locations *locs =
-               &ctx->shader_info->user_sgprs_locs;
+               &args->shader_info->user_sgprs_locs;
        struct radv_userdata_info *ud_info = &locs->descriptor_sets[idx];
        assert(ud_info);
 
@@ -486,22 +385,22 @@ struct user_sgpr_info {
        uint8_t remaining_sgprs;
 };
 
-static bool needs_view_index_sgpr(struct radv_shader_context *ctx,
+static bool needs_view_index_sgpr(struct radv_shader_args *args,
                                  gl_shader_stage stage)
 {
        switch (stage) {
        case MESA_SHADER_VERTEX:
-               if (ctx->shader_info->needs_multiview_view_index ||
-                   (!ctx->options->key.vs_common_out.as_es && !ctx->options->key.vs_common_out.as_ls && ctx->options->key.has_multiview_view_index))
+               if (args->shader_info->needs_multiview_view_index ||
+                   (!args->options->key.vs_common_out.as_es && !args->options->key.vs_common_out.as_ls && args->options->key.has_multiview_view_index))
                        return true;
                break;
        case MESA_SHADER_TESS_EVAL:
-               if (ctx->shader_info->needs_multiview_view_index || (!ctx->options->key.vs_common_out.as_es && ctx->options->key.has_multiview_view_index))
+               if (args->shader_info->needs_multiview_view_index || (!args->options->key.vs_common_out.as_es && args->options->key.has_multiview_view_index))
                        return true;
                break;
        case MESA_SHADER_GEOMETRY:
        case MESA_SHADER_TESS_CTRL:
-               if (ctx->shader_info->needs_multiview_view_index)
+               if (args->shader_info->needs_multiview_view_index)
                        return true;
                break;
        default:
@@ -511,62 +410,62 @@ static bool needs_view_index_sgpr(struct radv_shader_context *ctx,
 }
 
 static uint8_t
-count_vs_user_sgprs(struct radv_shader_context *ctx)
+count_vs_user_sgprs(struct radv_shader_args *args)
 {
        uint8_t count = 0;
 
-       if (ctx->shader_info->vs.has_vertex_buffers)
+       if (args->shader_info->vs.has_vertex_buffers)
                count++;
-       count += ctx->shader_info->vs.needs_draw_id ? 3 : 2;
+       count += args->shader_info->vs.needs_draw_id ? 3 : 2;
 
        return count;
 }
 
-static void allocate_inline_push_consts(struct radv_shader_context *ctx,
+static void allocate_inline_push_consts(struct radv_shader_args *args,
                                        struct user_sgpr_info *user_sgpr_info)
 {
        uint8_t remaining_sgprs = user_sgpr_info->remaining_sgprs;
 
        /* Only supported if shaders use push constants. */
-       if (ctx->shader_info->min_push_constant_used == UINT8_MAX)
+       if (args->shader_info->min_push_constant_used == UINT8_MAX)
                return;
 
        /* Only supported if shaders don't have indirect push constants. */
-       if (ctx->shader_info->has_indirect_push_constants)
+       if (args->shader_info->has_indirect_push_constants)
                return;
 
        /* Only supported for 32-bit push constants. */
-       if (!ctx->shader_info->has_only_32bit_push_constants)
+       if (!args->shader_info->has_only_32bit_push_constants)
                return;
 
        uint8_t num_push_consts =
-               (ctx->shader_info->max_push_constant_used -
-                ctx->shader_info->min_push_constant_used) / 4;
+               (args->shader_info->max_push_constant_used -
+                args->shader_info->min_push_constant_used) / 4;
 
        /* Check if the number of user SGPRs is large enough. */
        if (num_push_consts < remaining_sgprs) {
-               ctx->shader_info->num_inline_push_consts = num_push_consts;
+               args->shader_info->num_inline_push_consts = num_push_consts;
        } else {
-               ctx->shader_info->num_inline_push_consts = remaining_sgprs;
+               args->shader_info->num_inline_push_consts = remaining_sgprs;
        }
 
        /* Clamp to the maximum number of allowed inlined push constants. */
-       if (ctx->shader_info->num_inline_push_consts > AC_MAX_INLINE_PUSH_CONSTS)
-               ctx->shader_info->num_inline_push_consts = AC_MAX_INLINE_PUSH_CONSTS;
+       if (args->shader_info->num_inline_push_consts > AC_MAX_INLINE_PUSH_CONSTS)
+               args->shader_info->num_inline_push_consts = AC_MAX_INLINE_PUSH_CONSTS;
 
-       if (ctx->shader_info->num_inline_push_consts == num_push_consts &&
-           !ctx->shader_info->loads_dynamic_offsets) {
+       if (args->shader_info->num_inline_push_consts == num_push_consts &&
+           !args->shader_info->loads_dynamic_offsets) {
                /* Disable the default push constants path if all constants are
                 * inlined and if shaders don't use dynamic descriptors.
                 */
-               ctx->shader_info->loads_push_constants = false;
+               args->shader_info->loads_push_constants = false;
        }
 
-       ctx->shader_info->base_inline_push_consts =
-               ctx->shader_info->min_push_constant_used / 4;
+       args->shader_info->base_inline_push_consts =
+               args->shader_info->min_push_constant_used / 4;
 }
 
-static void allocate_user_sgprs(struct radv_shader_context *ctx,
+static void allocate_user_sgprs(struct radv_shader_args *args,
                                gl_shader_stage stage,
                                bool has_previous_stage,
                                gl_shader_stage previous_stage,
@@ -582,34 +481,34 @@ static void allocate_user_sgprs(struct radv_shader_context *ctx,
            stage == MESA_SHADER_VERTEX ||
            stage == MESA_SHADER_TESS_CTRL ||
            stage == MESA_SHADER_TESS_EVAL ||
-           ctx->is_gs_copy_shader)
+           args->is_gs_copy_shader)
                user_sgpr_info->need_ring_offsets = true;
 
        if (stage == MESA_SHADER_FRAGMENT &&
-           ctx->shader_info->ps.needs_sample_positions)
+           args->shader_info->ps.needs_sample_positions)
                user_sgpr_info->need_ring_offsets = true;
 
        /* 2 user sgprs will nearly always be allocated for scratch/rings */
-       if (ctx->options->supports_spill || user_sgpr_info->need_ring_offsets) {
+       if (args->options->supports_spill || user_sgpr_info->need_ring_offsets) {
                user_sgpr_count += 2;
        }
 
        switch (stage) {
        case MESA_SHADER_COMPUTE:
-               if (ctx->shader_info->cs.uses_grid_size)
+               if (args->shader_info->cs.uses_grid_size)
                        user_sgpr_count += 3;
                break;
        case MESA_SHADER_FRAGMENT:
-               user_sgpr_count += ctx->shader_info->ps.needs_sample_positions;
+               user_sgpr_count += args->shader_info->ps.needs_sample_positions;
                break;
        case MESA_SHADER_VERTEX:
-               if (!ctx->is_gs_copy_shader)
-                       user_sgpr_count += count_vs_user_sgprs(ctx);
+               if (!args->is_gs_copy_shader)
+                       user_sgpr_count += count_vs_user_sgprs(args);
                break;
        case MESA_SHADER_TESS_CTRL:
                if (has_previous_stage) {
                        if (previous_stage == MESA_SHADER_VERTEX)
-                               user_sgpr_count += count_vs_user_sgprs(ctx);
+                               user_sgpr_count += count_vs_user_sgprs(args);
                }
                break;
        case MESA_SHADER_TESS_EVAL:
@@ -617,7 +516,7 @@ static void allocate_user_sgprs(struct radv_shader_context *ctx,
        case MESA_SHADER_GEOMETRY:
                if (has_previous_stage) {
                        if (previous_stage == MESA_SHADER_VERTEX) {
-                               user_sgpr_count += count_vs_user_sgprs(ctx);
+                               user_sgpr_count += count_vs_user_sgprs(args);
                        }
                }
                break;
@@ -628,16 +527,16 @@ static void allocate_user_sgprs(struct radv_shader_context *ctx,
        if (needs_view_index)
                user_sgpr_count++;
 
-       if (ctx->shader_info->loads_push_constants)
+       if (args->shader_info->loads_push_constants)
                user_sgpr_count++;
 
-       if (ctx->shader_info->so.num_outputs)
+       if (args->shader_info->so.num_outputs)
                user_sgpr_count++;
 
-       uint32_t available_sgprs = ctx->options->chip_class >= GFX9 && stage != MESA_SHADER_COMPUTE ? 32 : 16;
+       uint32_t available_sgprs = args->options->chip_class >= GFX9 && stage != MESA_SHADER_COMPUTE ? 32 : 16;
        uint32_t remaining_sgprs = available_sgprs - user_sgpr_count;
        uint32_t num_desc_set =
-               util_bitcount(ctx->shader_info->desc_set_used_mask);
+               util_bitcount(args->shader_info->desc_set_used_mask);
 
        if (remaining_sgprs < num_desc_set) {
                user_sgpr_info->indirect_all_descriptor_sets = true;
@@ -646,166 +545,184 @@ static void allocate_user_sgprs(struct radv_shader_context *ctx,
                user_sgpr_info->remaining_sgprs = remaining_sgprs - num_desc_set;
        }
 
-       allocate_inline_push_consts(ctx, user_sgpr_info);
+       allocate_inline_push_consts(args, user_sgpr_info);
 }
 
 static void
-declare_global_input_sgprs(struct radv_shader_context *ctx,
-                          const struct user_sgpr_info *user_sgpr_info,
-                          struct arg_info *args,
-                          LLVMValueRef *desc_sets)
+declare_global_input_sgprs(struct radv_shader_args *args,
+                          const struct user_sgpr_info *user_sgpr_info)
 {
-       LLVMTypeRef type = ac_array_in_const32_addr_space(ctx->ac.i8);
-
        /* 1 for each descriptor set */
        if (!user_sgpr_info->indirect_all_descriptor_sets) {
-               uint32_t mask = ctx->shader_info->desc_set_used_mask;
+               uint32_t mask = args->shader_info->desc_set_used_mask;
 
                while (mask) {
                        int i = u_bit_scan(&mask);
 
-                       add_arg(args, ARG_SGPR, type, &ctx->descriptor_sets[i]);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_PTR,
+                                  &args->descriptor_sets[i]);
                }
        } else {
-               add_arg(args, ARG_SGPR, ac_array_in_const32_addr_space(type),
-                       desc_sets);
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_PTR_PTR,
+                          &args->descriptor_sets[0]);
        }
 
-       if (ctx->shader_info->loads_push_constants) {
+       if (args->shader_info->loads_push_constants) {
                /* 1 for push constants and dynamic descriptors */
-               add_arg(args, ARG_SGPR, type, &ctx->abi.push_constants);
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_PTR,
+                          &args->ac.push_constants);
        }
 
-       for (unsigned i = 0; i < ctx->shader_info->num_inline_push_consts; i++) {
-               add_arg(args, ARG_SGPR, ctx->ac.i32,
-                       &ctx->abi.inline_push_consts[i]);
+       for (unsigned i = 0; i < args->shader_info->num_inline_push_consts; i++) {
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                          &args->ac.inline_push_consts[i]);
        }
-       ctx->abi.num_inline_push_consts = ctx->shader_info->num_inline_push_consts;
-       ctx->abi.base_inline_push_consts = ctx->shader_info->base_inline_push_consts;
+       args->ac.num_inline_push_consts = args->shader_info->num_inline_push_consts;
+       args->ac.base_inline_push_consts = args->shader_info->base_inline_push_consts;
 
-       if (ctx->shader_info->so.num_outputs) {
-               add_arg(args, ARG_SGPR,
-                       ac_array_in_const32_addr_space(ctx->ac.v4i32),
-                       &ctx->streamout_buffers);
+       if (args->shader_info->so.num_outputs) {
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                          &args->streamout_buffers);
        }
 }
 
 static void
-declare_vs_specific_input_sgprs(struct radv_shader_context *ctx,
+declare_vs_specific_input_sgprs(struct radv_shader_args *args,
                                gl_shader_stage stage,
                                bool has_previous_stage,
-                               gl_shader_stage previous_stage,
-                               struct arg_info *args)
+                               gl_shader_stage previous_stage)
 {
-       if (!ctx->is_gs_copy_shader &&
+       if (!args->is_gs_copy_shader &&
            (stage == MESA_SHADER_VERTEX ||
             (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
-               if (ctx->shader_info->vs.has_vertex_buffers) {
-                       add_arg(args, ARG_SGPR,
-                               ac_array_in_const32_addr_space(ctx->ac.v4i32),
-                               &ctx->vertex_buffers);
+               if (args->shader_info->vs.has_vertex_buffers) {
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                                  &args->vertex_buffers);
                }
-               add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->abi.base_vertex);
-               add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->abi.start_instance);
-               if (ctx->shader_info->vs.needs_draw_id) {
-                       add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->abi.draw_id);
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->ac.base_vertex);
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->ac.start_instance);
+               if (args->shader_info->vs.needs_draw_id) {
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->ac.draw_id);
                }
        }
 }
 
 static void
-declare_vs_input_vgprs(struct radv_shader_context *ctx, struct arg_info *args)
+declare_vs_input_vgprs(struct radv_shader_args *args)
 {
-       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.vertex_id);
-       if (!ctx->is_gs_copy_shader) {
-               if (ctx->options->key.vs_common_out.as_ls) {
-                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->rel_auto_id);
-                       if (ctx->ac.chip_class >= GFX10) {
-                               add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* user vgpr */
-                               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.vertex_id);
+       if (!args->is_gs_copy_shader) {
+               if (args->options->key.vs_common_out.as_ls) {
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->rel_auto_id);
+                       if (args->options->chip_class >= GFX10) {
+                               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* user vgpr */
+                               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.instance_id);
                        } else {
-                               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
-                               add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* unused */
+                               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.instance_id);
+                               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* unused */
                        }
                } else {
-                       if (ctx->ac.chip_class >= GFX10) {
-                               if (ctx->options->key.vs_common_out.as_ngg) {
-                                       add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* user vgpr */
-                                       add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* user vgpr */
-                                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+                       if (args->options->chip_class >= GFX10) {
+                               if (args->options->key.vs_common_out.as_ngg) {
+                                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* user vgpr */
+                                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* user vgpr */
+                                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.instance_id);
                                } else {
-                                       add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* unused */
-                                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->vs_prim_id);
-                                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+                                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* unused */
+                                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->vs_prim_id);
+                                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.instance_id);
                                }
                        } else {
-                               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
-                               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->vs_prim_id);
-                               add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* unused */
+                               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.instance_id);
+                               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->vs_prim_id);
+                               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* unused */
                        }
                }
        }
 }
 
 static void
-declare_streamout_sgprs(struct radv_shader_context *ctx, gl_shader_stage stage,
-                       struct arg_info *args)
+declare_streamout_sgprs(struct radv_shader_args *args, gl_shader_stage stage)
 {
        int i;
 
-       if (ctx->options->use_ngg_streamout)
+       if (args->options->use_ngg_streamout) {
+               if (stage == MESA_SHADER_TESS_EVAL)
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL);
                return;
+       }
 
        /* Streamout SGPRs. */
-       if (ctx->shader_info->so.num_outputs) {
+       if (args->shader_info->so.num_outputs) {
                assert(stage == MESA_SHADER_VERTEX ||
                       stage == MESA_SHADER_TESS_EVAL);
 
-               if (stage != MESA_SHADER_TESS_EVAL) {
-                       add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->streamout_config);
-               } else {
-                       args->assign[args->count - 1] = &ctx->streamout_config;
-                       args->types[args->count - 1] = ctx->ac.i32;
-               }
-
-               add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->streamout_write_idx);
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->streamout_config);
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->streamout_write_idx);
+       } else if (stage == MESA_SHADER_TESS_EVAL) {
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL);
        }
 
        /* A streamout buffer offset is loaded if the stride is non-zero. */
        for (i = 0; i < 4; i++) {
-               if (!ctx->shader_info->so.strides[i])
+               if (!args->shader_info->so.strides[i])
                        continue;
 
-               add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->streamout_offset[i]);
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->streamout_offset[i]);
        }
 }
 
 static void
-declare_tes_input_vgprs(struct radv_shader_context *ctx, struct arg_info *args)
+declare_tes_input_vgprs(struct radv_shader_args *args)
 {
-       add_arg(args, ARG_VGPR, ctx->ac.f32, &ctx->tes_u);
-       add_arg(args, ARG_VGPR, ctx->ac.f32, &ctx->tes_v);
-       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->tes_rel_patch_id);
-       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.tes_patch_id);
+       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &args->tes_u);
+       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &args->tes_v);
+       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->tes_rel_patch_id);
+       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.tes_patch_id);
 }
 
 static void
-set_global_input_locs(struct radv_shader_context *ctx,
+set_global_input_locs(struct radv_shader_args *args,
                      const struct user_sgpr_info *user_sgpr_info,
-                     LLVMValueRef desc_sets, uint8_t *user_sgpr_idx)
+                     uint8_t *user_sgpr_idx)
 {
-       uint32_t mask = ctx->shader_info->desc_set_used_mask;
+       uint32_t mask = args->shader_info->desc_set_used_mask;
 
        if (!user_sgpr_info->indirect_all_descriptor_sets) {
                while (mask) {
                        int i = u_bit_scan(&mask);
 
-                       set_loc_desc(ctx, i, user_sgpr_idx);
+                       set_loc_desc(args, i, user_sgpr_idx);
                }
        } else {
-               set_loc_shader_ptr(ctx, AC_UD_INDIRECT_DESCRIPTOR_SETS,
-                                  user_sgpr_idx);
+               set_loc_shader_ptr(args, AC_UD_INDIRECT_DESCRIPTOR_SETS,
+                                  user_sgpr_idx);
+
+               args->shader_info->need_indirect_descriptor_sets = true;
+       }
+
+       if (args->shader_info->loads_push_constants) {
+               set_loc_shader_ptr(args, AC_UD_PUSH_CONSTANTS, user_sgpr_idx);
+       }
 
+       if (args->shader_info->num_inline_push_consts) {
+               set_loc_shader(args, AC_UD_INLINE_PUSH_CONSTANTS, user_sgpr_idx,
+                              args->shader_info->num_inline_push_consts);
+       }
+
+       if (args->streamout_buffers.used) {
+               set_loc_shader_ptr(args, AC_UD_STREAMOUT_BUFFERS,
+                                  user_sgpr_idx);
+       }
+}
+
+static void
+load_descriptor_sets(struct radv_shader_context *ctx)
+{
+       uint32_t mask = ctx->args->shader_info->desc_set_used_mask;
+       if (ctx->args->shader_info->need_indirect_descriptor_sets) {
+               LLVMValueRef desc_sets =
+                       ac_get_arg(&ctx->ac, ctx->args->descriptor_sets[0]);
                while (mask) {
                        int i = u_bit_scan(&mask);
 
@@ -814,75 +731,63 @@ set_global_input_locs(struct radv_shader_context *ctx,
                                                      LLVMConstInt(ctx->ac.i32, i, false));
 
                }
+       } else {
+               while (mask) {
+                       int i = u_bit_scan(&mask);
 
-               ctx->shader_info->need_indirect_descriptor_sets = true;
-       }
-
-       if (ctx->shader_info->loads_push_constants) {
-               set_loc_shader_ptr(ctx, AC_UD_PUSH_CONSTANTS, user_sgpr_idx);
-       }
-
-       if (ctx->shader_info->num_inline_push_consts) {
-               set_loc_shader(ctx, AC_UD_INLINE_PUSH_CONSTANTS, user_sgpr_idx,
-                              ctx->shader_info->num_inline_push_consts);
-       }
-
-       if (ctx->streamout_buffers) {
-               set_loc_shader_ptr(ctx, AC_UD_STREAMOUT_BUFFERS,
-                              user_sgpr_idx);
+                       ctx->descriptor_sets[i] =
+                               ac_get_arg(&ctx->ac, ctx->args->descriptor_sets[i]);
+               }
        }
 }
 
+
 static void
-set_vs_specific_input_locs(struct radv_shader_context *ctx,
+set_vs_specific_input_locs(struct radv_shader_args *args,
                           gl_shader_stage stage, bool has_previous_stage,
                           gl_shader_stage previous_stage,
                           uint8_t *user_sgpr_idx)
 {
-       if (!ctx->is_gs_copy_shader &&
+       if (!args->is_gs_copy_shader &&
            (stage == MESA_SHADER_VERTEX ||
             (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
-               if (ctx->shader_info->vs.has_vertex_buffers) {
-                       set_loc_shader_ptr(ctx, AC_UD_VS_VERTEX_BUFFERS,
+               if (args->shader_info->vs.has_vertex_buffers) {
+                       set_loc_shader_ptr(args, AC_UD_VS_VERTEX_BUFFERS,
                                           user_sgpr_idx);
                }
 
                unsigned vs_num = 2;
-               if (ctx->shader_info->vs.needs_draw_id)
+               if (args->shader_info->vs.needs_draw_id)
                        vs_num++;
 
-               set_loc_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE,
+               set_loc_shader(args, AC_UD_VS_BASE_VERTEX_START_INSTANCE,
                               user_sgpr_idx, vs_num);
        }
 }
 
-static void set_llvm_calling_convention(LLVMValueRef func,
-                                        gl_shader_stage stage)
+static enum ac_llvm_calling_convention
+get_llvm_calling_convention(LLVMValueRef func, gl_shader_stage stage)
 {
-       enum radeon_llvm_calling_convention calling_conv;
-
        switch (stage) {
        case MESA_SHADER_VERTEX:
        case MESA_SHADER_TESS_EVAL:
-               calling_conv = RADEON_LLVM_AMDGPU_VS;
+               return AC_LLVM_AMDGPU_VS;
                break;
        case MESA_SHADER_GEOMETRY:
-               calling_conv = RADEON_LLVM_AMDGPU_GS;
+               return AC_LLVM_AMDGPU_GS;
                break;
        case MESA_SHADER_TESS_CTRL:
-               calling_conv = RADEON_LLVM_AMDGPU_HS;
+               return AC_LLVM_AMDGPU_HS;
                break;
        case MESA_SHADER_FRAGMENT:
-               calling_conv = RADEON_LLVM_AMDGPU_PS;
+               return AC_LLVM_AMDGPU_PS;
                break;
        case MESA_SHADER_COMPUTE:
-               calling_conv = RADEON_LLVM_AMDGPU_CS;
+               return AC_LLVM_AMDGPU_CS;
                break;
        default:
                unreachable("Unhandle shader type");
        }
-
-       LLVMSetFunctionCallConv(func, calling_conv);
 }
 
 /* Returns whether the stage is a stage that can be directly before the GS */
@@ -891,19 +796,16 @@ static bool is_pre_gs_stage(gl_shader_stage stage)
        return stage == MESA_SHADER_VERTEX || stage == MESA_SHADER_TESS_EVAL;
 }
 
-static void create_function(struct radv_shader_context *ctx,
-                            gl_shader_stage stage,
-                            bool has_previous_stage,
-                            gl_shader_stage previous_stage)
+static void declare_inputs(struct radv_shader_args *args,
+                          gl_shader_stage stage,
+                          bool has_previous_stage,
+                          gl_shader_stage previous_stage)
 {
-       uint8_t user_sgpr_idx;
        struct user_sgpr_info user_sgpr_info;
-       struct arg_info args = {};
-       LLVMValueRef desc_sets;
-       bool needs_view_index = needs_view_index_sgpr(ctx, stage);
+       bool needs_view_index = needs_view_index_sgpr(args, stage);
 
-       if (ctx->ac.chip_class >= GFX10) {
-               if (is_pre_gs_stage(stage) && ctx->options->key.vs_common_out.as_ngg) {
+       if (args->options->chip_class >= GFX10) {
+               if (is_pre_gs_stage(stage) && args->options->key.vs_common_out.as_ngg) {
                        /* On GFX10, VS is merged into GS for NGG. */
                        previous_stage = stage;
                        stage = MESA_SHADER_GEOMETRY;
@@ -911,256 +813,244 @@ static void create_function(struct radv_shader_context *ctx,
                }
        }
 
-       allocate_user_sgprs(ctx, stage, has_previous_stage,
+       for (int i = 0; i < MAX_SETS; i++)
+               args->shader_info->user_sgprs_locs.descriptor_sets[i].sgpr_idx = -1;
+       for (int i = 0; i < AC_UD_MAX_UD; i++)
+               args->shader_info->user_sgprs_locs.shader_data[i].sgpr_idx = -1;
+
+
+       allocate_user_sgprs(args, stage, has_previous_stage,
                            previous_stage, needs_view_index, &user_sgpr_info);
 
-       if (user_sgpr_info.need_ring_offsets && !ctx->options->supports_spill) {
-               add_arg(&args, ARG_SGPR, ac_array_in_const_addr_space(ctx->ac.v4i32),
-                       &ctx->ring_offsets);
+       if (user_sgpr_info.need_ring_offsets && !args->options->supports_spill) {
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_DESC_PTR,
+                          &args->ring_offsets);
        }
 
        switch (stage) {
        case MESA_SHADER_COMPUTE:
-               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                          &desc_sets);
+               declare_global_input_sgprs(args, &user_sgpr_info);
 
-               if (ctx->shader_info->cs.uses_grid_size) {
-                       add_arg(&args, ARG_SGPR, ctx->ac.v3i32,
-                               &ctx->abi.num_work_groups);
+               if (args->shader_info->cs.uses_grid_size) {
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 3, AC_ARG_INT,
+                                  &args->ac.num_work_groups);
                }
 
                for (int i = 0; i < 3; i++) {
-                       ctx->abi.workgroup_ids[i] = NULL;
-                       if (ctx->shader_info->cs.uses_block_id[i]) {
-                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                                       &ctx->abi.workgroup_ids[i]);
+                       if (args->shader_info->cs.uses_block_id[i]) {
+                               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                          &args->ac.workgroup_ids[i]);
                        }
                }
 
-               if (ctx->shader_info->cs.uses_local_invocation_idx)
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->abi.tg_size);
-               add_arg(&args, ARG_VGPR, ctx->ac.v3i32,
-                       &ctx->abi.local_invocation_ids);
+               if (args->shader_info->cs.uses_local_invocation_idx) {
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                  &args->ac.tg_size);
+               }
+
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 3, AC_ARG_INT,
+                          &args->ac.local_invocation_ids);
                break;
        case MESA_SHADER_VERTEX:
-               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                          &desc_sets);
+               declare_global_input_sgprs(args, &user_sgpr_info);
 
-               declare_vs_specific_input_sgprs(ctx, stage, has_previous_stage,
-                                               previous_stage, &args);
+               declare_vs_specific_input_sgprs(args, stage, has_previous_stage,
+                                               previous_stage);
 
-               if (needs_view_index)
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->abi.view_index);
-               if (ctx->options->key.vs_common_out.as_es) {
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->es2gs_offset);
-               } else if (ctx->options->key.vs_common_out.as_ls) {
+               if (needs_view_index) {
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                  &args->ac.view_index);
+               }
+
+               if (args->options->key.vs_common_out.as_es) {
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                               &args->es2gs_offset);
+               } else if (args->options->key.vs_common_out.as_ls) {
                        /* no extra parameters */
                } else {
-                       declare_streamout_sgprs(ctx, stage, &args);
+                       declare_streamout_sgprs(args, stage);
                }
 
-               declare_vs_input_vgprs(ctx, &args);
+               declare_vs_input_vgprs(args);
                break;
        case MESA_SHADER_TESS_CTRL:
                if (has_previous_stage) {
                        // First 6 system regs
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->merged_wave_info);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->tess_factor_offset);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->oc_lds);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                  &args->merged_wave_info);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                  &args->tess_factor_offset);
 
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // scratch offset
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); // scratch offset
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); // unknown
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); // unknown
 
-                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                                  &desc_sets);
+                       declare_global_input_sgprs(args, &user_sgpr_info);
 
-                       declare_vs_specific_input_sgprs(ctx, stage,
+                       declare_vs_specific_input_sgprs(args, stage,
                                                        has_previous_stage,
-                                                       previous_stage, &args);
+                                                       previous_stage);
 
-                       if (needs_view_index)
-                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                                       &ctx->abi.view_index);
+                       if (needs_view_index) {
+                               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                          &args->ac.view_index);
+                       }
 
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.tcs_patch_id);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.tcs_rel_ids);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                 &args->ac.tcs_patch_id);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->ac.tcs_rel_ids);
 
-                       declare_vs_input_vgprs(ctx, &args);
+                       declare_vs_input_vgprs(args);
                } else {
-                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                                  &desc_sets);
-
-                       if (needs_view_index)
-                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                                       &ctx->abi.view_index);
-
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->tess_factor_offset);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.tcs_patch_id);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.tcs_rel_ids);
+                       declare_global_input_sgprs(args, &user_sgpr_info);
+
+                       if (needs_view_index) {
+                               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                          &args->ac.view_index);
+                       }
+
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->oc_lds);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                  &args->tess_factor_offset);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->ac.tcs_patch_id);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->ac.tcs_rel_ids);
                }
                break;
        case MESA_SHADER_TESS_EVAL:
-               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                          &desc_sets);
+               declare_global_input_sgprs(args, &user_sgpr_info);
 
                if (needs_view_index)
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->abi.view_index);
-
-               if (ctx->options->key.vs_common_out.as_es) {
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->es2gs_offset);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                               &args->ac.view_index);
+
+               if (args->options->key.vs_common_out.as_es) {
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->oc_lds);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                               &args->es2gs_offset);
                } else {
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL);
-                       declare_streamout_sgprs(ctx, stage, &args);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
+                       declare_streamout_sgprs(args, stage);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->oc_lds);
                }
-               declare_tes_input_vgprs(ctx, &args);
+               declare_tes_input_vgprs(args);
                break;
        case MESA_SHADER_GEOMETRY:
                if (has_previous_stage) {
                        // First 6 system regs
-                       if (ctx->options->key.vs_common_out.as_ngg) {
-                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                                       &ctx->gs_tg_info);
+                       if (args->options->key.vs_common_out.as_ngg) {
+                               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                       &args->gs_tg_info);
                        } else {
-                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                                       &ctx->gs2vs_offset);
+                               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                       &args->gs2vs_offset);
                        }
 
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                               &ctx->merged_wave_info);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                  &args->merged_wave_info);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->oc_lds);
 
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // scratch offset
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); // scratch offset
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); // unknown
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); // unknown
 
-                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                                  &desc_sets);
+                       declare_global_input_sgprs(args, &user_sgpr_info);
 
                        if (previous_stage != MESA_SHADER_TESS_EVAL) {
-                               declare_vs_specific_input_sgprs(ctx, stage,
+                               declare_vs_specific_input_sgprs(args, stage,
                                                                has_previous_stage,
-                                                               previous_stage,
-                                                               &args);
+                                                               previous_stage);
+                       }
+
+                       if (needs_view_index) {
+                               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                          &args->ac.view_index);
                        }
 
-                       if (needs_view_index)
-                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                                       &ctx->abi.view_index);
-
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[0]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[2]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.gs_prim_id);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.gs_invocation_id);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[4]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[0]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[2]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->ac.gs_prim_id);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->ac.gs_invocation_id);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[4]);
 
                        if (previous_stage == MESA_SHADER_VERTEX) {
-                               declare_vs_input_vgprs(ctx, &args);
+                               declare_vs_input_vgprs(args);
                        } else {
-                               declare_tes_input_vgprs(ctx, &args);
+                               declare_tes_input_vgprs(args);
                        }
                } else {
-                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                                  &desc_sets);
-
-                       if (needs_view_index)
-                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
-                                       &ctx->abi.view_index);
-
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->gs2vs_offset);
-                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->gs_wave_id);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[0]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[1]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.gs_prim_id);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[2]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[3]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[4]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->gs_vtx_offset[5]);
-                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
-                               &ctx->abi.gs_invocation_id);
+                       declare_global_input_sgprs(args, &user_sgpr_info);
+
+                       if (needs_view_index) {
+                               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                          &args->ac.view_index);
+                       }
+
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->gs2vs_offset);
+                       ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->gs_wave_id);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[0]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[1]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->ac.gs_prim_id);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[2]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[3]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[4]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->gs_vtx_offset[5]);
+                       ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT,
+                                  &args->ac.gs_invocation_id);
                }
                break;
        case MESA_SHADER_FRAGMENT:
-               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
-                                          &desc_sets);
-
-               add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->abi.prim_mask);
-               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->abi.persp_sample);
-               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->abi.persp_center);
-               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->abi.persp_centroid);
-               add_arg(&args, ARG_VGPR, ctx->ac.v3i32, NULL); /* persp pull model */
-               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->abi.linear_sample);
-               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->abi.linear_center);
-               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->abi.linear_centroid);
-               add_arg(&args, ARG_VGPR, ctx->ac.f32, NULL);  /* line stipple tex */
-               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[0]);
-               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[1]);
-               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[2]);
-               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[3]);
-               add_arg(&args, ARG_VGPR, ctx->ac.i32, &ctx->abi.front_face);
-               add_arg(&args, ARG_VGPR, ctx->ac.i32, &ctx->abi.ancillary);
-               add_arg(&args, ARG_VGPR, ctx->ac.i32, &ctx->abi.sample_coverage);
-               add_arg(&args, ARG_VGPR, ctx->ac.i32, NULL);  /* fixed pt */
+               declare_global_input_sgprs(args, &user_sgpr_info);
+
+               ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, &args->ac.prim_mask);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_INT, &args->ac.persp_sample);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_INT, &args->ac.persp_center);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_INT, &args->ac.persp_centroid);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 3, AC_ARG_INT, NULL); /* persp pull model */
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_INT, &args->ac.linear_sample);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_INT, &args->ac.linear_center);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_INT, &args->ac.linear_centroid);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_FLOAT, NULL);  /* line stipple tex */
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &args->ac.frag_pos[0]);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &args->ac.frag_pos[1]);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &args->ac.frag_pos[2]);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &args->ac.frag_pos[3]);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.front_face);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.ancillary);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.sample_coverage);
+               ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, NULL);  /* fixed pt */
                break;
        default:
                unreachable("Shader stage not implemented");
        }
 
-       ctx->main_function = create_llvm_function(
-           ctx->context, ctx->ac.module, ctx->ac.builder, NULL, 0, &args,
-           ctx->max_workgroup_size, ctx->options);
-       set_llvm_calling_convention(ctx->main_function, stage);
-
-
-       ctx->shader_info->num_input_vgprs = 0;
-       ctx->shader_info->num_input_sgprs = ctx->options->supports_spill ? 2 : 0;
-
-       ctx->shader_info->num_input_sgprs += args.num_sgprs_used;
+       args->shader_info->num_input_vgprs = 0;
+       args->shader_info->num_input_sgprs = args->options->supports_spill ? 2 : 0;
+       args->shader_info->num_input_sgprs += args->ac.num_sgprs_used;
 
-       if (ctx->stage != MESA_SHADER_FRAGMENT)
-               ctx->shader_info->num_input_vgprs = args.num_vgprs_used;
+       if (stage != MESA_SHADER_FRAGMENT)
+               args->shader_info->num_input_vgprs = args->ac.num_vgprs_used;
 
-       assign_arguments(ctx->main_function, &args);
+       uint8_t user_sgpr_idx = 0;
 
-       user_sgpr_idx = 0;
-
-       if (ctx->options->supports_spill || user_sgpr_info.need_ring_offsets) {
-               set_loc_shader_ptr(ctx, AC_UD_SCRATCH_RING_OFFSETS,
+       if (args->options->supports_spill || user_sgpr_info.need_ring_offsets) {
+               set_loc_shader_ptr(args, AC_UD_SCRATCH_RING_OFFSETS,
                                   &user_sgpr_idx);
-               if (ctx->options->supports_spill) {
-                       ctx->ring_offsets = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.implicit.buffer.ptr",
-                                                              LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_CONST),
-                                                              NULL, 0, AC_FUNC_ATTR_READNONE);
-                       ctx->ring_offsets = LLVMBuildBitCast(ctx->ac.builder, ctx->ring_offsets,
-                                                            ac_array_in_const_addr_space(ctx->ac.v4i32), "");
-               }
        }
 
        /* For merged shaders the user SGPRs start at 8, with 8 system SGPRs in front (including
@@ -1168,41 +1058,41 @@ static void create_function(struct radv_shader_context *ctx,
        if (has_previous_stage)
                user_sgpr_idx = 0;
 
-       set_global_input_locs(ctx, &user_sgpr_info, desc_sets, &user_sgpr_idx);
+       set_global_input_locs(args, &user_sgpr_info, &user_sgpr_idx);
 
        switch (stage) {
        case MESA_SHADER_COMPUTE:
-               if (ctx->shader_info->cs.uses_grid_size) {
-                       set_loc_shader(ctx, AC_UD_CS_GRID_SIZE,
+               if (args->shader_info->cs.uses_grid_size) {
+                       set_loc_shader(args, AC_UD_CS_GRID_SIZE,
                                       &user_sgpr_idx, 3);
                }
                break;
        case MESA_SHADER_VERTEX:
-               set_vs_specific_input_locs(ctx, stage, has_previous_stage,
+               set_vs_specific_input_locs(args, stage, has_previous_stage,
                                           previous_stage, &user_sgpr_idx);
-               if (ctx->abi.view_index)
-                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+               if (args->ac.view_index.used)
+                       set_loc_shader(args, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                break;
        case MESA_SHADER_TESS_CTRL:
-               set_vs_specific_input_locs(ctx, stage, has_previous_stage,
+               set_vs_specific_input_locs(args, stage, has_previous_stage,
                                           previous_stage, &user_sgpr_idx);
-               if (ctx->abi.view_index)
-                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+               if (args->ac.view_index.used)
+                       set_loc_shader(args, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                break;
        case MESA_SHADER_TESS_EVAL:
-               if (ctx->abi.view_index)
-                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+               if (args->ac.view_index.used)
+                       set_loc_shader(args, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                break;
        case MESA_SHADER_GEOMETRY:
                if (has_previous_stage) {
                        if (previous_stage == MESA_SHADER_VERTEX)
-                               set_vs_specific_input_locs(ctx, stage,
+                               set_vs_specific_input_locs(args, stage,
                                                           has_previous_stage,
                                                           previous_stage,
                                                           &user_sgpr_idx);
                }
-               if (ctx->abi.view_index)
-                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+               if (args->ac.view_index.used)
+                       set_loc_shader(args, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                break;
        case MESA_SHADER_FRAGMENT:
                break;
@@ -1210,14 +1100,46 @@ static void create_function(struct radv_shader_context *ctx,
                unreachable("Shader stage not implemented");
        }
 
+       args->shader_info->num_user_sgprs = user_sgpr_idx;
+}
+
+static void create_function(struct radv_shader_context *ctx,
+                            gl_shader_stage stage,
+                            bool has_previous_stage)
+{
+       if (ctx->ac.chip_class >= GFX10) {
+               if (is_pre_gs_stage(stage) && ctx->args->options->key.vs_common_out.as_ngg) {
+                       /* On GFX10, VS is merged into GS for NGG. */
+                       stage = MESA_SHADER_GEOMETRY;
+                       has_previous_stage = true;
+               }
+       }
+
+       ctx->main_function = create_llvm_function(
+           &ctx->ac, ctx->ac.module, ctx->ac.builder, &ctx->args->ac,
+           get_llvm_calling_convention(ctx->main_function, stage),
+           ctx->max_workgroup_size,
+           ctx->args->options);
+
+       if (ctx->args->options->supports_spill) {
+               ctx->ring_offsets = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.implicit.buffer.ptr",
+                                                      LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_CONST),
+                                                      NULL, 0, AC_FUNC_ATTR_READNONE);
+               ctx->ring_offsets = LLVMBuildBitCast(ctx->ac.builder, ctx->ring_offsets,
+                                                    ac_array_in_const_addr_space(ctx->ac.v4i32), "");
+       } else if (ctx->args->ring_offsets.used) {
+               ctx->ring_offsets = ac_get_arg(&ctx->ac, ctx->args->ring_offsets);
+       }
+
+       load_descriptor_sets(ctx);
+
        if (stage == MESA_SHADER_TESS_CTRL ||
-           (stage == MESA_SHADER_VERTEX && ctx->options->key.vs_common_out.as_ls) ||
+           (stage == MESA_SHADER_VERTEX && ctx->args->options->key.vs_common_out.as_ls) ||
            /* GFX9 has the ESGS ring buffer in LDS. */
            (stage == MESA_SHADER_GEOMETRY && has_previous_stage)) {
                ac_declare_lds_as_pointer(&ctx->ac);
        }
 
-       ctx->shader_info->num_user_sgprs = user_sgpr_idx;
 }
 
 
@@ -1227,7 +1149,7 @@ radv_load_resource(struct ac_shader_abi *abi, LLVMValueRef index,
 {
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
        LLVMValueRef desc_ptr = ctx->descriptor_sets[desc_set];
-       struct radv_pipeline_layout *pipeline_layout = ctx->options->layout;
+       struct radv_pipeline_layout *pipeline_layout = ctx->args->options->layout;
        struct radv_descriptor_set_layout *layout = pipeline_layout->set[desc_set].layout;
        unsigned base_offset = layout->binding[binding].offset;
        LLVMValueRef offset, stride;
@@ -1236,7 +1158,7 @@ radv_load_resource(struct ac_shader_abi *abi, LLVMValueRef index,
            layout->binding[binding].type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC) {
                unsigned idx = pipeline_layout->set[desc_set].dynamic_offset_start +
                        layout->binding[binding].dynamic_offset_offset;
-               desc_ptr = ctx->abi.push_constants;
+               desc_ptr = ac_get_arg(&ctx->ac, ctx->args->ac.push_constants);
                base_offset = pipeline_layout->push_constant_size + 16 * idx;
                stride = LLVMConstInt(ctx->ac.i32, 16, false);
        } else
@@ -1269,7 +1191,7 @@ radv_load_resource(struct ac_shader_abi *abi, LLVMValueRef index,
 
                LLVMValueRef desc_components[4] = {
                        LLVMBuildPtrToInt(ctx->ac.builder, desc_ptr, ctx->ac.intptr, ""),
-                       LLVMConstInt(ctx->ac.i32, S_008F04_BASE_ADDRESS_HI(ctx->options->address32_hi), false),
+                       LLVMConstInt(ctx->ac.i32, S_008F04_BASE_ADDRESS_HI(ctx->args->options->address32_hi), false),
                        /* High limit to support variable sizes. */
                        LLVMConstInt(ctx->ac.i32, 0xffffffff, false),
                        LLVMConstInt(ctx->ac.i32, desc_type, false),
@@ -1305,9 +1227,9 @@ static LLVMValueRef get_non_vertex_index_offset(struct radv_shader_context *ctx)
        uint32_t num_patches = ctx->tcs_num_patches;
        uint32_t num_tcs_outputs;
        if (ctx->stage == MESA_SHADER_TESS_CTRL)
-               num_tcs_outputs = util_last_bit64(ctx->shader_info->tcs.outputs_written);
+               num_tcs_outputs = util_last_bit64(ctx->args->shader_info->tcs.outputs_written);
        else
-               num_tcs_outputs = ctx->options->key.tes.tcs_num_outputs;
+               num_tcs_outputs = ctx->args->options->key.tes.tcs_num_outputs;
 
        uint32_t output_vertex_size = num_tcs_outputs * 16;
        uint32_t pervertex_output_patch_size = ctx->shader->info.tess.tcs_vertices_out * output_vertex_size;
@@ -1476,6 +1398,7 @@ store_tcs_output(struct ac_shader_abi *abi,
        LLVMValueRef dw_addr;
        LLVMValueRef stride = NULL;
        LLVMValueRef buf_addr = NULL;
+       LLVMValueRef oc_lds = ac_get_arg(&ctx->ac, ctx->args->oc_lds);
        unsigned param;
        bool store_lds = true;
 
@@ -1532,13 +1455,13 @@ store_tcs_output(struct ac_shader_abi *abi,
 
                if (!is_tess_factor && writemask != 0xF)
                        ac_build_buffer_store_dword(&ctx->ac, ctx->hs_ring_tess_offchip, value, 1,
-                                                   buf_addr, ctx->oc_lds,
+                                                   buf_addr, oc_lds,
                                                    4 * (base + chan), ac_glc, false);
        }
 
        if (writemask == 0xF) {
                ac_build_buffer_store_dword(&ctx->ac, ctx->hs_ring_tess_offchip, src, 4,
-                                           buf_addr, ctx->oc_lds,
+                                           buf_addr, oc_lds,
                                            (base * 4), ac_glc, false);
        }
 }
@@ -1560,6 +1483,7 @@ load_tes_input(struct ac_shader_abi *abi,
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
        LLVMValueRef buf_addr;
        LLVMValueRef result;
+       LLVMValueRef oc_lds = ac_get_arg(&ctx->ac, ctx->args->oc_lds);
        unsigned param = shader_io_get_unique_index(location);
 
        if ((location == VARYING_SLOT_CLIP_DIST0 || location == VARYING_SLOT_CLIP_DIST1) && is_compact) {
@@ -1578,7 +1502,7 @@ load_tes_input(struct ac_shader_abi *abi,
        buf_addr = LLVMBuildAdd(ctx->ac.builder, buf_addr, comp_offset, "");
 
        result = ac_build_buffer_load(&ctx->ac, ctx->hs_ring_tess_offchip, num_components, NULL,
-                                     buf_addr, ctx->oc_lds, is_compact ? (4 * const_index) : 0, ac_glc, true, false);
+                                     buf_addr, oc_lds, is_compact ? (4 * const_index) : 0, ac_glc, true, false);
        result = ac_trim_vector(&ctx->ac, result, num_components);
        return result;
 }
@@ -1711,7 +1635,7 @@ static LLVMValueRef load_sample_position(struct ac_shader_abi *abi,
                               ac_array_in_const_addr_space(ctx->ac.v2f32), "");
 
        uint32_t sample_pos_offset =
-               radv_get_sample_pos_offset(ctx->options->key.fs.num_samples);
+               radv_get_sample_pos_offset(ctx->args->options->key.fs.num_samples);
 
        sample_id =
                LLVMBuildAdd(ctx->ac.builder, sample_id,
@@ -1727,11 +1651,11 @@ static LLVMValueRef load_sample_mask_in(struct ac_shader_abi *abi)
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
        uint8_t log2_ps_iter_samples;
 
-       if (ctx->shader_info->ps.force_persample) {
+       if (ctx->args->shader_info->ps.force_persample) {
                log2_ps_iter_samples =
-                       util_logbase2(ctx->options->key.fs.num_samples);
+                       util_logbase2(ctx->args->options->key.fs.num_samples);
        } else {
-               log2_ps_iter_samples = ctx->options->key.fs.log2_ps_iter_samples;
+               log2_ps_iter_samples = ctx->args->options->key.fs.log2_ps_iter_samples;
        }
 
        /* The bit pattern matches that used by fixed function fragment
@@ -1748,9 +1672,10 @@ static LLVMValueRef load_sample_mask_in(struct ac_shader_abi *abi)
        uint32_t ps_iter_mask = ps_iter_masks[log2_ps_iter_samples];
 
        LLVMValueRef result, sample_id;
-       sample_id = ac_unpack_param(&ctx->ac, abi->ancillary, 8, 4);
+       sample_id = ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->ac.ancillary), 8, 4);
        sample_id = LLVMBuildShl(ctx->ac.builder, LLVMConstInt(ctx->ac.i32, ps_iter_mask, false), sample_id, "");
-       result = LLVMBuildAnd(ctx->ac.builder, sample_id, abi->sample_coverage, "");
+       result = LLVMBuildAnd(ctx->ac.builder, sample_id,
+                             ac_get_arg(&ctx->ac, ctx->args->ac.sample_coverage), "");
        return result;
 }
 
@@ -1767,7 +1692,7 @@ visit_emit_vertex(struct ac_shader_abi *abi, unsigned stream, LLVMValueRef *addr
        unsigned offset = 0;
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
 
-       if (ctx->options->key.vs_common_out.as_ngg) {
+       if (ctx->args->options->key.vs_common_out.as_ngg) {
                gfx10_ngg_gs_emit_vertex(ctx, stream, addrs);
                return;
        }
@@ -1784,7 +1709,7 @@ visit_emit_vertex(struct ac_shader_abi *abi, unsigned stream, LLVMValueRef *addr
        can_emit = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, gs_next_vertex,
                                 LLVMConstInt(ctx->ac.i32, ctx->shader->info.gs.vertices_out, false), "");
 
-       bool use_kill = !ctx->shader_info->gs.writes_memory;
+       bool use_kill = !ctx->args->shader_info->gs.writes_memory;
        if (use_kill)
                ac_build_kill_if_false(&ctx->ac, can_emit);
        else
@@ -1792,9 +1717,9 @@ visit_emit_vertex(struct ac_shader_abi *abi, unsigned stream, LLVMValueRef *addr
 
        for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
                unsigned output_usage_mask =
-                       ctx->shader_info->gs.output_usage_mask[i];
+                       ctx->args->shader_info->gs.output_usage_mask[i];
                uint8_t output_stream =
-                       ctx->shader_info->gs.output_streams[i];
+                       ctx->args->shader_info->gs.output_streams[i];
                LLVMValueRef *out_ptr = &addrs[i * 4];
                int length = util_last_bit(output_usage_mask);
 
@@ -1823,8 +1748,10 @@ visit_emit_vertex(struct ac_shader_abi *abi, unsigned stream, LLVMValueRef *addr
                        ac_build_buffer_store_dword(&ctx->ac,
                                                    ctx->gsvs_ring[stream],
                                                    out_val, 1,
-                                                   voffset, ctx->gs2vs_offset, 0,
-                                                   ac_glc | ac_slc, true);
+                                                   voffset,
+                                                   ac_get_arg(&ctx->ac,
+                                                              ctx->args->gs2vs_offset),
+                                                   0, ac_glc | ac_slc, true);
                }
        }
 
@@ -1845,7 +1772,7 @@ visit_end_primitive(struct ac_shader_abi *abi, unsigned stream)
 {
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
 
-       if (ctx->options->key.vs_common_out.as_ngg) {
+       if (ctx->args->options->key.vs_common_out.as_ngg) {
                LLVMBuildStore(ctx->ac.builder, ctx->ac.i32_0, ctx->gs_curprim_verts[stream]);
                return;
        }
@@ -1859,8 +1786,8 @@ load_tess_coord(struct ac_shader_abi *abi)
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
 
        LLVMValueRef coord[4] = {
-               ctx->tes_u,
-               ctx->tes_v,
+               ac_get_arg(&ctx->ac, ctx->args->tes_u),
+               ac_get_arg(&ctx->ac, ctx->args->tes_v),
                ctx->ac.f32_0,
                ctx->ac.f32_0,
        };
@@ -1876,13 +1803,14 @@ static LLVMValueRef
 load_patch_vertices_in(struct ac_shader_abi *abi)
 {
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
-       return LLVMConstInt(ctx->ac.i32, ctx->options->key.tcs.input_vertices, false);
+       return LLVMConstInt(ctx->ac.i32, ctx->args->options->key.tcs.input_vertices, false);
 }
 
 
 static LLVMValueRef radv_load_base_vertex(struct ac_shader_abi *abi)
 {
-       return abi->base_vertex;
+       struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
+       return ac_get_arg(&ctx->ac, ctx->args->ac.base_vertex);
 }
 
 static LLVMValueRef radv_load_ssbo(struct ac_shader_abi *abi,
@@ -1928,7 +1856,7 @@ static LLVMValueRef radv_get_sampler_desc(struct ac_shader_abi *abi,
 {
        struct radv_shader_context *ctx = radv_shader_context_from_abi(abi);
        LLVMValueRef list = ctx->descriptor_sets[descriptor_set];
-       struct radv_descriptor_set_layout *layout = ctx->options->layout->set[descriptor_set].layout;
+       struct radv_descriptor_set_layout *layout = ctx->args->options->layout->set[descriptor_set].layout;
        struct radv_descriptor_set_binding_layout *binding = layout->binding + base_index;
        unsigned offset = binding->offset;
        unsigned stride = binding->size;
@@ -2134,14 +2062,14 @@ static void
 handle_vs_input_decl(struct radv_shader_context *ctx,
                     struct nir_variable *variable)
 {
-       LLVMValueRef t_list_ptr = ctx->vertex_buffers;
+       LLVMValueRef t_list_ptr = ac_get_arg(&ctx->ac, ctx->args->vertex_buffers);
        LLVMValueRef t_offset;
        LLVMValueRef t_list;
        LLVMValueRef input;
        LLVMValueRef buffer_index;
        unsigned attrib_count = glsl_count_attribute_slots(variable->type, true);
        uint8_t input_usage_mask =
-               ctx->shader_info->vs.input_usage_mask[variable->data.location];
+               ctx->args->shader_info->vs.input_usage_mask[variable->data.location];
        unsigned num_input_channels = util_last_bit(input_usage_mask);
 
        variable->data.driver_location = variable->data.location * 4;
@@ -2150,14 +2078,14 @@ handle_vs_input_decl(struct radv_shader_context *ctx,
        for (unsigned i = 0; i < attrib_count; ++i) {
                LLVMValueRef output[4];
                unsigned attrib_index = variable->data.location + i - VERT_ATTRIB_GENERIC0;
-               unsigned attrib_format = ctx->options->key.vs.vertex_attribute_formats[attrib_index];
+               unsigned attrib_format = ctx->args->options->key.vs.vertex_attribute_formats[attrib_index];
                unsigned data_format = attrib_format & 0x0f;
                unsigned num_format = (attrib_format >> 4) & 0x07;
                bool is_float = num_format != V_008F0C_BUF_NUM_FORMAT_UINT &&
                                num_format != V_008F0C_BUF_NUM_FORMAT_SINT;
 
-               if (ctx->options->key.vs.instance_rate_inputs & (1u << attrib_index)) {
-                       uint32_t divisor = ctx->options->key.vs.instance_rate_divisors[attrib_index];
+               if (ctx->args->options->key.vs.instance_rate_inputs & (1u << attrib_index)) {
+                       uint32_t divisor = ctx->args->options->key.vs.instance_rate_divisors[attrib_index];
 
                        if (divisor) {
                                buffer_index = ctx->abi.instance_id;
@@ -2170,21 +2098,27 @@ handle_vs_input_decl(struct radv_shader_context *ctx,
                                buffer_index = ctx->ac.i32_0;
                        }
 
-                       buffer_index = LLVMBuildAdd(ctx->ac.builder, ctx->abi.start_instance, buffer_index, "");
-               } else
-                       buffer_index = LLVMBuildAdd(ctx->ac.builder, ctx->abi.vertex_id,
-                                                   ctx->abi.base_vertex, "");
+                       buffer_index = LLVMBuildAdd(ctx->ac.builder,
+                                                   ac_get_arg(&ctx->ac,
+                                                              ctx->args->ac.start_instance),\
+                                                   buffer_index, "");
+               } else {
+                       buffer_index = LLVMBuildAdd(ctx->ac.builder,
+                                                   ctx->abi.vertex_id,
+                                                   ac_get_arg(&ctx->ac,
+                                                              ctx->args->ac.base_vertex), "");
+               }
 
                /* Adjust the number of channels to load based on the vertex
                 * attribute format.
                 */
                unsigned num_format_channels = get_num_channels_from_data_format(data_format);
                unsigned num_channels = MIN2(num_input_channels, num_format_channels);
-               unsigned attrib_binding = ctx->options->key.vs.vertex_attribute_bindings[attrib_index];
-               unsigned attrib_offset = ctx->options->key.vs.vertex_attribute_offsets[attrib_index];
-               unsigned attrib_stride = ctx->options->key.vs.vertex_attribute_strides[attrib_index];
+               unsigned attrib_binding = ctx->args->options->key.vs.vertex_attribute_bindings[attrib_index];
+               unsigned attrib_offset = ctx->args->options->key.vs.vertex_attribute_offsets[attrib_index];
+               unsigned attrib_stride = ctx->args->options->key.vs.vertex_attribute_strides[attrib_index];
 
-               if (ctx->options->key.vs.post_shuffle & (1 << attrib_index)) {
+               if (ctx->args->options->key.vs.post_shuffle & (1 << attrib_index)) {
                        /* Always load, at least, 3 channels for formats that
                         * need to be shuffled because X<->Z.
                         */
@@ -2213,7 +2147,7 @@ handle_vs_input_decl(struct radv_shader_context *ctx,
                                                     num_channels,
                                                     data_format, num_format, 0, true);
 
-               if (ctx->options->key.vs.post_shuffle & (1 << attrib_index)) {
+               if (ctx->args->options->key.vs.post_shuffle & (1 << attrib_index)) {
                        LLVMValueRef c[4];
                        c[0] = ac_llvm_extract_elem(&ctx->ac, input, 2);
                        c[1] = ac_llvm_extract_elem(&ctx->ac, input, 1);
@@ -2235,7 +2169,7 @@ handle_vs_input_decl(struct radv_shader_context *ctx,
                        }
                }
 
-               unsigned alpha_adjust = (ctx->options->key.vs.alpha_adjust >> (attrib_index * 2)) & 3;
+               unsigned alpha_adjust = (ctx->args->options->key.vs.alpha_adjust >> (attrib_index * 2)) & 3;
                output[3] = adjust_vertex_fetch_alpha(ctx, alpha_adjust, output[3]);
 
                for (unsigned chan = 0; chan < 4; chan++) {
@@ -2272,10 +2206,21 @@ prepare_interp_optimize(struct radv_shader_context *ctx,
                        uses_center = true;
        }
 
+       ctx->abi.persp_centroid = ac_get_arg(&ctx->ac, ctx->args->ac.persp_centroid);
+       ctx->abi.linear_centroid = ac_get_arg(&ctx->ac, ctx->args->ac.linear_centroid);
+
        if (uses_center && uses_centroid) {
-               LLVMValueRef sel = LLVMBuildICmp(ctx->ac.builder, LLVMIntSLT, ctx->abi.prim_mask, ctx->ac.i32_0, "");
-               ctx->abi.persp_centroid = LLVMBuildSelect(ctx->ac.builder, sel, ctx->abi.persp_center, ctx->abi.persp_centroid, "");
-               ctx->abi.linear_centroid = LLVMBuildSelect(ctx->ac.builder, sel, ctx->abi.linear_center, ctx->abi.linear_centroid, "");
+               LLVMValueRef sel = LLVMBuildICmp(ctx->ac.builder, LLVMIntSLT,
+                                                ac_get_arg(&ctx->ac, ctx->args->ac.prim_mask),
+                                                ctx->ac.i32_0, "");
+               ctx->abi.persp_centroid =
+                       LLVMBuildSelect(ctx->ac.builder, sel,
+                                       ac_get_arg(&ctx->ac, ctx->args->ac.persp_center),
+                                       ctx->abi.persp_centroid, "");
+               ctx->abi.linear_centroid =
+                       LLVMBuildSelect(ctx->ac.builder, sel,
+                                       ac_get_arg(&ctx->ac, ctx->args->ac.linear_center),
+                                       ctx->abi.linear_centroid, "");
        }
 }
 
@@ -2339,9 +2284,9 @@ si_llvm_init_export_args(struct radv_shader_context *ctx,
        bool is_16bit = ac_get_type_size(LLVMTypeOf(values[0])) == 2;
        if (ctx->stage == MESA_SHADER_FRAGMENT) {
                unsigned index = target - V_008DFC_SQ_EXP_MRT;
-               unsigned col_format = (ctx->options->key.fs.col_format >> (4 * index)) & 0xf;
-               bool is_int8 = (ctx->options->key.fs.is_int8 >> index) & 1;
-               bool is_int10 = (ctx->options->key.fs.is_int10 >> index) & 1;
+               unsigned col_format = (ctx->args->options->key.fs.col_format >> (4 * index)) & 0xf;
+               bool is_int8 = (ctx->args->options->key.fs.is_int8 >> index) & 1;
+               bool is_int10 = (ctx->args->options->key.fs.is_int10 >> index) & 1;
                unsigned chan;
 
                LLVMValueRef (*packf)(struct ac_llvm_context *ctx, LLVMValueRef args[2]) = NULL;
@@ -2546,9 +2491,10 @@ radv_emit_streamout(struct radv_shader_context *ctx, unsigned stream)
        int i;
 
        /* Get bits [22:16], i.e. (so_param >> 16) & 127; */
-       assert(ctx->streamout_config);
+       assert(ctx->args->streamout_config.used);
        LLVMValueRef so_vtx_count =
-               ac_build_bfe(&ctx->ac, ctx->streamout_config,
+               ac_build_bfe(&ctx->ac,
+                            ac_get_arg(&ctx->ac, ctx->args->streamout_config),
                             LLVMConstInt(ctx->ac.i32, 16, false),
                             LLVMConstInt(ctx->ac.i32, 7, false), false);
 
@@ -2569,7 +2515,8 @@ radv_emit_streamout(struct radv_shader_context *ctx, unsigned stream)
                 *                (streamout_write_index + thread_id)*stride[buffer_id] +
                 *                attrib_offset
                 */
-               LLVMValueRef so_write_index = ctx->streamout_write_idx;
+               LLVMValueRef so_write_index =
+                       ac_get_arg(&ctx->ac, ctx->args->streamout_write_idx);
 
                /* Compute (streamout_write_index + thread_id). */
                so_write_index =
@@ -2580,10 +2527,10 @@ radv_emit_streamout(struct radv_shader_context *ctx, unsigned stream)
                 */
                LLVMValueRef so_write_offset[4] = {};
                LLVMValueRef so_buffers[4] = {};
-               LLVMValueRef buf_ptr = ctx->streamout_buffers;
+               LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->args->streamout_buffers);
 
                for (i = 0; i < 4; i++) {
-                       uint16_t stride = ctx->shader_info->so.strides[i];
+                       uint16_t stride = ctx->args->shader_info->so.strides[i];
 
                        if (!stride)
                                continue;
@@ -2594,7 +2541,8 @@ radv_emit_streamout(struct radv_shader_context *ctx, unsigned stream)
                        so_buffers[i] = ac_build_load_to_sgpr(&ctx->ac,
                                                              buf_ptr, offset);
 
-                       LLVMValueRef so_offset = ctx->streamout_offset[i];
+                       LLVMValueRef so_offset =
+                               ac_get_arg(&ctx->ac, ctx->args->streamout_offset[i]);
 
                        so_offset = LLVMBuildMul(ctx->ac.builder, so_offset,
                                                 LLVMConstInt(ctx->ac.i32, 4, false), "");
@@ -2607,10 +2555,10 @@ radv_emit_streamout(struct radv_shader_context *ctx, unsigned stream)
                }
 
                /* Write streamout data. */
-               for (i = 0; i < ctx->shader_info->so.num_outputs; i++) {
+               for (i = 0; i < ctx->args->shader_info->so.num_outputs; i++) {
                        struct radv_shader_output_values shader_out = {};
                        struct radv_stream_output *output =
-                               &ctx->shader_info->so.outputs[i];
+                               &ctx->args->shader_info->so.outputs[i];
 
                        if (stream != output->stream)
                                continue;
@@ -2735,7 +2683,7 @@ radv_llvm_export_vs(struct radv_shader_context *ctx,
                if (outinfo->writes_layer == true)
                        pos_args[1].out[2] = layer_value;
                if (outinfo->writes_viewport_index == true) {
-                       if (ctx->options->chip_class >= GFX9) {
+                       if (ctx->args->options->chip_class >= GFX9) {
                                /* GFX9 has the layer in out.z[10:0] and the viewport
                                 * index in out.z[19:16].
                                 */
@@ -2797,7 +2745,7 @@ handle_vs_outputs_post(struct radv_shader_context *ctx,
        struct radv_shader_output_values *outputs;
        unsigned noutput = 0;
 
-       if (ctx->options->key.has_multiview_view_index) {
+       if (ctx->args->options->key.has_multiview_view_index) {
                LLVMValueRef* tmp_out = &ctx->abi.outputs[ac_llvm_reg_index_soa(VARYING_SLOT_LAYER, 0)];
                if(!*tmp_out) {
                        for(unsigned i = 0; i < 4; ++i)
@@ -2805,7 +2753,8 @@ handle_vs_outputs_post(struct radv_shader_context *ctx,
                                            ac_build_alloca_undef(&ctx->ac, ctx->ac.f32, "");
                }
 
-               LLVMBuildStore(ctx->ac.builder, ac_to_float(&ctx->ac, ctx->abi.view_index),  *tmp_out);
+               LLVMValueRef view_index = ac_get_arg(&ctx->ac, ctx->args->ac.view_index);
+               LLVMBuildStore(ctx->ac.builder, ac_to_float(&ctx->ac, view_index), *tmp_out);
                ctx->output_mask |= 1ull << VARYING_SLOT_LAYER;
        }
 
@@ -2813,9 +2762,9 @@ handle_vs_outputs_post(struct radv_shader_context *ctx,
               sizeof(outinfo->vs_output_param_offset));
        outinfo->pos_exports = 0;
 
-       if (!ctx->options->use_ngg_streamout &&
-           ctx->shader_info->so.num_outputs &&
-           !ctx->is_gs_copy_shader) {
+       if (!ctx->args->options->use_ngg_streamout &&
+           ctx->args->shader_info->so.num_outputs &&
+           !ctx->args->is_gs_copy_shader) {
                /* The GS copy shader emission already emits streamout. */
                radv_emit_streamout(ctx, 0);
        }
@@ -2832,16 +2781,16 @@ handle_vs_outputs_post(struct radv_shader_context *ctx,
                outputs[noutput].slot_index = i == VARYING_SLOT_CLIP_DIST1;
 
                if (ctx->stage == MESA_SHADER_VERTEX &&
-                   !ctx->is_gs_copy_shader) {
+                   !ctx->args->is_gs_copy_shader) {
                        outputs[noutput].usage_mask =
-                               ctx->shader_info->vs.output_usage_mask[i];
+                               ctx->args->shader_info->vs.output_usage_mask[i];
                } else if (ctx->stage == MESA_SHADER_TESS_EVAL) {
                        outputs[noutput].usage_mask =
-                               ctx->shader_info->tes.output_usage_mask[i];
+                               ctx->args->shader_info->tes.output_usage_mask[i];
                } else {
-                       assert(ctx->is_gs_copy_shader);
+                       assert(ctx->args->is_gs_copy_shader);
                        outputs[noutput].usage_mask =
-                               ctx->shader_info->gs.output_usage_mask[i];
+                               ctx->args->shader_info->gs.output_usage_mask[i];
                }
 
                for (unsigned j = 0; j < 4; j++) {
@@ -2857,7 +2806,8 @@ handle_vs_outputs_post(struct radv_shader_context *ctx,
                outputs[noutput].slot_name = VARYING_SLOT_PRIMITIVE_ID;
                outputs[noutput].slot_index = 0;
                outputs[noutput].usage_mask = 0x1;
-               outputs[noutput].values[0] = ctx->vs_prim_id;
+               outputs[noutput].values[0] =
+                       ac_get_arg(&ctx->ac, ctx->args->vs_prim_id);
                for (unsigned j = 1; j < 4; j++)
                        outputs[noutput].values[j] = ctx->ac.f32_0;
                noutput++;
@@ -2878,7 +2828,9 @@ handle_es_outputs_post(struct radv_shader_context *ctx,
        if (ctx->ac.chip_class  >= GFX9) {
                unsigned itemsize_dw = outinfo->esgs_itemsize / 4;
                LLVMValueRef vertex_idx = ac_get_thread_id(&ctx->ac);
-               LLVMValueRef wave_idx = ac_unpack_param(&ctx->ac, ctx->merged_wave_info, 24, 4);
+               LLVMValueRef wave_idx =
+                       ac_unpack_param(&ctx->ac,
+                                       ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 24, 4);
                vertex_idx = LLVMBuildOr(ctx->ac.builder, vertex_idx,
                                         LLVMBuildMul(ctx->ac.builder, wave_idx,
                                                      LLVMConstInt(ctx->ac.i32,
@@ -2898,11 +2850,11 @@ handle_es_outputs_post(struct radv_shader_context *ctx,
 
                if (ctx->stage == MESA_SHADER_VERTEX) {
                        output_usage_mask =
-                               ctx->shader_info->vs.output_usage_mask[i];
+                               ctx->args->shader_info->vs.output_usage_mask[i];
                } else {
                        assert(ctx->stage == MESA_SHADER_TESS_EVAL);
                        output_usage_mask =
-                               ctx->shader_info->tes.output_usage_mask[i];
+                               ctx->args->shader_info->tes.output_usage_mask[i];
                }
 
                param_index = shader_io_get_unique_index(i);
@@ -2932,7 +2884,8 @@ handle_es_outputs_post(struct radv_shader_context *ctx,
                                ac_build_buffer_store_dword(&ctx->ac,
                                                            ctx->esgs_ring,
                                                            out_val, 1,
-                                                           NULL, ctx->es2gs_offset,
+                                                           NULL,
+                                                           ac_get_arg(&ctx->ac, ctx->args->es2gs_offset),
                                                            (4 * param_index + j) * 4,
                                                            ac_glc | ac_slc, true);
                        }
@@ -2944,7 +2897,7 @@ static void
 handle_ls_outputs_post(struct radv_shader_context *ctx)
 {
        LLVMValueRef vertex_id = ctx->rel_auto_id;
-       uint32_t num_tcs_inputs = util_last_bit64(ctx->shader_info->vs.ls_outputs_written);
+       uint32_t num_tcs_inputs = util_last_bit64(ctx->args->shader_info->vs.ls_outputs_written);
        LLVMValueRef vertex_dw_stride = LLVMConstInt(ctx->ac.i32, num_tcs_inputs * 4, false);
        LLVMValueRef base_dw_addr = LLVMBuildMul(ctx->ac.builder, vertex_id,
                                                 vertex_dw_stride, "");
@@ -2971,12 +2924,13 @@ handle_ls_outputs_post(struct radv_shader_context *ctx)
 
 static LLVMValueRef get_wave_id_in_tg(struct radv_shader_context *ctx)
 {
-       return ac_unpack_param(&ctx->ac, ctx->merged_wave_info, 24, 4);
+       return ac_unpack_param(&ctx->ac,
+                              ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 24, 4);
 }
 
 static LLVMValueRef get_tgsize(struct radv_shader_context *ctx)
 {
-       return ac_unpack_param(&ctx->ac, ctx->merged_wave_info, 28, 4);
+       return ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 28, 4);
 }
 
 static LLVMValueRef get_thread_id_in_tg(struct radv_shader_context *ctx)
@@ -2990,7 +2944,7 @@ static LLVMValueRef get_thread_id_in_tg(struct radv_shader_context *ctx)
 
 static LLVMValueRef ngg_get_vtx_cnt(struct radv_shader_context *ctx)
 {
-       return ac_build_bfe(&ctx->ac, ctx->gs_tg_info,
+       return ac_build_bfe(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->gs_tg_info),
                            LLVMConstInt(ctx->ac.i32, 12, false),
                            LLVMConstInt(ctx->ac.i32, 9, false),
                            false);
@@ -2998,7 +2952,7 @@ static LLVMValueRef ngg_get_vtx_cnt(struct radv_shader_context *ctx)
 
 static LLVMValueRef ngg_get_prim_cnt(struct radv_shader_context *ctx)
 {
-       return ac_build_bfe(&ctx->ac, ctx->gs_tg_info,
+       return ac_build_bfe(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->gs_tg_info),
                            LLVMConstInt(ctx->ac.i32, 22, false),
                            LLVMConstInt(ctx->ac.i32, 9, false),
                            false);
@@ -3006,7 +2960,7 @@ static LLVMValueRef ngg_get_prim_cnt(struct radv_shader_context *ctx)
 
 static LLVMValueRef ngg_get_ordered_id(struct radv_shader_context *ctx)
 {
-       return ac_build_bfe(&ctx->ac, ctx->gs_tg_info,
+       return ac_build_bfe(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->gs_tg_info),
                            ctx->ac.i32_0,
                            LLVMConstInt(ctx->ac.i32, 11, false),
                            false);
@@ -3017,7 +2971,7 @@ ngg_gs_get_vertex_storage(struct radv_shader_context *ctx)
 {
        unsigned num_outputs = util_bitcount64(ctx->output_mask);
 
-       if (ctx->options->key.has_multiview_view_index)
+       if (ctx->args->options->key.has_multiview_view_index)
                num_outputs++;
 
        LLVMTypeRef elements[2] = {
@@ -3207,7 +3161,7 @@ static void build_streamout_vertex(struct radv_shader_context *ctx,
                                   unsigned stream, LLVMValueRef offset_vtx,
                                   LLVMValueRef vertexptr)
 {
-       struct radv_streamout_info *so = &ctx->shader_info->so;
+       struct radv_streamout_info *so = &ctx->args->shader_info->so;
        LLVMBuilderRef builder = ctx->ac.builder;
        LLVMValueRef offset[4] = {};
        LLVMValueRef tmp;
@@ -3229,9 +3183,9 @@ static void build_streamout_vertex(struct radv_shader_context *ctx,
 
                for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
                        unsigned output_usage_mask =
-                               ctx->shader_info->gs.output_usage_mask[i];
+                               ctx->args->shader_info->gs.output_usage_mask[i];
                        uint8_t output_stream =
-                               output_stream = ctx->shader_info->gs.output_streams[i];
+                               output_stream = ctx->args->shader_info->gs.output_streams[i];
 
                        if (!(ctx->output_mask & (1ull << i)) ||
                            output_stream != stream)
@@ -3277,7 +3231,7 @@ static void build_streamout_vertex(struct radv_shader_context *ctx,
        } else {
                for (unsigned i = 0; i < so->num_outputs; ++i) {
                        struct radv_stream_output *output =
-                               &ctx->shader_info->so.outputs[i];
+                               &ctx->args->shader_info->so.outputs[i];
 
                        if (stream != output->stream)
                                continue;
@@ -3321,9 +3275,9 @@ struct ngg_streamout {
 static void build_streamout(struct radv_shader_context *ctx,
                            struct ngg_streamout *nggso)
 {
-       struct radv_streamout_info *so = &ctx->shader_info->so;
+       struct radv_streamout_info *so = &ctx->args->shader_info->so;
        LLVMBuilderRef builder = ctx->ac.builder;
-       LLVMValueRef buf_ptr = ctx->streamout_buffers;
+       LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->args->streamout_buffers);
        LLVMValueRef tid = get_thread_id_in_tg(ctx);
        LLVMValueRef cond, tmp, tmp2;
        LLVMValueRef i32_2 = LLVMConstInt(ctx->ac.i32, 2, false);
@@ -3402,7 +3356,7 @@ static void build_streamout(struct radv_shader_context *ctx,
                        unsigned swizzle[4];
                        int unused_stream = -1;
                        for (unsigned stream = 0; stream < 4; ++stream) {
-                               if (!ctx->shader_info->gs.num_stream_output_components[stream]) {
+                               if (!ctx->args->shader_info->gs.num_stream_output_components[stream]) {
                                        unused_stream = stream;
                                        break;
                                }
@@ -3488,7 +3442,7 @@ static void build_streamout(struct radv_shader_context *ctx,
                LLVMValueRef emit_vgpr = ctx->ac.i32_0;
 
                for (unsigned stream = 0; stream < 4; ++stream) {
-                       if (!ctx->shader_info->gs.num_stream_output_components[stream])
+                       if (!ctx->args->shader_info->gs.num_stream_output_components[stream])
                                continue;
 
                        /* Load the number of generated primitives from GDS and
@@ -3551,7 +3505,7 @@ static void build_streamout(struct radv_shader_context *ctx,
 
        if (isgs) {
                for (unsigned stream = 0; stream < 4; ++stream) {
-                       if (!ctx->shader_info->gs.num_stream_output_components[stream])
+                       if (!ctx->args->shader_info->gs.num_stream_output_components[stream])
                                continue;
 
                        primemit_scan[stream].enable_exclusive = true;
@@ -3587,7 +3541,7 @@ static void build_streamout(struct radv_shader_context *ctx,
                }
 
                for (unsigned stream = 0; stream < 4; ++stream) {
-                       if (ctx->shader_info->gs.num_stream_output_components[stream]) {
+                       if (ctx->args->shader_info->gs.num_stream_output_components[stream]) {
                                nggso->emit[stream] = ac_build_readlane(
                                        &ctx->ac, scratch_vgpr,
                                        LLVMConstInt(ctx->ac.i32, scratch_emit_base + stream, false));
@@ -3597,7 +3551,7 @@ static void build_streamout(struct radv_shader_context *ctx,
 
        /* Write out primitive data */
        for (unsigned stream = 0; stream < 4; ++stream) {
-               if (!ctx->shader_info->gs.num_stream_output_components[stream])
+               if (!ctx->args->shader_info->gs.num_stream_output_components[stream])
                        continue;
 
                if (isgs) {
@@ -3635,8 +3589,8 @@ static unsigned ngg_nogs_vertex_size(struct radv_shader_context *ctx)
 {
        unsigned lds_vertex_size = 0;
 
-       if (ctx->shader_info->so.num_outputs)
-               lds_vertex_size = 4 * ctx->shader_info->so.num_outputs + 1;
+       if (ctx->args->shader_info->so.num_outputs)
+               lds_vertex_size = 4 * ctx->args->shader_info->so.num_outputs + 1;
 
        return lds_vertex_size;
 }
@@ -3659,22 +3613,22 @@ static LLVMValueRef ngg_nogs_vertex_ptr(struct radv_shader_context *ctx,
 static void
 handle_ngg_outputs_post_1(struct radv_shader_context *ctx)
 {
-       struct radv_streamout_info *so = &ctx->shader_info->so;
+       struct radv_streamout_info *so = &ctx->args->shader_info->so;
        LLVMBuilderRef builder = ctx->ac.builder;
        LLVMValueRef vertex_ptr = NULL;
        LLVMValueRef tmp, tmp2;
 
        assert((ctx->stage == MESA_SHADER_VERTEX ||
-               ctx->stage == MESA_SHADER_TESS_EVAL) && !ctx->is_gs_copy_shader);
+               ctx->stage == MESA_SHADER_TESS_EVAL) && !ctx->args->is_gs_copy_shader);
 
-       if (!ctx->shader_info->so.num_outputs)
+       if (!ctx->args->shader_info->so.num_outputs)
                return;
 
        vertex_ptr = ngg_nogs_vertex_ptr(ctx, get_thread_id_in_tg(ctx));
 
        for (unsigned i = 0; i < so->num_outputs; ++i) {
                struct radv_stream_output *output =
-                       &ctx->shader_info->so.outputs[i];
+                       &ctx->args->shader_info->so.outputs[i];
 
                unsigned loc = output->location;
 
@@ -3699,18 +3653,20 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
        LLVMValueRef tmp;
 
        assert((ctx->stage == MESA_SHADER_VERTEX ||
-               ctx->stage == MESA_SHADER_TESS_EVAL) && !ctx->is_gs_copy_shader);
+               ctx->stage == MESA_SHADER_TESS_EVAL) && !ctx->args->is_gs_copy_shader);
 
-       LLVMValueRef prims_in_wave = ac_unpack_param(&ctx->ac, ctx->merged_wave_info, 8, 8);
-       LLVMValueRef vtx_in_wave = ac_unpack_param(&ctx->ac, ctx->merged_wave_info, 0, 8);
+       LLVMValueRef prims_in_wave = ac_unpack_param(&ctx->ac,
+                                                    ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 8, 8);
+       LLVMValueRef vtx_in_wave = ac_unpack_param(&ctx->ac, 
+                                                  ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 0, 8);
        LLVMValueRef is_gs_thread = LLVMBuildICmp(builder, LLVMIntULT,
                                                  ac_get_thread_id(&ctx->ac), prims_in_wave, "");
        LLVMValueRef is_es_thread = LLVMBuildICmp(builder, LLVMIntULT,
                                                  ac_get_thread_id(&ctx->ac), vtx_in_wave, "");
        LLVMValueRef vtxindex[] = {
-               ac_unpack_param(&ctx->ac, ctx->gs_vtx_offset[0], 0, 16),
-               ac_unpack_param(&ctx->ac, ctx->gs_vtx_offset[0], 16, 16),
-               ac_unpack_param(&ctx->ac, ctx->gs_vtx_offset[2], 0, 16),
+               ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->gs_vtx_offset[0]), 0, 16),
+               ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->gs_vtx_offset[0]), 16, 16),
+               ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->gs_vtx_offset[2]), 0, 16),
        };
 
        /* Determine the number of vertices per primitive. */
@@ -3720,7 +3676,7 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
        if (ctx->stage == MESA_SHADER_VERTEX) {
                LLVMValueRef outprim_val =
                        LLVMConstInt(ctx->ac.i32,
-                                    ctx->options->key.vs.outprim, false);
+                                    ctx->args->options->key.vs.outprim, false);
                num_vertices_val = LLVMBuildAdd(builder, outprim_val,
                                                ctx->ac.i32_1, "");
                num_vertices = 3; /* TODO: optimize for points & lines */
@@ -3738,7 +3694,7 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
        }
 
        /* Streamout */
-       if (ctx->shader_info->so.num_outputs) {
+       if (ctx->args->shader_info->so.num_outputs) {
                struct ngg_streamout nggso = {};
 
                nggso.num_vertices = num_vertices_val;
@@ -3754,8 +3710,8 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
         * to the ES thread of the provoking vertex.
         */
        if (ctx->stage == MESA_SHADER_VERTEX &&
-           ctx->options->key.vs_common_out.export_prim_id) {
-               if (ctx->shader_info->so.num_outputs)
+           ctx->args->options->key.vs_common_out.export_prim_id) {
+               if (ctx->args->shader_info->so.num_outputs)
                        ac_build_s_barrier(&ctx->ac);
 
                ac_build_ifcc(&ctx->ac, is_gs_thread, 5400);
@@ -3768,7 +3724,7 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
                LLVMValueRef provoking_vtx_index =
                        LLVMBuildExtractElement(builder, indices, provoking_vtx_in_prim, "");
 
-               LLVMBuildStore(builder, ctx->abi.gs_prim_id,
+               LLVMBuildStore(builder, ac_get_arg(&ctx->ac, ctx->args->ac.gs_prim_id),
                               ac_build_gep0(&ctx->ac, ctx->esgs_ring, provoking_vtx_index));
                ac_build_endif(&ctx->ac, 5400);
        }
@@ -3804,7 +3760,8 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
                memcpy(prim.index, vtxindex, sizeof(vtxindex[0]) * 3);
 
                for (unsigned i = 0; i < num_vertices; ++i) {
-                       tmp = LLVMBuildLShr(builder, ctx->abi.gs_invocation_id,
+                       tmp = LLVMBuildLShr(builder,
+                                           ac_get_arg(&ctx->ac, ctx->args->ac.gs_invocation_id),
                                            LLVMConstInt(ctx->ac.i32, 8 + i, false), "");
                        prim.edgeflag[i] = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
                }
@@ -3817,15 +3774,16 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
        ac_build_ifcc(&ctx->ac, is_es_thread, 6002);
        {
                struct radv_vs_output_info *outinfo =
-                       ctx->stage == MESA_SHADER_TESS_EVAL ? &ctx->shader_info->tes.outinfo : &ctx->shader_info->vs.outinfo;
+                       ctx->stage == MESA_SHADER_TESS_EVAL ?
+                       &ctx->args->shader_info->tes.outinfo : &ctx->args->shader_info->vs.outinfo;
 
                /* Exporting the primitive ID is handled below. */
                /* TODO: use the new VS export path */
                handle_vs_outputs_post(ctx, false,
-                                      ctx->options->key.vs_common_out.export_clip_dists,
+                                      ctx->args->options->key.vs_common_out.export_clip_dists,
                                       outinfo);
 
-               if (ctx->options->key.vs_common_out.export_prim_id) {
+               if (ctx->args->options->key.vs_common_out.export_prim_id) {
                        unsigned param_count = outinfo->param_exports;
                        LLVMValueRef values[4];
 
@@ -3838,7 +3796,7 @@ handle_ngg_outputs_post_2(struct radv_shader_context *ctx)
                                values[0] = LLVMBuildLoad(builder, tmp, "");
                        } else {
                                assert(ctx->stage == MESA_SHADER_TESS_EVAL);
-                               values[0] = ctx->abi.tes_patch_id;
+                               values[0] = ac_get_arg(&ctx->ac, ctx->args->ac.tes_patch_id);
                        }
 
                        values[0] = ac_to_float(&ctx->ac, values[0]);
@@ -3899,7 +3857,7 @@ static void gfx10_ngg_gs_emit_epilogue_1(struct radv_shader_context *ctx)
                unsigned num_components;
 
                num_components =
-                       ctx->shader_info->gs.num_stream_output_components[stream];
+                       ctx->args->shader_info->gs.num_stream_output_components[stream];
                if (!num_components)
                        continue;
 
@@ -3935,7 +3893,7 @@ static void gfx10_ngg_gs_emit_epilogue_1(struct radv_shader_context *ctx)
                unsigned num_components;
 
                num_components =
-                       ctx->shader_info->gs.num_stream_output_components[stream];
+                       ctx->args->shader_info->gs.num_stream_output_components[stream];
                if (!num_components)
                        continue;
 
@@ -3967,14 +3925,14 @@ static void gfx10_ngg_gs_emit_epilogue_2(struct radv_shader_context *ctx)
        LLVMValueRef num_emit_threads = ngg_get_prim_cnt(ctx);
 
        /* Streamout */
-       if (ctx->shader_info->so.num_outputs) {
+       if (ctx->args->shader_info->so.num_outputs) {
                struct ngg_streamout nggso = {};
 
                nggso.num_vertices = LLVMConstInt(ctx->ac.i32, verts_per_prim, false);
 
                LLVMValueRef vertexptr = ngg_gs_vertex_ptr(ctx, tid);
                for (unsigned stream = 0; stream < 4; ++stream) {
-                       if (!ctx->shader_info->gs.num_stream_output_components[stream])
+                       if (!ctx->args->shader_info->gs.num_stream_output_components[stream])
                                continue;
 
                        LLVMValueRef gep_idx[3] = {
@@ -4127,8 +4085,8 @@ static void gfx10_ngg_gs_emit_epilogue_2(struct radv_shader_context *ctx)
        tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, vertlive_scan.result_reduce, "");
        ac_build_ifcc(&ctx->ac, tmp, 5145);
        {
-               struct radv_vs_output_info *outinfo = &ctx->shader_info->vs.outinfo;
-               bool export_view_index = ctx->options->key.has_multiview_view_index;
+               struct radv_vs_output_info *outinfo = &ctx->args->shader_info->vs.outinfo;
+               bool export_view_index = ctx->args->options->key.has_multiview_view_index;
                struct radv_shader_output_values *outputs;
                unsigned noutput = 0;
 
@@ -4155,7 +4113,7 @@ static void gfx10_ngg_gs_emit_epilogue_2(struct radv_shader_context *ctx)
                gep_idx[1] = ctx->ac.i32_0;
                for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
                        unsigned output_usage_mask =
-                               ctx->shader_info->gs.output_usage_mask[i];
+                               ctx->args->shader_info->gs.output_usage_mask[i];
                        int length = util_last_bit(output_usage_mask);
 
                        if (!(ctx->output_mask & (1ull << i)))
@@ -4193,14 +4151,15 @@ static void gfx10_ngg_gs_emit_epilogue_2(struct radv_shader_context *ctx)
                        outputs[noutput].slot_name = VARYING_SLOT_LAYER;
                        outputs[noutput].slot_index = 0;
                        outputs[noutput].usage_mask = 0x1;
-                       outputs[noutput].values[0] = ac_to_float(&ctx->ac, ctx->abi.view_index);
+                       outputs[noutput].values[0] =
+                               ac_to_float(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->ac.view_index));
                        for (unsigned j = 1; j < 4; j++)
                                outputs[noutput].values[j] = ctx->ac.f32_0;
                        noutput++;
                }
 
                radv_llvm_export_vs(ctx, outputs, noutput, outinfo,
-                                   ctx->options->key.vs_common_out.export_clip_dists);
+                                   ctx->args->options->key.vs_common_out.export_clip_dists);
                FREE(outputs);
        }
        ac_build_endif(&ctx->ac, 5145);
@@ -4233,9 +4192,9 @@ static void gfx10_ngg_gs_emit_vertex(struct radv_shader_context *ctx,
        unsigned out_idx = 0;
        for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
                unsigned output_usage_mask =
-                       ctx->shader_info->gs.output_usage_mask[i];
+                       ctx->args->shader_info->gs.output_usage_mask[i];
                uint8_t output_stream =
-                       ctx->shader_info->gs.output_streams[i];
+                       ctx->args->shader_info->gs.output_streams[i];
                LLVMValueRef *out_ptr = &addrs[i * 4];
                int length = util_last_bit(output_usage_mask);
 
@@ -4262,7 +4221,7 @@ static void gfx10_ngg_gs_emit_vertex(struct radv_shader_context *ctx,
                        LLVMBuildStore(builder, out_val, ptr);
                }
        }
-       assert(out_idx * 4 <= ctx->shader_info->gs.gsvs_vertex_size);
+       assert(out_idx * 4 <= ctx->args->shader_info->gs.gsvs_vertex_size);
 
        /* Determine and store whether this vertex completed a primitive. */
        const LLVMValueRef curverts = LLVMBuildLoad(builder, ctx->gs_curprim_verts[stream], "");
@@ -4296,15 +4255,16 @@ static void
 write_tess_factors(struct radv_shader_context *ctx)
 {
        unsigned stride, outer_comps, inner_comps;
-       LLVMValueRef invocation_id = ac_unpack_param(&ctx->ac, ctx->abi.tcs_rel_ids, 8, 5);
-       LLVMValueRef rel_patch_id = ac_unpack_param(&ctx->ac, ctx->abi.tcs_rel_ids, 0, 8);
+       LLVMValueRef tcs_rel_ids = ac_get_arg(&ctx->ac, ctx->args->ac.tcs_rel_ids);
+       LLVMValueRef invocation_id = ac_unpack_param(&ctx->ac, tcs_rel_ids, 8, 5);
+       LLVMValueRef rel_patch_id = ac_unpack_param(&ctx->ac, tcs_rel_ids, 0, 8);
        unsigned tess_inner_index = 0, tess_outer_index;
        LLVMValueRef lds_base, lds_inner = NULL, lds_outer, byteoffset, buffer;
        LLVMValueRef out[6], vec0, vec1, tf_base, inner[4], outer[4];
        int i;
        ac_emit_barrier(&ctx->ac, ctx->stage);
 
-       switch (ctx->options->key.tcs.primitive_mode) {
+       switch (ctx->args->options->key.tcs.primitive_mode) {
        case GL_ISOLINES:
                stride = 2;
                outer_comps = 2;
@@ -4346,7 +4306,7 @@ write_tess_factors(struct radv_shader_context *ctx)
        }
 
        // LINES reversal
-       if (ctx->options->key.tcs.primitive_mode == GL_ISOLINES) {
+       if (ctx->args->options->key.tcs.primitive_mode == GL_ISOLINES) {
                outer[0] = out[1] = ac_lds_load(&ctx->ac, lds_outer);
                lds_outer = LLVMBuildAdd(ctx->ac.builder, lds_outer,
                                         ctx->ac.i32_1, "");
@@ -4375,12 +4335,12 @@ write_tess_factors(struct radv_shader_context *ctx)
 
 
        buffer = ctx->hs_ring_tess_factor;
-       tf_base = ctx->tess_factor_offset;
+       tf_base = ac_get_arg(&ctx->ac, ctx->args->tess_factor_offset);
        byteoffset = LLVMBuildMul(ctx->ac.builder, rel_patch_id,
                                  LLVMConstInt(ctx->ac.i32, 4 * stride, false), "");
        unsigned tf_offset = 0;
 
-       if (ctx->options->chip_class <= GFX8) {
+       if (ctx->ac.chip_class <= GFX8) {
                ac_build_ifcc(&ctx->ac,
                                LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ,
                                              rel_patch_id, ctx->ac.i32_0, ""), 6504);
@@ -4405,7 +4365,7 @@ write_tess_factors(struct radv_shader_context *ctx)
                                            16 + tf_offset, ac_glc, false);
 
        //store to offchip for TES to read - only if TES reads them
-       if (ctx->options->key.tcs.tes_reads_tess_factors) {
+       if (ctx->args->options->key.tcs.tes_reads_tess_factors) {
                LLVMValueRef inner_vec, outer_vec, tf_outer_offset;
                LLVMValueRef tf_inner_offset;
                unsigned param_outer, param_inner;
@@ -4419,7 +4379,8 @@ write_tess_factors(struct radv_shader_context *ctx)
 
                ac_build_buffer_store_dword(&ctx->ac, ctx->hs_ring_tess_offchip, outer_vec,
                                            outer_comps, tf_outer_offset,
-                                           ctx->oc_lds, 0, ac_glc, false);
+                                           ac_get_arg(&ctx->ac, ctx->args->oc_lds),
+                                           0, ac_glc, false);
                if (inner_comps) {
                        param_inner = shader_io_get_unique_index(VARYING_SLOT_TESS_LEVEL_INNER);
                        tf_inner_offset = get_tcs_tes_buffer_address(ctx, NULL,
@@ -4429,7 +4390,8 @@ write_tess_factors(struct radv_shader_context *ctx)
                                ac_build_gather_values(&ctx->ac, inner, inner_comps);
                        ac_build_buffer_store_dword(&ctx->ac, ctx->hs_ring_tess_offchip, inner_vec,
                                                    inner_comps, tf_inner_offset,
-                                                   ctx->oc_lds, 0, ac_glc, false);
+                                                   ac_get_arg(&ctx->ac, ctx->args->oc_lds),
+                                                   0, ac_glc, false);
                }
        }
        
@@ -4496,15 +4458,15 @@ handle_fs_outputs_post(struct radv_shader_context *ctx)
        }
 
        /* Process depth, stencil, samplemask. */
-       if (ctx->shader_info->ps.writes_z) {
+       if (ctx->args->shader_info->ps.writes_z) {
                depth = ac_to_float(&ctx->ac,
                                    radv_load_output(ctx, FRAG_RESULT_DEPTH, 0));
        }
-       if (ctx->shader_info->ps.writes_stencil) {
+       if (ctx->args->shader_info->ps.writes_stencil) {
                stencil = ac_to_float(&ctx->ac,
                                      radv_load_output(ctx, FRAG_RESULT_STENCIL, 0));
        }
-       if (ctx->shader_info->ps.writes_sample_mask) {
+       if (ctx->args->shader_info->ps.writes_sample_mask) {
                samplemask = ac_to_float(&ctx->ac,
                                         radv_load_output(ctx, FRAG_RESULT_SAMPLE_MASK, 0));
        }
@@ -4513,9 +4475,9 @@ handle_fs_outputs_post(struct radv_shader_context *ctx)
         * exported.
         */
        if (index > 0 &&
-           !ctx->shader_info->ps.writes_z &&
-           !ctx->shader_info->ps.writes_stencil &&
-           !ctx->shader_info->ps.writes_sample_mask) {
+           !ctx->args->shader_info->ps.writes_z &&
+           !ctx->args->shader_info->ps.writes_stencil &&
+           !ctx->args->shader_info->ps.writes_sample_mask) {
                unsigned last = index - 1;
 
                color_args[last].valid_mask = 1; /* whether the EXEC mask is valid */
@@ -4535,7 +4497,7 @@ handle_fs_outputs_post(struct radv_shader_context *ctx)
 static void
 emit_gs_epilogue(struct radv_shader_context *ctx)
 {
-       if (ctx->options->key.vs_common_out.as_ngg) {
+       if (ctx->args->options->key.vs_common_out.as_ngg) {
                gfx10_ngg_gs_emit_epilogue_1(ctx);
                return;
        }
@@ -4554,16 +4516,16 @@ handle_shader_outputs_post(struct ac_shader_abi *abi, unsigned max_outputs,
 
        switch (ctx->stage) {
        case MESA_SHADER_VERTEX:
-               if (ctx->options->key.vs_common_out.as_ls)
+               if (ctx->args->options->key.vs_common_out.as_ls)
                        handle_ls_outputs_post(ctx);
-               else if (ctx->options->key.vs_common_out.as_es)
-                       handle_es_outputs_post(ctx, &ctx->shader_info->vs.es_info);
-               else if (ctx->options->key.vs_common_out.as_ngg)
+               else if (ctx->args->options->key.vs_common_out.as_es)
+                       handle_es_outputs_post(ctx, &ctx->args->shader_info->vs.es_info);
+               else if (ctx->args->options->key.vs_common_out.as_ngg)
                        handle_ngg_outputs_post_1(ctx);
                else
-                       handle_vs_outputs_post(ctx, ctx->options->key.vs_common_out.export_prim_id,
-                                              ctx->options->key.vs_common_out.export_clip_dists,
-                                              &ctx->shader_info->vs.outinfo);
+                       handle_vs_outputs_post(ctx, ctx->args->options->key.vs_common_out.export_prim_id,
+                                              ctx->args->options->key.vs_common_out.export_clip_dists,
+                                              &ctx->args->shader_info->vs.outinfo);
                break;
        case MESA_SHADER_FRAGMENT:
                handle_fs_outputs_post(ctx);
@@ -4575,14 +4537,14 @@ handle_shader_outputs_post(struct ac_shader_abi *abi, unsigned max_outputs,
                handle_tcs_outputs_post(ctx);
                break;
        case MESA_SHADER_TESS_EVAL:
-               if (ctx->options->key.vs_common_out.as_es)
-                       handle_es_outputs_post(ctx, &ctx->shader_info->tes.es_info);
-               else if (ctx->options->key.vs_common_out.as_ngg)
+               if (ctx->args->options->key.vs_common_out.as_es)
+                       handle_es_outputs_post(ctx, &ctx->args->shader_info->tes.es_info);
+               else if (ctx->args->options->key.vs_common_out.as_ngg)
                        handle_ngg_outputs_post_1(ctx);
                else
-                       handle_vs_outputs_post(ctx, ctx->options->key.vs_common_out.export_prim_id,
-                                              ctx->options->key.vs_common_out.export_clip_dists,
-                                              &ctx->shader_info->tes.outinfo);
+                       handle_vs_outputs_post(ctx, ctx->args->options->key.vs_common_out.export_prim_id,
+                                              ctx->args->options->key.vs_common_out.export_clip_dists,
+                                              &ctx->args->shader_info->tes.outinfo);
                break;
        default:
                break;
@@ -4611,15 +4573,15 @@ ac_nir_eliminate_const_vs_outputs(struct radv_shader_context *ctx)
        case MESA_SHADER_GEOMETRY:
                return;
        case MESA_SHADER_VERTEX:
-               if (ctx->options->key.vs_common_out.as_ls ||
-                   ctx->options->key.vs_common_out.as_es)
+               if (ctx->args->options->key.vs_common_out.as_ls ||
+                   ctx->args->options->key.vs_common_out.as_es)
                        return;
-               outinfo = &ctx->shader_info->vs.outinfo;
+               outinfo = &ctx->args->shader_info->vs.outinfo;
                break;
        case MESA_SHADER_TESS_EVAL:
-               if (ctx->options->key.vs_common_out.as_es)
+               if (ctx->args->options->key.vs_common_out.as_es)
                        return;
-               outinfo = &ctx->shader_info->tes.outinfo;
+               outinfo = &ctx->args->shader_info->tes.outinfo;
                break;
        default:
                unreachable("Unhandled shader type");
@@ -4635,9 +4597,9 @@ ac_nir_eliminate_const_vs_outputs(struct radv_shader_context *ctx)
 static void
 ac_setup_rings(struct radv_shader_context *ctx)
 {
-       if (ctx->options->chip_class <= GFX8 &&
+       if (ctx->args->options->chip_class <= GFX8 &&
            (ctx->stage == MESA_SHADER_GEOMETRY ||
-            ctx->options->key.vs_common_out.as_es || ctx->options->key.vs_common_out.as_es)) {
+            ctx->args->options->key.vs_common_out.as_es || ctx->args->options->key.vs_common_out.as_es)) {
                unsigned ring = ctx->stage == MESA_SHADER_GEOMETRY ? RING_ESGS_GS
                                                                   : RING_ESGS_VS;
                LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, ring, false);
@@ -4647,7 +4609,7 @@ ac_setup_rings(struct radv_shader_context *ctx)
                                                       offset);
        }
 
-       if (ctx->is_gs_copy_shader) {
+       if (ctx->args->is_gs_copy_shader) {
                ctx->gsvs_ring[0] =
                        ac_build_load_to_sgpr(&ctx->ac, ctx->ring_offsets,
                                              LLVMConstInt(ctx->ac.i32,
@@ -4678,7 +4640,7 @@ ac_setup_rings(struct radv_shader_context *ctx)
                        LLVMValueRef ring, tmp;
 
                        num_components =
-                               ctx->shader_info->gs.num_stream_output_components[stream];
+                               ctx->args->shader_info->gs.num_stream_output_components[stream];
 
                        if (!num_components)
                                continue;
@@ -4742,22 +4704,40 @@ radv_nir_get_max_workgroup_size(enum chip_class chip_class,
 /* Fixup the HW not emitting the TCS regs if there are no HS threads. */
 static void ac_nir_fixup_ls_hs_input_vgprs(struct radv_shader_context *ctx)
 {
-       LLVMValueRef count = ac_unpack_param(&ctx->ac, ctx->merged_wave_info, 8, 8);
+       LLVMValueRef count =
+               ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 8, 8);
        LLVMValueRef hs_empty = LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, count,
                                              ctx->ac.i32_0, "");
-       ctx->abi.instance_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->rel_auto_id, ctx->abi.instance_id, "");
-       ctx->rel_auto_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->abi.tcs_rel_ids, ctx->rel_auto_id, "");
-       ctx->abi.vertex_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->abi.tcs_patch_id, ctx->abi.vertex_id, "");
+       ctx->abi.instance_id = LLVMBuildSelect(ctx->ac.builder, hs_empty,
+                                              ac_get_arg(&ctx->ac, ctx->args->rel_auto_id),
+                                              ctx->abi.instance_id, "");
+       ctx->rel_auto_id = LLVMBuildSelect(ctx->ac.builder, hs_empty,
+                                          ac_get_arg(&ctx->ac, ctx->args->ac.tcs_rel_ids),
+                                          ctx->rel_auto_id,
+                                          "");
+       ctx->abi.vertex_id = LLVMBuildSelect(ctx->ac.builder, hs_empty,
+                                                ac_get_arg(&ctx->ac, ctx->args->ac.tcs_patch_id),
+                                                ctx->abi.vertex_id, "");
 }
 
-static void prepare_gs_input_vgprs(struct radv_shader_context *ctx)
+static void prepare_gs_input_vgprs(struct radv_shader_context *ctx, bool merged)
 {
-       for(int i = 5; i >= 0; --i) {
-               ctx->gs_vtx_offset[i] = ac_unpack_param(&ctx->ac, ctx->gs_vtx_offset[i & ~1],
-                                                       (i & 1) * 16, 16);
-       }
+       if (merged) {
+               for(int i = 5; i >= 0; --i) {
+                       ctx->gs_vtx_offset[i] =
+                               ac_unpack_param(&ctx->ac,
+                                               ac_get_arg(&ctx->ac, ctx->args->gs_vtx_offset[i & ~1]),
+                                                          (i & 1) * 16, 16);
+               }
 
-       ctx->gs_wave_id = ac_unpack_param(&ctx->ac, ctx->merged_wave_info, 16, 8);
+               ctx->gs_wave_id = ac_unpack_param(&ctx->ac,
+                                                 ac_get_arg(&ctx->ac, ctx->args->merged_wave_info),
+                                                 16, 8);
+       } else {
+               for (int i = 0; i < 6; i++)
+                       ctx->gs_vtx_offset[i] = ac_get_arg(&ctx->ac, ctx->args->gs_vtx_offset[i]);
+               ctx->gs_wave_id = ac_get_arg(&ctx->ac, ctx->args->gs_wave_id);
+       }
 }
 
 /* Ensure that the esgs ring is declared.
@@ -4788,9 +4768,13 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                                        const struct radv_nir_compiler_options *options)
 {
        struct radv_shader_context ctx = {0};
-       unsigned i;
-       ctx.options = options;
-       ctx.shader_info = shader_info;
+       struct radv_shader_args args = {0};
+       args.options = options;
+       args.shader_info = shader_info;
+       ctx.args = &args;
+
+       declare_inputs(&args, shaders[shader_count - 1]->info.stage, shader_count >= 2,
+                      shader_count >= 2 ? shaders[shader_count - 2]->info.stage  : MESA_SHADER_VERTEX);
 
        enum ac_float_mode float_mode = AC_FLOAT_MODE_DEFAULT;
 
@@ -4802,15 +4786,10 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                             options->family, float_mode, shader_info->wave_size, 64);
        ctx.context = ctx.ac.context;
 
-       for (i = 0; i < MAX_SETS; i++)
-               shader_info->user_sgprs_locs.descriptor_sets[i].sgpr_idx = -1;
-       for (i = 0; i < AC_UD_MAX_UD; i++)
-               shader_info->user_sgprs_locs.shader_data[i].sgpr_idx = -1;
-
        ctx.max_workgroup_size = 0;
        for (int i = 0; i < shader_count; ++i) {
                ctx.max_workgroup_size = MAX2(ctx.max_workgroup_size,
-                                             radv_nir_get_max_workgroup_size(ctx.options->chip_class,
+                                             radv_nir_get_max_workgroup_size(args.options->chip_class,
                                                                              shaders[i]->info.stage,
                                                                              shaders[i]));
        }
@@ -4822,8 +4801,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                }
        }
 
-       create_function(&ctx, shaders[shader_count - 1]->info.stage, shader_count >= 2,
-                       shader_count >= 2 ? shaders[shader_count - 2]->info.stage  : MESA_SHADER_VERTEX);
+       create_function(&ctx, shaders[shader_count - 1]->info.stage, shader_count >= 2);
 
        ctx.abi.inputs = &ctx.inputs[0];
        ctx.abi.emit_outputs = handle_shader_outputs_post;
@@ -4835,10 +4813,17 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
        ctx.abi.clamp_shadow_reference = false;
        ctx.abi.robust_buffer_access = options->robust_buffer_access;
 
-       bool is_ngg = is_pre_gs_stage(shaders[0]->info.stage) &&  ctx.options->key.vs_common_out.as_ngg;
+       bool is_ngg = is_pre_gs_stage(shaders[0]->info.stage) &&  args.options->key.vs_common_out.as_ngg;
        if (shader_count >= 2 || is_ngg)
                ac_init_exec_full_mask(&ctx.ac);
 
+       if (args.ac.vertex_id.used)
+               ctx.abi.vertex_id = ac_get_arg(&ctx.ac, args.ac.vertex_id);
+       if (args.rel_auto_id.used)
+               ctx.rel_auto_id = ac_get_arg(&ctx.ac, args.rel_auto_id);
+       if (args.ac.instance_id.used)
+               ctx.abi.instance_id = ac_get_arg(&ctx.ac, args.ac.instance_id);
+
        if (options->has_ls_vgpr_init_bug &&
            shaders[shader_count - 1]->info.stage == MESA_SHADER_TESS_CTRL)
                ac_nir_fixup_ls_hs_input_vgprs(&ctx);
@@ -4873,7 +4858,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                                ctx.gs_next_vertex[i] =
                                        ac_build_alloca(&ctx.ac, ctx.ac.i32, "");
                        }
-                       if (ctx.options->key.vs_common_out.as_ngg) {
+                       if (args.options->key.vs_common_out.as_ngg) {
                                for (unsigned i = 0; i < 4; ++i) {
                                        ctx.gs_curprim_verts[i] =
                                                ac_build_alloca(&ctx.ac, ctx.ac.i32, "");
@@ -4882,7 +4867,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                                }
 
                                unsigned scratch_size = 8;
-                               if (ctx.shader_info->so.num_outputs)
+                               if (args.shader_info->so.num_outputs)
                                        scratch_size = 44;
 
                                LLVMTypeRef ai32 = LLVMArrayType(ctx.ac.i32, scratch_size);
@@ -4905,7 +4890,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                        ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
                        ctx.abi.store_tcs_outputs = store_tcs_output;
                        if (shader_count == 1)
-                               ctx.tcs_num_inputs = ctx.options->key.tcs.num_inputs;
+                               ctx.tcs_num_inputs = args.options->key.tcs.num_inputs;
                        else
                                ctx.tcs_num_inputs = util_last_bit64(shader_info->vs.ls_outputs_written);
                        ctx.tcs_num_patches = get_tcs_num_patches(&ctx);
@@ -4913,7 +4898,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                        ctx.abi.load_tess_varyings = load_tes_input;
                        ctx.abi.load_tess_coord = load_tess_coord;
                        ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
-                       ctx.tcs_num_patches = ctx.options->key.tes.num_patches;
+                       ctx.tcs_num_patches = args.options->key.tes.num_patches;
                } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX) {
                        ctx.abi.load_base_vertex = radv_load_base_vertex;
                } else if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT) {
@@ -4923,8 +4908,8 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                }
 
                if (shaders[i]->info.stage == MESA_SHADER_VERTEX &&
-                   ctx.options->key.vs_common_out.as_ngg &&
-                   ctx.options->key.vs_common_out.export_prim_id) {
+                   args.options->key.vs_common_out.as_ngg &&
+                   args.options->key.vs_common_out.export_prim_id) {
                        declare_esgs_ring(&ctx);
                }
 
@@ -4932,7 +4917,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
 
                if (i) {
                        if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY &&
-                           ctx.options->key.vs_common_out.as_ngg) {
+                           args.options->key.vs_common_out.as_ngg) {
                                gfx10_ngg_gs_emit_prologue(&ctx);
                                nested_barrier = false;
                        } else {
@@ -4972,7 +4957,10 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                        LLVMBasicBlockRef then_block = LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
                        merge_block = LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
 
-                       LLVMValueRef count = ac_unpack_param(&ctx.ac, ctx.merged_wave_info, 8 * i, 8);
+                       LLVMValueRef count =
+                               ac_unpack_param(&ctx.ac,
+                                               ac_get_arg(&ctx.ac, args.merged_wave_info),
+                                               8 * i, 8);
                        LLVMValueRef thread_id = ac_get_thread_id(&ctx.ac);
                        LLVMValueRef cond = LLVMBuildICmp(ctx.ac.builder, LLVMIntULT,
                                                          thread_id, count, "");
@@ -4985,10 +4973,10 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                        prepare_interp_optimize(&ctx, shaders[i]);
                else if(shaders[i]->info.stage == MESA_SHADER_VERTEX)
                        handle_vs_inputs(&ctx, shaders[i]);
-               else if(shader_count >= 2 && shaders[i]->info.stage == MESA_SHADER_GEOMETRY)
-                       prepare_gs_input_vgprs(&ctx);
+               else if(shaders[i]->info.stage == MESA_SHADER_GEOMETRY)
+                       prepare_gs_input_vgprs(&ctx, shader_count >= 2);
 
-               ac_nir_translate(&ctx.ac, &ctx.abi, shaders[i]);
+               ac_nir_translate(&ctx.ac, &ctx.abi, &args.ac, shaders[i]);
 
                if (shader_count >= 2 || is_ngg) {
                        LLVMBuildBr(ctx.ac.builder, merge_block);
@@ -4998,11 +4986,11 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                /* This needs to be outside the if wrapping the shader body, as sometimes
                 * the HW generates waves with 0 es/vs threads. */
                if (is_pre_gs_stage(shaders[i]->info.stage) &&
-                   ctx.options->key.vs_common_out.as_ngg &&
+                   args.options->key.vs_common_out.as_ngg &&
                    i == shader_count - 1) {
                        handle_ngg_outputs_post_2(&ctx);
                } else if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY &&
-                          ctx.options->key.vs_common_out.as_ngg) {
+                          args.options->key.vs_common_out.as_ngg) {
                        gfx10_ngg_gs_emit_epilogue_2(&ctx);
                }
 
@@ -5028,7 +5016,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
                ac_nir_eliminate_const_vs_outputs(&ctx);
 
        if (options->dump_shader) {
-               ctx.shader_info->private_mem_vgprs =
+               args.shader_info->private_mem_vgprs =
                        ac_count_scratch_private_memory(ctx.main_function);
        }
 
@@ -5152,15 +5140,18 @@ static void
 ac_gs_copy_shader_emit(struct radv_shader_context *ctx)
 {
        LLVMValueRef vtx_offset =
-               LLVMBuildMul(ctx->ac.builder, ctx->abi.vertex_id,
+               LLVMBuildMul(ctx->ac.builder, ac_get_arg(&ctx->ac, ctx->args->ac.vertex_id),
                             LLVMConstInt(ctx->ac.i32, 4, false), "");
        LLVMValueRef stream_id;
 
        /* Fetch the vertex stream ID. */
-       if (!ctx->options->use_ngg_streamout &&
-           ctx->shader_info->so.num_outputs) {
+       if (!ctx->args->options->use_ngg_streamout &&
+           ctx->args->shader_info->so.num_outputs) {
                stream_id =
-                       ac_unpack_param(&ctx->ac, ctx->streamout_config, 24, 2);
+                       ac_unpack_param(&ctx->ac,
+                                       ac_get_arg(&ctx->ac,
+                                                  ctx->args->streamout_config),
+                                       24, 2);
        } else {
                stream_id = ctx->ac.i32_0;
        }
@@ -5174,14 +5165,14 @@ ac_gs_copy_shader_emit(struct radv_shader_context *ctx)
 
        for (unsigned stream = 0; stream < 4; stream++) {
                unsigned num_components =
-                       ctx->shader_info->gs.num_stream_output_components[stream];
+                       ctx->args->shader_info->gs.num_stream_output_components[stream];
                LLVMBasicBlockRef bb;
                unsigned offset;
 
                if (stream > 0 && !num_components)
                        continue;
 
-               if (stream > 0 && !ctx->shader_info->so.num_outputs)
+               if (stream > 0 && !ctx->args->shader_info->so.num_outputs)
                        continue;
 
                bb = LLVMInsertBasicBlockInContext(ctx->ac.context, end_bb, "out");
@@ -5191,9 +5182,9 @@ ac_gs_copy_shader_emit(struct radv_shader_context *ctx)
                offset = 0;
                for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
                        unsigned output_usage_mask =
-                               ctx->shader_info->gs.output_usage_mask[i];
+                               ctx->args->shader_info->gs.output_usage_mask[i];
                        unsigned output_stream =
-                               ctx->shader_info->gs.output_streams[i];
+                               ctx->args->shader_info->gs.output_streams[i];
                        int length = util_last_bit(output_usage_mask);
 
                        if (!(ctx->output_mask & (1ull << i)) ||
@@ -5229,13 +5220,13 @@ ac_gs_copy_shader_emit(struct radv_shader_context *ctx)
                        }
                }
 
-               if (!ctx->options->use_ngg_streamout &&
-                   ctx->shader_info->so.num_outputs)
+               if (!ctx->args->options->use_ngg_streamout &&
+                   ctx->args->shader_info->so.num_outputs)
                        radv_emit_streamout(ctx, stream);
 
                if (stream == 0) {
                        handle_vs_outputs_post(ctx, false, true,
-                                              &ctx->shader_info->vs.outinfo);
+                                              &ctx->args->shader_info->vs.outinfo);
                }
 
                LLVMBuildBr(ctx->ac.builder, end_bb);
@@ -5252,18 +5243,22 @@ radv_compile_gs_copy_shader(struct ac_llvm_compiler *ac_llvm,
                            const struct radv_nir_compiler_options *options)
 {
        struct radv_shader_context ctx = {0};
-       ctx.options = options;
-       ctx.shader_info = shader_info;
+       struct radv_shader_args args = {0};
+       args.options = options;
+       args.shader_info = shader_info;
+       ctx.args = &args;
+
+       args.is_gs_copy_shader = true;
+       declare_inputs(&args, MESA_SHADER_VERTEX, false, MESA_SHADER_VERTEX);
 
        ac_llvm_context_init(&ctx.ac, ac_llvm, options->chip_class,
                             options->family, AC_FLOAT_MODE_DEFAULT, 64, 64);
        ctx.context = ctx.ac.context;
 
-       ctx.is_gs_copy_shader = true;
        ctx.stage = MESA_SHADER_VERTEX;
        ctx.shader = geom_shader;
 
-       create_function(&ctx, MESA_SHADER_VERTEX, false, MESA_SHADER_VERTEX);
+       create_function(&ctx, MESA_SHADER_VERTEX, false);
 
        ac_setup_rings(&ctx);
 
diff --git a/src/amd/vulkan/radv_shader_args.h b/src/amd/vulkan/radv_shader_args.h
new file mode 100644 (file)
index 0000000..5f295b5
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Copyright © 2019 Valve Corporation.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "ac_shader_args.h"
+#include "radv_constants.h"
+#include "util/list.h"
+#include "amd_family.h"
+
+struct radv_shader_args {
+       struct ac_shader_args ac;
+       struct radv_shader_info *shader_info;
+       const struct radv_nir_compiler_options *options;
+
+       struct ac_arg descriptor_sets[MAX_SETS];
+       struct ac_arg ring_offsets;
+
+       struct ac_arg vertex_buffers;
+       struct ac_arg rel_auto_id;
+       struct ac_arg vs_prim_id;
+       struct ac_arg es2gs_offset;
+
+       struct ac_arg oc_lds;
+       struct ac_arg merged_wave_info;
+       struct ac_arg tess_factor_offset;
+       struct ac_arg tes_rel_patch_id;
+       struct ac_arg tes_u;
+       struct ac_arg tes_v;
+
+       /* HW GS */
+       /* On gfx10:
+        *  - bits 0..10: ordered_wave_id
+        *  - bits 12..20: number of vertices in group
+        *  - bits 22..30: number of primitives in group
+        */
+       struct ac_arg gs_tg_info;
+       struct ac_arg gs2vs_offset;
+       struct ac_arg gs_wave_id;
+       struct ac_arg gs_vtx_offset[6];
+
+       /* Streamout */
+       struct ac_arg streamout_buffers;
+       struct ac_arg streamout_write_idx;
+       struct ac_arg streamout_config;
+       struct ac_arg streamout_offset[4];
+
+       bool is_gs_copy_shader;
+};
+
+static inline struct radv_shader_args *
+radv_shader_args_from_ac(struct ac_shader_args *args)
+{
+       struct radv_shader_args *radv_args = NULL;
+       return (struct radv_shader_args *) container_of(args, radv_args, ac);
+}
+
+
index a3551300516205552bc029e319aca5b9a67ba926..458c71d767bcf95d4edddc5d56e5ccdf5eb747c7 100644 (file)
 
 static LLVMValueRef get_wave_id_in_tg(struct si_shader_context *ctx)
 {
-       return si_unpack_param(ctx, ctx->param_merged_wave_info, 24, 4);
+       return si_unpack_param(ctx, ctx->merged_wave_info, 24, 4);
 }
 
 static LLVMValueRef get_tgsize(struct si_shader_context *ctx)
 {
-       return si_unpack_param(ctx, ctx->param_merged_wave_info, 28, 4);
+       return si_unpack_param(ctx, ctx->merged_wave_info, 28, 4);
 }
 
 static LLVMValueRef get_thread_id_in_tg(struct si_shader_context *ctx)
@@ -50,32 +50,22 @@ static LLVMValueRef get_thread_id_in_tg(struct si_shader_context *ctx)
 
 static LLVMValueRef ngg_get_vtx_cnt(struct si_shader_context *ctx)
 {
-       return ac_build_bfe(&ctx->ac, ctx->gs_tg_info,
-                           LLVMConstInt(ctx->ac.i32, 12, false),
-                           LLVMConstInt(ctx->ac.i32, 9, false),
-                           false);
+       return si_unpack_param(ctx, ctx->gs_tg_info, 12, 9);
 }
 
 static LLVMValueRef ngg_get_prim_cnt(struct si_shader_context *ctx)
 {
-       return ac_build_bfe(&ctx->ac, ctx->gs_tg_info,
-                           LLVMConstInt(ctx->ac.i32, 22, false),
-                           LLVMConstInt(ctx->ac.i32, 9, false),
-                           false);
+       return si_unpack_param(ctx, ctx->gs_tg_info, 22, 9);
 }
 
 static LLVMValueRef ngg_get_ordered_id(struct si_shader_context *ctx)
 {
-       return ac_build_bfe(&ctx->ac, ctx->gs_tg_info,
-                           ctx->i32_0,
-                           LLVMConstInt(ctx->ac.i32, 11, false),
-                           false);
+       return si_unpack_param(ctx, ctx->gs_tg_info, 0, 11);
 }
 
 static LLVMValueRef ngg_get_query_buf(struct si_shader_context *ctx)
 {
-       LLVMValueRef buf_ptr = LLVMGetParam(ctx->main_fn,
-                                           ctx->param_rw_buffers);
+       LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->rw_buffers);
 
        return ac_build_load_to_sgpr(&ctx->ac, buf_ptr,
                                     LLVMConstInt(ctx->i32, GFX10_GS_QUERY_BUF, false));
@@ -212,7 +202,7 @@ static void build_streamout(struct si_shader_context *ctx,
        struct tgsi_shader_info *info = &ctx->shader->selector->info;
        struct pipe_stream_output_info *so = &ctx->shader->selector->so;
        LLVMBuilderRef builder = ctx->ac.builder;
-       LLVMValueRef buf_ptr = LLVMGetParam(ctx->main_fn, ctx->param_rw_buffers);
+       LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->rw_buffers);
        LLVMValueRef tid = get_thread_id_in_tg(ctx);
        LLVMValueRef tmp, tmp2;
        LLVMValueRef i32_2 = LLVMConstInt(ctx->i32, 2, false);
@@ -583,16 +573,16 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
 
        ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
 
-       LLVMValueRef prims_in_wave = si_unpack_param(ctx, ctx->param_merged_wave_info, 8, 8);
-       LLVMValueRef vtx_in_wave = si_unpack_param(ctx, ctx->param_merged_wave_info, 0, 8);
+       LLVMValueRef prims_in_wave = si_unpack_param(ctx, ctx->merged_wave_info, 8, 8);
+       LLVMValueRef vtx_in_wave = si_unpack_param(ctx, ctx->merged_wave_info, 0, 8);
        LLVMValueRef is_gs_thread = LLVMBuildICmp(builder, LLVMIntULT,
                                                  ac_get_thread_id(&ctx->ac), prims_in_wave, "");
        LLVMValueRef is_es_thread = LLVMBuildICmp(builder, LLVMIntULT,
                                                  ac_get_thread_id(&ctx->ac), vtx_in_wave, "");
        LLVMValueRef vtxindex[] = {
-               si_unpack_param(ctx, ctx->param_gs_vtx01_offset, 0, 16),
-               si_unpack_param(ctx, ctx->param_gs_vtx01_offset, 16, 16),
-               si_unpack_param(ctx, ctx->param_gs_vtx23_offset, 0, 16),
+               si_unpack_param(ctx, ctx->gs_vtx01_offset, 0, 16),
+               si_unpack_param(ctx, ctx->gs_vtx01_offset, 16, 16),
+               si_unpack_param(ctx, ctx->gs_vtx23_offset, 0, 16),
        };
 
        /* Determine the number of vertices per primitive. */
@@ -606,7 +596,7 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
                        num_vertices_val = LLVMConstInt(ctx->i32, 3, 0);
                } else {
                        /* Extract OUTPRIM field. */
-                       tmp = si_unpack_param(ctx, ctx->param_vs_state_bits, 2, 2);
+                       tmp = si_unpack_param(ctx, ctx->vs_state_bits, 2, 2);
                        num_vertices_val = LLVMBuildAdd(builder, tmp, ctx->i32_1, "");
                        num_vertices = 3; /* TODO: optimize for points & lines */
                }
@@ -673,14 +663,14 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
                ac_build_ifcc(&ctx->ac, is_gs_thread, 5400);
                /* Extract the PROVOKING_VTX_INDEX field. */
                LLVMValueRef provoking_vtx_in_prim =
-                       si_unpack_param(ctx, ctx->param_vs_state_bits, 4, 2);
+                       si_unpack_param(ctx, ctx->vs_state_bits, 4, 2);
 
                /* provoking_vtx_index = vtxindex[provoking_vtx_in_prim]; */
                LLVMValueRef indices = ac_build_gather_values(&ctx->ac, vtxindex, 3);
                LLVMValueRef provoking_vtx_index =
                        LLVMBuildExtractElement(builder, indices, provoking_vtx_in_prim, "");
 
-               LLVMBuildStore(builder, ctx->abi.gs_prim_id,
+               LLVMBuildStore(builder, ac_get_arg(&ctx->ac, ctx->args.gs_prim_id),
                               ac_build_gep0(&ctx->ac, ctx->esgs_ring, provoking_vtx_index));
                ac_build_endif(&ctx->ac, 5400);
        }
@@ -690,7 +680,7 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
        /* Update query buffer */
        /* TODO: this won't catch 96-bit clear_buffer via transform feedback. */
        if (!info->properties[TGSI_PROPERTY_VS_BLIT_SGPRS_AMD]) {
-               tmp = si_unpack_param(ctx, ctx->param_vs_state_bits, 6, 1);
+               tmp = si_unpack_param(ctx, ctx->vs_state_bits, 6, 1);
                tmp = LLVMBuildTrunc(builder, tmp, ctx->i1, "");
                ac_build_ifcc(&ctx->ac, tmp, 5029); /* if (STREAMOUT_QUERY_ENABLED) */
                tmp = LLVMBuildICmp(builder, LLVMIntEQ, get_wave_id_in_tg(ctx), ctx->ac.i32_0, "");
@@ -752,7 +742,8 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
                                continue;
                        }
 
-                       tmp = LLVMBuildLShr(builder, ctx->abi.gs_invocation_id,
+                       tmp = LLVMBuildLShr(builder,
+                                           ac_get_arg(&ctx->ac, ctx->args.gs_invocation_id),
                                            LLVMConstInt(ctx->ac.i32, 8 + i, false), "");
                        prim.edgeflag[i] = LLVMBuildTrunc(builder, tmp, ctx->ac.i1, "");
 
@@ -1099,7 +1090,7 @@ void gfx10_ngg_gs_emit_epilogue(struct si_shader_context *ctx)
        }
 
        /* Write shader query data. */
-       tmp = si_unpack_param(ctx, ctx->param_vs_state_bits, 6, 1);
+       tmp = si_unpack_param(ctx, ctx->vs_state_bits, 6, 1);
        tmp = LLVMBuildTrunc(builder, tmp, ctx->i1, "");
        ac_build_ifcc(&ctx->ac, tmp, 5109); /* if (STREAMOUT_QUERY_ENABLED) */
        unsigned num_query_comps = sel->so.num_outputs ? 8 : 4;
index 34e6d344486fbc1c5a4abdcab1043728a61cf812..a52966f2376ee6fa3c9e883f233f83eaf11784a7 100644 (file)
@@ -318,50 +318,51 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        ac_add_function_attr(ctx->ac.context, vs, -1, AC_FUNC_ATTR_ALWAYSINLINE);
        LLVMSetLinkage(vs, LLVMPrivateLinkage);
 
-       LLVMTypeRef const_desc_type;
+       enum ac_arg_type const_desc_type;
        if (ctx->shader->selector->info.const_buffers_declared == 1 &&
            ctx->shader->selector->info.shader_buffers_declared == 0)
-               const_desc_type = ctx->f32;
+               const_desc_type = AC_ARG_CONST_FLOAT_PTR;
        else
-               const_desc_type = ctx->v4i32;
-
-       struct si_function_info fninfo;
-       si_init_function_info(&fninfo);
-
-       LLVMValueRef index_buffers_and_constants, vertex_counter, vb_desc, const_desc;
-       LLVMValueRef base_vertex, start_instance, block_id, local_id, ordered_wave_id;
-       LLVMValueRef restart_index, vp_scale[2], vp_translate[2], smallprim_precision;
-       LLVMValueRef num_prims_udiv_multiplier, num_prims_udiv_terms, sampler_desc;
-       LLVMValueRef last_wave_prim_id, vertex_count_addr;
-
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(ctx->v4i32),
-                      &index_buffers_and_constants);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &vertex_counter);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &last_wave_prim_id);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &vertex_count_addr);
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(ctx->v4i32),
-                      &vb_desc);
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(const_desc_type),
-                      &const_desc);
-       add_arg_assign(&fninfo, ARG_SGPR, ac_array_in_const32_addr_space(ctx->v8i32),
-                      &sampler_desc);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &base_vertex);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &start_instance);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &num_prims_udiv_multiplier);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &num_prims_udiv_terms);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &restart_index);
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->f32, &smallprim_precision);
+               const_desc_type = AC_ARG_CONST_DESC_PTR;
+
+       memset(&ctx->args, 0, sizeof(ctx->args));
+
+       struct ac_arg param_index_buffers_and_constants, param_vertex_counter;
+       struct ac_arg param_vb_desc, param_const_desc;
+       struct ac_arg param_base_vertex, param_start_instance;
+       struct ac_arg param_block_id, param_local_id, param_ordered_wave_id;
+       struct ac_arg param_restart_index, param_smallprim_precision;
+       struct ac_arg param_num_prims_udiv_multiplier, param_num_prims_udiv_terms;
+       struct ac_arg param_sampler_desc, param_last_wave_prim_id, param_vertex_count_addr;
+
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                  &param_index_buffers_and_constants);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_vertex_counter);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_last_wave_prim_id);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_vertex_count_addr);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                  &param_vb_desc);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, const_desc_type,
+                  &param_const_desc);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_IMAGE_PTR,
+                  &param_sampler_desc);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_base_vertex);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_start_instance);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_num_prims_udiv_multiplier);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_num_prims_udiv_terms);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_restart_index);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, &param_smallprim_precision);
 
        /* Block ID and thread ID inputs. */
-       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &block_id);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_block_id);
        if (VERTEX_COUNTER_GDS_MODE == 2)
-               add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &ordered_wave_id);
-       add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &local_id);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &param_ordered_wave_id);
+       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &param_local_id);
 
        /* Create the compute shader function. */
        unsigned old_type = ctx->type;
        ctx->type = PIPE_SHADER_COMPUTE;
-       si_create_function(ctx, "prim_discard_cs", NULL, 0, &fninfo, THREADGROUP_SIZE);
+       si_create_function(ctx, "prim_discard_cs", NULL, 0, THREADGROUP_SIZE);
        ctx->type = old_type;
 
        if (VERTEX_COUNTER_GDS_MODE == 1) {
@@ -376,14 +377,14 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
        vs_params[num_vs_params++] = LLVMGetUndef(LLVMTypeOf(LLVMGetParam(vs, 0))); /* RW_BUFFERS */
        vs_params[num_vs_params++] = LLVMGetUndef(LLVMTypeOf(LLVMGetParam(vs, 1))); /* BINDLESS */
-       vs_params[num_vs_params++] = const_desc;
-       vs_params[num_vs_params++] = sampler_desc;
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_const_desc);
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_sampler_desc);
        vs_params[num_vs_params++] = LLVMConstInt(ctx->i32,
                                        S_VS_STATE_INDEXED(key->opt.cs_indexed), 0);
-       vs_params[num_vs_params++] = base_vertex;
-       vs_params[num_vs_params++] = start_instance;
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_base_vertex);
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_start_instance);
        vs_params[num_vs_params++] = ctx->i32_0; /* DrawID */
-       vs_params[num_vs_params++] = vb_desc;
+       vs_params[num_vs_params++] = ac_get_arg(&ctx->ac, param_vb_desc);
 
        vs_params[(param_vertex_id = num_vs_params++)] = NULL; /* VertexID */
        vs_params[(param_instance_id = num_vs_params++)] = NULL; /* InstanceID */
@@ -396,6 +397,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        /* Load descriptors. (load 8 dwords at once) */
        LLVMValueRef input_indexbuf, output_indexbuf, tmp, desc[8];
 
+       LLVMValueRef index_buffers_and_constants = ac_get_arg(&ctx->ac, param_index_buffers_and_constants);
        tmp = LLVMBuildPointerCast(builder, index_buffers_and_constants,
                                   ac_array_in_const32_addr_space(ctx->v8i32), "");
        tmp = ac_build_load_to_sgpr(&ctx->ac, tmp, ctx->i32_0);
@@ -408,12 +410,17 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
        /* Compute PrimID and InstanceID. */
        LLVMValueRef global_thread_id =
-               ac_build_imad(&ctx->ac, block_id,
-                             LLVMConstInt(ctx->i32, THREADGROUP_SIZE, 0), local_id);
+               ac_build_imad(&ctx->ac, ac_get_arg(&ctx->ac, param_block_id),
+                             LLVMConstInt(ctx->i32, THREADGROUP_SIZE, 0),
+                             ac_get_arg(&ctx->ac, param_local_id));
        LLVMValueRef prim_id = global_thread_id; /* PrimID within an instance */
        LLVMValueRef instance_id = ctx->i32_0;
 
        if (key->opt.cs_instancing) {
+               LLVMValueRef num_prims_udiv_terms =
+                       ac_get_arg(&ctx->ac, param_num_prims_udiv_terms);
+               LLVMValueRef num_prims_udiv_multiplier =
+                       ac_get_arg(&ctx->ac, param_num_prims_udiv_multiplier);
                /* Unpack num_prims_udiv_terms. */
                LLVMValueRef post_shift = LLVMBuildAnd(builder, num_prims_udiv_terms,
                                                       LLVMConstInt(ctx->i32, 0x1f, 0), "");
@@ -477,6 +484,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                }
        }
 
+       LLVMValueRef ordered_wave_id = ac_get_arg(&ctx->ac, param_ordered_wave_id);
+
        /* Extract the ordered wave ID. */
        if (VERTEX_COUNTER_GDS_MODE == 2) {
                ordered_wave_id = LLVMBuildLShr(builder, ordered_wave_id,
@@ -485,7 +494,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                                               LLVMConstInt(ctx->i32, 0xfff, 0), "");
        }
        LLVMValueRef thread_id =
-               LLVMBuildAnd(builder, local_id, LLVMConstInt(ctx->i32, 63, 0), "");
+               LLVMBuildAnd(builder, ac_get_arg(&ctx->ac, param_local_id),
+                            LLVMConstInt(ctx->i32, 63, 0), "");
 
        /* Every other triangle in a strip has a reversed vertex order, so we
         * need to swap vertices of odd primitives to get the correct primitive
@@ -493,6 +503,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
         * restart complicates it, because a strip can start anywhere.
         */
        LLVMValueRef prim_restart_accepted = ctx->i1true;
+       LLVMValueRef vertex_counter = ac_get_arg(&ctx->ac, param_vertex_counter);
 
        if (key->opt.cs_prim_type == PIPE_PRIM_TRIANGLE_STRIP) {
                /* Without primitive restart, odd primitives have reversed orientation.
@@ -520,7 +531,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
                        for (unsigned i = 0; i < 3; i++) {
                                LLVMValueRef not_reset = LLVMBuildICmp(builder, LLVMIntNE, index[i],
-                                                                      restart_index, "");
+                                                                      ac_get_arg(&ctx->ac, param_restart_index),
+                                                                      "");
                                if (i == 0)
                                        index0_is_reset = LLVMBuildNot(builder, not_reset, "");
                                prim_restart_accepted = LLVMBuildAnd(builder, prim_restart_accepted,
@@ -680,6 +692,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        LLVMValueRef vp = ac_build_load_invariant(&ctx->ac, index_buffers_and_constants,
                                                  LLVMConstInt(ctx->i32, 2, 0));
        vp = LLVMBuildBitCast(builder, vp, ctx->v4f32, "");
+       LLVMValueRef vp_scale[2], vp_translate[2];
        vp_scale[0] = ac_llvm_extract_elem(&ctx->ac, vp, 0);
        vp_scale[1] = ac_llvm_extract_elem(&ctx->ac, vp, 1);
        vp_translate[0] = ac_llvm_extract_elem(&ctx->ac, vp, 2);
@@ -699,7 +712,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
 
        LLVMValueRef accepted =
                ac_cull_triangle(&ctx->ac, pos, prim_restart_accepted,
-                                vp_scale, vp_translate, smallprim_precision,
+                                vp_scale, vp_translate,
+                                ac_get_arg(&ctx->ac, param_smallprim_precision),
                                 &options);
 
        LLVMValueRef accepted_threadmask = ac_get_i1_sgpr_mask(&ctx->ac, accepted);
@@ -788,7 +802,8 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
        if (VERTEX_COUNTER_GDS_MODE == 2) {
                ac_build_ifcc(&ctx->ac,
                              LLVMBuildICmp(builder, LLVMIntEQ, global_thread_id,
-                                           last_wave_prim_id, ""), 12606);
+                                           ac_get_arg(&ctx->ac, param_last_wave_prim_id), ""),
+                             12606);
                LLVMValueRef count = LLVMBuildAdd(builder, start, num_prims_accepted, "");
                count = LLVMBuildMul(builder, count,
                                     LLVMConstInt(ctx->i32, vertices_per_prim, 0), "");
@@ -798,7 +813,7 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                 */
                if (ctx->screen->info.chip_class <= GFX8) {
                        LLVMValueRef desc[] = {
-                               vertex_count_addr,
+                               ac_get_arg(&ctx->ac, param_vertex_count_addr),
                                LLVMConstInt(ctx->i32,
                                        S_008F04_BASE_ADDRESS_HI(ctx->screen->info.address32_hi), 0),
                                LLVMConstInt(ctx->i32, 4, 0),
@@ -810,7 +825,9 @@ void si_build_prim_discard_compute_shader(struct si_shader_context *ctx)
                                                    ctx->i32_0, 0, ac_glc | ac_slc, false);
                } else {
                        LLVMBuildStore(builder, count,
-                                      si_expand_32bit_pointer(ctx, vertex_count_addr));
+                                      si_expand_32bit_pointer(ctx,
+                                                              ac_get_arg(&ctx->ac,
+                                                                         param_vertex_count_addr)));
                }
                ac_build_endif(&ctx->ac, 12606);
        } else {
index fba7187204d6788c6bbe15ca67562f1ef2dfbc0f..b63a39efe2da5e67a492a4608030c6759e9bd423 100644 (file)
@@ -105,50 +105,6 @@ static bool is_merged_shader(struct si_shader_context *ctx)
        return ctx->shader->key.as_ngg || is_multi_part_shader(ctx);
 }
 
-void si_init_function_info(struct si_function_info *fninfo)
-{
-       fninfo->num_params = 0;
-       fninfo->num_sgpr_params = 0;
-}
-
-unsigned add_arg_assign(struct si_function_info *fninfo,
-                       enum si_arg_regfile regfile, LLVMTypeRef type,
-                       LLVMValueRef *assign)
-{
-       assert(regfile != ARG_SGPR || fninfo->num_sgpr_params == fninfo->num_params);
-
-       unsigned idx = fninfo->num_params++;
-       assert(idx < ARRAY_SIZE(fninfo->types));
-
-       if (regfile == ARG_SGPR)
-               fninfo->num_sgpr_params = fninfo->num_params;
-
-       fninfo->types[idx] = type;
-       fninfo->assign[idx] = assign;
-       return idx;
-}
-
-static unsigned add_arg(struct si_function_info *fninfo,
-                       enum si_arg_regfile regfile, LLVMTypeRef type)
-{
-       return add_arg_assign(fninfo, regfile, type, NULL);
-}
-
-static void add_arg_assign_checked(struct si_function_info *fninfo,
-                                  enum si_arg_regfile regfile, LLVMTypeRef type,
-                                  LLVMValueRef *assign, unsigned idx)
-{
-       ASSERTED unsigned actual = add_arg_assign(fninfo, regfile, type, assign);
-       assert(actual == idx);
-}
-
-static void add_arg_checked(struct si_function_info *fninfo,
-                           enum si_arg_regfile regfile, LLVMTypeRef type,
-                           unsigned idx)
-{
-       add_arg_assign_checked(fninfo, regfile, type, NULL, idx);
-}
-
 /**
  * Returns a unique index for a per-patch semantic name and index. The index
  * must be less than 32, so that a 32-bit bitmask of used inputs or outputs
@@ -257,10 +213,10 @@ static LLVMValueRef unpack_llvm_param(struct si_shader_context *ctx,
 }
 
 LLVMValueRef si_unpack_param(struct si_shader_context *ctx,
-                            unsigned param, unsigned rshift,
+                            struct ac_arg param, unsigned rshift,
                             unsigned bitwidth)
 {
-       LLVMValueRef value = LLVMGetParam(ctx->main_fn, param);
+       LLVMValueRef value = ac_get_arg(&ctx->ac, param);
 
        return unpack_llvm_param(ctx, value, rshift, bitwidth);
 }
@@ -269,11 +225,10 @@ static LLVMValueRef get_rel_patch_id(struct si_shader_context *ctx)
 {
        switch (ctx->type) {
        case PIPE_SHADER_TESS_CTRL:
-               return unpack_llvm_param(ctx, ctx->abi.tcs_rel_ids, 0, 8);
+               return si_unpack_param(ctx, ctx->args.tcs_rel_ids, 0, 8);
 
        case PIPE_SHADER_TESS_EVAL:
-               return LLVMGetParam(ctx->main_fn,
-                                   ctx->param_tes_rel_patch_id);
+               return ac_get_arg(&ctx->ac, ctx->tes_rel_patch_id);
 
        default:
                assert(0);
@@ -305,7 +260,7 @@ static LLVMValueRef get_rel_patch_id(struct si_shader_context *ctx)
 static LLVMValueRef
 get_tcs_in_patch_stride(struct si_shader_context *ctx)
 {
-       return si_unpack_param(ctx, ctx->param_vs_state_bits, 8, 13);
+       return si_unpack_param(ctx, ctx->vs_state_bits, 8, 13);
 }
 
 static unsigned get_tcs_out_vertex_dw_stride_constant(struct si_shader_context *ctx)
@@ -328,7 +283,7 @@ static LLVMValueRef get_tcs_out_vertex_dw_stride(struct si_shader_context *ctx)
 static LLVMValueRef get_tcs_out_patch_stride(struct si_shader_context *ctx)
 {
        if (ctx->shader->key.mono.u.ff_tcs_inputs_to_copy)
-               return si_unpack_param(ctx, ctx->param_tcs_out_lds_layout, 0, 13);
+               return si_unpack_param(ctx, ctx->tcs_out_lds_layout, 0, 13);
 
        const struct tgsi_shader_info *info = &ctx->shader->selector->info;
        unsigned tcs_out_vertices = info->properties[TGSI_PROPERTY_TCS_VERTICES_OUT];
@@ -343,9 +298,7 @@ static LLVMValueRef
 get_tcs_out_patch0_offset(struct si_shader_context *ctx)
 {
        return LLVMBuildMul(ctx->ac.builder,
-                           si_unpack_param(ctx,
-                                           ctx->param_tcs_out_lds_offsets,
-                                           0, 16),
+                           si_unpack_param(ctx, ctx->tcs_out_lds_offsets, 0, 16),
                            LLVMConstInt(ctx->i32, 4, 0), "");
 }
 
@@ -353,9 +306,7 @@ static LLVMValueRef
 get_tcs_out_patch0_patch_data_offset(struct si_shader_context *ctx)
 {
        return LLVMBuildMul(ctx->ac.builder,
-                           si_unpack_param(ctx,
-                                           ctx->param_tcs_out_lds_offsets,
-                                           16, 16),
+                           si_unpack_param(ctx, ctx->tcs_out_lds_offsets, 16, 16),
                            LLVMConstInt(ctx->i32, 4, 0), "");
 }
 
@@ -399,7 +350,7 @@ static LLVMValueRef get_num_tcs_out_vertices(struct si_shader_context *ctx)
        if (ctx->type == PIPE_SHADER_TESS_CTRL && tcs_out_vertices)
                return LLVMConstInt(ctx->i32, tcs_out_vertices, 0);
 
-       return si_unpack_param(ctx, ctx->param_tcs_offchip_layout, 6, 6);
+       return si_unpack_param(ctx, ctx->tcs_offchip_layout, 6, 6);
 }
 
 static LLVMValueRef get_tcs_in_vertex_dw_stride(struct si_shader_context *ctx)
@@ -417,7 +368,7 @@ static LLVMValueRef get_tcs_in_vertex_dw_stride(struct si_shader_context *ctx)
                        stride = ctx->shader->key.part.tcs.ls->lshs_vertex_stride / 4;
                        return LLVMConstInt(ctx->i32, stride, 0);
                }
-               return si_unpack_param(ctx, ctx->param_vs_state_bits, 24, 8);
+               return si_unpack_param(ctx, ctx->vs_state_bits, 24, 8);
 
        default:
                assert(0);
@@ -460,12 +411,13 @@ void si_llvm_load_input_vs(
                                                    LLVMIntNE, vertex_id,
                                                    ctx->i32_1, "");
 
+               unsigned param_vs_blit_inputs = ctx->vs_blit_inputs.arg_index;
                if (input_index == 0) {
                        /* Position: */
                        LLVMValueRef x1y1 = LLVMGetParam(ctx->main_fn,
-                                                        ctx->param_vs_blit_inputs);
+                                                        param_vs_blit_inputs);
                        LLVMValueRef x2y2 = LLVMGetParam(ctx->main_fn,
-                                                        ctx->param_vs_blit_inputs + 1);
+                                                        param_vs_blit_inputs + 1);
 
                        LLVMValueRef x1 = unpack_sint16(ctx, x1y1, 0);
                        LLVMValueRef y1 = unpack_sint16(ctx, x1y1, 1);
@@ -480,7 +432,7 @@ void si_llvm_load_input_vs(
                        out[0] = LLVMBuildSIToFP(ctx->ac.builder, x, ctx->f32, "");
                        out[1] = LLVMBuildSIToFP(ctx->ac.builder, y, ctx->f32, "");
                        out[2] = LLVMGetParam(ctx->main_fn,
-                                             ctx->param_vs_blit_inputs + 2);
+                                             param_vs_blit_inputs + 2);
                        out[3] = ctx->ac.f32_1;
                        return;
                }
@@ -491,27 +443,27 @@ void si_llvm_load_input_vs(
                if (vs_blit_property == SI_VS_BLIT_SGPRS_POS_COLOR) {
                        for (int i = 0; i < 4; i++) {
                                out[i] = LLVMGetParam(ctx->main_fn,
-                                                     ctx->param_vs_blit_inputs + 3 + i);
+                                                     param_vs_blit_inputs + 3 + i);
                        }
                } else {
                        assert(vs_blit_property == SI_VS_BLIT_SGPRS_POS_TEXCOORD);
                        LLVMValueRef x1 = LLVMGetParam(ctx->main_fn,
-                                                      ctx->param_vs_blit_inputs + 3);
+                                                      param_vs_blit_inputs + 3);
                        LLVMValueRef y1 = LLVMGetParam(ctx->main_fn,
-                                                      ctx->param_vs_blit_inputs + 4);
+                                                      param_vs_blit_inputs + 4);
                        LLVMValueRef x2 = LLVMGetParam(ctx->main_fn,
-                                                      ctx->param_vs_blit_inputs + 5);
+                                                      param_vs_blit_inputs + 5);
                        LLVMValueRef y2 = LLVMGetParam(ctx->main_fn,
-                                                      ctx->param_vs_blit_inputs + 6);
+                                                      param_vs_blit_inputs + 6);
 
                        out[0] = LLVMBuildSelect(ctx->ac.builder, sel_x1,
                                                 x1, x2, "");
                        out[1] = LLVMBuildSelect(ctx->ac.builder, sel_y1,
                                                 y1, y2, "");
                        out[2] = LLVMGetParam(ctx->main_fn,
-                                             ctx->param_vs_blit_inputs + 7);
+                                             param_vs_blit_inputs + 7);
                        out[3] = LLVMGetParam(ctx->main_fn,
-                                             ctx->param_vs_blit_inputs + 8);
+                                             param_vs_blit_inputs + 8);
                }
                return;
        }
@@ -524,14 +476,14 @@ void si_llvm_load_input_vs(
        LLVMValueRef tmp;
 
        /* Load the T list */
-       t_list_ptr = LLVMGetParam(ctx->main_fn, ctx->param_vertex_buffers);
+       t_list_ptr = ac_get_arg(&ctx->ac, ctx->vertex_buffers);
 
        t_offset = LLVMConstInt(ctx->i32, input_index, 0);
 
        t_list = ac_build_load_to_sgpr(&ctx->ac, t_list_ptr, t_offset);
 
        vertex_index = LLVMGetParam(ctx->main_fn,
-                                   ctx->param_vertex_index0 +
+                                   ctx->vertex_index0.arg_index +
                                    input_index);
 
        /* Use the open-coded implementation for all loads of doubles and
@@ -661,14 +613,13 @@ LLVMValueRef si_get_primitive_id(struct si_shader_context *ctx,
 
        switch (ctx->type) {
        case PIPE_SHADER_VERTEX:
-               return LLVMGetParam(ctx->main_fn,
-                                   ctx->param_vs_prim_id);
+               return ac_get_arg(&ctx->ac, ctx->vs_prim_id);
        case PIPE_SHADER_TESS_CTRL:
-               return ctx->abi.tcs_patch_id;
+               return ac_get_arg(&ctx->ac, ctx->args.tcs_patch_id);
        case PIPE_SHADER_TESS_EVAL:
-               return ctx->abi.tes_patch_id;
+               return ac_get_arg(&ctx->ac, ctx->args.tes_patch_id);
        case PIPE_SHADER_GEOMETRY:
-               return ctx->abi.gs_prim_id;
+               return ac_get_arg(&ctx->ac, ctx->args.gs_prim_id);
        default:
                assert(0);
                return ctx->i32_0;
@@ -853,7 +804,7 @@ static LLVMValueRef get_tcs_tes_buffer_address(struct si_shader_context *ctx,
        LLVMValueRef param_stride, constant16;
 
        vertices_per_patch = get_num_tcs_out_vertices(ctx);
-       num_patches = si_unpack_param(ctx, ctx->param_tcs_offchip_layout, 0, 6);
+       num_patches = si_unpack_param(ctx, ctx->tcs_offchip_layout, 0, 6);
        total_vertices = LLVMBuildMul(ctx->ac.builder, vertices_per_patch,
                                      num_patches, "");
 
@@ -872,7 +823,7 @@ static LLVMValueRef get_tcs_tes_buffer_address(struct si_shader_context *ctx,
 
        if (!vertex_index) {
                LLVMValueRef patch_data_offset =
-                          si_unpack_param(ctx, ctx->param_tcs_offchip_layout, 12, 20);
+                          si_unpack_param(ctx, ctx->tcs_offchip_layout, 12, 20);
 
                base_addr = LLVMBuildAdd(ctx->ac.builder, base_addr,
                                         patch_data_offset, "");
@@ -1065,9 +1016,10 @@ static LLVMValueRef get_tess_ring_descriptor(struct si_shader_context *ctx,
                                             enum si_tess_ring ring)
 {
        LLVMBuilderRef builder = ctx->ac.builder;
-       unsigned param = ring == TESS_OFFCHIP_RING_TES ? ctx->param_tes_offchip_addr :
-                                                        ctx->param_tcs_out_lds_layout;
-       LLVMValueRef addr = LLVMGetParam(ctx->main_fn, param);
+       LLVMValueRef addr = ac_get_arg(&ctx->ac,
+                                      ring == TESS_OFFCHIP_RING_TES ?
+                                      ctx->tes_offchip_addr :
+                                      ctx->tcs_out_lds_layout);
 
        /* TCS only receives high 13 bits of the address. */
        if (ring == TESS_OFFCHIP_RING_TCS || ring == TCS_FACTOR_RING) {
@@ -1215,7 +1167,7 @@ static LLVMValueRef fetch_input_tes(
        LLVMValueRef base, addr;
        unsigned swizzle = (swizzle_in & 0xffff);
 
-       base = LLVMGetParam(ctx->main_fn, ctx->param_tcs_offchip_offset);
+       base = ac_get_arg(&ctx->ac, ctx->tcs_offchip_offset);
        addr = get_tcs_tes_buffer_address_from_reg(ctx, NULL, reg);
 
        return buffer_load(bld_base, tgsi2llvmtype(bld_base, type), swizzle,
@@ -1241,7 +1193,7 @@ LLVMValueRef si_nir_load_input_tes(struct ac_shader_abi *abi,
 
        driver_location = driver_location / 4;
 
-       base = LLVMGetParam(ctx->main_fn, ctx->param_tcs_offchip_offset);
+       base = ac_get_arg(&ctx->ac, ctx->tcs_offchip_offset);
 
        if (!param_index) {
                param_index = LLVMConstInt(ctx->i32, const_index, 0);
@@ -1336,7 +1288,7 @@ static void store_output_tcs(struct lp_build_tgsi_context *bld_base,
 
        buffer = get_tess_ring_descriptor(ctx, TESS_OFFCHIP_RING_TCS);
 
-       base = LLVMGetParam(ctx->main_fn, ctx->param_tcs_offchip_offset);
+       base = ac_get_arg(&ctx->ac, ctx->tcs_offchip_offset);
        buf_addr = get_tcs_tes_buffer_address_from_reg(ctx, reg, NULL);
 
        uint32_t writemask = reg->Register.WriteMask;
@@ -1445,7 +1397,7 @@ static void si_nir_store_output_tcs(struct ac_shader_abi *abi,
 
        buffer = get_tess_ring_descriptor(ctx, TESS_OFFCHIP_RING_TCS);
 
-       base = LLVMGetParam(ctx->main_fn, ctx->param_tcs_offchip_offset);
+       base = ac_get_arg(&ctx->ac, ctx->tcs_offchip_offset);
 
        addr = get_tcs_tes_buffer_address_from_generic_indices(ctx, vertex_index,
                                                               param_index, driver_location,
@@ -1528,16 +1480,16 @@ LLVMValueRef si_llvm_load_input_gs(struct ac_shader_abi *abi,
 
                switch (index / 2) {
                case 0:
-                       vtx_offset = si_unpack_param(ctx, ctx->param_gs_vtx01_offset,
-                                                 index % 2 ? 16 : 0, 16);
+                       vtx_offset = si_unpack_param(ctx, ctx->gs_vtx01_offset,
+                                                    index % 2 ? 16 : 0, 16);
                        break;
                case 1:
-                       vtx_offset = si_unpack_param(ctx, ctx->param_gs_vtx23_offset,
-                                                 index % 2 ? 16 : 0, 16);
+                       vtx_offset = si_unpack_param(ctx, ctx->gs_vtx23_offset,
+                                                    index % 2 ? 16 : 0, 16);
                        break;
                case 2:
-                       vtx_offset = si_unpack_param(ctx, ctx->param_gs_vtx45_offset,
-                                                 index % 2 ? 16 : 0, 16);
+                       vtx_offset = si_unpack_param(ctx, ctx->gs_vtx45_offset,
+                                                    index % 2 ? 16 : 0, 16);
                        break;
                default:
                        assert(0);
@@ -1575,7 +1527,8 @@ LLVMValueRef si_llvm_load_input_gs(struct ac_shader_abi *abi,
        }
 
        /* Get the vertex offset parameter on GFX6. */
-       LLVMValueRef gs_vtx_offset = ctx->gs_vtx_offset[vtx_offset_param];
+       LLVMValueRef gs_vtx_offset = ac_get_arg(&ctx->ac,
+                                               ctx->gs_vtx_offset[vtx_offset_param]);
 
        vtx_offset = LLVMBuildMul(ctx->ac.builder, gs_vtx_offset,
                                  LLVMConstInt(ctx->i32, 4, 0), "");
@@ -1829,7 +1782,7 @@ void si_llvm_load_input_fs(
        interp_fs_input(ctx, input_index, semantic_name,
                        semantic_index, 0, /* this param is unused */
                        shader->selector->info.colors_read, interp_param,
-                       ctx->abi.prim_mask,
+                       ac_get_arg(&ctx->ac, ctx->args.prim_mask),
                        LLVMGetParam(main_fn, SI_PARAM_FRONT_FACE),
                        &out[0]);
 }
@@ -1845,7 +1798,7 @@ static void declare_input_fs(
 
 LLVMValueRef si_get_sample_id(struct si_shader_context *ctx)
 {
-       return si_unpack_param(ctx, SI_PARAM_ANCILLARY, 8, 4);
+       return si_unpack_param(ctx, ctx->args.ancillary, 8, 4);
 }
 
 static LLVMValueRef get_base_vertex(struct ac_shader_abi *abi)
@@ -1856,14 +1809,15 @@ static LLVMValueRef get_base_vertex(struct ac_shader_abi *abi)
         * (for direct draws) or the CP (for indirect draws) is the
         * first vertex ID, but GLSL expects 0 to be returned.
         */
-       LLVMValueRef vs_state = LLVMGetParam(ctx->main_fn,
-                                            ctx->param_vs_state_bits);
+       LLVMValueRef vs_state = ac_get_arg(&ctx->ac,
+                                          ctx->vs_state_bits);
        LLVMValueRef indexed;
 
        indexed = LLVMBuildLShr(ctx->ac.builder, vs_state, ctx->i32_1, "");
        indexed = LLVMBuildTrunc(ctx->ac.builder, indexed, ctx->i1, "");
 
-       return LLVMBuildSelect(ctx->ac.builder, indexed, ctx->abi.base_vertex,
+       return LLVMBuildSelect(ctx->ac.builder, indexed,
+                              ac_get_arg(&ctx->ac, ctx->args.base_vertex),
                               ctx->i32_0, "");
 }
 
@@ -1888,7 +1842,7 @@ static LLVMValueRef get_block_size(struct ac_shader_abi *abi)
 
                result = ac_build_gather_values(&ctx->ac, values, 3);
        } else {
-               result = LLVMGetParam(ctx->main_fn, ctx->param_block_size);
+               result = ac_get_arg(&ctx->ac, ctx->block_size);
        }
 
        return result;
@@ -1908,7 +1862,7 @@ static LLVMValueRef buffer_load_const(struct si_shader_context *ctx,
 static LLVMValueRef load_sample_position(struct ac_shader_abi *abi, LLVMValueRef sample_id)
 {
        struct si_shader_context *ctx = si_shader_context_from_abi(abi);
-       LLVMValueRef desc = LLVMGetParam(ctx->main_fn, ctx->param_rw_buffers);
+       LLVMValueRef desc = ac_get_arg(&ctx->ac, ctx->rw_buffers);
        LLVMValueRef buf_index = LLVMConstInt(ctx->i32, SI_PS_CONST_SAMPLE_POSITIONS, 0);
        LLVMValueRef resource = ac_build_load_to_sgpr(&ctx->ac, desc, buf_index);
 
@@ -1929,15 +1883,15 @@ static LLVMValueRef load_sample_position(struct ac_shader_abi *abi, LLVMValueRef
 static LLVMValueRef load_sample_mask_in(struct ac_shader_abi *abi)
 {
        struct si_shader_context *ctx = si_shader_context_from_abi(abi);
-       return ac_to_integer(&ctx->ac, abi->sample_coverage);
+       return ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args.sample_coverage));
 }
 
 static LLVMValueRef si_load_tess_coord(struct ac_shader_abi *abi)
 {
        struct si_shader_context *ctx = si_shader_context_from_abi(abi);
        LLVMValueRef coord[4] = {
-               LLVMGetParam(ctx->main_fn, ctx->param_tes_u),
-               LLVMGetParam(ctx->main_fn, ctx->param_tes_v),
+               ac_get_arg(&ctx->ac, ctx->tes_u),
+               ac_get_arg(&ctx->ac, ctx->tes_v),
                ctx->ac.f32_0,
                ctx->ac.f32_0
        };
@@ -1959,7 +1913,7 @@ static LLVMValueRef load_tess_level(struct si_shader_context *ctx,
 
        int param = si_shader_io_get_unique_index_patch(semantic_name, 0);
 
-       base = LLVMGetParam(ctx->main_fn, ctx->param_tcs_offchip_offset);
+       base = ac_get_arg(&ctx->ac, ctx->tcs_offchip_offset);
        addr = get_tcs_tes_buffer_address(ctx, get_rel_patch_id(ctx), NULL,
                                          LLVMConstInt(ctx->i32, param, 0));
 
@@ -1975,7 +1929,7 @@ static LLVMValueRef load_tess_level_default(struct si_shader_context *ctx,
        int i, offset;
 
        slot = LLVMConstInt(ctx->i32, SI_HS_CONST_DEFAULT_TESS_LEVELS, 0);
-       buf = LLVMGetParam(ctx->main_fn, ctx->param_rw_buffers);
+       buf = ac_get_arg(&ctx->ac, ctx->rw_buffers);
        buf = ac_build_load_to_sgpr(&ctx->ac, buf, slot);
        offset = semantic_name == TGSI_SEMANTIC_TESS_DEFAULT_INNER_LEVEL ? 4 : 0;
 
@@ -2025,7 +1979,7 @@ static LLVMValueRef si_load_patch_vertices_in(struct ac_shader_abi *abi)
 {
        struct si_shader_context *ctx = si_shader_context_from_abi(abi);
        if (ctx->type == PIPE_SHADER_TESS_CTRL)
-               return si_unpack_param(ctx, ctx->param_tcs_out_lds_layout, 13, 6);
+               return si_unpack_param(ctx, ctx->tcs_out_lds_layout, 13, 6);
        else if (ctx->type == PIPE_SHADER_TESS_EVAL)
                return get_num_tcs_out_vertices(ctx);
        else
@@ -2048,7 +2002,7 @@ void si_load_system_value(struct si_shader_context *ctx,
        case TGSI_SEMANTIC_VERTEXID:
                value = LLVMBuildAdd(ctx->ac.builder,
                                     ctx->abi.vertex_id,
-                                    ctx->abi.base_vertex, "");
+                                    ac_get_arg(&ctx->ac, ctx->args.base_vertex), "");
                break;
 
        case TGSI_SEMANTIC_VERTEXID_NOBASE:
@@ -2062,23 +2016,23 @@ void si_load_system_value(struct si_shader_context *ctx,
                break;
 
        case TGSI_SEMANTIC_BASEINSTANCE:
-               value = ctx->abi.start_instance;
+               value = ac_get_arg(&ctx->ac, ctx->args.start_instance);
                break;
 
        case TGSI_SEMANTIC_DRAWID:
-               value = ctx->abi.draw_id;
+               value = ac_get_arg(&ctx->ac, ctx->args.draw_id);
                break;
 
        case TGSI_SEMANTIC_INVOCATIONID:
                if (ctx->type == PIPE_SHADER_TESS_CTRL) {
-                       value = unpack_llvm_param(ctx, ctx->abi.tcs_rel_ids, 8, 5);
+                       value = si_unpack_param(ctx, ctx->args.tcs_rel_ids, 8, 5);
                } else if (ctx->type == PIPE_SHADER_GEOMETRY) {
                        if (ctx->screen->info.chip_class >= GFX10) {
                                value = LLVMBuildAnd(ctx->ac.builder,
-                                                    ctx->abi.gs_invocation_id,
+                                                    ac_get_arg(&ctx->ac, ctx->args.gs_invocation_id),
                                                     LLVMConstInt(ctx->i32, 127, 0), "");
                        } else {
-                               value = ctx->abi.gs_invocation_id;
+                               value = ac_get_arg(&ctx->ac, ctx->args.gs_invocation_id);
                        }
                } else {
                        assert(!"INVOCATIONID not implemented");
@@ -2099,7 +2053,7 @@ void si_load_system_value(struct si_shader_context *ctx,
        }
 
        case TGSI_SEMANTIC_FACE:
-               value = ctx->abi.front_face;
+               value = ac_get_arg(&ctx->ac, ctx->args.front_face);
                break;
 
        case TGSI_SEMANTIC_SAMPLEID:
@@ -2149,7 +2103,7 @@ void si_load_system_value(struct si_shader_context *ctx,
                break;
 
        case TGSI_SEMANTIC_GRID_SIZE:
-               value = ctx->abi.num_work_groups;
+               value = ac_get_arg(&ctx->ac, ctx->args.num_work_groups);
                break;
 
        case TGSI_SEMANTIC_BLOCK_SIZE:
@@ -2162,8 +2116,8 @@ void si_load_system_value(struct si_shader_context *ctx,
 
                for (int i = 0; i < 3; i++) {
                        values[i] = ctx->i32_0;
-                       if (ctx->abi.workgroup_ids[i]) {
-                               values[i] = ctx->abi.workgroup_ids[i];
+                       if (ctx->args.workgroup_ids[i].used) {
+                               values[i] = ac_get_arg(&ctx->ac, ctx->args.workgroup_ids[i]);
                        }
                }
                value = ac_build_gather_values(&ctx->ac, values, 3);
@@ -2171,7 +2125,7 @@ void si_load_system_value(struct si_shader_context *ctx,
        }
 
        case TGSI_SEMANTIC_THREAD_ID:
-               value = ctx->abi.local_invocation_ids;
+               value = ac_get_arg(&ctx->ac, ctx->args.local_invocation_ids);
                break;
 
        case TGSI_SEMANTIC_HELPER_INVOCATION:
@@ -2226,7 +2180,7 @@ void si_load_system_value(struct si_shader_context *ctx,
        }
 
        case TGSI_SEMANTIC_CS_USER_DATA_AMD:
-               value = LLVMGetParam(ctx->main_fn, ctx->param_cs_user_data);
+               value = ac_get_arg(&ctx->ac, ctx->cs_user_data);
                break;
 
        default:
@@ -2268,7 +2222,7 @@ void si_tgsi_declare_compute_memory(struct si_shader_context *ctx,
 static LLVMValueRef load_const_buffer_desc_fast_path(struct si_shader_context *ctx)
 {
        LLVMValueRef ptr =
-               LLVMGetParam(ctx->main_fn, ctx->param_const_and_shader_buffers);
+               ac_get_arg(&ctx->ac, ctx->const_and_shader_buffers);
        struct si_shader_selector *sel = ctx->shader->selector;
 
        /* Do the bounds checking with a descriptor, because
@@ -2308,8 +2262,8 @@ static LLVMValueRef load_const_buffer_desc_fast_path(struct si_shader_context *c
 
 static LLVMValueRef load_const_buffer_desc(struct si_shader_context *ctx, int i)
 {
-       LLVMValueRef list_ptr = LLVMGetParam(ctx->main_fn,
-                                            ctx->param_const_and_shader_buffers);
+       LLVMValueRef list_ptr = ac_get_arg(&ctx->ac,
+                                          ctx->const_and_shader_buffers);
 
        return ac_build_load_to_sgpr(&ctx->ac, list_ptr,
                                     LLVMConstInt(ctx->i32, si_get_constbuf_slot(i), 0));
@@ -2320,7 +2274,7 @@ static LLVMValueRef load_ubo(struct ac_shader_abi *abi, LLVMValueRef index)
        struct si_shader_context *ctx = si_shader_context_from_abi(abi);
        struct si_shader_selector *sel = ctx->shader->selector;
 
-       LLVMValueRef ptr = LLVMGetParam(ctx->main_fn, ctx->param_const_and_shader_buffers);
+       LLVMValueRef ptr = ac_get_arg(&ctx->ac, ctx->const_and_shader_buffers);
 
        if (sel->info.const_buffers_declared == 1 &&
            sel->info.shader_buffers_declared == 0) {
@@ -2338,8 +2292,8 @@ static LLVMValueRef
 load_ssbo(struct ac_shader_abi *abi, LLVMValueRef index, bool write)
 {
        struct si_shader_context *ctx = si_shader_context_from_abi(abi);
-       LLVMValueRef rsrc_ptr = LLVMGetParam(ctx->main_fn,
-                                            ctx->param_const_and_shader_buffers);
+       LLVMValueRef rsrc_ptr = ac_get_arg(&ctx->ac,
+                                          ctx->const_and_shader_buffers);
 
        index = si_llvm_bound_index(ctx, index, ctx->num_shader_buffers);
        index = LLVMBuildSub(ctx->ac.builder,
@@ -2401,7 +2355,7 @@ static LLVMValueRef fetch_constant(
        buf = reg->Dimension.Index;
 
        if (reg->Dimension.Indirect) {
-               LLVMValueRef ptr = LLVMGetParam(ctx->main_fn, ctx->param_const_and_shader_buffers);
+               LLVMValueRef ptr = ac_get_arg(&ctx->ac, ctx->const_and_shader_buffers);
                LLVMValueRef index;
                index = si_get_bounded_indirect_index(ctx, &reg->DimIndirect,
                                                      reg->Dimension.Index,
@@ -2605,7 +2559,7 @@ static void si_llvm_emit_clipvertex(struct si_shader_context *ctx,
        unsigned chan;
        unsigned const_chan;
        LLVMValueRef base_elt;
-       LLVMValueRef ptr = LLVMGetParam(ctx->main_fn, ctx->param_rw_buffers);
+       LLVMValueRef ptr = ac_get_arg(&ctx->ac, ctx->rw_buffers);
        LLVMValueRef constbuf_index = LLVMConstInt(ctx->i32,
                                                   SI_VS_CONST_CLIP_PLANES, 0);
        LLVMValueRef const_resource = ac_build_load_to_sgpr(&ctx->ac, ptr, constbuf_index);
@@ -2725,7 +2679,7 @@ static void si_llvm_emit_streamout(struct si_shader_context *ctx,
 
        /* Get bits [22:16], i.e. (so_param >> 16) & 127; */
        LLVMValueRef so_vtx_count =
-               si_unpack_param(ctx, ctx->param_streamout_config, 16, 7);
+               si_unpack_param(ctx, ctx->streamout_config, 16, 7);
 
        LLVMValueRef tid = ac_get_thread_id(&ctx->ac);
 
@@ -2745,8 +2699,8 @@ static void si_llvm_emit_streamout(struct si_shader_context *ctx,
                  */
 
                LLVMValueRef so_write_index =
-                       LLVMGetParam(ctx->main_fn,
-                                    ctx->param_streamout_write_index);
+                       ac_get_arg(&ctx->ac,
+                                  ctx->streamout_write_index);
 
                /* Compute (streamout_write_index + thread_id). */
                so_write_index = LLVMBuildAdd(builder, so_write_index, tid, "");
@@ -2755,8 +2709,8 @@ static void si_llvm_emit_streamout(struct si_shader_context *ctx,
                 * enabled buffer. */
                LLVMValueRef so_write_offset[4] = {};
                LLVMValueRef so_buffers[4];
-               LLVMValueRef buf_ptr = LLVMGetParam(ctx->main_fn,
-                                                   ctx->param_rw_buffers);
+               LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac,
+                                                 ctx->rw_buffers);
 
                for (i = 0; i < 4; i++) {
                        if (!so->stride[i])
@@ -2767,8 +2721,8 @@ static void si_llvm_emit_streamout(struct si_shader_context *ctx,
 
                        so_buffers[i] = ac_build_load_to_sgpr(&ctx->ac, buf_ptr, offset);
 
-                       LLVMValueRef so_offset = LLVMGetParam(ctx->main_fn,
-                                                             ctx->param_streamout_offset[i]);
+                       LLVMValueRef so_offset = ac_get_arg(&ctx->ac,
+                                                           ctx->streamout_offset[i]);
                        so_offset = LLVMBuildMul(builder, so_offset, LLVMConstInt(ctx->i32, 4, 0), "");
 
                        so_write_offset[i] = ac_build_imad(&ctx->ac, so_write_index,
@@ -2882,7 +2836,7 @@ static void si_vertex_color_clamping(struct si_shader_context *ctx,
                return;
 
        /* The state is in the first bit of the user SGPR. */
-       LLVMValueRef cond = LLVMGetParam(ctx->main_fn, ctx->param_vs_state_bits);
+       LLVMValueRef cond = ac_get_arg(&ctx->ac, ctx->vs_state_bits);
        cond = LLVMBuildTrunc(ctx->ac.builder, cond, ctx->i1, "");
 
        ac_build_ifcc(&ctx->ac, cond, 6502);
@@ -3087,9 +3041,9 @@ static void si_copy_tcs_inputs(struct lp_build_tgsi_context *bld_base)
        LLVMValueRef lds_vertex_stride, lds_base;
        uint64_t inputs;
 
-       invocation_id = unpack_llvm_param(ctx, ctx->abi.tcs_rel_ids, 8, 5);
+       invocation_id = si_unpack_param(ctx, ctx->args.tcs_rel_ids, 8, 5);
        buffer = get_tess_ring_descriptor(ctx, TESS_OFFCHIP_RING_TCS);
-       buffer_offset = LLVMGetParam(ctx->main_fn, ctx->param_tcs_offchip_offset);
+       buffer_offset = ac_get_arg(&ctx->ac, ctx->tcs_offchip_offset);
 
        lds_vertex_stride = get_tcs_in_vertex_dw_stride(ctx);
        lds_base = get_tcs_in_current_patch_offset(ctx);
@@ -3222,8 +3176,8 @@ static void si_write_tess_factors(struct lp_build_tgsi_context *bld_base,
        buffer = get_tess_ring_descriptor(ctx, TCS_FACTOR_RING);
 
        /* Get the offset. */
-       tf_base = LLVMGetParam(ctx->main_fn,
-                              ctx->param_tcs_factor_offset);
+       tf_base = ac_get_arg(&ctx->ac,
+                            ctx->tcs_factor_offset);
        byteoffset = LLVMBuildMul(ctx->ac.builder, rel_patch_id,
                                  LLVMConstInt(ctx->i32, 4 * stride, 0), "");
 
@@ -3260,7 +3214,7 @@ static void si_write_tess_factors(struct lp_build_tgsi_context *bld_base,
                unsigned param_outer, param_inner;
 
                buf = get_tess_ring_descriptor(ctx, TESS_OFFCHIP_RING_TCS);
-               base = LLVMGetParam(ctx->main_fn, ctx->param_tcs_offchip_offset);
+               base = ac_get_arg(&ctx->ac, ctx->tcs_offchip_offset);
 
                param_outer = si_shader_io_get_unique_index_patch(
                                      TGSI_SEMANTIC_TESSOUTER, 0);
@@ -3294,19 +3248,19 @@ static void si_write_tess_factors(struct lp_build_tgsi_context *bld_base,
 
 static LLVMValueRef
 si_insert_input_ret(struct si_shader_context *ctx, LLVMValueRef ret,
-                   unsigned param, unsigned return_index)
+                   struct ac_arg param, unsigned return_index)
 {
        return LLVMBuildInsertValue(ctx->ac.builder, ret,
-                                   LLVMGetParam(ctx->main_fn, param),
+                                   ac_get_arg(&ctx->ac, param),
                                    return_index, "");
 }
 
 static LLVMValueRef
 si_insert_input_ret_float(struct si_shader_context *ctx, LLVMValueRef ret,
-                         unsigned param, unsigned return_index)
+                         struct ac_arg param, unsigned return_index)
 {
        LLVMBuilderRef builder = ctx->ac.builder;
-       LLVMValueRef p = LLVMGetParam(ctx->main_fn, param);
+       LLVMValueRef p = ac_get_arg(&ctx->ac, param);
 
        return LLVMBuildInsertValue(builder, ret,
                                    ac_to_float(&ctx->ac, p),
@@ -3315,10 +3269,10 @@ si_insert_input_ret_float(struct si_shader_context *ctx, LLVMValueRef ret,
 
 static LLVMValueRef
 si_insert_input_ptr(struct si_shader_context *ctx, LLVMValueRef ret,
-                   unsigned param, unsigned return_index)
+                   struct ac_arg param, unsigned return_index)
 {
        LLVMBuilderRef builder = ctx->ac.builder;
-       LLVMValueRef ptr = LLVMGetParam(ctx->main_fn, param);
+       LLVMValueRef ptr = ac_get_arg(&ctx->ac, param);
        ptr = LLVMBuildPtrToInt(builder, ptr, ctx->i32, "");
        return LLVMBuildInsertValue(builder, ret, ptr, return_index, "");
 }
@@ -3336,7 +3290,7 @@ static void si_llvm_emit_tcs_epilogue(struct ac_shader_abi *abi,
        si_copy_tcs_inputs(bld_base);
 
        rel_patch_id = get_rel_patch_id(ctx);
-       invocation_id = unpack_llvm_param(ctx, ctx->abi.tcs_rel_ids, 8, 5);
+       invocation_id = si_unpack_param(ctx, ctx->args.tcs_rel_ids, 8, 5);
        tf_lds_offset = get_tcs_out_current_patch_data_offset(ctx);
 
        if (ctx->screen->info.chip_class >= GFX9) {
@@ -3366,23 +3320,23 @@ static void si_llvm_emit_tcs_epilogue(struct ac_shader_abi *abi,
        unsigned vgpr;
 
        if (ctx->screen->info.chip_class >= GFX9) {
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_offchip_layout,
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_offchip_layout,
                                          8 + GFX9_SGPR_TCS_OFFCHIP_LAYOUT);
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_out_lds_layout,
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_out_lds_layout,
                                          8 + GFX9_SGPR_TCS_OUT_LAYOUT);
                /* Tess offchip and tess factor offsets are at the beginning. */
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_offchip_offset, 2);
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_factor_offset, 4);
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_offchip_offset, 2);
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_factor_offset, 4);
                vgpr = 8 + GFX9_SGPR_TCS_OUT_LAYOUT + 1;
        } else {
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_offchip_layout,
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_offchip_layout,
                                          GFX6_SGPR_TCS_OFFCHIP_LAYOUT);
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_out_lds_layout,
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_out_lds_layout,
                                          GFX6_SGPR_TCS_OUT_LAYOUT);
                /* Tess offchip and tess factor offsets are after user SGPRs. */
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_offchip_offset,
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_offchip_offset,
                                          GFX6_TCS_NUM_USER_SGPR);
-               ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_factor_offset,
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_factor_offset,
                                          GFX6_TCS_NUM_USER_SGPR + 1);
                vgpr = GFX6_TCS_NUM_USER_SGPR + 2;
        }
@@ -3420,35 +3374,37 @@ static void si_set_ls_return_value_for_tcs(struct si_shader_context *ctx)
 {
        LLVMValueRef ret = ctx->return_value;
 
-       ret = si_insert_input_ptr(ctx, ret, 0, 0);
-       ret = si_insert_input_ptr(ctx, ret, 1, 1);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_offchip_offset, 2);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_merged_wave_info, 3);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_factor_offset, 4);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_merged_scratch_offset, 5);
+       ret = si_insert_input_ptr(ctx, ret, ctx->other_const_and_shader_buffers, 0);
+       ret = si_insert_input_ptr(ctx, ret, ctx->other_samplers_and_images, 1);
+       ret = si_insert_input_ret(ctx, ret, ctx->tcs_offchip_offset, 2);
+       ret = si_insert_input_ret(ctx, ret, ctx->merged_wave_info, 3);
+       ret = si_insert_input_ret(ctx, ret, ctx->tcs_factor_offset, 4);
+       ret = si_insert_input_ret(ctx, ret, ctx->merged_scratch_offset, 5);
 
-       ret = si_insert_input_ptr(ctx, ret, ctx->param_rw_buffers,
+       ret = si_insert_input_ptr(ctx, ret, ctx->rw_buffers,
                                  8 + SI_SGPR_RW_BUFFERS);
        ret = si_insert_input_ptr(ctx, ret,
-                                 ctx->param_bindless_samplers_and_images,
+                                 ctx->bindless_samplers_and_images,
                                  8 + SI_SGPR_BINDLESS_SAMPLERS_AND_IMAGES);
 
-       ret = si_insert_input_ret(ctx, ret, ctx->param_vs_state_bits,
+       ret = si_insert_input_ret(ctx, ret, ctx->vs_state_bits,
                                  8 + SI_SGPR_VS_STATE_BITS);
 
-       ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_offchip_layout,
+       ret = si_insert_input_ret(ctx, ret, ctx->tcs_offchip_layout,
                                  8 + GFX9_SGPR_TCS_OFFCHIP_LAYOUT);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_out_lds_offsets,
+       ret = si_insert_input_ret(ctx, ret, ctx->tcs_out_lds_offsets,
                                  8 + GFX9_SGPR_TCS_OUT_OFFSETS);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_tcs_out_lds_layout,
+       ret = si_insert_input_ret(ctx, ret, ctx->tcs_out_lds_layout,
                                  8 + GFX9_SGPR_TCS_OUT_LAYOUT);
 
        unsigned vgpr = 8 + GFX9_TCS_NUM_USER_SGPR;
        ret = LLVMBuildInsertValue(ctx->ac.builder, ret,
-                                  ac_to_float(&ctx->ac, ctx->abi.tcs_patch_id),
+                                  ac_to_float(&ctx->ac,
+                                              ac_get_arg(&ctx->ac, ctx->args.tcs_patch_id)),
                                   vgpr++, "");
        ret = LLVMBuildInsertValue(ctx->ac.builder, ret,
-                                  ac_to_float(&ctx->ac, ctx->abi.tcs_rel_ids),
+                                  ac_to_float(&ctx->ac,
+                                              ac_get_arg(&ctx->ac, ctx->args.tcs_rel_ids)),
                                   vgpr++, "");
        ctx->return_value = ret;
 }
@@ -3456,25 +3412,24 @@ static void si_set_ls_return_value_for_tcs(struct si_shader_context *ctx)
 /* Pass GS inputs from ES to GS on GFX9. */
 static void si_set_es_return_value_for_gs(struct si_shader_context *ctx)
 {
-       LLVMBuilderRef builder = ctx->ac.builder;
        LLVMValueRef ret = ctx->return_value;
 
-       ret = si_insert_input_ptr(ctx, ret, 0, 0);
-       ret = si_insert_input_ptr(ctx, ret, 1, 1);
+       ret = si_insert_input_ptr(ctx, ret, ctx->other_const_and_shader_buffers, 0);
+       ret = si_insert_input_ptr(ctx, ret, ctx->other_samplers_and_images, 1);
        if (ctx->shader->key.as_ngg)
-               ret = LLVMBuildInsertValue(builder, ret, ctx->gs_tg_info, 2, "");
+               ret = si_insert_input_ptr(ctx, ret, ctx->gs_tg_info, 2);
        else
-               ret = si_insert_input_ret(ctx, ret, ctx->param_gs2vs_offset, 2);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_merged_wave_info, 3);
-       ret = si_insert_input_ret(ctx, ret, ctx->param_merged_scratch_offset, 5);
+               ret = si_insert_input_ret(ctx, ret, ctx->gs2vs_offset, 2);
+       ret = si_insert_input_ret(ctx, ret, ctx->merged_wave_info, 3);
+       ret = si_insert_input_ret(ctx, ret, ctx->merged_scratch_offset, 5);
 
-       ret = si_insert_input_ptr(ctx, ret, ctx->param_rw_buffers,
+       ret = si_insert_input_ptr(ctx, ret, ctx->rw_buffers,
                                  8 + SI_SGPR_RW_BUFFERS);
        ret = si_insert_input_ptr(ctx, ret,
-                                 ctx->param_bindless_samplers_and_images,
+                                 ctx->bindless_samplers_and_images,
                                  8 + SI_SGPR_BINDLESS_SAMPLERS_AND_IMAGES);
        if (ctx->screen->use_ngg) {
-               ret = si_insert_input_ptr(ctx, ret, ctx->param_vs_state_bits,
+               ret = si_insert_input_ptr(ctx, ret, ctx->vs_state_bits,
                                          8 + SI_SGPR_VS_STATE_BITS);
        }
 
@@ -3484,10 +3439,11 @@ static void si_set_es_return_value_for_gs(struct si_shader_context *ctx)
        else
                vgpr = 8 + GFX9_TESGS_NUM_USER_SGPR;
 
-       for (unsigned i = 0; i < 5; i++) {
-               unsigned param = ctx->param_gs_vtx01_offset + i;
-               ret = si_insert_input_ret_float(ctx, ret, param, vgpr++);
-       }
+       ret = si_insert_input_ret_float(ctx, ret, ctx->gs_vtx01_offset, vgpr++);
+       ret = si_insert_input_ret_float(ctx, ret, ctx->gs_vtx23_offset, vgpr++);
+       ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_prim_id, vgpr++);
+       ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_invocation_id, vgpr++);
+       ret = si_insert_input_ret_float(ctx, ret, ctx->gs_vtx45_offset, vgpr++);
        ctx->return_value = ret;
 }
 
@@ -3499,8 +3455,7 @@ static void si_llvm_emit_ls_epilogue(struct ac_shader_abi *abi,
        struct si_shader *shader = ctx->shader;
        struct tgsi_shader_info *info = &shader->selector->info;
        unsigned i, chan;
-       LLVMValueRef vertex_id = LLVMGetParam(ctx->main_fn,
-                                             ctx->param_rel_auto_id);
+       LLVMValueRef vertex_id = ac_get_arg(&ctx->ac, ctx->rel_auto_id);
        LLVMValueRef vertex_dw_stride = get_tcs_in_vertex_dw_stride(ctx);
        LLVMValueRef base_dw_addr = LLVMBuildMul(ctx->ac.builder, vertex_id,
                                                 vertex_dw_stride, "");
@@ -3554,8 +3509,6 @@ static void si_llvm_emit_es_epilogue(struct ac_shader_abi *abi,
        struct si_shader_context *ctx = si_shader_context_from_abi(abi);
        struct si_shader *es = ctx->shader;
        struct tgsi_shader_info *info = &es->selector->info;
-       LLVMValueRef soffset = LLVMGetParam(ctx->main_fn,
-                                           ctx->param_es2gs_offset);
        LLVMValueRef lds_base = NULL;
        unsigned chan;
        int i;
@@ -3563,7 +3516,7 @@ static void si_llvm_emit_es_epilogue(struct ac_shader_abi *abi,
        if (ctx->screen->info.chip_class >= GFX9 && info->num_outputs) {
                unsigned itemsize_dw = es->selector->esgs_itemsize / 4;
                LLVMValueRef vertex_idx = ac_get_thread_id(&ctx->ac);
-               LLVMValueRef wave_idx = si_unpack_param(ctx, ctx->param_merged_wave_info, 24, 4);
+               LLVMValueRef wave_idx = si_unpack_param(ctx, ctx->merged_wave_info, 24, 4);
                vertex_idx = LLVMBuildOr(ctx->ac.builder, vertex_idx,
                                         LLVMBuildMul(ctx->ac.builder, wave_idx,
                                                      LLVMConstInt(ctx->i32, ctx->ac.wave_size, false), ""), "");
@@ -3598,7 +3551,8 @@ static void si_llvm_emit_es_epilogue(struct ac_shader_abi *abi,
 
                        ac_build_buffer_store_dword(&ctx->ac,
                                                    ctx->esgs_ring,
-                                                   out_val, 1, NULL, soffset,
+                                                   out_val, 1, NULL,
+                                                   ac_get_arg(&ctx->ac, ctx->es2gs_offset),
                                                    (4 * param + chan) * 4,
                                                    ac_glc | ac_slc, true);
                }
@@ -3611,9 +3565,9 @@ static void si_llvm_emit_es_epilogue(struct ac_shader_abi *abi,
 static LLVMValueRef si_get_gs_wave_id(struct si_shader_context *ctx)
 {
        if (ctx->screen->info.chip_class >= GFX9)
-               return si_unpack_param(ctx, ctx->param_merged_wave_info, 16, 8);
+               return si_unpack_param(ctx, ctx->merged_wave_info, 16, 8);
        else
-               return LLVMGetParam(ctx->main_fn, ctx->param_gs_wave_id);
+               return ac_get_arg(&ctx->ac, ctx->gs_wave_id);
 }
 
 static void emit_gs_epilogue(struct si_shader_context *ctx)
@@ -4000,7 +3954,7 @@ static void build_interp_intrinsic(const struct lp_build_tgsi_action *action,
        int input_base, input_array_size;
        int chan;
        int i;
-       LLVMValueRef prim_mask = ctx->abi.prim_mask;
+       LLVMValueRef prim_mask = ac_get_arg(&ctx->ac, ctx->args.prim_mask);
        LLVMValueRef array_idx, offset_x = NULL, offset_y = NULL;
        int interp_param_idx;
        unsigned interp;
@@ -4276,8 +4230,7 @@ static void si_llvm_emit_vertex(struct ac_shader_abi *abi,
 
        struct tgsi_shader_info *info = &ctx->shader->selector->info;
        struct si_shader *shader = ctx->shader;
-       LLVMValueRef soffset = LLVMGetParam(ctx->main_fn,
-                                           ctx->param_gs2vs_offset);
+       LLVMValueRef soffset = ac_get_arg(&ctx->ac, ctx->gs2vs_offset);
        LLVMValueRef gs_next_vertex;
        LLVMValueRef can_emit;
        unsigned chan, offset;
@@ -4408,75 +4361,47 @@ static void si_llvm_emit_barrier(const struct lp_build_tgsi_action *action,
 void si_create_function(struct si_shader_context *ctx,
                        const char *name,
                        LLVMTypeRef *returns, unsigned num_returns,
-                       struct si_function_info *fninfo,
                        unsigned max_workgroup_size)
 {
-       int i;
-
-       si_llvm_create_func(ctx, name, returns, num_returns,
-                           fninfo->types, fninfo->num_params);
+       si_llvm_create_func(ctx, name, returns, num_returns);
        ctx->return_value = LLVMGetUndef(ctx->return_type);
 
-       for (i = 0; i < fninfo->num_sgpr_params; ++i) {
-               LLVMValueRef P = LLVMGetParam(ctx->main_fn, i);
-
-               /* The combination of:
-                * - noalias
-                * - dereferenceable
-                * - invariant.load
-                * allows the optimization passes to move loads and reduces
-                * SGPR spilling significantly.
-                */
-               ac_add_function_attr(ctx->ac.context, ctx->main_fn, i + 1,
-                                    AC_FUNC_ATTR_INREG);
-
-               if (LLVMGetTypeKind(LLVMTypeOf(P)) == LLVMPointerTypeKind) {
-                       ac_add_function_attr(ctx->ac.context, ctx->main_fn, i + 1,
-                                            AC_FUNC_ATTR_NOALIAS);
-                       ac_add_attr_dereferenceable(P, UINT64_MAX);
-               }
-       }
-
-       for (i = 0; i < fninfo->num_params; ++i) {
-               if (fninfo->assign[i])
-                       *fninfo->assign[i] = LLVMGetParam(ctx->main_fn, i);
-       }
-
        if (ctx->screen->info.address32_hi) {
                ac_llvm_add_target_dep_function_attr(ctx->main_fn,
                                                     "amdgpu-32bit-address-high-bits",
                                                     ctx->screen->info.address32_hi);
        }
 
-       ac_llvm_set_workgroup_size(ctx->main_fn, max_workgroup_size);
-
        LLVMAddTargetDependentFunctionAttr(ctx->main_fn,
                                           "no-signed-zeros-fp-math",
                                           "true");
+
+       ac_llvm_set_workgroup_size(ctx->main_fn, max_workgroup_size);
 }
 
 static void declare_streamout_params(struct si_shader_context *ctx,
-                                    struct pipe_stream_output_info *so,
-                                    struct si_function_info *fninfo)
+                                    struct pipe_stream_output_info *so)
 {
-       if (ctx->screen->use_ngg_streamout)
+       if (ctx->screen->use_ngg_streamout) {
+               if (ctx->type == PIPE_SHADER_TESS_EVAL)
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL);
                return;
+       }
 
        /* Streamout SGPRs. */
        if (so->num_outputs) {
-               if (ctx->type != PIPE_SHADER_TESS_EVAL)
-                       ctx->param_streamout_config = add_arg(fninfo, ARG_SGPR, ctx->ac.i32);
-               else
-                       ctx->param_streamout_config = fninfo->num_params - 1;
-
-               ctx->param_streamout_write_index = add_arg(fninfo, ARG_SGPR, ctx->ac.i32);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->streamout_config);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->streamout_write_index);
+       } else if (ctx->type == PIPE_SHADER_TESS_EVAL) {
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL);
        }
+
        /* A streamout buffer offset is loaded if the stride is non-zero. */
        for (int i = 0; i < 4; i++) {
                if (!so->stride[i])
                        continue;
 
-               ctx->param_streamout_offset[i] = add_arg(fninfo, ARG_SGPR, ctx->ac.i32);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->streamout_offset[i]);
        }
 }
 
@@ -4518,129 +4443,118 @@ static unsigned si_get_max_workgroup_size(const struct si_shader *shader)
 }
 
 static void declare_const_and_shader_buffers(struct si_shader_context *ctx,
-                                            struct si_function_info *fninfo,
                                             bool assign_params)
 {
-       LLVMTypeRef const_shader_buf_type;
+       enum ac_arg_type const_shader_buf_type;
 
        if (ctx->shader->selector->info.const_buffers_declared == 1 &&
            ctx->shader->selector->info.shader_buffers_declared == 0)
-               const_shader_buf_type = ctx->f32;
+               const_shader_buf_type = AC_ARG_CONST_FLOAT_PTR;
        else
-               const_shader_buf_type = ctx->v4i32;
+               const_shader_buf_type = AC_ARG_CONST_DESC_PTR;
 
-       unsigned const_and_shader_buffers =
-               add_arg(fninfo, ARG_SGPR,
-                       ac_array_in_const32_addr_space(const_shader_buf_type));
-
-       if (assign_params)
-               ctx->param_const_and_shader_buffers = const_and_shader_buffers;
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, const_shader_buf_type,
+                  assign_params ? &ctx->const_and_shader_buffers :
+                  &ctx->other_const_and_shader_buffers);
 }
 
 static void declare_samplers_and_images(struct si_shader_context *ctx,
-                                       struct si_function_info *fninfo,
                                        bool assign_params)
 {
-       unsigned samplers_and_images =
-               add_arg(fninfo, ARG_SGPR,
-                       ac_array_in_const32_addr_space(ctx->v8i32));
-
-       if (assign_params)
-               ctx->param_samplers_and_images = samplers_and_images;
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_IMAGE_PTR,
+                  assign_params ? &ctx->samplers_and_images :
+                  &ctx->other_samplers_and_images);
 }
 
 static void declare_per_stage_desc_pointers(struct si_shader_context *ctx,
-                                           struct si_function_info *fninfo,
                                            bool assign_params)
 {
-       declare_const_and_shader_buffers(ctx, fninfo, assign_params);
-       declare_samplers_and_images(ctx, fninfo, assign_params);
+       declare_const_and_shader_buffers(ctx, assign_params);
+       declare_samplers_and_images(ctx, assign_params);
 }
 
-static void declare_global_desc_pointers(struct si_shader_context *ctx,
-                                        struct si_function_info *fninfo)
+static void declare_global_desc_pointers(struct si_shader_context *ctx)
 {
-       ctx->param_rw_buffers = add_arg(fninfo, ARG_SGPR,
-               ac_array_in_const32_addr_space(ctx->v4i32));
-       ctx->param_bindless_samplers_and_images = add_arg(fninfo, ARG_SGPR,
-               ac_array_in_const32_addr_space(ctx->v8i32));
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                  &ctx->rw_buffers);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_IMAGE_PTR,
+                  &ctx->bindless_samplers_and_images);
 }
 
-static void declare_vs_specific_input_sgprs(struct si_shader_context *ctx,
-                                           struct si_function_info *fninfo)
+static void declare_vs_specific_input_sgprs(struct si_shader_context *ctx)
 {
-       ctx->param_vs_state_bits = add_arg(fninfo, ARG_SGPR, ctx->i32);
-       add_arg_assign(fninfo, ARG_SGPR, ctx->i32, &ctx->abi.base_vertex);
-       add_arg_assign(fninfo, ARG_SGPR, ctx->i32, &ctx->abi.start_instance);
-       add_arg_assign(fninfo, ARG_SGPR, ctx->i32, &ctx->abi.draw_id);
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->vs_state_bits);
+       if (!ctx->shader->is_gs_copy_shader) {
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->args.base_vertex);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->args.start_instance);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->args.draw_id);
+       }
 }
 
 static void declare_vs_input_vgprs(struct si_shader_context *ctx,
-                                  struct si_function_info *fninfo,
                                   unsigned *num_prolog_vgprs)
 {
        struct si_shader *shader = ctx->shader;
 
-       add_arg_assign(fninfo, ARG_VGPR, ctx->i32, &ctx->abi.vertex_id);
+       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.vertex_id);
        if (shader->key.as_ls) {
-               ctx->param_rel_auto_id = add_arg(fninfo, ARG_VGPR, ctx->i32);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->rel_auto_id);
                if (ctx->screen->info.chip_class >= GFX10) {
-                       add_arg(fninfo, ARG_VGPR, ctx->i32); /* user VGPR */
-                       add_arg_assign(fninfo, ARG_VGPR, ctx->i32, &ctx->abi.instance_id);
+                       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* user VGPR */
+                       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.instance_id);
                } else {
-                       add_arg_assign(fninfo, ARG_VGPR, ctx->i32, &ctx->abi.instance_id);
-                       add_arg(fninfo, ARG_VGPR, ctx->i32); /* unused */
+                       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.instance_id);
+                       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* unused */
                }
        } else if (ctx->screen->info.chip_class >= GFX10) {
-               add_arg(fninfo, ARG_VGPR, ctx->i32); /* user vgpr */
-               ctx->param_vs_prim_id = add_arg(fninfo, ARG_VGPR, ctx->i32); /* user vgpr or PrimID (legacy) */
-               add_arg_assign(fninfo, ARG_VGPR, ctx->i32, &ctx->abi.instance_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* user VGPR */
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT,
+                          &ctx->vs_prim_id); /* user vgpr or PrimID (legacy) */
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.instance_id);
        } else {
-               add_arg_assign(fninfo, ARG_VGPR, ctx->i32, &ctx->abi.instance_id);
-               ctx->param_vs_prim_id = add_arg(fninfo, ARG_VGPR, ctx->i32);
-               add_arg(fninfo, ARG_VGPR, ctx->i32); /* unused */
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.instance_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->vs_prim_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, NULL); /* unused */
        }
 
        if (!shader->is_gs_copy_shader) {
                /* Vertex load indices. */
-               ctx->param_vertex_index0 = fninfo->num_params;
-               for (unsigned i = 0; i < shader->selector->info.num_inputs; i++)
-                       add_arg(fninfo, ARG_VGPR, ctx->i32);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->vertex_index0);
+               for (unsigned i = 1; i < shader->selector->info.num_inputs; i++)
+                       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, NULL);
                *num_prolog_vgprs += shader->selector->info.num_inputs;
        }
 }
 
 static void declare_vs_blit_inputs(struct si_shader_context *ctx,
-                                  struct si_function_info *fninfo,
                                   unsigned vs_blit_property)
 {
-       ctx->param_vs_blit_inputs = fninfo->num_params;
-       add_arg(fninfo, ARG_SGPR, ctx->i32); /* i16 x1, y1 */
-       add_arg(fninfo, ARG_SGPR, ctx->i32); /* i16 x2, y2 */
-       add_arg(fninfo, ARG_SGPR, ctx->f32); /* depth */
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT,
+                  &ctx->vs_blit_inputs); /* i16 x1, y1 */
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); /* i16 x1, y1 */
+       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* depth */
 
        if (vs_blit_property == SI_VS_BLIT_SGPRS_POS_COLOR) {
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* color0 */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* color1 */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* color2 */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* color3 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* color0 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* color1 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* color2 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* color3 */
        } else if (vs_blit_property == SI_VS_BLIT_SGPRS_POS_TEXCOORD) {
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* texcoord.x1 */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* texcoord.y1 */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* texcoord.x2 */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* texcoord.y2 */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* texcoord.z */
-               add_arg(fninfo, ARG_SGPR, ctx->f32); /* texcoord.w */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* texcoord.x1 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* texcoord.y1 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* texcoord.x2 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* texcoord.y2 */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* texcoord.z */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_FLOAT, NULL); /* texcoord.w */
        }
 }
 
-static void declare_tes_input_vgprs(struct si_shader_context *ctx,
-                                   struct si_function_info *fninfo)
+static void declare_tes_input_vgprs(struct si_shader_context *ctx)
 {
-       ctx->param_tes_u = add_arg(fninfo, ARG_VGPR, ctx->f32);
-       ctx->param_tes_v = add_arg(fninfo, ARG_VGPR, ctx->f32);
-       ctx->param_tes_rel_patch_id = add_arg(fninfo, ARG_VGPR, ctx->i32);
-       add_arg_assign(fninfo, ARG_VGPR, ctx->i32, &ctx->abi.tes_patch_id);
+       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &ctx->tes_u);
+       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT, &ctx->tes_v);
+       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->tes_rel_patch_id);
+       ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.tes_patch_id);
 }
 
 enum {
@@ -4649,10 +4563,19 @@ enum {
        SI_SHADER_MERGED_VERTEX_OR_TESSEVAL_GEOMETRY,
 };
 
+static void add_arg_checked(struct ac_shader_args *args,
+                           enum ac_arg_regfile file,
+                           unsigned registers, enum ac_arg_type type,
+                           struct ac_arg *arg,
+                           unsigned idx)
+{
+       assert(args->arg_count == idx);
+       ac_add_arg(args, file, registers, type, arg);
+}
+
 static void create_function(struct si_shader_context *ctx)
 {
        struct si_shader *shader = ctx->shader;
-       struct si_function_info fninfo;
        LLVMTypeRef returns[16+32*4];
        unsigned i, num_return_sgprs;
        unsigned num_returns = 0;
@@ -4661,7 +4584,7 @@ static void create_function(struct si_shader_context *ctx)
        unsigned vs_blit_property =
                shader->selector->info.properties[TGSI_PROPERTY_VS_BLIT_SGPRS_AMD];
 
-       si_init_function_info(&fninfo);
+       memset(&ctx->args, 0, sizeof(ctx->args));
 
        /* Set MERGED shaders. */
        if (ctx->screen->info.chip_class >= GFX9) {
@@ -4671,42 +4594,37 @@ static void create_function(struct si_shader_context *ctx)
                        type = SI_SHADER_MERGED_VERTEX_OR_TESSEVAL_GEOMETRY;
        }
 
-       LLVMTypeRef v3i32 = LLVMVectorType(ctx->i32, 3);
-
        switch (type) {
        case PIPE_SHADER_VERTEX:
-               declare_global_desc_pointers(ctx, &fninfo);
+               declare_global_desc_pointers(ctx);
 
                if (vs_blit_property) {
-                       declare_vs_blit_inputs(ctx, &fninfo, vs_blit_property);
+                       declare_vs_blit_inputs(ctx, vs_blit_property);
 
                        /* VGPRs */
-                       declare_vs_input_vgprs(ctx, &fninfo, &num_prolog_vgprs);
+                       declare_vs_input_vgprs(ctx, &num_prolog_vgprs);
                        break;
                }
 
-               declare_per_stage_desc_pointers(ctx, &fninfo, true);
-               declare_vs_specific_input_sgprs(ctx, &fninfo);
-               ctx->param_vertex_buffers = add_arg(&fninfo, ARG_SGPR,
-                       ac_array_in_const32_addr_space(ctx->v4i32));
+               declare_per_stage_desc_pointers(ctx, true);
+               declare_vs_specific_input_sgprs(ctx); 
+               if (!shader->is_gs_copy_shader) {
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                                  &ctx->vertex_buffers);
+               }
 
                if (shader->key.as_es) {
-                       ctx->param_es2gs_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                  &ctx->es2gs_offset);
                } else if (shader->key.as_ls) {
                        /* no extra parameters */
                } else {
-                       if (shader->is_gs_copy_shader) {
-                               fninfo.num_params = ctx->param_vs_state_bits + 1;
-                               fninfo.num_sgpr_params = fninfo.num_params;
-                       }
-
                        /* The locations of the other parameters are assigned dynamically. */
-                       declare_streamout_params(ctx, &shader->selector->so,
-                                                &fninfo);
+                       declare_streamout_params(ctx, &shader->selector->so);
                }
 
                /* VGPRs */
-               declare_vs_input_vgprs(ctx, &fninfo, &num_prolog_vgprs);
+               declare_vs_input_vgprs(ctx, &num_prolog_vgprs);
 
                /* Return values */
                if (shader->key.opt.vs_as_prim_discard_cs) {
@@ -4716,18 +4634,18 @@ static void create_function(struct si_shader_context *ctx)
                break;
 
        case PIPE_SHADER_TESS_CTRL: /* GFX6-GFX8 */
-               declare_global_desc_pointers(ctx, &fninfo);
-               declare_per_stage_desc_pointers(ctx, &fninfo, true);
-               ctx->param_tcs_offchip_layout = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_out_lds_offsets = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_out_lds_layout = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_vs_state_bits = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_offchip_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_factor_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+               declare_global_desc_pointers(ctx);
+               declare_per_stage_desc_pointers(ctx, true);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_layout);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_out_lds_offsets);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_out_lds_layout);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->vs_state_bits);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_offset);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_factor_offset);
 
                /* VGPRs */
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.tcs_patch_id);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.tcs_rel_ids);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.tcs_patch_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.tcs_rel_ids);
 
                /* param_tcs_offchip_offset and param_tcs_factor_offset are
                 * placed after the user SGPRs.
@@ -4741,33 +4659,31 @@ static void create_function(struct si_shader_context *ctx)
        case SI_SHADER_MERGED_VERTEX_TESSCTRL:
                /* Merged stages have 8 system SGPRs at the beginning. */
                /* SPI_SHADER_USER_DATA_ADDR_LO/HI_HS */
-               declare_per_stage_desc_pointers(ctx, &fninfo,
+               declare_per_stage_desc_pointers(ctx,
                                                ctx->type == PIPE_SHADER_TESS_CTRL);
-               ctx->param_tcs_offchip_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_merged_wave_info = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_factor_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_merged_scratch_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               add_arg(&fninfo, ARG_SGPR, ctx->i32); /* unused */
-               add_arg(&fninfo, ARG_SGPR, ctx->i32); /* unused */
-
-               declare_global_desc_pointers(ctx, &fninfo);
-               declare_per_stage_desc_pointers(ctx, &fninfo,
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_offset);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->merged_wave_info);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_factor_offset);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->merged_scratch_offset);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); /* unused */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); /* unused */
+
+               declare_global_desc_pointers(ctx);
+               declare_per_stage_desc_pointers(ctx,
                                                ctx->type == PIPE_SHADER_VERTEX);
-               declare_vs_specific_input_sgprs(ctx, &fninfo);
+               declare_vs_specific_input_sgprs(ctx);
 
-               ctx->param_tcs_offchip_layout = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_out_lds_offsets = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_out_lds_layout = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_vertex_buffers = add_arg(&fninfo, ARG_SGPR,
-                       ac_array_in_const32_addr_space(ctx->v4i32));
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_layout);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_out_lds_offsets);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_out_lds_layout);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR, &ctx->vertex_buffers);
 
                /* VGPRs (first TCS, then VS) */
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.tcs_patch_id);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.tcs_rel_ids);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.tcs_patch_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.tcs_rel_ids);
 
                if (ctx->type == PIPE_SHADER_VERTEX) {
-                       declare_vs_input_vgprs(ctx, &fninfo,
-                                              &num_prolog_vgprs);
+                       declare_vs_input_vgprs(ctx, &num_prolog_vgprs);
 
                        /* LS return values are inputs to the TCS main shader part. */
                        for (i = 0; i < 8 + GFX9_TCS_NUM_USER_SGPR; i++)
@@ -4791,56 +4707,55 @@ static void create_function(struct si_shader_context *ctx)
        case SI_SHADER_MERGED_VERTEX_OR_TESSEVAL_GEOMETRY:
                /* Merged stages have 8 system SGPRs at the beginning. */
                /* SPI_SHADER_USER_DATA_ADDR_LO/HI_GS */
-               declare_per_stage_desc_pointers(ctx, &fninfo,
+               declare_per_stage_desc_pointers(ctx,
                                                ctx->type == PIPE_SHADER_GEOMETRY);
 
                if (ctx->shader->key.as_ngg)
-                       add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &ctx->gs_tg_info);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->gs_tg_info);
                else
-                       ctx->param_gs2vs_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->gs2vs_offset);
 
-               ctx->param_merged_wave_info = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_offchip_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_merged_scratch_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               add_arg(&fninfo, ARG_SGPR, ctx->i32); /* unused (SPI_SHADER_PGM_LO/HI_GS << 8) */
-               add_arg(&fninfo, ARG_SGPR, ctx->i32); /* unused (SPI_SHADER_PGM_LO/HI_GS >> 24) */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->merged_wave_info);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_offset);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->merged_scratch_offset);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); /* unused (SPI_SHADER_PGM_LO/HI_GS << 8) */
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL); /* unused (SPI_SHADER_PGM_LO/HI_GS >> 24) */
 
-               declare_global_desc_pointers(ctx, &fninfo);
+               declare_global_desc_pointers(ctx);
                if (ctx->type != PIPE_SHADER_VERTEX || !vs_blit_property) {
-                       declare_per_stage_desc_pointers(ctx, &fninfo,
+                       declare_per_stage_desc_pointers(ctx,
                                                        (ctx->type == PIPE_SHADER_VERTEX ||
                                                         ctx->type == PIPE_SHADER_TESS_EVAL));
                }
 
                if (ctx->type == PIPE_SHADER_VERTEX) {
                        if (vs_blit_property)
-                               declare_vs_blit_inputs(ctx, &fninfo, vs_blit_property);
+                               declare_vs_blit_inputs(ctx, vs_blit_property);
                        else
-                               declare_vs_specific_input_sgprs(ctx, &fninfo);
+                               declare_vs_specific_input_sgprs(ctx);
                } else {
-                       ctx->param_vs_state_bits = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-                       ctx->param_tcs_offchip_layout = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-                       ctx->param_tes_offchip_addr = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->vs_state_bits);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_layout);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tes_offchip_addr);
                        /* Declare as many input SGPRs as the VS has. */
                }
 
                if (ctx->type == PIPE_SHADER_VERTEX) {
-                       ctx->param_vertex_buffers = add_arg(&fninfo, ARG_SGPR,
-                               ac_array_in_const32_addr_space(ctx->v4i32));
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_CONST_DESC_PTR,
+                                  &ctx->vertex_buffers);
                }
 
                /* VGPRs (first GS, then VS/TES) */
-               ctx->param_gs_vtx01_offset = add_arg(&fninfo, ARG_VGPR, ctx->i32);
-               ctx->param_gs_vtx23_offset = add_arg(&fninfo, ARG_VGPR, ctx->i32);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_prim_id);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_invocation_id);
-               ctx->param_gs_vtx45_offset = add_arg(&fninfo, ARG_VGPR, ctx->i32);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx01_offset);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx23_offset);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.gs_prim_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.gs_invocation_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx45_offset);
 
                if (ctx->type == PIPE_SHADER_VERTEX) {
-                       declare_vs_input_vgprs(ctx, &fninfo,
-                                              &num_prolog_vgprs);
+                       declare_vs_input_vgprs(ctx, &num_prolog_vgprs);
                } else if (ctx->type == PIPE_SHADER_TESS_EVAL) {
-                       declare_tes_input_vgprs(ctx, &fninfo);
+                       declare_tes_input_vgprs(ctx);
                }
 
                if (ctx->shader->key.as_es &&
@@ -4862,91 +4777,92 @@ static void create_function(struct si_shader_context *ctx)
                break;
 
        case PIPE_SHADER_TESS_EVAL:
-               declare_global_desc_pointers(ctx, &fninfo);
-               declare_per_stage_desc_pointers(ctx, &fninfo, true);
-               ctx->param_vs_state_bits = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tcs_offchip_layout = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_tes_offchip_addr = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+               declare_global_desc_pointers(ctx);
+               declare_per_stage_desc_pointers(ctx, true);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->vs_state_bits);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_layout);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tes_offchip_addr);
 
                if (shader->key.as_es) {
-                       ctx->param_tcs_offchip_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-                       add_arg(&fninfo, ARG_SGPR, ctx->i32);
-                       ctx->param_es2gs_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_offset);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->es2gs_offset);
                } else {
-                       add_arg(&fninfo, ARG_SGPR, ctx->i32);
-                       declare_streamout_params(ctx, &shader->selector->so,
-                                                &fninfo);
-                       ctx->param_tcs_offchip_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+                       declare_streamout_params(ctx, &shader->selector->so);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->tcs_offchip_offset);
                }
 
                /* VGPRs */
-               declare_tes_input_vgprs(ctx, &fninfo);
+               declare_tes_input_vgprs(ctx);
                break;
 
        case PIPE_SHADER_GEOMETRY:
-               declare_global_desc_pointers(ctx, &fninfo);
-               declare_per_stage_desc_pointers(ctx, &fninfo, true);
-               ctx->param_gs2vs_offset = add_arg(&fninfo, ARG_SGPR, ctx->i32);
-               ctx->param_gs_wave_id = add_arg(&fninfo, ARG_SGPR, ctx->i32);
+               declare_global_desc_pointers(ctx);
+               declare_per_stage_desc_pointers(ctx, true);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->gs2vs_offset);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, &ctx->gs_wave_id);
 
                /* VGPRs */
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[0]);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[1]);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_prim_id);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[2]);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[3]);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[4]);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->gs_vtx_offset[5]);
-               add_arg_assign(&fninfo, ARG_VGPR, ctx->i32, &ctx->abi.gs_invocation_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx_offset[0]);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx_offset[1]);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.gs_prim_id);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx_offset[2]);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx_offset[3]);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx_offset[4]);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->gs_vtx_offset[5]);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, &ctx->args.gs_invocation_id);
                break;
 
        case PIPE_SHADER_FRAGMENT:
-               declare_global_desc_pointers(ctx, &fninfo);
-               declare_per_stage_desc_pointers(ctx, &fninfo, true);
-               add_arg_checked(&fninfo, ARG_SGPR, ctx->f32, SI_PARAM_ALPHA_REF);
-               add_arg_assign_checked(&fninfo, ARG_SGPR, ctx->i32,
-                                      &ctx->abi.prim_mask, SI_PARAM_PRIM_MASK);
-
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->v2i32,
-                                      &ctx->abi.persp_sample, SI_PARAM_PERSP_SAMPLE);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->v2i32,
-                                      &ctx->abi.persp_center, SI_PARAM_PERSP_CENTER);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->v2i32,
-                                      &ctx->abi.persp_centroid, SI_PARAM_PERSP_CENTROID);
-               add_arg_checked(&fninfo, ARG_VGPR, v3i32, SI_PARAM_PERSP_PULL_MODEL);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->v2i32,
-                                      &ctx->abi.linear_sample, SI_PARAM_LINEAR_SAMPLE);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->v2i32,
-                                      &ctx->abi.linear_center, SI_PARAM_LINEAR_CENTER);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->v2i32,
-                                      &ctx->abi.linear_centroid, SI_PARAM_LINEAR_CENTROID);
-               add_arg_checked(&fninfo, ARG_VGPR, ctx->f32, SI_PARAM_LINE_STIPPLE_TEX);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->f32,
-                                      &ctx->abi.frag_pos[0], SI_PARAM_POS_X_FLOAT);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->f32,
-                                      &ctx->abi.frag_pos[1], SI_PARAM_POS_Y_FLOAT);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->f32,
-                                      &ctx->abi.frag_pos[2], SI_PARAM_POS_Z_FLOAT);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->f32,
-                                      &ctx->abi.frag_pos[3], SI_PARAM_POS_W_FLOAT);
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->i32,
-                                      &ctx->abi.front_face, SI_PARAM_FRONT_FACE);
+               declare_global_desc_pointers(ctx);
+               declare_per_stage_desc_pointers(ctx, true);
+               add_arg_checked(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL,
+                               SI_PARAM_ALPHA_REF);
+               add_arg_checked(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT,
+                               &ctx->args.prim_mask, SI_PARAM_PRIM_MASK);
+
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 2, AC_ARG_INT, &ctx->args.persp_sample,
+                               SI_PARAM_PERSP_SAMPLE);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 2, AC_ARG_INT,
+                               &ctx->args.persp_center, SI_PARAM_PERSP_CENTER);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 2, AC_ARG_INT,
+                               &ctx->args.persp_centroid, SI_PARAM_PERSP_CENTROID);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 3, AC_ARG_INT,
+                               NULL, SI_PARAM_PERSP_PULL_MODEL);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 2, AC_ARG_INT, 
+                               &ctx->args.linear_sample, SI_PARAM_LINEAR_SAMPLE);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 2, AC_ARG_INT,
+                               &ctx->args.linear_center, SI_PARAM_LINEAR_CENTER);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 2, AC_ARG_INT,
+                               &ctx->args.linear_centroid, SI_PARAM_LINEAR_CENTROID);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 3, AC_ARG_FLOAT,
+                               NULL, SI_PARAM_LINE_STIPPLE_TEX);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT,
+                               &ctx->args.frag_pos[0], SI_PARAM_POS_X_FLOAT);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT,
+                               &ctx->args.frag_pos[1], SI_PARAM_POS_Y_FLOAT);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT,
+                               &ctx->args.frag_pos[2], SI_PARAM_POS_Z_FLOAT);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT,
+                               &ctx->args.frag_pos[3], SI_PARAM_POS_W_FLOAT);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT,
+                               &ctx->args.front_face, SI_PARAM_FRONT_FACE);
                shader->info.face_vgpr_index = 20;
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->i32,
-                                      &ctx->abi.ancillary, SI_PARAM_ANCILLARY);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT,
+                               &ctx->args.ancillary, SI_PARAM_ANCILLARY);
                shader->info.ancillary_vgpr_index = 21;
-               add_arg_assign_checked(&fninfo, ARG_VGPR, ctx->f32,
-                                      &ctx->abi.sample_coverage, SI_PARAM_SAMPLE_COVERAGE);
-               add_arg_checked(&fninfo, ARG_VGPR, ctx->i32, SI_PARAM_POS_FIXED_PT);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT,
+                               &ctx->args.sample_coverage, SI_PARAM_SAMPLE_COVERAGE);
+               add_arg_checked(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT,
+                               &ctx->pos_fixed_pt, SI_PARAM_POS_FIXED_PT);
 
                /* Color inputs from the prolog. */
                if (shader->selector->info.colors_read) {
                        unsigned num_color_elements =
                                util_bitcount(shader->selector->info.colors_read);
 
-                       assert(fninfo.num_params + num_color_elements <= ARRAY_SIZE(fninfo.types));
                        for (i = 0; i < num_color_elements; i++)
-                               add_arg(&fninfo, ARG_VGPR, ctx->f32);
+                               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_FLOAT, NULL);
 
                        num_prolog_vgprs += num_color_elements;
                }
@@ -4972,35 +4888,38 @@ static void create_function(struct si_shader_context *ctx)
                break;
 
        case PIPE_SHADER_COMPUTE:
-               declare_global_desc_pointers(ctx, &fninfo);
-               declare_per_stage_desc_pointers(ctx, &fninfo, true);
+               declare_global_desc_pointers(ctx);
+               declare_per_stage_desc_pointers(ctx, true);
                if (shader->selector->info.uses_grid_size)
-                       add_arg_assign(&fninfo, ARG_SGPR, v3i32, &ctx->abi.num_work_groups);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 3, AC_ARG_INT,
+                                  &ctx->args.num_work_groups);
                if (shader->selector->info.uses_block_size &&
                    shader->selector->info.properties[TGSI_PROPERTY_CS_FIXED_BLOCK_WIDTH] == 0)
-                       ctx->param_block_size = add_arg(&fninfo, ARG_SGPR, v3i32);
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, 3, AC_ARG_INT, &ctx->block_size);
 
                unsigned cs_user_data_dwords =
                        shader->selector->info.properties[TGSI_PROPERTY_CS_USER_DATA_COMPONENTS_AMD];
                if (cs_user_data_dwords) {
-                       ctx->param_cs_user_data = add_arg(&fninfo, ARG_SGPR,
-                                                         LLVMVectorType(ctx->i32, cs_user_data_dwords));
+                       ac_add_arg(&ctx->args, AC_ARG_SGPR, cs_user_data_dwords, AC_ARG_INT,
+                                  &ctx->cs_user_data);
                }
 
                for (i = 0; i < 3; i++) {
-                       ctx->abi.workgroup_ids[i] = NULL;
-                       if (shader->selector->info.uses_block_id[i])
-                               add_arg_assign(&fninfo, ARG_SGPR, ctx->i32, &ctx->abi.workgroup_ids[i]);
+                       if (shader->selector->info.uses_block_id[i]) {
+                               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT,
+                                          &ctx->args.workgroup_ids[i]);
+                       }
                }
 
-               add_arg_assign(&fninfo, ARG_VGPR, v3i32, &ctx->abi.local_invocation_ids);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 3, AC_ARG_INT,
+                          &ctx->args.local_invocation_ids);
                break;
        default:
                assert(0 && "unimplemented shader");
                return;
        }
 
-       si_create_function(ctx, "main", returns, num_returns, &fninfo,
+       si_create_function(ctx, "main", returns, num_returns,
                           si_get_max_workgroup_size(shader));
 
        /* Reserve register locations for VGPR inputs the PS prolog may need. */
@@ -5018,14 +4937,8 @@ static void create_function(struct si_shader_context *ctx)
                                                     S_0286D0_POS_FIXED_PT_ENA(1));
        }
 
-       shader->info.num_input_sgprs = 0;
-       shader->info.num_input_vgprs = 0;
-
-       for (i = 0; i < fninfo.num_sgpr_params; ++i)
-               shader->info.num_input_sgprs += ac_get_type_size(fninfo.types[i]) / 4;
-
-       for (; i < fninfo.num_params; ++i)
-               shader->info.num_input_vgprs += ac_get_type_size(fninfo.types[i]) / 4;
+       shader->info.num_input_sgprs = ctx->args.num_sgprs_used;
+       shader->info.num_input_vgprs = ctx->args.num_vgprs_used;
 
        assert(shader->info.num_input_vgprs >= num_prolog_vgprs);
        shader->info.num_input_vgprs -= num_prolog_vgprs;
@@ -5045,6 +4958,17 @@ static void create_function(struct si_shader_context *ctx)
                        ac_declare_lds_as_pointer(&ctx->ac);
                }
        }
+
+       /* Unlike radv, we override these arguments in the prolog, so to the
+        * API shader they appear as normal arguments.
+        */
+       if (ctx->type == PIPE_SHADER_VERTEX) {
+               ctx->abi.vertex_id = ac_get_arg(&ctx->ac, ctx->args.vertex_id);
+               ctx->abi.instance_id = ac_get_arg(&ctx->ac, ctx->args.instance_id);
+       } else if (ctx->type == PIPE_SHADER_FRAGMENT) {
+               ctx->abi.persp_centroid = ac_get_arg(&ctx->ac, ctx->args.persp_centroid);
+               ctx->abi.linear_centroid = ac_get_arg(&ctx->ac, ctx->args.linear_centroid);
+       }
 }
 
 /* Ensure that the esgs ring is declared.
@@ -5075,8 +4999,7 @@ static void preload_ring_buffers(struct si_shader_context *ctx)
 {
        LLVMBuilderRef builder = ctx->ac.builder;
 
-       LLVMValueRef buf_ptr = LLVMGetParam(ctx->main_fn,
-                                           ctx->param_rw_buffers);
+       LLVMValueRef buf_ptr = ac_get_arg(&ctx->ac, ctx->rw_buffers);
 
        if (ctx->shader->key.as_es || ctx->type == PIPE_SHADER_GEOMETRY) {
                if (ctx->screen->info.chip_class <= GFX8) {
@@ -5188,7 +5111,7 @@ static void preload_ring_buffers(struct si_shader_context *ctx)
 
 static void si_llvm_emit_polygon_stipple(struct si_shader_context *ctx,
                                         LLVMValueRef param_rw_buffers,
-                                        unsigned param_pos_fixed_pt)
+                                        struct ac_arg param_pos_fixed_pt)
 {
        LLVMBuilderRef builder = ctx->ac.builder;
        LLVMValueRef slot, desc, offset, row, bit, address[2];
@@ -5747,7 +5670,7 @@ si_generate_gs_copy_shader(struct si_screen *sscreen,
        LLVMValueRef stream_id;
 
        if (!sscreen->use_ngg_streamout && gs_selector->so.num_outputs)
-               stream_id = si_unpack_param(&ctx, ctx.param_streamout_config, 24, 2);
+               stream_id = si_unpack_param(&ctx, ctx.streamout_config, 24, 2);
        else
                stream_id = ctx.i32_0;
 
@@ -6051,10 +5974,10 @@ static void si_optimize_vs_outputs(struct si_shader_context *ctx)
 }
 
 static void si_init_exec_from_input(struct si_shader_context *ctx,
-                                   unsigned param, unsigned bitoffset)
+                                   struct ac_arg param, unsigned bitoffset)
 {
        LLVMValueRef args[] = {
-               LLVMGetParam(ctx->main_fn, param),
+               ac_get_arg(&ctx->ac, param),
                LLVMConstInt(ctx->i32, bitoffset, 0),
        };
        ac_build_intrinsic(&ctx->ac,
@@ -6233,7 +6156,7 @@ static bool si_compile_tgsi_main(struct si_shader_context *ctx,
                     (ctx->type == PIPE_SHADER_VERTEX &&
                      !si_vs_needs_prolog(sel, &shader->key.part.vs.prolog)))) {
                        si_init_exec_from_input(ctx,
-                                               ctx->param_merged_wave_info, 0);
+                                               ctx->merged_wave_info, 0);
                } else if (ctx->type == PIPE_SHADER_TESS_CTRL ||
                           ctx->type == PIPE_SHADER_GEOMETRY ||
                           (shader->key.as_ngg && !shader->key.as_es)) {
@@ -6255,10 +6178,10 @@ static bool si_compile_tgsi_main(struct si_shader_context *ctx,
                                }
 
                                /* Number of patches / primitives */
-                               num_threads = si_unpack_param(ctx, ctx->param_merged_wave_info, 8, 8);
+                               num_threads = si_unpack_param(ctx, ctx->merged_wave_info, 8, 8);
                        } else {
                                /* Number of vertices */
-                               num_threads = si_unpack_param(ctx, ctx->param_merged_wave_info, 0, 8);
+                               num_threads = si_unpack_param(ctx, ctx->merged_wave_info, 0, 8);
                                nested_barrier = false;
                        }
 
@@ -6535,12 +6458,11 @@ static void si_build_gs_prolog_function(struct si_shader_context *ctx,
                                        union si_shader_part_key *key)
 {
        unsigned num_sgprs, num_vgprs;
-       struct si_function_info fninfo;
        LLVMBuilderRef builder = ctx->ac.builder;
        LLVMTypeRef returns[48];
        LLVMValueRef func, ret;
 
-       si_init_function_info(&fninfo);
+       memset(&ctx->args, 0, sizeof(ctx->args));
 
        if (ctx->screen->info.chip_class >= GFX9) {
                if (key->gs_prolog.states.gfx9_prev_is_vs)
@@ -6554,18 +6476,18 @@ static void si_build_gs_prolog_function(struct si_shader_context *ctx,
        }
 
        for (unsigned i = 0; i < num_sgprs; ++i) {
-               add_arg(&fninfo, ARG_SGPR, ctx->i32);
+               ac_add_arg(&ctx->args, AC_ARG_SGPR, 1, AC_ARG_INT, NULL);
                returns[i] = ctx->i32;
        }
 
        for (unsigned i = 0; i < num_vgprs; ++i) {
-               add_arg(&fninfo, ARG_VGPR, ctx->i32);
+               ac_add_arg(&ctx->args, AC_ARG_VGPR, 1, AC_ARG_INT, NULL);
                returns[num_sgprs + i] = ctx->f32;
        }
 
        /* Create the function. */
        si_create_function(ctx, "gs_prolog", returns, num_sgprs + num_vgprs,
-                          &fninfo, 0);
+                          0);
        func = ctx->main_fn;
 
        /* Set the full EXEC mask for the prolog, because we are only