ac/nir: fix translation of nir_op_frcp for doubles
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index 2ae656693fc63e54f0dcf998a0942b97ed2a9008..7b348d97f0a086346e393be605c91f8fb2dd1589 100644 (file)
@@ -32,6 +32,7 @@
 #include <llvm-c/Transforms/Scalar.h>
 #include "ac_shader_abi.h"
 #include "ac_shader_info.h"
+#include "ac_shader_util.h"
 #include "ac_exp_param.h"
 
 enum radeon_llvm_calling_convention {
@@ -93,7 +94,7 @@ struct nir_to_llvm_context {
        LLVMValueRef push_constants;
        LLVMValueRef view_index;
        LLVMValueRef num_work_groups;
-       LLVMValueRef workgroup_ids;
+       LLVMValueRef workgroup_ids[3];
        LLVMValueRef local_invocation_ids;
        LLVMValueRef tg_size;
 
@@ -110,10 +111,7 @@ struct nir_to_llvm_context {
        LLVMValueRef oc_lds;
        LLVMValueRef merged_wave_info;
        LLVMValueRef tess_factor_offset;
-       LLVMValueRef tcs_patch_id;
-       LLVMValueRef tcs_rel_ids;
        LLVMValueRef tes_rel_patch_id;
-       LLVMValueRef tes_patch_id;
        LLVMValueRef tes_u;
        LLVMValueRef tes_v;
 
@@ -122,7 +120,6 @@ struct nir_to_llvm_context {
        LLVMValueRef gs2vs_offset;
        LLVMValueRef gs_wave_id;
        LLVMValueRef gs_vtx_offset[6];
-       LLVMValueRef gs_prim_id, gs_invocation_id;
 
        LLVMValueRef esgs_ring;
        LLVMValueRef gsvs_ring;
@@ -150,6 +147,9 @@ struct nir_to_llvm_context {
        unsigned tes_primitive_mode;
        uint64_t tess_outputs_written;
        uint64_t tess_patch_outputs_written;
+
+       uint32_t tcs_patch_outputs_read;
+       uint64_t tcs_outputs_read;
 };
 
 static inline struct nir_to_llvm_context *
@@ -159,6 +159,28 @@ nir_to_llvm_context_from_abi(struct ac_shader_abi *abi)
        return container_of(abi, ctx, abi);
 }
 
+static LLVMTypeRef
+nir2llvmtype(struct ac_nir_context *ctx,
+            const struct glsl_type *type)
+{
+       switch (glsl_get_base_type(glsl_without_array(type))) {
+       case GLSL_TYPE_UINT:
+       case GLSL_TYPE_INT:
+               return ctx->ac.i32;
+       case GLSL_TYPE_UINT64:
+       case GLSL_TYPE_INT64:
+               return ctx->ac.i64;
+       case GLSL_TYPE_DOUBLE:
+               return ctx->ac.f64;
+       case GLSL_TYPE_FLOAT:
+               return ctx->ac.f32;
+       default:
+               assert(!"Unsupported type in nir2llvmtype()");
+               break;
+       }
+       return 0;
+}
+
 static LLVMValueRef get_sampler_desc(struct ac_nir_context *ctx,
                                     const nir_deref_var *deref,
                                     enum ac_descriptor_type desc_type,
@@ -227,58 +249,40 @@ struct arg_info {
        LLVMValueRef *assign[MAX_ARGS];
        unsigned array_params_mask;
        uint8_t count;
-       uint8_t user_sgpr_count;
        uint8_t sgpr_count;
-       uint8_t num_user_sgprs_used;
        uint8_t num_sgprs_used;
        uint8_t num_vgprs_used;
 };
 
-static inline void
-add_argument(struct arg_info *info,
-            LLVMTypeRef type, LLVMValueRef *param_ptr)
+enum ac_arg_regfile {
+       ARG_SGPR,
+       ARG_VGPR,
+};
+
+static void
+add_arg(struct arg_info *info, enum ac_arg_regfile regfile, LLVMTypeRef type,
+       LLVMValueRef *param_ptr)
 {
        assert(info->count < MAX_ARGS);
+
        info->assign[info->count] = param_ptr;
        info->types[info->count] = type;
        info->count++;
-}
-
-static inline void
-add_sgpr_argument(struct arg_info *info,
-                 LLVMTypeRef type, LLVMValueRef *param_ptr)
-{
-       add_argument(info, type, param_ptr);
-       info->num_sgprs_used += ac_get_type_size(type) / 4;
-       info->sgpr_count++;
-}
-
-static inline void
-add_user_sgpr_argument(struct arg_info *info,
-                      LLVMTypeRef type,
-                      LLVMValueRef *param_ptr)
-{
-       add_sgpr_argument(info, type, param_ptr);
-       info->num_user_sgprs_used += ac_get_type_size(type) / 4;
-       info->user_sgpr_count++;
-}
 
-static inline void
-add_vgpr_argument(struct arg_info *info,
-                 LLVMTypeRef type,
-                 LLVMValueRef *param_ptr)
-{
-       add_argument(info, type, param_ptr);
-       info->num_vgprs_used += ac_get_type_size(type) / 4;
+       if (regfile == ARG_SGPR) {
+               info->num_sgprs_used += ac_get_type_size(type) / 4;
+               info->sgpr_count++;
+       } else {
+               assert(regfile == ARG_VGPR);
+               info->num_vgprs_used += ac_get_type_size(type) / 4;
+       }
 }
 
 static inline void
-add_user_sgpr_array_argument(struct arg_info *info,
-                            LLVMTypeRef type,
-                            LLVMValueRef *param_ptr)
+add_array_arg(struct arg_info *info, LLVMTypeRef type, LLVMValueRef *param_ptr)
 {
        info->array_params_mask |= (1 << info->count);
-       add_user_sgpr_argument(info, type, param_ptr);
+       add_arg(info, ARG_SGPR, type, param_ptr);
 }
 
 static void assign_arguments(LLVMValueRef main_function,
@@ -397,7 +401,7 @@ static LLVMValueRef get_rel_patch_id(struct nir_to_llvm_context *ctx)
 {
        switch (ctx->stage) {
        case MESA_SHADER_TESS_CTRL:
-               return unpack_param(&ctx->ac, ctx->tcs_rel_ids, 0, 8);
+               return unpack_param(&ctx->ac, ctx->abi.tcs_rel_ids, 0, 8);
        case MESA_SHADER_TESS_EVAL:
                return ctx->tes_rel_patch_id;
                break;
@@ -497,29 +501,37 @@ get_tcs_out_current_patch_data_offset(struct nir_to_llvm_context *ctx)
                            "");
 }
 
-static void set_userdata_location(struct ac_userdata_info *ud_info, uint8_t *sgpr_idx, uint8_t num_sgprs)
+static void
+set_loc(struct ac_userdata_info *ud_info, uint8_t *sgpr_idx, uint8_t num_sgprs,
+       uint32_t indirect_offset)
 {
        ud_info->sgpr_idx = *sgpr_idx;
        ud_info->num_sgprs = num_sgprs;
-       ud_info->indirect = false;
-       ud_info->indirect_offset = 0;
+       ud_info->indirect = indirect_offset > 0;
+       ud_info->indirect_offset = indirect_offset;
        *sgpr_idx += num_sgprs;
 }
 
-static void set_userdata_location_shader(struct nir_to_llvm_context *ctx,
-                                        int idx, uint8_t *sgpr_idx, uint8_t num_sgprs)
+static void
+set_loc_shader(struct nir_to_llvm_context *ctx, int idx, uint8_t *sgpr_idx,
+              uint8_t num_sgprs)
 {
-       set_userdata_location(&ctx->shader_info->user_sgprs_locs.shader_data[idx], sgpr_idx, num_sgprs);
-}
+       struct ac_userdata_info *ud_info =
+               &ctx->shader_info->user_sgprs_locs.shader_data[idx];
+       assert(ud_info);
 
+       set_loc(ud_info, sgpr_idx, num_sgprs, 0);
+}
 
-static void set_userdata_location_indirect(struct ac_userdata_info *ud_info, uint8_t sgpr_idx, uint8_t num_sgprs,
-                                          uint32_t indirect_offset)
+static void
+set_loc_desc(struct nir_to_llvm_context *ctx, int idx,  uint8_t *sgpr_idx,
+            uint32_t indirect_offset)
 {
-       ud_info->sgpr_idx = sgpr_idx;
-       ud_info->num_sgprs = num_sgprs;
-       ud_info->indirect = true;
-       ud_info->indirect_offset = indirect_offset;
+       struct ac_userdata_info *ud_info =
+               &ctx->shader_info->user_sgprs_locs.descriptor_sets[idx];
+       assert(ud_info);
+
+       set_loc(ud_info, sgpr_idx, 2, indirect_offset);
 }
 
 struct user_sgpr_info {
@@ -529,19 +541,20 @@ struct user_sgpr_info {
 };
 
 static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
+                               gl_shader_stage stage,
                                struct user_sgpr_info *user_sgpr_info)
 {
        memset(user_sgpr_info, 0, sizeof(struct user_sgpr_info));
 
        /* until we sort out scratch/global buffers always assign ring offsets for gs/vs/es */
-       if (ctx->stage == MESA_SHADER_GEOMETRY ||
-           ctx->stage == MESA_SHADER_VERTEX ||
-           ctx->stage == MESA_SHADER_TESS_CTRL ||
-           ctx->stage == MESA_SHADER_TESS_EVAL ||
+       if (stage == MESA_SHADER_GEOMETRY ||
+           stage == MESA_SHADER_VERTEX ||
+           stage == MESA_SHADER_TESS_CTRL ||
+           stage == MESA_SHADER_TESS_EVAL ||
            ctx->is_gs_copy_shader)
                user_sgpr_info->need_ring_offsets = true;
 
-       if (ctx->stage == MESA_SHADER_FRAGMENT &&
+       if (stage == MESA_SHADER_FRAGMENT &&
            ctx->shader_info->info.ps.needs_sample_positions)
                user_sgpr_info->need_ring_offsets = true;
 
@@ -550,9 +563,11 @@ static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
                user_sgpr_info->sgpr_count += 2;
        }
 
-       switch (ctx->stage) {
+       /* FIXME: fix the number of user sgprs for merged shaders on GFX9 */
+       switch (stage) {
        case MESA_SHADER_COMPUTE:
-               user_sgpr_info->sgpr_count += ctx->shader_info->info.cs.grid_components_used;
+               if (ctx->shader_info->info.cs.uses_grid_size)
+                       user_sgpr_info->sgpr_count += 3;
                break;
        case MESA_SHADER_FRAGMENT:
                user_sgpr_info->sgpr_count += ctx->shader_info->info.ps.needs_sample_positions;
@@ -582,10 +597,12 @@ static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
                break;
        }
 
-       if (ctx->shader_info->info.needs_push_constants)
+       if (ctx->shader_info->info.loads_push_constants)
                user_sgpr_info->sgpr_count += 2;
 
-       uint32_t remaining_sgprs = 16 - user_sgpr_info->sgpr_count;
+       uint32_t available_sgprs = ctx->options->chip_class >= GFX9 ? 32 : 16;
+       uint32_t remaining_sgprs = available_sgprs - user_sgpr_info->sgpr_count;
+
        if (remaining_sgprs / 2 < util_bitcount(ctx->shader_info->info.desc_set_used_mask)) {
                user_sgpr_info->sgpr_count += 2;
                user_sgpr_info->indirect_all_descriptor_sets = true;
@@ -595,16 +612,19 @@ static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
 }
 
 static void
-radv_define_common_user_sgprs_phase1(struct nir_to_llvm_context *ctx,
-                                     gl_shader_stage stage,
-                                     bool has_previous_stage,
-                                     gl_shader_stage previous_stage,
-                                     const struct user_sgpr_info *user_sgpr_info,
-                                     struct arg_info *args,
-                                     LLVMValueRef *desc_sets)
-{
-       unsigned num_sets = ctx->options->layout ? ctx->options->layout->num_sets : 0;
+declare_global_input_sgprs(struct nir_to_llvm_context *ctx,
+                          gl_shader_stage stage,
+                          bool has_previous_stage,
+                          gl_shader_stage previous_stage,
+                          const struct user_sgpr_info *user_sgpr_info,
+                          struct arg_info *args,
+                          LLVMValueRef *desc_sets)
+{
+       LLVMTypeRef type = const_array(ctx->ac.i8, 1024 * 1024);
+       unsigned num_sets = ctx->options->layout ?
+                           ctx->options->layout->num_sets : 0;
        unsigned stage_mask = 1 << stage;
+
        if (has_previous_stage)
                stage_mask |= 1 << previous_stage;
 
@@ -612,47 +632,98 @@ radv_define_common_user_sgprs_phase1(struct nir_to_llvm_context *ctx,
        if (!user_sgpr_info->indirect_all_descriptor_sets) {
                for (unsigned i = 0; i < num_sets; ++i) {
                        if (ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
-                               add_user_sgpr_array_argument(args, const_array(ctx->ac.i8, 1024 * 1024), &ctx->descriptor_sets[i]);
+                               add_array_arg(args, type,
+                                             &ctx->descriptor_sets[i]);
                        }
                }
-       } else
-               add_user_sgpr_array_argument(args, const_array(const_array(ctx->ac.i8, 1024 * 1024), 32), desc_sets);
+       } else {
+               add_array_arg(args, const_array(type, 32), desc_sets);
+       }
 
-       if (ctx->shader_info->info.needs_push_constants) {
+       if (ctx->shader_info->info.loads_push_constants) {
                /* 1 for push constants and dynamic descriptors */
-               add_user_sgpr_array_argument(args, const_array(ctx->ac.i8, 1024 * 1024), &ctx->push_constants);
+               add_array_arg(args, type, &ctx->push_constants);
+       }
+}
+
+static void
+declare_vs_specific_input_sgprs(struct nir_to_llvm_context *ctx,
+                               gl_shader_stage stage,
+                               bool has_previous_stage,
+                               gl_shader_stage previous_stage,
+                               struct arg_info *args)
+{
+       if (!ctx->is_gs_copy_shader &&
+           (stage == MESA_SHADER_VERTEX ||
+            (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
+               if (ctx->shader_info->info.vs.has_vertex_buffers) {
+                       add_arg(args, ARG_SGPR, const_array(ctx->ac.v4i32, 16),
+                               &ctx->vertex_buffers);
+               }
+               add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->abi.base_vertex);
+               add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->abi.start_instance);
+               if (ctx->shader_info->info.vs.needs_draw_id) {
+                       add_arg(args, ARG_SGPR, ctx->ac.i32, &ctx->abi.draw_id);
+               }
+       }
+}
+
+static void
+declare_vs_input_vgprs(struct nir_to_llvm_context *ctx, struct arg_info *args)
+{
+       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.vertex_id);
+       if (!ctx->is_gs_copy_shader) {
+               if (ctx->options->key.vs.as_ls) {
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->rel_auto_id);
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+               } else {
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->vs_prim_id);
+               }
+               add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* unused */
        }
 }
 
 static void
-radv_define_common_user_sgprs_phase2(struct nir_to_llvm_context *ctx,
-                                     gl_shader_stage stage,
-                                     bool has_previous_stage,
-                                     gl_shader_stage previous_stage,
-                                     const struct user_sgpr_info *user_sgpr_info,
-                                    LLVMValueRef desc_sets,
-                                     uint8_t *user_sgpr_idx)
-{
-       unsigned num_sets = ctx->options->layout ? ctx->options->layout->num_sets : 0;
+declare_tes_input_vgprs(struct nir_to_llvm_context *ctx, struct arg_info *args)
+{
+       add_arg(args, ARG_VGPR, ctx->ac.f32, &ctx->tes_u);
+       add_arg(args, ARG_VGPR, ctx->ac.f32, &ctx->tes_v);
+       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->tes_rel_patch_id);
+       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.tes_patch_id);
+}
+
+static void
+set_global_input_locs(struct nir_to_llvm_context *ctx, gl_shader_stage stage,
+                     bool has_previous_stage, gl_shader_stage previous_stage,
+                     const struct user_sgpr_info *user_sgpr_info,
+                     LLVMValueRef desc_sets, uint8_t *user_sgpr_idx)
+{
+       unsigned num_sets = ctx->options->layout ?
+                           ctx->options->layout->num_sets : 0;
        unsigned stage_mask = 1 << stage;
+
        if (has_previous_stage)
                stage_mask |= 1 << previous_stage;
 
        if (!user_sgpr_info->indirect_all_descriptor_sets) {
                for (unsigned i = 0; i < num_sets; ++i) {
                        if (ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
-                               set_userdata_location(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], user_sgpr_idx, 2);
+                               set_loc_desc(ctx, i, user_sgpr_idx, 0);
                        } else
                                ctx->descriptor_sets[i] = NULL;
                }
        } else {
-               uint32_t desc_sgpr_idx = *user_sgpr_idx;
-               set_userdata_location_shader(ctx, AC_UD_INDIRECT_DESCRIPTOR_SETS, user_sgpr_idx, 2);
+               set_loc_shader(ctx, AC_UD_INDIRECT_DESCRIPTOR_SETS,
+                              user_sgpr_idx, 2);
 
                for (unsigned i = 0; i < num_sets; ++i) {
                        if (ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
-                               set_userdata_location_indirect(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], desc_sgpr_idx, 2, i * 8);
-                               ctx->descriptor_sets[i] = ac_build_load_to_sgpr(&ctx->ac, desc_sets, LLVMConstInt(ctx->ac.i32, i, false));
+                               set_loc_desc(ctx, i, user_sgpr_idx, i * 8);
+                               ctx->descriptor_sets[i] =
+                                       ac_build_load_to_sgpr(&ctx->ac,
+                                                             desc_sets,
+                                                             LLVMConstInt(ctx->ac.i32, i, false));
 
                        } else
                                ctx->descriptor_sets[i] = NULL;
@@ -660,48 +731,34 @@ radv_define_common_user_sgprs_phase2(struct nir_to_llvm_context *ctx,
                ctx->shader_info->need_indirect_descriptor_sets = true;
        }
 
-       if (ctx->shader_info->info.needs_push_constants) {
-               set_userdata_location_shader(ctx, AC_UD_PUSH_CONSTANTS, user_sgpr_idx, 2);
-       }
-}
-
-static void
-radv_define_vs_user_sgprs_phase1(struct nir_to_llvm_context *ctx,
-                                 gl_shader_stage stage,
-                                 bool has_previous_stage,
-                                 gl_shader_stage previous_stage,
-                                 struct arg_info *args)
-{
-       if (!ctx->is_gs_copy_shader && (stage == MESA_SHADER_VERTEX || (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
-               if (ctx->shader_info->info.vs.has_vertex_buffers)
-                       add_user_sgpr_argument(args, const_array(ctx->ac.v4i32, 16), &ctx->vertex_buffers); /* vertex buffers */
-               add_user_sgpr_argument(args, ctx->ac.i32, &ctx->abi.base_vertex); // base vertex
-               add_user_sgpr_argument(args, ctx->ac.i32, &ctx->abi.start_instance);// start instance
-               if (ctx->shader_info->info.vs.needs_draw_id)
-                       add_user_sgpr_argument(args, ctx->ac.i32, &ctx->abi.draw_id); // draw id
+       if (ctx->shader_info->info.loads_push_constants) {
+               set_loc_shader(ctx, AC_UD_PUSH_CONSTANTS, user_sgpr_idx, 2);
        }
 }
 
 static void
-radv_define_vs_user_sgprs_phase2(struct nir_to_llvm_context *ctx,
-                                 gl_shader_stage stage,
-                                 bool has_previous_stage,
-                                 gl_shader_stage previous_stage,
-                                 uint8_t *user_sgpr_idx)
-{
-       if (!ctx->is_gs_copy_shader && (stage == MESA_SHADER_VERTEX || (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
+set_vs_specific_input_locs(struct nir_to_llvm_context *ctx,
+                          gl_shader_stage stage, bool has_previous_stage,
+                          gl_shader_stage previous_stage,
+                          uint8_t *user_sgpr_idx)
+{
+       if (!ctx->is_gs_copy_shader &&
+           (stage == MESA_SHADER_VERTEX ||
+            (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
                if (ctx->shader_info->info.vs.has_vertex_buffers) {
-                       set_userdata_location_shader(ctx, AC_UD_VS_VERTEX_BUFFERS, user_sgpr_idx, 2);
+                       set_loc_shader(ctx, AC_UD_VS_VERTEX_BUFFERS,
+                                      user_sgpr_idx, 2);
                }
+
                unsigned vs_num = 2;
                if (ctx->shader_info->info.vs.needs_draw_id)
                        vs_num++;
 
-               set_userdata_location_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE, user_sgpr_idx, vs_num);
+               set_loc_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE,
+                              user_sgpr_idx, vs_num);
        }
 }
 
-
 static void create_function(struct nir_to_llvm_context *ctx,
                             gl_shader_stage stage,
                             bool has_previous_stage,
@@ -712,175 +769,260 @@ static void create_function(struct nir_to_llvm_context *ctx,
        struct arg_info args = {};
        LLVMValueRef desc_sets;
 
-       allocate_user_sgprs(ctx, &user_sgpr_info);
+       allocate_user_sgprs(ctx, stage, &user_sgpr_info);
 
        if (user_sgpr_info.need_ring_offsets && !ctx->options->supports_spill) {
-               add_user_sgpr_argument(&args, const_array(ctx->ac.v4i32, 16), &ctx->ring_offsets); /* address of rings */
+               add_arg(&args, ARG_SGPR, const_array(ctx->ac.v4i32, 16),
+                       &ctx->ring_offsets);
        }
 
        switch (stage) {
        case MESA_SHADER_COMPUTE:
-               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-               if (ctx->shader_info->info.cs.grid_components_used)
-                       add_user_sgpr_argument(&args, LLVMVectorType(ctx->ac.i32, ctx->shader_info->info.cs.grid_components_used), &ctx->num_work_groups); /* grid size */
-               add_sgpr_argument(&args, ctx->ac.v3i32, &ctx->workgroup_ids);
-               add_sgpr_argument(&args, ctx->ac.i32, &ctx->tg_size);
-               add_vgpr_argument(&args, ctx->ac.v3i32, &ctx->local_invocation_ids);
+               declare_global_input_sgprs(ctx, stage, has_previous_stage,
+                                          previous_stage, &user_sgpr_info,
+                                          &args, &desc_sets);
+
+               if (ctx->shader_info->info.cs.uses_grid_size) {
+                       add_arg(&args, ARG_SGPR, ctx->ac.v3i32,
+                               &ctx->num_work_groups);
+               }
+
+               for (int i = 0; i < 3; i++) {
+                       ctx->workgroup_ids[i] = NULL;
+                       if (ctx->shader_info->info.cs.uses_block_id[i]) {
+                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                                       &ctx->workgroup_ids[i]);
+                       }
+               }
+
+               if (ctx->shader_info->info.cs.uses_local_invocation_idx)
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->tg_size);
+               add_arg(&args, ARG_VGPR, ctx->ac.v3i32,
+                       &ctx->local_invocation_ids);
                break;
        case MESA_SHADER_VERTEX:
-               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-               radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
+               declare_global_input_sgprs(ctx, stage, has_previous_stage,
+                                          previous_stage, &user_sgpr_info,
+                                          &args, &desc_sets);
+               declare_vs_specific_input_sgprs(ctx, stage, has_previous_stage,
+                                               previous_stage, &args);
+
                if (ctx->shader_info->info.needs_multiview_view_index || (!ctx->options->key.vs.as_es && !ctx->options->key.vs.as_ls && ctx->options->key.has_multiview_view_index))
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->view_index);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->view_index);
                if (ctx->options->key.vs.as_es)
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->es2gs_offset); // es2gs offset
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->es2gs_offset);
                else if (ctx->options->key.vs.as_ls)
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->ls_out_layout); // ls out layout
-               add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.vertex_id); // vertex id
-               if (!ctx->is_gs_copy_shader) {
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->rel_auto_id); // rel auto id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->vs_prim_id); // vs prim id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.instance_id); // instance id
-               }
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->ls_out_layout);
+
+               declare_vs_input_vgprs(ctx, &args);
                break;
        case MESA_SHADER_TESS_CTRL:
                if (has_previous_stage) {
                        // First 6 system regs
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->oc_lds); // param oc lds
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->merged_wave_info); // merged wave info
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->tess_factor_offset); // tess factor offset
-
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); // scratch offset
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); // unknown
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); // unknown
-
-                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-                       radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->ls_out_layout); // ls out layout
-
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_offchip_layout); // tcs offchip layout
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_out_offsets); // tcs out offsets
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_out_layout); // tcs out layout
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_in_layout); // tcs in layout
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->merged_wave_info);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tess_factor_offset);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // scratch offset
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
+
+                       declare_global_input_sgprs(ctx, stage,
+                                                  has_previous_stage,
+                                                  previous_stage,
+                                                  &user_sgpr_info, &args,
+                                                  &desc_sets);
+                       declare_vs_specific_input_sgprs(ctx, stage,
+                                                       has_previous_stage,
+                                                       previous_stage, &args);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->ls_out_layout);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_offchip_layout);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_out_offsets);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_out_layout);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_in_layout);
                        if (ctx->shader_info->info.needs_multiview_view_index)
-                               add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->view_index);
-
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->tcs_patch_id); // patch id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->tcs_rel_ids); // rel ids;
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.vertex_id); // vertex id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->rel_auto_id); // rel auto id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->vs_prim_id); // vs prim id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.instance_id); // instance id
+                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                                       &ctx->view_index);
+
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.tcs_patch_id);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.tcs_rel_ids);
+
+                       declare_vs_input_vgprs(ctx, &args);
                } else {
-                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_offchip_layout); // tcs offchip layout
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_out_offsets); // tcs out offsets
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_out_layout); // tcs out layout
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_in_layout); // tcs in layout
+                       declare_global_input_sgprs(ctx, stage,
+                                                  has_previous_stage,
+                                                  previous_stage,
+                                                  &user_sgpr_info, &args,
+                                                  &desc_sets);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_offchip_layout);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_out_offsets);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_out_layout);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tcs_in_layout);
                        if (ctx->shader_info->info.needs_multiview_view_index)
-                               add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->view_index);
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->oc_lds); // param oc lds
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->tess_factor_offset); // tess factor offset
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->tcs_patch_id); // patch id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->tcs_rel_ids); // rel ids;
+                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                                       &ctx->view_index);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->tess_factor_offset);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.tcs_patch_id);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.tcs_rel_ids);
                }
                break;
        case MESA_SHADER_TESS_EVAL:
-               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-               add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_offchip_layout); // tcs offchip layout
+               declare_global_input_sgprs(ctx, stage, has_previous_stage,
+                                          previous_stage, &user_sgpr_info,
+                                          &args, &desc_sets);
+
+               add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->tcs_offchip_layout);
                if (ctx->shader_info->info.needs_multiview_view_index || (!ctx->options->key.tes.as_es && ctx->options->key.has_multiview_view_index))
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->view_index);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->view_index);
+
                if (ctx->options->key.tes.as_es) {
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->oc_lds); // OC LDS
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); //
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->es2gs_offset); // es2gs offset
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->es2gs_offset);
                } else {
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); //
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->oc_lds); // OC LDS
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
                }
-               add_vgpr_argument(&args, ctx->ac.f32, &ctx->tes_u); // tes_u
-               add_vgpr_argument(&args, ctx->ac.f32, &ctx->tes_v); // tes_v
-               add_vgpr_argument(&args, ctx->ac.i32, &ctx->tes_rel_patch_id); // tes rel patch id
-               add_vgpr_argument(&args, ctx->ac.i32, &ctx->tes_patch_id); // tes patch id
+               declare_tes_input_vgprs(ctx, &args);
                break;
        case MESA_SHADER_GEOMETRY:
                if (has_previous_stage) {
                        // First 6 system regs
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->gs2vs_offset); // tess factor offset
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->merged_wave_info); // merged wave info
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->oc_lds); // param oc lds
-
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); // scratch offset
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); // unknown
-                       add_sgpr_argument(&args, ctx->ac.i32, NULL); // unknown
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->gs2vs_offset);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->merged_wave_info);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->oc_lds);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // scratch offset
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
+
+                       declare_global_input_sgprs(ctx, stage,
+                                                  has_previous_stage,
+                                                  previous_stage,
+                                                  &user_sgpr_info, &args,
+                                                  &desc_sets);
+
+                       if (previous_stage == MESA_SHADER_TESS_EVAL) {
+                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                                       &ctx->tcs_offchip_layout);
+                       } else {
+                               declare_vs_specific_input_sgprs(ctx, stage,
+                                                               has_previous_stage,
+                                                               previous_stage,
+                                                               &args);
+                       }
 
-                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-                       if (previous_stage == MESA_SHADER_TESS_EVAL)
-                               add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->tcs_offchip_layout); // tcs offchip layout
-                       else
-                               radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->gsvs_ring_stride); // gsvs stride
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->gsvs_num_entries); // gsvs num entires
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->gsvs_ring_stride);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->gsvs_num_entries);
                        if (ctx->shader_info->info.needs_multiview_view_index)
-                               add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->view_index);
-
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[0]); // vtx01
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[2]); // vtx23
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_prim_id); // prim id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_invocation_id);
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[4]);
+                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                                       &ctx->view_index);
+
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[0]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[2]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.gs_prim_id);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.gs_invocation_id);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[4]);
 
                        if (previous_stage == MESA_SHADER_VERTEX) {
-                               add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.vertex_id); // vertex id
-                               add_vgpr_argument(&args, ctx->ac.i32, &ctx->rel_auto_id); // rel auto id
-                               add_vgpr_argument(&args, ctx->ac.i32, &ctx->vs_prim_id); // vs prim id
-                               add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.instance_id); // instance id
+                               declare_vs_input_vgprs(ctx, &args);
                        } else {
-                               add_vgpr_argument(&args, ctx->ac.f32, &ctx->tes_u); // tes_u
-                               add_vgpr_argument(&args, ctx->ac.f32, &ctx->tes_v); // tes_v
-                               add_vgpr_argument(&args, ctx->ac.i32, &ctx->tes_rel_patch_id); // tes rel patch id
-                               add_vgpr_argument(&args, ctx->ac.i32, &ctx->tes_patch_id); // tes patch id
+                               declare_tes_input_vgprs(ctx, &args);
                        }
                } else {
-                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-                       radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->gsvs_ring_stride); // gsvs stride
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->gsvs_num_entries); // gsvs num entires
+                       declare_global_input_sgprs(ctx, stage,
+                                                  has_previous_stage,
+                                                  previous_stage,
+                                                  &user_sgpr_info, &args,
+                                                  &desc_sets);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->gsvs_ring_stride);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->gsvs_num_entries);
                        if (ctx->shader_info->info.needs_multiview_view_index)
-                               add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->view_index);
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->gs2vs_offset); // gs2vs offset
-                       add_sgpr_argument(&args, ctx->ac.i32, &ctx->gs_wave_id); // wave id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[0]); // vtx0
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[1]); // vtx1
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_prim_id); // prim id
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[2]);
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[3]);
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[4]);
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_vtx_offset[5]);
-                       add_vgpr_argument(&args, ctx->ac.i32, &ctx->gs_invocation_id);
+                               add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                                       &ctx->view_index);
+
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->gs2vs_offset);
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->gs_wave_id);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[0]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[1]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.gs_prim_id);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[2]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[3]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[4]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->gs_vtx_offset[5]);
+                       add_arg(&args, ARG_VGPR, ctx->ac.i32,
+                               &ctx->abi.gs_invocation_id);
                }
                break;
        case MESA_SHADER_FRAGMENT:
-               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+               declare_global_input_sgprs(ctx, stage, has_previous_stage,
+                                          previous_stage, &user_sgpr_info,
+                                          &args, &desc_sets);
+
                if (ctx->shader_info->info.ps.needs_sample_positions)
-                       add_user_sgpr_argument(&args, ctx->ac.i32, &ctx->sample_pos_offset); /* sample position offset */
-               add_sgpr_argument(&args, ctx->ac.i32, &ctx->prim_mask); /* prim mask */
-               add_vgpr_argument(&args, ctx->ac.v2i32, &ctx->persp_sample); /* persp sample */
-               add_vgpr_argument(&args, ctx->ac.v2i32, &ctx->persp_center); /* persp center */
-               add_vgpr_argument(&args, ctx->ac.v2i32, &ctx->persp_centroid); /* persp centroid */
-               add_vgpr_argument(&args, ctx->ac.v3i32, NULL); /* persp pull model */
-               add_vgpr_argument(&args, ctx->ac.v2i32, &ctx->linear_sample); /* linear sample */
-               add_vgpr_argument(&args, ctx->ac.v2i32, &ctx->linear_center); /* linear center */
-               add_vgpr_argument(&args, ctx->ac.v2i32, &ctx->linear_centroid); /* linear centroid */
-               add_vgpr_argument(&args, ctx->ac.f32, NULL);  /* line stipple tex */
-               add_vgpr_argument(&args, ctx->ac.f32, &ctx->abi.frag_pos[0]);  /* pos x float */
-               add_vgpr_argument(&args, ctx->ac.f32, &ctx->abi.frag_pos[1]);  /* pos y float */
-               add_vgpr_argument(&args, ctx->ac.f32, &ctx->abi.frag_pos[2]);  /* pos z float */
-               add_vgpr_argument(&args, ctx->ac.f32, &ctx->abi.frag_pos[3]);  /* pos w float */
-               add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.front_face);  /* front face */
-               add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.ancillary);  /* ancillary */
-               add_vgpr_argument(&args, ctx->ac.i32, &ctx->abi.sample_coverage);  /* sample coverage */
-               add_vgpr_argument(&args, ctx->ac.i32, NULL);  /* fixed pt */
+                       add_arg(&args, ARG_SGPR, ctx->ac.i32,
+                               &ctx->sample_pos_offset);
+
+               add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->prim_mask);
+               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->persp_sample);
+               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->persp_center);
+               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->persp_centroid);
+               add_arg(&args, ARG_VGPR, ctx->ac.v3i32, NULL); /* persp pull model */
+               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->linear_sample);
+               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->linear_center);
+               add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->linear_centroid);
+               add_arg(&args, ARG_VGPR, ctx->ac.f32, NULL);  /* line stipple tex */
+               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[0]);
+               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[1]);
+               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[2]);
+               add_arg(&args, ARG_VGPR, ctx->ac.f32, &ctx->abi.frag_pos[3]);
+               add_arg(&args, ARG_VGPR, ctx->ac.i32, &ctx->abi.front_face);
+               add_arg(&args, ARG_VGPR, ctx->ac.i32, &ctx->abi.ancillary);
+               add_arg(&args, ARG_VGPR, ctx->ac.i32, &ctx->abi.sample_coverage);
+               add_arg(&args, ARG_VGPR, ctx->ac.i32, NULL);  /* fixed pt */
                break;
        default:
                unreachable("Shader stage not implemented");
@@ -906,7 +1048,8 @@ static void create_function(struct nir_to_llvm_context *ctx,
        user_sgpr_idx = 0;
 
        if (ctx->options->supports_spill || user_sgpr_info.need_ring_offsets) {
-               set_userdata_location_shader(ctx, AC_UD_SCRATCH_RING_OFFSETS, &user_sgpr_idx, 2);
+               set_loc_shader(ctx, AC_UD_SCRATCH_RING_OFFSETS,
+                              &user_sgpr_idx, 2);
                if (ctx->options->supports_spill) {
                        ctx->ring_offsets = ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.implicit.buffer.ptr",
                                                               LLVMPointerType(ctx->ac.i8, CONST_ADDR_SPACE),
@@ -921,54 +1064,66 @@ static void create_function(struct nir_to_llvm_context *ctx,
        if (has_previous_stage)
                user_sgpr_idx = 0;
 
-       radv_define_common_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, desc_sets, &user_sgpr_idx);
+       set_global_input_locs(ctx, stage, has_previous_stage, previous_stage,
+                             &user_sgpr_info, desc_sets, &user_sgpr_idx);
 
        switch (stage) {
        case MESA_SHADER_COMPUTE:
-               if (ctx->shader_info->info.cs.grid_components_used) {
-                       set_userdata_location_shader(ctx, AC_UD_CS_GRID_SIZE, &user_sgpr_idx, ctx->shader_info->info.cs.grid_components_used);
+               if (ctx->shader_info->info.cs.uses_grid_size) {
+                       set_loc_shader(ctx, AC_UD_CS_GRID_SIZE,
+                                      &user_sgpr_idx, 3);
                }
                break;
        case MESA_SHADER_VERTEX:
-               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
+               set_vs_specific_input_locs(ctx, stage, has_previous_stage,
+                                          previous_stage, &user_sgpr_idx);
                if (ctx->view_index)
-                       set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                if (ctx->options->key.vs.as_ls) {
-                       set_userdata_location_shader(ctx, AC_UD_VS_LS_TCS_IN_LAYOUT, &user_sgpr_idx, 1);
+                       set_loc_shader(ctx, AC_UD_VS_LS_TCS_IN_LAYOUT,
+                                      &user_sgpr_idx, 1);
                }
                if (ctx->options->key.vs.as_ls)
                        ac_declare_lds_as_pointer(&ctx->ac);
                break;
        case MESA_SHADER_TESS_CTRL:
-               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
+               set_vs_specific_input_locs(ctx, stage, has_previous_stage,
+                                          previous_stage, &user_sgpr_idx);
                if (has_previous_stage)
-                       set_userdata_location_shader(ctx, AC_UD_VS_LS_TCS_IN_LAYOUT, &user_sgpr_idx, 1);
-               set_userdata_location_shader(ctx, AC_UD_TCS_OFFCHIP_LAYOUT, &user_sgpr_idx, 4);
+                       set_loc_shader(ctx, AC_UD_VS_LS_TCS_IN_LAYOUT,
+                                      &user_sgpr_idx, 1);
+               set_loc_shader(ctx, AC_UD_TCS_OFFCHIP_LAYOUT, &user_sgpr_idx, 4);
                if (ctx->view_index)
-                       set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                ac_declare_lds_as_pointer(&ctx->ac);
                break;
        case MESA_SHADER_TESS_EVAL:
-               set_userdata_location_shader(ctx, AC_UD_TES_OFFCHIP_LAYOUT, &user_sgpr_idx, 1);
+               set_loc_shader(ctx, AC_UD_TES_OFFCHIP_LAYOUT, &user_sgpr_idx, 1);
                if (ctx->view_index)
-                       set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                break;
        case MESA_SHADER_GEOMETRY:
                if (has_previous_stage) {
                        if (previous_stage == MESA_SHADER_VERTEX)
-                               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
+                               set_vs_specific_input_locs(ctx, stage,
+                                                          has_previous_stage,
+                                                          previous_stage,
+                                                          &user_sgpr_idx);
                        else
-                               set_userdata_location_shader(ctx, AC_UD_TES_OFFCHIP_LAYOUT, &user_sgpr_idx, 1);
+                               set_loc_shader(ctx, AC_UD_TES_OFFCHIP_LAYOUT,
+                                              &user_sgpr_idx, 1);
                }
-               set_userdata_location_shader(ctx, AC_UD_GS_VS_RING_STRIDE_ENTRIES, &user_sgpr_idx, 2);
+               set_loc_shader(ctx, AC_UD_GS_VS_RING_STRIDE_ENTRIES,
+                              &user_sgpr_idx, 2);
                if (ctx->view_index)
-                       set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+                       set_loc_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                if (has_previous_stage)
                        ac_declare_lds_as_pointer(&ctx->ac);
                break;
        case MESA_SHADER_FRAGMENT:
                if (ctx->shader_info->info.ps.needs_sample_positions) {
-                       set_userdata_location_shader(ctx, AC_UD_PS_SAMPLE_POS_OFFSET, &user_sgpr_idx, 1);
+                       set_loc_shader(ctx, AC_UD_PS_SAMPLE_POS_OFFSET,
+                                      &user_sgpr_idx, 1);
                }
                break;
        default:
@@ -978,32 +1133,10 @@ static void create_function(struct nir_to_llvm_context *ctx,
        ctx->shader_info->num_user_sgprs = user_sgpr_idx;
 }
 
-static int get_llvm_num_components(LLVMValueRef value)
-{
-       LLVMTypeRef type = LLVMTypeOf(value);
-       unsigned num_components = LLVMGetTypeKind(type) == LLVMVectorTypeKind
-                                     ? LLVMGetVectorSize(type)
-                                     : 1;
-       return num_components;
-}
-
-static LLVMValueRef llvm_extract_elem(struct ac_llvm_context *ac,
-                                     LLVMValueRef value,
-                                     int index)
-{
-       int count = get_llvm_num_components(value);
-
-       if (count == 1)
-               return value;
-
-       return LLVMBuildExtractElement(ac->builder, value,
-                                      LLVMConstInt(ac->i32, index, false), "");
-}
-
 static LLVMValueRef trim_vector(struct ac_llvm_context *ctx,
                                 LLVMValueRef value, unsigned count)
 {
-       unsigned num_components = get_llvm_num_components(value);
+       unsigned num_components = ac_get_llvm_num_components(value);
        if (count == num_components)
                return value;
 
@@ -1112,7 +1245,7 @@ static LLVMValueRef emit_int_cmp(struct ac_llvm_context *ctx,
        LLVMValueRef result = LLVMBuildICmp(ctx->builder, pred, src0, src1, "");
        return LLVMBuildSelect(ctx->builder, result,
                               LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
-                              LLVMConstInt(ctx->i32, 0, false), "");
+                              ctx->i32_0, "");
 }
 
 static LLVMValueRef emit_float_cmp(struct ac_llvm_context *ctx,
@@ -1125,7 +1258,7 @@ static LLVMValueRef emit_float_cmp(struct ac_llvm_context *ctx,
        result = LLVMBuildFCmp(ctx->builder, pred, src0, src1, "");
        return LLVMBuildSelect(ctx->builder, result,
                               LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
-                              LLVMConstInt(ctx->i32, 0, false), "");
+                              ctx->i32_0, "");
 }
 
 static LLVMValueRef emit_intrin_1f_param(struct ac_llvm_context *ctx,
@@ -1402,23 +1535,13 @@ static LLVMValueRef emit_bitfield_insert(struct ac_llvm_context *ctx,
 static LLVMValueRef emit_pack_half_2x16(struct ac_llvm_context *ctx,
                                        LLVMValueRef src0)
 {
-       LLVMValueRef const16 = LLVMConstInt(ctx->i32, 16, false);
-       int i;
        LLVMValueRef comp[2];
 
        src0 = ac_to_float(ctx, src0);
        comp[0] = LLVMBuildExtractElement(ctx->builder, src0, ctx->i32_0, "");
        comp[1] = LLVMBuildExtractElement(ctx->builder, src0, ctx->i32_1, "");
-       for (i = 0; i < 2; i++) {
-               comp[i] = LLVMBuildFPTrunc(ctx->builder, comp[i], ctx->f16, "");
-               comp[i] = LLVMBuildBitCast(ctx->builder, comp[i], ctx->i16, "");
-               comp[i] = LLVMBuildZExt(ctx->builder, comp[i], ctx->i32, "");
-       }
 
-       comp[1] = LLVMBuildShl(ctx->builder, comp[1], const16, "");
-       comp[0] = LLVMBuildOr(ctx->builder, comp[0], comp[1], "");
-
-       return comp[0];
+       return ac_build_cvt_pkrtz_f16(ctx, comp);
 }
 
 static LLVMValueRef emit_unpack_half_2x16(struct ac_llvm_context *ctx,
@@ -1592,7 +1715,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_frcp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, src[0]);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      src[0]);
                break;
        case nir_op_iand:
                result = LLVMBuildAnd(ctx->ac.builder, src[0], src[1], "");
@@ -1719,7 +1843,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        case nir_op_frsq:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.sqrt",
                                              ac_to_float_type(&ctx->ac, def_type), src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, result);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      result);
                break;
        case nir_op_fpow:
                result = emit_intrin_2f_param(&ctx->ac, "llvm.pow",
@@ -1788,6 +1913,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
        case nir_op_f2f64:
+               src[0] = ac_to_float(&ctx->ac, src[0]);
                result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
        case nir_op_f2f32:
@@ -1975,7 +2101,7 @@ get_buffer_size(struct ac_nir_context *ctx, LLVMValueRef descriptor, bool in_ele
                 */
                LLVMValueRef stride =
                        LLVMBuildExtractElement(ctx->ac.builder, descriptor,
-                                               LLVMConstInt(ctx->ac.i32, 1, false), "");
+                                               ctx->ac.i32_1, "");
                stride = LLVMBuildLShr(ctx->ac.builder, stride,
                                       LLVMConstInt(ctx->ac.i32, 16, false), "");
                stride = LLVMBuildAnd(ctx->ac.builder, stride,
@@ -2126,7 +2252,7 @@ static LLVMValueRef build_tex_intrinsic(struct ac_nir_context *ctx,
                return ac_build_buffer_load_format(&ctx->ac,
                                                   args->resource,
                                                   args->addr,
-                                                  LLVMConstInt(ctx->ac.i32, 0, false),
+                                                  ctx->ac.i32_0,
                                                   true);
        }
 
@@ -2137,7 +2263,9 @@ static LLVMValueRef build_tex_intrinsic(struct ac_nir_context *ctx,
        case nir_texop_txf:
        case nir_texop_txf_ms:
        case nir_texop_samples_identical:
-               args->opcode = instr->sampler_dim == GLSL_SAMPLER_DIM_MS ? ac_image_load : ac_image_load_mip;
+               args->opcode = lod_is_zero ||
+                              instr->sampler_dim == GLSL_SAMPLER_DIM_MS ?
+                                       ac_image_load : ac_image_load_mip;
                args->compare = false;
                args->offset = false;
                break;
@@ -2213,7 +2341,18 @@ static LLVMValueRef visit_vulkan_resource_index(struct nir_to_llvm_context *ctx,
        desc_ptr = cast_ptr(ctx, desc_ptr, ctx->ac.v4i32);
        LLVMSetMetadata(desc_ptr, ctx->ac.uniform_md_kind, ctx->ac.empty_md);
 
-       return LLVMBuildLoad(ctx->builder, desc_ptr, "");
+       return desc_ptr;
+}
+
+static LLVMValueRef visit_vulkan_resource_reindex(struct nir_to_llvm_context *ctx,
+                                                  nir_intrinsic_instr *instr)
+{
+       LLVMValueRef ptr = get_src(ctx->nir, instr->src[0]);
+       LLVMValueRef index = get_src(ctx->nir, instr->src[1]);
+
+       LLVMValueRef result = LLVMBuildGEP(ctx->builder, ptr, &index, 1, "");
+       LLVMSetMetadata(result, ctx->ac.uniform_md_kind, ctx->ac.empty_md);
+       return result;
 }
 
 static LLVMValueRef visit_load_push_constant(struct nir_to_llvm_context *ctx,
@@ -2233,9 +2372,9 @@ static LLVMValueRef visit_load_push_constant(struct nir_to_llvm_context *ctx,
 static LLVMValueRef visit_get_buffer_size(struct ac_nir_context *ctx,
                                           const nir_intrinsic_instr *instr)
 {
-       LLVMValueRef desc = get_src(ctx, instr->src[0]);
+       LLVMValueRef ptr = get_src(ctx, instr->src[0]);
 
-       return get_buffer_size(ctx, desc, false);
+       return get_buffer_size(ctx, LLVMBuildLoad(ctx->ac.builder, ptr, ""), false);
 }
 static void visit_store_ssbo(struct ac_nir_context *ctx,
                              nir_intrinsic_instr *instr)
@@ -2251,7 +2390,7 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
 
        params[1] = ctx->abi->load_ssbo(ctx->abi,
                                        get_src(ctx, instr->src[1]), true);
-       params[2] = LLVMConstInt(ctx->ac.i32, 0, false); /* vindex */
+       params[2] = ctx->ac.i32_0; /* vindex */
        params[4] = ctx->ac.i1false;  /* glc */
        params[5] = ctx->ac.i1false;  /* slc */
 
@@ -2302,7 +2441,7 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
 
                } else {
                        assert(count == 1);
-                       if (get_llvm_num_components(base_data) > 1)
+                       if (ac_get_llvm_num_components(base_data) > 1)
                                data = LLVMBuildExtractElement(ctx->ac.builder, base_data,
                                                               LLVMConstInt(ctx->ac.i32, start, false), "");
                        else
@@ -2329,13 +2468,13 @@ static LLVMValueRef visit_atomic_ssbo(struct ac_nir_context *ctx,
        int arg_count = 0;
 
        if (instr->intrinsic == nir_intrinsic_ssbo_atomic_comp_swap) {
-               params[arg_count++] = llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[3]), 0);
+               params[arg_count++] = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[3]), 0);
        }
-       params[arg_count++] = llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[2]), 0);
+       params[arg_count++] = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[2]), 0);
        params[arg_count++] = ctx->abi->load_ssbo(ctx->abi,
                                                 get_src(ctx, instr->src[0]),
                                                 true);
-       params[arg_count++] = LLVMConstInt(ctx->ac.i32, 0, false); /* vindex */
+       params[arg_count++] = ctx->ac.i32_0; /* vindex */
        params[arg_count++] = get_src(ctx, instr->src[1]);      /* voffset */
        params[arg_count++] = LLVMConstInt(ctx->ac.i1, 0, false);  /* slc */
 
@@ -2411,7 +2550,7 @@ static LLVMValueRef visit_load_buffer(struct ac_nir_context *ctx,
                        ctx->abi->load_ssbo(ctx->abi,
                                            get_src(ctx, instr->src[0]),
                                            false),
-                       LLVMConstInt(ctx->ac.i32, 0, false),
+                       ctx->ac.i32_0,
                        offset,
                        ctx->ac.i1false,
                        ctx->ac.i1false,
@@ -2443,7 +2582,7 @@ static LLVMValueRef visit_load_buffer(struct ac_nir_context *ctx,
 static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
                                           const nir_intrinsic_instr *instr)
 {
-       LLVMValueRef results[8], ret;
+       LLVMValueRef ret;
        LLVMValueRef rsrc = get_src(ctx, instr->src[0]);
        LLVMValueRef offset = get_src(ctx, instr->src[1]);
        int num_components = instr->num_components;
@@ -2454,20 +2593,9 @@ static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
        if (instr->dest.ssa.bit_size == 64)
                num_components *= 2;
 
-       for (unsigned i = 0; i < num_components; ++i) {
-               LLVMValueRef params[] = {
-                       rsrc,
-                       LLVMBuildAdd(ctx->ac.builder, LLVMConstInt(ctx->ac.i32, 4 * i, 0),
-                                    offset, "")
-               };
-               results[i] = ac_build_intrinsic(&ctx->ac, "llvm.SI.load.const.v4i32", ctx->ac.f32,
-                                               params, 2,
-                                               AC_FUNC_ATTR_READNONE |
-                                               AC_FUNC_ATTR_LEGACY);
-       }
-
+       ret = ac_build_buffer_load(&ctx->ac, rsrc, num_components, NULL, offset,
+                                  NULL, 0, false, false, true, true);
 
-       ret = ac_build_gather_values(&ctx->ac, results, num_components);
        return LLVMBuildBitCast(ctx->ac.builder, ret,
                                get_def_type(ctx, &instr->dest.ssa), "");
 }
@@ -2681,57 +2809,33 @@ get_dw_address(struct nir_to_llvm_context *ctx,
 }
 
 static LLVMValueRef
-build_varying_gather_values(struct ac_llvm_context *ctx, LLVMValueRef *values,
-                           unsigned value_count, unsigned component)
-{
-       LLVMValueRef vec = NULL;
-
-       if (value_count == 1) {
-               return values[component];
-       } else if (!value_count)
-               unreachable("value_count is 0");
-
-       for (unsigned i = component; i < value_count + component; i++) {
-               LLVMValueRef value = values[i];
-
-               if (!i)
-                       vec = LLVMGetUndef( LLVMVectorType(LLVMTypeOf(value), value_count));
-               LLVMValueRef index = LLVMConstInt(ctx->i32, i - component, false);
-               vec = LLVMBuildInsertElement(ctx->builder, vec, value, index, "");
-       }
-       return vec;
-}
-
-static LLVMValueRef
-load_tcs_input(struct nir_to_llvm_context *ctx,
-              nir_intrinsic_instr *instr)
+load_tcs_input(struct ac_shader_abi *abi,
+              LLVMValueRef vertex_index,
+              LLVMValueRef indir_index,
+              unsigned const_index,
+              unsigned location,
+              unsigned driver_location,
+              unsigned component,
+              unsigned num_components,
+              bool is_patch,
+              bool is_compact)
 {
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
        LLVMValueRef dw_addr, stride;
-       unsigned const_index;
-       LLVMValueRef vertex_index;
-       LLVMValueRef indir_index;
-       unsigned param;
        LLVMValueRef value[4], result;
-       const bool per_vertex = nir_is_per_vertex_io(instr->variables[0]->var, ctx->stage);
-       const bool is_compact = instr->variables[0]->var->data.compact;
-       param = shader_io_get_unique_index(instr->variables[0]->var->data.location);
-       get_deref_offset(ctx->nir, instr->variables[0],
-                        false, NULL, per_vertex ? &vertex_index : NULL,
-                        &const_index, &indir_index);
+       unsigned param = shader_io_get_unique_index(location);
 
        stride = unpack_param(&ctx->ac, ctx->tcs_in_layout, 13, 8);
        dw_addr = get_tcs_in_current_patch_offset(ctx);
        dw_addr = get_dw_address(ctx, dw_addr, param, const_index, is_compact, vertex_index, stride,
                                 indir_index);
 
-       unsigned comp = instr->variables[0]->var->data.location_frac;
-       for (unsigned i = 0; i < instr->num_components + comp; i++) {
+       for (unsigned i = 0; i < num_components + component; i++) {
                value[i] = ac_lds_load(&ctx->ac, dw_addr);
                dw_addr = LLVMBuildAdd(ctx->builder, dw_addr,
                                       ctx->ac.i32_1, "");
        }
-       result = build_varying_gather_values(&ctx->ac, value, instr->num_components, comp);
-       result = LLVMBuildBitCast(ctx->builder, result, get_def_type(ctx->nir, &instr->dest.ssa), "");
+       result = ac_build_varying_gather_values(&ctx->ac, value, num_components, component);
        return result;
 }
 
@@ -2769,65 +2873,73 @@ load_tcs_output(struct nir_to_llvm_context *ctx,
                dw_addr = LLVMBuildAdd(ctx->builder, dw_addr,
                                       ctx->ac.i32_1, "");
        }
-       result = build_varying_gather_values(&ctx->ac, value, instr->num_components, comp);
+       result = ac_build_varying_gather_values(&ctx->ac, value, instr->num_components, comp);
        result = LLVMBuildBitCast(ctx->builder, result, get_def_type(ctx->nir, &instr->dest.ssa), "");
        return result;
 }
 
 static void
-store_tcs_output(struct nir_to_llvm_context *ctx,
-                nir_intrinsic_instr *instr,
+store_tcs_output(struct ac_shader_abi *abi,
+                LLVMValueRef vertex_index,
+                LLVMValueRef param_index,
+                unsigned const_index,
+                unsigned location,
+                unsigned driver_location,
                 LLVMValueRef src,
+                unsigned component,
+                bool is_patch,
+                bool is_compact,
                 unsigned writemask)
 {
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
        LLVMValueRef dw_addr;
        LLVMValueRef stride = NULL;
        LLVMValueRef buf_addr = NULL;
-       LLVMValueRef vertex_index = NULL;
-       LLVMValueRef indir_index = NULL;
-       unsigned const_index = 0;
        unsigned param;
-       const unsigned comp = instr->variables[0]->var->data.location_frac;
-       const bool per_vertex = nir_is_per_vertex_io(instr->variables[0]->var, ctx->stage);
-       const bool is_compact = instr->variables[0]->var->data.compact;
+       bool store_lds = true;
 
-       get_deref_offset(ctx->nir, instr->variables[0],
-                        false, NULL, per_vertex ? &vertex_index : NULL,
-                        &const_index, &indir_index);
+       if (is_patch) {
+               if (!(ctx->tcs_patch_outputs_read & (1U << (location - VARYING_SLOT_PATCH0))))
+                       store_lds = false;
+       } else {
+               if (!(ctx->tcs_outputs_read & (1ULL << location)))
+                       store_lds = false;
+       }
 
-       param = shader_io_get_unique_index(instr->variables[0]->var->data.location);
-       if (instr->variables[0]->var->data.location == VARYING_SLOT_CLIP_DIST0 &&
+       param = shader_io_get_unique_index(location);
+       if (location == VARYING_SLOT_CLIP_DIST0 &&
            is_compact && const_index > 3) {
                const_index -= 3;
                param++;
        }
 
-       if (!instr->variables[0]->var->data.patch) {
+       if (!is_patch) {
                stride = unpack_param(&ctx->ac, ctx->tcs_out_layout, 13, 8);
                dw_addr = get_tcs_out_current_patch_offset(ctx);
        } else {
                dw_addr = get_tcs_out_current_patch_data_offset(ctx);
        }
 
-       mark_tess_output(ctx, instr->variables[0]->var->data.patch, param);
+       mark_tess_output(ctx, is_patch, param);
 
        dw_addr = get_dw_address(ctx, dw_addr, param, const_index, is_compact, vertex_index, stride,
-                                indir_index);
+                                param_index);
        buf_addr = get_tcs_tes_buffer_address_params(ctx, param, const_index, is_compact,
-                                                    vertex_index, indir_index);
+                                                    vertex_index, param_index);
 
        bool is_tess_factor = false;
-       if (instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
-           instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER)
+       if (location == VARYING_SLOT_TESS_LEVEL_INNER ||
+           location == VARYING_SLOT_TESS_LEVEL_OUTER)
                is_tess_factor = true;
 
        unsigned base = is_compact ? const_index : 0;
        for (unsigned chan = 0; chan < 8; chan++) {
                if (!(writemask & (1 << chan)))
                        continue;
-               LLVMValueRef value = llvm_extract_elem(&ctx->ac, src, chan - comp);
+               LLVMValueRef value = ac_llvm_extract_elem(&ctx->ac, src, chan - component);
 
-               ac_lds_store(&ctx->ac, dw_addr, value);
+               if (store_lds || is_tess_factor)
+                       ac_lds_store(&ctx->ac, dw_addr, value);
 
                if (!is_tess_factor && writemask != 0xF)
                        ac_build_buffer_store_dword(&ctx->ac, ctx->hs_ring_tess_offchip, value, 1,
@@ -2846,64 +2958,63 @@ store_tcs_output(struct nir_to_llvm_context *ctx,
 }
 
 static LLVMValueRef
-load_tes_input(struct nir_to_llvm_context *ctx,
-              const nir_intrinsic_instr *instr)
+load_tes_input(struct ac_shader_abi *abi,
+              LLVMValueRef vertex_index,
+              LLVMValueRef param_index,
+              unsigned const_index,
+              unsigned location,
+              unsigned driver_location,
+              unsigned component,
+              unsigned num_components,
+              bool is_patch,
+              bool is_compact)
 {
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
        LLVMValueRef buf_addr;
        LLVMValueRef result;
-       LLVMValueRef vertex_index = NULL;
-       LLVMValueRef indir_index = NULL;
-       unsigned const_index = 0;
-       unsigned param;
-       const bool per_vertex = nir_is_per_vertex_io(instr->variables[0]->var, ctx->stage);
-       const bool is_compact = instr->variables[0]->var->data.compact;
+       unsigned param = shader_io_get_unique_index(location);
 
-       get_deref_offset(ctx->nir, instr->variables[0],
-                        false, NULL, per_vertex ? &vertex_index : NULL,
-                        &const_index, &indir_index);
-       param = shader_io_get_unique_index(instr->variables[0]->var->data.location);
-       if (instr->variables[0]->var->data.location == VARYING_SLOT_CLIP_DIST0 &&
-           is_compact && const_index > 3) {
+       if (location == VARYING_SLOT_CLIP_DIST0 && is_compact && const_index > 3) {
                const_index -= 3;
                param++;
        }
 
-       unsigned comp = instr->variables[0]->var->data.location_frac;
        buf_addr = get_tcs_tes_buffer_address_params(ctx, param, const_index,
-                                                    is_compact, vertex_index, indir_index);
+                                                    is_compact, vertex_index, param_index);
 
-       LLVMValueRef comp_offset = LLVMConstInt(ctx->ac.i32, comp * 4, false);
+       LLVMValueRef comp_offset = LLVMConstInt(ctx->ac.i32, component * 4, false);
        buf_addr = LLVMBuildAdd(ctx->builder, buf_addr, comp_offset, "");
 
-       result = ac_build_buffer_load(&ctx->ac, ctx->hs_ring_tess_offchip, instr->num_components, NULL,
+       result = ac_build_buffer_load(&ctx->ac, ctx->hs_ring_tess_offchip, num_components, NULL,
                                      buf_addr, ctx->oc_lds, is_compact ? (4 * const_index) : 0, 1, 0, true, false);
-       result = trim_vector(&ctx->ac, result, instr->num_components);
-       result = LLVMBuildBitCast(ctx->builder, result, get_def_type(ctx->nir, &instr->dest.ssa), "");
+       result = trim_vector(&ctx->ac, result, num_components);
        return result;
 }
 
 static LLVMValueRef
-load_gs_input(struct nir_to_llvm_context *ctx,
-             nir_intrinsic_instr *instr)
+load_gs_input(struct ac_shader_abi *abi,
+             unsigned location,
+             unsigned driver_location,
+             unsigned component,
+             unsigned num_components,
+             unsigned vertex_index,
+             unsigned const_index,
+             LLVMTypeRef type)
 {
-       LLVMValueRef indir_index, vtx_offset;
-       unsigned const_index;
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
+       LLVMValueRef vtx_offset;
        LLVMValueRef args[9];
        unsigned param, vtx_offset_param;
        LLVMValueRef value[4], result;
-       unsigned vertex_index;
-       get_deref_offset(ctx->nir, instr->variables[0],
-                        false, &vertex_index, NULL,
-                        &const_index, &indir_index);
+
        vtx_offset_param = vertex_index;
        assert(vtx_offset_param < 6);
        vtx_offset = LLVMBuildMul(ctx->builder, ctx->gs_vtx_offset[vtx_offset_param],
                                  LLVMConstInt(ctx->ac.i32, 4, false), "");
 
-       param = shader_io_get_unique_index(instr->variables[0]->var->data.location);
+       param = shader_io_get_unique_index(location);
 
-       unsigned comp = instr->variables[0]->var->data.location_frac;
-       for (unsigned i = comp; i < instr->num_components + comp; i++) {
+       for (unsigned i = component; i < num_components + component; i++) {
                if (ctx->ac.chip_class >= GFX9) {
                        LLVMValueRef dw_addr = ctx->gs_vtx_offset[vtx_offset_param];
                        dw_addr = LLVMBuildAdd(ctx->ac.builder, dw_addr,
@@ -2926,7 +3037,7 @@ load_gs_input(struct nir_to_llvm_context *ctx,
                                                      AC_FUNC_ATTR_LEGACY);
                }
        }
-       result = build_varying_gather_values(&ctx->ac, value, instr->num_components, comp);
+       result = ac_build_varying_gather_values(&ctx->ac, value, num_components, component);
 
        return result;
 }
@@ -2980,6 +3091,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
        LLVMValueRef indir_index;
        LLVMValueRef ret;
        unsigned const_index;
+       unsigned stride = instr->variables[0]->var->data.compact ? 1 : 4;
        bool vs_in = ctx->stage == MESA_SHADER_VERTEX &&
                     instr->variables[0]->var->data.mode == nir_var_shader_in;
        get_deref_offset(ctx, instr->variables[0], vs_in, NULL, NULL,
@@ -2990,12 +3102,40 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
 
        switch (instr->variables[0]->var->data.mode) {
        case nir_var_shader_in:
-               if (ctx->stage == MESA_SHADER_TESS_CTRL)
-                       return load_tcs_input(ctx->nctx, instr);
-               if (ctx->stage == MESA_SHADER_TESS_EVAL)
-                       return load_tes_input(ctx->nctx, instr);
+               if (ctx->stage == MESA_SHADER_TESS_CTRL ||
+                   ctx->stage == MESA_SHADER_TESS_EVAL) {
+                       LLVMValueRef result;
+                       LLVMValueRef vertex_index = NULL;
+                       LLVMValueRef indir_index = NULL;
+                       unsigned const_index = 0;
+                       unsigned location = instr->variables[0]->var->data.location;
+                       unsigned driver_location = instr->variables[0]->var->data.driver_location;
+                       const bool is_patch =  instr->variables[0]->var->data.patch;
+                       const bool is_compact = instr->variables[0]->var->data.compact;
+
+                       get_deref_offset(ctx, instr->variables[0],
+                                        false, NULL, is_patch ? NULL : &vertex_index,
+                                        &const_index, &indir_index);
+
+                       result = ctx->abi->load_tess_inputs(ctx->abi, vertex_index, indir_index,
+                                                           const_index, location, driver_location,
+                                                           instr->variables[0]->var->data.location_frac,
+                                                           instr->num_components,
+                                                           is_patch, is_compact);
+                       return LLVMBuildBitCast(ctx->ac.builder, result, get_def_type(ctx, &instr->dest.ssa), "");
+               }
+
                if (ctx->stage == MESA_SHADER_GEOMETRY) {
-                       return load_gs_input(ctx->nctx, instr);
+                               LLVMValueRef indir_index;
+                               unsigned const_index, vertex_index;
+                               get_deref_offset(ctx, instr->variables[0],
+                                                false, &vertex_index, NULL,
+                                                &const_index, &indir_index);
+                       return ctx->abi->load_inputs(ctx->abi, instr->variables[0]->var->data.location,
+                                                    instr->variables[0]->var->data.driver_location,
+                                                    instr->variables[0]->var->data.location_frac, ve,
+                                                    vertex_index, const_index,
+                                                    nir2llvmtype(ctx, instr->variables[0]->var->type));
                }
 
                for (unsigned chan = comp; chan < ve + comp; chan++) {
@@ -3006,13 +3146,13 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
                                                &ctx->ac, ctx->abi->inputs + idx + chan, count,
-                                               4, false, true);
+                                               stride, false, true);
 
                                values[chan] = LLVMBuildExtractElement(ctx->ac.builder,
                                                                       tmp_vec,
                                                                       indir_index, "");
                        } else
-                               values[chan] = ctx->abi->inputs[idx + chan + const_index * 4];
+                               values[chan] = ctx->abi->inputs[idx + chan + const_index * stride];
                }
                break;
        case nir_var_local:
@@ -3023,13 +3163,13 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
                                                &ctx->ac, ctx->locals + idx + chan, count,
-                                               4, true, true);
+                                               stride, true, true);
 
                                values[chan] = LLVMBuildExtractElement(ctx->ac.builder,
                                                                       tmp_vec,
                                                                       indir_index, "");
                        } else {
-                               values[chan] = LLVMBuildLoad(ctx->ac.builder, ctx->locals[idx + chan + const_index * 4], "");
+                               values[chan] = LLVMBuildLoad(ctx->ac.builder, ctx->locals[idx + chan + const_index * stride], "");
                        }
                }
                break;
@@ -3052,14 +3192,14 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                                count -= chan / 4;
                                LLVMValueRef tmp_vec = ac_build_gather_values_extended(
                                                &ctx->ac, ctx->outputs + idx + chan, count,
-                                               4, true, true);
+                                               stride, true, true);
 
                                values[chan] = LLVMBuildExtractElement(ctx->ac.builder,
                                                                       tmp_vec,
                                                                       indir_index, "");
                        } else {
                                values[chan] = LLVMBuildLoad(ctx->ac.builder,
-                                                    ctx->outputs[idx + chan + const_index * 4],
+                                                    ctx->outputs[idx + chan + const_index * stride],
                                                     "");
                        }
                }
@@ -3067,7 +3207,7 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
        default:
                unreachable("unhandle variable mode");
        }
-       ret = build_varying_gather_values(&ctx->ac, values, ve, comp);
+       ret = ac_build_varying_gather_values(&ctx->ac, values, ve, comp);
        return LLVMBuildBitCast(ctx->ac.builder, ret, get_def_type(ctx, &instr->dest.ssa), "");
 }
 
@@ -3089,7 +3229,7 @@ visit_store_var(struct ac_nir_context *ctx,
                int old_writemask = writemask;
 
                src = LLVMBuildBitCast(ctx->ac.builder, src,
-                                      LLVMVectorType(ctx->ac.f32, get_llvm_num_components(src) * 2),
+                                      LLVMVectorType(ctx->ac.f32, ac_get_llvm_num_components(src) * 2),
                                       "");
 
                writemask = 0;
@@ -3103,7 +3243,22 @@ visit_store_var(struct ac_nir_context *ctx,
        case nir_var_shader_out:
 
                if (ctx->stage == MESA_SHADER_TESS_CTRL) {
-                       store_tcs_output(ctx->nctx, instr, src, writemask);
+                       LLVMValueRef vertex_index = NULL;
+                       LLVMValueRef indir_index = NULL;
+                       unsigned const_index = 0;
+                       const unsigned location = instr->variables[0]->var->data.location;
+                       const unsigned driver_location = instr->variables[0]->var->data.driver_location;
+                       const unsigned comp = instr->variables[0]->var->data.location_frac;
+                       const bool is_patch = instr->variables[0]->var->data.patch;
+                       const bool is_compact = instr->variables[0]->var->data.compact;
+
+                       get_deref_offset(ctx, instr->variables[0],
+                                        false, NULL, is_patch ? NULL : &vertex_index,
+                                        &const_index, &indir_index);
+
+                       ctx->abi->store_tcs_outputs(ctx->abi, vertex_index, indir_index,
+                                                   const_index, location, driver_location,
+                                                   src, comp, is_patch, is_compact, writemask);
                        return;
                }
 
@@ -3112,7 +3267,7 @@ visit_store_var(struct ac_nir_context *ctx,
                        if (!(writemask & (1 << chan)))
                                continue;
 
-                       value = llvm_extract_elem(&ctx->ac, src, chan - comp);
+                       value = ac_llvm_extract_elem(&ctx->ac, src, chan - comp);
 
                        if (instr->variables[0]->var->data.compact)
                                stride = 1;
@@ -3141,7 +3296,7 @@ visit_store_var(struct ac_nir_context *ctx,
                        if (!(writemask & (1 << chan)))
                                continue;
 
-                       value = llvm_extract_elem(&ctx->ac, src, chan);
+                       value = ac_llvm_extract_elem(&ctx->ac, src, chan);
                        if (indir_index) {
                                unsigned count = glsl_count_attribute_slots(
                                        instr->variables[0]->var->type, false);
@@ -3181,8 +3336,8 @@ visit_store_var(struct ac_nir_context *ctx,
                                LLVMValueRef ptr =
                                        LLVMBuildStructGEP(ctx->ac.builder,
                                                           address, chan, "");
-                               LLVMValueRef src = llvm_extract_elem(&ctx->ac, val,
-                                                                    chan);
+                               LLVMValueRef src = ac_llvm_extract_elem(&ctx->ac, val,
+                                                                       chan);
                                src = LLVMBuildBitCast(
                                   ctx->ac.builder, src,
                                   LLVMGetElementType(LLVMTypeOf(ptr)), "");
@@ -3314,7 +3469,7 @@ static LLVMValueRef get_image_coords(struct ac_nir_context *ctx,
                LLVMConstInt(ctx->ac.i32, 2, false), LLVMConstInt(ctx->ac.i32, 3, false),
        };
        LLVMValueRef res;
-       LLVMValueRef sample_index = llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[1]), 0);
+       LLVMValueRef sample_index = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[1]), 0);
 
        int count;
        enum glsl_sampler_dim dim = glsl_get_sampler_dim(type);
@@ -3361,7 +3516,7 @@ static LLVMValueRef get_image_coords(struct ac_nir_context *ctx,
                if (is_ms)
                        count--;
                for (chan = 0; chan < count; ++chan) {
-                       coords[chan] = llvm_extract_elem(&ctx->ac, src0, chan);
+                       coords[chan] = ac_llvm_extract_elem(&ctx->ac, src0, chan);
                }
                if (add_frag_pos) {
                        for (chan = 0; chan < 2; ++chan)
@@ -3644,29 +3799,43 @@ static LLVMValueRef visit_image_size(struct ac_nir_context *ctx,
 #define LGKM_CNT 0x07f
 #define VM_CNT 0xf70
 
-static void emit_waitcnt(struct nir_to_llvm_context *ctx,
-                        unsigned simm16)
+static void emit_membar(struct nir_to_llvm_context *ctx,
+                       const nir_intrinsic_instr *instr)
 {
-       LLVMValueRef args[1] = {
-               LLVMConstInt(ctx->ac.i32, simm16, false),
-       };
-       ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.s.waitcnt",
-                          ctx->ac.voidt, args, 1, 0);
+       unsigned waitcnt = NOOP_WAITCNT;
+
+       switch (instr->intrinsic) {
+       case nir_intrinsic_memory_barrier:
+       case nir_intrinsic_group_memory_barrier:
+               waitcnt &= VM_CNT & LGKM_CNT;
+               break;
+       case nir_intrinsic_memory_barrier_atomic_counter:
+       case nir_intrinsic_memory_barrier_buffer:
+       case nir_intrinsic_memory_barrier_image:
+               waitcnt &= VM_CNT;
+               break;
+       case nir_intrinsic_memory_barrier_shared:
+               waitcnt &= LGKM_CNT;
+               break;
+       default:
+               break;
+       }
+       if (waitcnt != NOOP_WAITCNT)
+               ac_build_waitcnt(&ctx->ac, waitcnt);
 }
 
-static void emit_barrier(struct nir_to_llvm_context *ctx)
+static void emit_barrier(struct ac_llvm_context *ac, gl_shader_stage stage)
 {
        /* SI only (thanks to a hw bug workaround):
         * The real barrier instruction isn’t needed, because an entire patch
         * always fits into a single wave.
         */
-       if (ctx->options->chip_class == SI &&
-           ctx->stage == MESA_SHADER_TESS_CTRL) {
-               emit_waitcnt(ctx, LGKM_CNT & VM_CNT);
+       if (ac->chip_class == SI && stage == MESA_SHADER_TESS_CTRL) {
+               ac_build_waitcnt(ac, LGKM_CNT & VM_CNT);
                return;
        }
-       ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.s.barrier",
-                          ctx->ac.voidt, NULL, 0, AC_FUNC_ATTR_CONVERGENT);
+       ac_build_intrinsic(ac, "llvm.amdgcn.s.barrier",
+                          ac->voidt, NULL, 0, AC_FUNC_ATTR_CONVERGENT);
 }
 
 static void emit_discard_if(struct ac_nir_context *ctx,
@@ -3904,19 +4073,18 @@ static LLVMValueRef visit_interp(struct nir_to_llvm_context *ctx,
                                                              ctx->prim_mask);
                }
        }
-       return build_varying_gather_values(&ctx->ac, result, instr->num_components,
-                                          instr->variables[0]->var->data.location_frac);
+       return ac_build_varying_gather_values(&ctx->ac, result, instr->num_components,
+                                             instr->variables[0]->var->data.location_frac);
 }
 
 static void
-visit_emit_vertex(struct nir_to_llvm_context *ctx,
-                 const nir_intrinsic_instr *instr)
+visit_emit_vertex(struct ac_shader_abi *abi, unsigned stream, LLVMValueRef *addrs)
 {
        LLVMValueRef gs_next_vertex;
        LLVMValueRef can_emit;
        int idx;
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
 
-       assert(instr->const_index[0] == 0);
        /* Write vertex attribute values to GSVS ring */
        gs_next_vertex = LLVMBuildLoad(ctx->builder,
                                       ctx->gs_next_vertex,
@@ -3934,7 +4102,7 @@ visit_emit_vertex(struct nir_to_llvm_context *ctx,
        /* loop num outputs */
        idx = 0;
        for (unsigned i = 0; i < RADEON_LLVM_MAX_OUTPUTS; ++i) {
-               LLVMValueRef *out_ptr = &ctx->nir->outputs[i * 4];
+               LLVMValueRef *out_ptr = &addrs[i * 4];
                int length = 4;
                int slot = idx;
                int slot_inc = 1;
@@ -3980,9 +4148,11 @@ visit_end_primitive(struct nir_to_llvm_context *ctx,
 }
 
 static LLVMValueRef
-visit_load_tess_coord(struct nir_to_llvm_context *ctx,
-                     const nir_intrinsic_instr *instr)
+load_tess_coord(struct ac_shader_abi *abi, LLVMTypeRef type,
+               unsigned num_components)
 {
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
+
        LLVMValueRef coord[4] = {
                ctx->tes_u,
                ctx->tes_v,
@@ -3994,9 +4164,15 @@ visit_load_tess_coord(struct nir_to_llvm_context *ctx,
                coord[2] = LLVMBuildFSub(ctx->builder, ctx->ac.f32_1,
                                        LLVMBuildFAdd(ctx->builder, coord[0], coord[1], ""), "");
 
-       LLVMValueRef result = ac_build_gather_values(&ctx->ac, coord, instr->num_components);
-       return LLVMBuildBitCast(ctx->builder, result,
-                               get_def_type(ctx->nir, &instr->dest.ssa), "");
+       LLVMValueRef result = ac_build_gather_values(&ctx->ac, coord, num_components);
+       return LLVMBuildBitCast(ctx->builder, result, type, "");
+}
+
+static LLVMValueRef
+load_patch_vertices_in(struct ac_shader_abi *abi)
+{
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
+       return LLVMConstInt(ctx->ac.i32, ctx->options->key.tcs.input_vertices, false);
 }
 
 static void visit_intrinsic(struct ac_nir_context *ctx,
@@ -4006,7 +4182,14 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
 
        switch (instr->intrinsic) {
        case nir_intrinsic_load_work_group_id: {
-               result = ctx->nctx->workgroup_ids;
+               LLVMValueRef values[3];
+
+               for (int i = 0; i < 3; i++) {
+                       values[i] = ctx->nctx->workgroup_ids[i] ?
+                                   ctx->nctx->workgroup_ids[i] : ctx->ac.i32_0;
+               }
+
+               result = ac_build_gather_values(&ctx->ac, values, 3);
                break;
        }
        case nir_intrinsic_load_base_vertex: {
@@ -4032,20 +4215,17 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                break;
        case nir_intrinsic_load_invocation_id:
                if (ctx->stage == MESA_SHADER_TESS_CTRL)
-                       result = unpack_param(&ctx->ac, ctx->nctx->tcs_rel_ids, 8, 5);
+                       result = unpack_param(&ctx->ac, ctx->abi->tcs_rel_ids, 8, 5);
                else
-                       result = ctx->nctx->gs_invocation_id;
+                       result = ctx->abi->gs_invocation_id;
                break;
        case nir_intrinsic_load_primitive_id:
                if (ctx->stage == MESA_SHADER_GEOMETRY) {
-                       ctx->nctx->shader_info->gs.uses_prim_id = true;
-                       result = ctx->nctx->gs_prim_id;
+                       result = ctx->abi->gs_prim_id;
                } else if (ctx->stage == MESA_SHADER_TESS_CTRL) {
-                       ctx->nctx->shader_info->tcs.uses_prim_id = true;
-                       result = ctx->nctx->tcs_patch_id;
+                       result = ctx->abi->tcs_patch_id;
                } else if (ctx->stage == MESA_SHADER_TESS_EVAL) {
-                       ctx->nctx->shader_info->tcs.uses_prim_id = true;
-                       result = ctx->nctx->tes_patch_id;
+                       result = ctx->abi->tes_patch_id;
                } else
                        fprintf(stderr, "Unknown primitive id intrinsic: %d", ctx->stage);
                break;
@@ -4086,6 +4266,9 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_vulkan_resource_index:
                result = visit_vulkan_resource_index(ctx->nctx, instr);
                break;
+       case nir_intrinsic_vulkan_resource_reindex:
+               result = visit_vulkan_resource_reindex(ctx->nctx, instr);
+               break;
        case nir_intrinsic_store_ssbo:
                visit_store_ssbo(ctx, instr);
                break;
@@ -4144,10 +4327,15 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                emit_discard_if(ctx, instr);
                break;
        case nir_intrinsic_memory_barrier:
-               emit_waitcnt(ctx->nctx, VM_CNT);
+       case nir_intrinsic_group_memory_barrier:
+       case nir_intrinsic_memory_barrier_atomic_counter:
+       case nir_intrinsic_memory_barrier_buffer:
+       case nir_intrinsic_memory_barrier_image:
+       case nir_intrinsic_memory_barrier_shared:
+               emit_membar(ctx->nctx, instr);
                break;
        case nir_intrinsic_barrier:
-               emit_barrier(ctx->nctx);
+               emit_barrier(&ctx->ac, ctx->stage);
                break;
        case nir_intrinsic_var_atomic_add:
        case nir_intrinsic_var_atomic_imin:
@@ -4167,16 +4355,27 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = visit_interp(ctx->nctx, instr);
                break;
        case nir_intrinsic_emit_vertex:
-               visit_emit_vertex(ctx->nctx, instr);
+               assert(instr->const_index[0] == 0);
+               ctx->abi->emit_vertex(ctx->abi, 0, ctx->outputs);
                break;
        case nir_intrinsic_end_primitive:
                visit_end_primitive(ctx->nctx, instr);
                break;
-       case nir_intrinsic_load_tess_coord:
-               result = visit_load_tess_coord(ctx->nctx, instr);
+       case nir_intrinsic_load_tess_coord: {
+               LLVMTypeRef type = ctx->nctx ?
+                       get_def_type(ctx->nctx->nir, &instr->dest.ssa) :
+                       NULL;
+               result = ctx->abi->load_tess_coord(ctx->abi, type, instr->num_components);
+               break;
+       }
+       case nir_intrinsic_load_tess_level_outer:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_OUTER);
+               break;
+       case nir_intrinsic_load_tess_level_inner:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_INNER);
                break;
        case nir_intrinsic_load_patch_vertices_in:
-               result = LLVMConstInt(ctx->ac.i32, ctx->nctx->options->key.tcs.input_vertices, false);
+               result = ctx->abi->load_patch_vertices_in(ctx->abi);
                break;
        default:
                fprintf(stderr, "Unknown intrinsic: ");
@@ -4190,14 +4389,21 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
 }
 
 static LLVMValueRef radv_load_ssbo(struct ac_shader_abi *abi,
-                                  LLVMValueRef buffer, bool write)
+                                  LLVMValueRef buffer_ptr, bool write)
 {
        struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
 
        if (write && ctx->stage == MESA_SHADER_FRAGMENT)
                ctx->shader_info->fs.writes_memory = true;
 
-       return buffer;
+       return LLVMBuildLoad(ctx->builder, buffer_ptr, "");
+}
+
+static LLVMValueRef radv_load_ubo(struct ac_shader_abi *abi, LLVMValueRef buffer_ptr)
+{
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
+
+       return LLVMBuildLoad(ctx->builder, buffer_ptr, "");
 }
 
 static LLVMValueRef radv_get_sampler_desc(struct ac_shader_abi *abi,
@@ -4521,7 +4727,7 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
 
        if (coord)
                for (chan = 0; chan < instr->coord_components; chan++)
-                       coords[chan] = llvm_extract_elem(&ctx->ac, coord, chan);
+                       coords[chan] = ac_llvm_extract_elem(&ctx->ac, coord, chan);
 
        if (offsets && instr->op != nir_texop_txf) {
                LLVMValueRef offset[3], pack;
@@ -4529,8 +4735,8 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
                        offset[chan] = ctx->ac.i32_0;
 
                args.offset = true;
-               for (chan = 0; chan < get_llvm_num_components(offsets); chan++) {
-                       offset[chan] = llvm_extract_elem(&ctx->ac, offsets, chan);
+               for (chan = 0; chan < ac_get_llvm_num_components(offsets); chan++) {
+                       offset[chan] = ac_llvm_extract_elem(&ctx->ac, offsets, chan);
                        offset[chan] = LLVMBuildAnd(ctx->ac.builder, offset[chan],
                                                    LLVMConstInt(ctx->ac.i32, 0x3f, false), "");
                        if (chan)
@@ -4550,7 +4756,7 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
        /* Pack depth comparison value */
        if (instr->is_shadow && comparator) {
                LLVMValueRef z = ac_to_float(&ctx->ac,
-                                            llvm_extract_elem(&ctx->ac, comparator, 0));
+                                            ac_llvm_extract_elem(&ctx->ac, comparator, 0));
 
                /* TC-compatible HTILE on radeonsi promotes Z16 and Z24 to Z32_FLOAT,
                 * so the depth comparison value isn't clamped for Z16 and
@@ -4594,8 +4800,8 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
                }
 
                for (unsigned i = 0; i < num_src_deriv_channels; i++) {
-                       derivs[i] = ac_to_float(&ctx->ac, llvm_extract_elem(&ctx->ac, ddx, i));
-                       derivs[num_dest_deriv_channels + i] = ac_to_float(&ctx->ac, llvm_extract_elem(&ctx->ac, ddy, i));
+                       derivs[i] = ac_to_float(&ctx->ac, ac_llvm_extract_elem(&ctx->ac, ddx, i));
+                       derivs[num_dest_deriv_channels + i] = ac_to_float(&ctx->ac, ac_llvm_extract_elem(&ctx->ac, ddy, i));
                }
                for (unsigned i = num_src_deriv_channels; i < num_dest_deriv_channels; i++) {
                        derivs[i] = ctx->ac.f32_0;
@@ -4894,7 +5100,7 @@ static void visit_if(struct ac_nir_context *ctx, nir_if *if_stmt)
                    ctx->ac.context, fn, "");
 
        LLVMValueRef cond = LLVMBuildICmp(ctx->ac.builder, LLVMIntNE, value,
-                                         LLVMConstInt(ctx->ac.i32, 0, false), "");
+                                         ctx->ac.i32_0, "");
        LLVMBuildCondBr(ctx->ac.builder, cond, if_block, else_block);
 
        LLVMPositionBuilderAtEnd(ctx->ac.builder, if_block);
@@ -4977,8 +5183,13 @@ handle_vs_input_decl(struct nir_to_llvm_context *ctx,
        if (ctx->options->key.vs.instance_rate_inputs & (1u << index)) {
                buffer_index = LLVMBuildAdd(ctx->builder, ctx->abi.instance_id,
                                            ctx->abi.start_instance, "");
-               ctx->shader_info->vs.vgpr_comp_cnt = MAX2(3,
-                                           ctx->shader_info->vs.vgpr_comp_cnt);
+               if (ctx->options->key.vs.as_ls) {
+                       ctx->shader_info->vs.vgpr_comp_cnt =
+                               MAX2(2, ctx->shader_info->vs.vgpr_comp_cnt);
+               } else {
+                       ctx->shader_info->vs.vgpr_comp_cnt =
+                               MAX2(1, ctx->shader_info->vs.vgpr_comp_cnt);
+               }
        } else
                buffer_index = LLVMBuildAdd(ctx->builder, ctx->abi.vertex_id,
                                            ctx->abi.base_vertex, "");
@@ -4990,7 +5201,7 @@ handle_vs_input_decl(struct nir_to_llvm_context *ctx,
 
                input = ac_build_buffer_load_format(&ctx->ac, t_list,
                                                    buffer_index,
-                                                   LLVMConstInt(ctx->ac.i32, 0, false),
+                                                   ctx->ac.i32_0,
                                                    true);
 
                for (unsigned chan = 0; chan < 4; chan++) {
@@ -5354,6 +5565,7 @@ setup_locals(struct ac_nir_context *ctx,
        nir_foreach_variable(variable, &func->impl->locals) {
                unsigned attrib_count = glsl_count_attribute_slots(variable->type, false);
                variable->data.driver_location = ctx->num_locals * 4;
+               variable->data.location_frac = 0;
                ctx->num_locals += attrib_count;
        }
        ctx->locals = malloc(4 * ctx->num_locals * sizeof(LLVMValueRef));
@@ -5961,13 +6173,13 @@ write_tess_factors(struct nir_to_llvm_context *ctx)
 {
        unsigned stride, outer_comps, inner_comps;
        struct ac_build_if_state if_ctx, inner_if_ctx;
-       LLVMValueRef invocation_id = unpack_param(&ctx->ac, ctx->tcs_rel_ids, 8, 5);
-       LLVMValueRef rel_patch_id = unpack_param(&ctx->ac, ctx->tcs_rel_ids, 0, 8);
+       LLVMValueRef invocation_id = unpack_param(&ctx->ac, ctx->abi.tcs_rel_ids, 8, 5);
+       LLVMValueRef rel_patch_id = unpack_param(&ctx->ac, ctx->abi.tcs_rel_ids, 0, 8);
        unsigned tess_inner_index, tess_outer_index;
        LLVMValueRef lds_base, lds_inner, lds_outer, byteoffset, buffer;
        LLVMValueRef out[6], vec0, vec1, tf_base, inner[4], outer[4];
        int i;
-       emit_barrier(ctx);
+       emit_barrier(&ctx->ac, ctx->stage);
 
        switch (ctx->options->key.tcs.primitive_mode) {
        case GL_ISOLINES:
@@ -6013,20 +6225,20 @@ write_tess_factors(struct nir_to_llvm_context *ctx)
        if (ctx->options->key.tcs.primitive_mode == GL_ISOLINES) {
                outer[0] = out[1] = ac_lds_load(&ctx->ac, lds_outer);
                lds_outer = LLVMBuildAdd(ctx->builder, lds_outer,
-                                        LLVMConstInt(ctx->ac.i32, 1, false), "");
+                                        ctx->ac.i32_1, "");
                outer[1] = out[0] = ac_lds_load(&ctx->ac, lds_outer);
        } else {
                for (i = 0; i < outer_comps; i++) {
                        outer[i] = out[i] =
                                ac_lds_load(&ctx->ac, lds_outer);
                        lds_outer = LLVMBuildAdd(ctx->builder, lds_outer,
-                                                LLVMConstInt(ctx->ac.i32, 1, false), "");
+                                                ctx->ac.i32_1, "");
                }
                for (i = 0; i < inner_comps; i++) {
                        inner[i] = out[outer_comps+i] =
                                ac_lds_load(&ctx->ac, lds_inner);
                        lds_inner = LLVMBuildAdd(ctx->builder, lds_inner,
-                                                LLVMConstInt(ctx->ac.i32, 1, false), "");
+                                                ctx->ac.i32_1, "");
                }
        }
 
@@ -6124,44 +6336,13 @@ si_export_mrt_color(struct nir_to_llvm_context *ctx,
 }
 
 static void
-si_export_mrt_z(struct nir_to_llvm_context *ctx,
-               LLVMValueRef depth, LLVMValueRef stencil,
-               LLVMValueRef samplemask)
+radv_export_mrt_z(struct nir_to_llvm_context *ctx,
+                 LLVMValueRef depth, LLVMValueRef stencil,
+                 LLVMValueRef samplemask)
 {
        struct ac_export_args args;
 
-       args.enabled_channels = 0;
-       args.valid_mask = 1;
-       args.done = 1;
-       args.target = V_008DFC_SQ_EXP_MRTZ;
-       args.compr = false;
-
-       args.out[0] = LLVMGetUndef(ctx->ac.f32); /* R, depth */
-       args.out[1] = LLVMGetUndef(ctx->ac.f32); /* G, stencil test val[0:7], stencil op val[8:15] */
-       args.out[2] = LLVMGetUndef(ctx->ac.f32); /* B, sample mask */
-       args.out[3] = LLVMGetUndef(ctx->ac.f32); /* A, alpha to mask */
-
-       if (depth) {
-               args.out[0] = depth;
-               args.enabled_channels |= 0x1;
-       }
-
-       if (stencil) {
-               args.out[1] = stencil;
-               args.enabled_channels |= 0x2;
-       }
-
-       if (samplemask) {
-               args.out[2] = samplemask;
-               args.enabled_channels |= 0x4;
-       }
-
-       /* SI (except OLAND and HAINAN) has a bug that it only looks
-        * at the X writemask component. */
-       if (ctx->options->chip_class == SI &&
-           ctx->options->family != CHIP_OLAND &&
-           ctx->options->family != CHIP_HAINAN)
-               args.enabled_channels |= 0x1;
+       ac_export_mrt_z(&ctx->ac, depth, stencil, samplemask, &args);
 
        ac_build_export(&ctx->ac, &args);
 }
@@ -6209,7 +6390,7 @@ handle_fs_outputs_post(struct nir_to_llvm_context *ctx)
        for (unsigned i = 0; i < index; i++)
                ac_build_export(&ctx->ac, &color_args[i]);
        if (depth || stencil || samplemask)
-               si_export_mrt_z(ctx, depth, stencil, samplemask);
+               radv_export_mrt_z(ctx, depth, stencil, samplemask);
        else if (!index) {
                si_export_mrt_color(ctx, NULL, V_008DFC_SQ_EXP_NULL, true, &color_args[0]);
                ac_build_export(&ctx->ac, &color_args[0]);
@@ -6379,11 +6560,11 @@ static void ac_nir_fixup_ls_hs_input_vgprs(struct nir_to_llvm_context *ctx)
                                          LLVMConstInt(ctx->ac.i32, 8, false),
                                          LLVMConstInt(ctx->ac.i32, 8, false), false);
        LLVMValueRef hs_empty = LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, count,
-                                             LLVMConstInt(ctx->ac.i32, 0, false), "");
+                                             ctx->ac.i32_0, "");
        ctx->abi.instance_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->rel_auto_id, ctx->abi.instance_id, "");
        ctx->vs_prim_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->abi.vertex_id, ctx->vs_prim_id, "");
-       ctx->rel_auto_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->tcs_rel_ids, ctx->rel_auto_id, "");
-       ctx->abi.vertex_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->tcs_patch_id, ctx->abi.vertex_id, "");
+       ctx->rel_auto_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->abi.tcs_rel_ids, ctx->rel_auto_id, "");
+       ctx->abi.vertex_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->abi.tcs_patch_id, ctx->abi.vertex_id, "");
 }
 
 static void prepare_gs_input_vgprs(struct nir_to_llvm_context *ctx)
@@ -6462,7 +6643,8 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
        ctx.context = LLVMContextCreate();
        ctx.module = LLVMModuleCreateWithNameInContext("shader", ctx.context);
 
-       ac_llvm_context_init(&ctx.ac, ctx.context, options->chip_class);
+       ac_llvm_context_init(&ctx.ac, ctx.context, options->chip_class,
+                            options->family);
        ctx.ac.module = ctx.module;
        LLVMSetTarget(ctx.module, options->supports_spill ? "amdgcn-mesa-mesa3d" : "amdgcn--");
 
@@ -6497,6 +6679,8 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
 
        ctx.abi.inputs = &ctx.inputs[0];
        ctx.abi.emit_outputs = handle_shader_outputs_post;
+       ctx.abi.emit_vertex = visit_emit_vertex;
+       ctx.abi.load_ubo = radv_load_ubo;
        ctx.abi.load_ssbo = radv_load_ssbo;
        ctx.abi.load_sampler_desc = radv_get_sampler_desc;
        ctx.abi.clamp_shadow_reference = false;
@@ -6517,21 +6701,36 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
 
                if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY) {
                        ctx.gs_next_vertex = ac_build_alloca(&ctx.ac, ctx.ac.i32, "gs_next_vertex");
-
                        ctx.gs_max_out_vertices = shaders[i]->info.gs.vertices_out;
+                       ctx.abi.load_inputs = load_gs_input;
+               } else if (shaders[i]->info.stage == MESA_SHADER_TESS_CTRL) {
+                       ctx.tcs_outputs_read = shaders[i]->info.outputs_read;
+                       ctx.tcs_patch_outputs_read = shaders[i]->info.patch_outputs_read;
+                       ctx.abi.load_tess_inputs = load_tcs_input;
+                       ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
+                       ctx.abi.store_tcs_outputs = store_tcs_output;
                } else if (shaders[i]->info.stage == MESA_SHADER_TESS_EVAL) {
                        ctx.tes_primitive_mode = shaders[i]->info.tess.primitive_mode;
+                       ctx.abi.load_tess_inputs = load_tes_input;
+                       ctx.abi.load_tess_coord = load_tess_coord;
+                       ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
                } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX) {
                        if (shader_info->info.vs.needs_instance_id) {
-                               ctx.shader_info->vs.vgpr_comp_cnt =
-                                       MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+                               if (ctx.ac.chip_class == GFX9 &&
+                                   shaders[shader_count - 1]->info.stage == MESA_SHADER_TESS_CTRL) {
+                                       ctx.shader_info->vs.vgpr_comp_cnt =
+                                               MAX2(2, ctx.shader_info->vs.vgpr_comp_cnt);
+                               } else {
+                                       ctx.shader_info->vs.vgpr_comp_cnt =
+                                               MAX2(1, ctx.shader_info->vs.vgpr_comp_cnt);
+                               }
                        }
                } else if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT) {
                        shader_info->fs.can_discard = shaders[i]->info.fs.uses_discard;
                }
 
                if (i)
-                       emit_barrier(&ctx);
+                       emit_barrier(&ctx.ac, ctx.stage);
 
                ac_setup_rings(&ctx);
 
@@ -6717,6 +6916,20 @@ static void ac_compile_llvm_module(LLVMTargetMachineRef tm,
        /* +3 for scratch wave offset and VCC */
        config->num_sgprs = MAX2(config->num_sgprs,
                                 shader_info->num_input_sgprs + 3);
+
+       /* Enable 64-bit and 16-bit denormals, because there is no performance
+        * cost.
+        *
+        * If denormals are enabled, all floating-point output modifiers are
+        * ignored.
+        *
+        * Don't enable denormals for 32-bit floats, because:
+        * - Floating-point output modifiers would be ignored by the hw.
+        * - Some opcodes don't support denormals, such as v_mad_f32. We would
+        *   have to stop using those.
+        * - SI & CI would be very slow.
+        */
+       config->float_mode |= V_00B028_FP_64_DENORMS;
 }
 
 static void
@@ -6749,7 +6962,7 @@ ac_fill_shader_info(struct ac_shader_variant_info *shader_info, struct nir_shade
         case MESA_SHADER_VERTEX:
                 shader_info->vs.as_es = options->key.vs.as_es;
                 shader_info->vs.as_ls = options->key.vs.as_ls;
-                /* in LS mode we need at least 1, invocation id needs 3, handled elsewhere */
+                /* in LS mode we need at least 1, invocation id needs 2, handled elsewhere */
                 if (options->key.vs.as_ls)
                         shader_info->vs.vgpr_comp_cnt = MAX2(1, shader_info->vs.vgpr_comp_cnt);
                 break;
@@ -6774,6 +6987,14 @@ void ac_compile_nir_shader(LLVMTargetMachineRef tm,
        ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, nir[0]->info.stage, dump_shader, options->supports_spill);
        for (int i = 0; i < nir_count; ++i)
                ac_fill_shader_info(shader_info, nir[i], options);
+
+       /* Determine the ES type (VS or TES) for the GS on GFX9. */
+       if (options->chip_class == GFX9) {
+               if (nir_count == 2 &&
+                   nir[1]->info.stage == MESA_SHADER_GEOMETRY) {
+                       shader_info->gs.es_type = nir[0]->info.stage;
+               }
+       }
 }
 
 static void
@@ -6839,7 +7060,8 @@ void ac_create_gs_copy_shader(LLVMTargetMachineRef tm,
        ctx.options = options;
        ctx.shader_info = shader_info;
 
-       ac_llvm_context_init(&ctx.ac, ctx.context, options->chip_class);
+       ac_llvm_context_init(&ctx.ac, ctx.context, options->chip_class,
+                            options->family);
        ctx.ac.module = ctx.module;
 
        ctx.is_gs_copy_shader = true;