nir: Get rid of nir_shader::stage
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index 1df97b59a2e2b1d62ce6cf7e252590b798bdb682..61ffe91eafd001d7d3e56ceb1c746f4dcc6260b1 100644 (file)
@@ -108,6 +108,7 @@ struct nir_to_llvm_context {
        LLVMValueRef tcs_out_layout;
        LLVMValueRef tcs_in_layout;
        LLVMValueRef oc_lds;
+       LLVMValueRef merged_wave_info;
        LLVMValueRef tess_factor_offset;
        LLVMValueRef tcs_patch_id;
        LLVMValueRef tcs_rel_ids;
@@ -627,36 +628,133 @@ static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
        }
 }
 
-static void create_function(struct nir_to_llvm_context *ctx)
+static void
+radv_define_common_user_sgprs_phase1(struct nir_to_llvm_context *ctx,
+                                     gl_shader_stage stage,
+                                     bool has_previous_stage,
+                                     gl_shader_stage previous_stage,
+                                     const struct user_sgpr_info *user_sgpr_info,
+                                     struct arg_info *args,
+                                     LLVMValueRef *desc_sets)
 {
        unsigned num_sets = ctx->options->layout ? ctx->options->layout->num_sets : 0;
-       uint8_t user_sgpr_idx;
-       struct user_sgpr_info user_sgpr_info;
-       struct arg_info args = {};
-       LLVMValueRef desc_sets;
-
-       allocate_user_sgprs(ctx, &user_sgpr_info);
-       if (user_sgpr_info.need_ring_offsets && !ctx->options->supports_spill) {
-               add_user_sgpr_argument(&args, const_array(ctx->v4i32, 16), &ctx->ring_offsets); /* address of rings */
-       }
+       unsigned stage_mask = 1 << stage;
+       if (has_previous_stage)
+               stage_mask |= 1 << previous_stage;
 
        /* 1 for each descriptor set */
-       if (!user_sgpr_info.indirect_all_descriptor_sets) {
+       if (!user_sgpr_info->indirect_all_descriptor_sets) {
                for (unsigned i = 0; i < num_sets; ++i) {
-                       if (ctx->options->layout->set[i].layout->shader_stages & (1 << ctx->stage)) {
-                               add_user_sgpr_array_argument(&args, const_array(ctx->i8, 1024 * 1024), &ctx->descriptor_sets[i]);
+                       if (ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
+                               add_user_sgpr_array_argument(args, const_array(ctx->i8, 1024 * 1024), &ctx->descriptor_sets[i]);
                        }
                }
        } else
-               add_user_sgpr_array_argument(&args, const_array(const_array(ctx->i8, 1024 * 1024), 32), &desc_sets);
+               add_user_sgpr_array_argument(args, const_array(const_array(ctx->i8, 1024 * 1024), 32), desc_sets);
 
        if (ctx->shader_info->info.needs_push_constants) {
                /* 1 for push constants and dynamic descriptors */
-               add_user_sgpr_array_argument(&args, const_array(ctx->i8, 1024 * 1024), &ctx->push_constants);
+               add_user_sgpr_array_argument(args, const_array(ctx->i8, 1024 * 1024), &ctx->push_constants);
        }
+}
 
-       switch (ctx->stage) {
+static void
+radv_define_common_user_sgprs_phase2(struct nir_to_llvm_context *ctx,
+                                     gl_shader_stage stage,
+                                     bool has_previous_stage,
+                                     gl_shader_stage previous_stage,
+                                     const struct user_sgpr_info *user_sgpr_info,
+                                    LLVMValueRef desc_sets,
+                                     uint8_t *user_sgpr_idx)
+{
+       unsigned num_sets = ctx->options->layout ? ctx->options->layout->num_sets : 0;
+       unsigned stage_mask = 1 << stage;
+       if (has_previous_stage)
+               stage_mask |= 1 << previous_stage;
+
+       if (!user_sgpr_info->indirect_all_descriptor_sets) {
+               for (unsigned i = 0; i < num_sets; ++i) {
+                       if (ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
+                               set_userdata_location(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], user_sgpr_idx, 2);
+                       } else
+                               ctx->descriptor_sets[i] = NULL;
+               }
+       } else {
+               uint32_t desc_sgpr_idx = *user_sgpr_idx;
+               set_userdata_location_shader(ctx, AC_UD_INDIRECT_DESCRIPTOR_SETS, user_sgpr_idx, 2);
+
+               for (unsigned i = 0; i < num_sets; ++i) {
+                       if (ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
+                               set_userdata_location_indirect(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], desc_sgpr_idx, 2, i * 8);
+                               ctx->descriptor_sets[i] = ac_build_load_to_sgpr(&ctx->ac, desc_sets, LLVMConstInt(ctx->i32, i, false));
+
+                       } else
+                               ctx->descriptor_sets[i] = NULL;
+               }
+               ctx->shader_info->need_indirect_descriptor_sets = true;
+       }
+
+       if (ctx->shader_info->info.needs_push_constants) {
+               set_userdata_location_shader(ctx, AC_UD_PUSH_CONSTANTS, user_sgpr_idx, 2);
+       }
+}
+
+static void
+radv_define_vs_user_sgprs_phase1(struct nir_to_llvm_context *ctx,
+                                 gl_shader_stage stage,
+                                 bool has_previous_stage,
+                                 gl_shader_stage previous_stage,
+                                 struct arg_info *args)
+{
+       if (!ctx->is_gs_copy_shader && (stage == MESA_SHADER_VERTEX || (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
+               if (ctx->shader_info->info.vs.has_vertex_buffers)
+                       add_user_sgpr_argument(args, const_array(ctx->v4i32, 16), &ctx->vertex_buffers); /* vertex buffers */
+               add_user_sgpr_argument(args, ctx->i32, &ctx->abi.base_vertex); // base vertex
+               add_user_sgpr_argument(args, ctx->i32, &ctx->abi.start_instance);// start instance
+               if (ctx->shader_info->info.vs.needs_draw_id)
+                       add_user_sgpr_argument(args, ctx->i32, &ctx->abi.draw_id); // draw id
+       }
+}
+
+static void
+radv_define_vs_user_sgprs_phase2(struct nir_to_llvm_context *ctx,
+                                 gl_shader_stage stage,
+                                 bool has_previous_stage,
+                                 gl_shader_stage previous_stage,
+                                 uint8_t *user_sgpr_idx)
+{
+       if (!ctx->is_gs_copy_shader && (stage == MESA_SHADER_VERTEX || (has_previous_stage && previous_stage == MESA_SHADER_VERTEX))) {
+               if (ctx->shader_info->info.vs.has_vertex_buffers) {
+                       set_userdata_location_shader(ctx, AC_UD_VS_VERTEX_BUFFERS, user_sgpr_idx, 2);
+               }
+               unsigned vs_num = 2;
+               if (ctx->shader_info->info.vs.needs_draw_id)
+                       vs_num++;
+
+               set_userdata_location_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE, user_sgpr_idx, vs_num);
+       }
+}
+
+
+static void create_function(struct nir_to_llvm_context *ctx,
+                            gl_shader_stage stage,
+                            bool has_previous_stage,
+                            gl_shader_stage previous_stage)
+{
+       uint8_t user_sgpr_idx;
+       struct user_sgpr_info user_sgpr_info;
+       struct arg_info args = {};
+       LLVMValueRef desc_sets;
+
+       allocate_user_sgprs(ctx, &user_sgpr_info);
+
+       if (user_sgpr_info.need_ring_offsets && !ctx->options->supports_spill) {
+               add_user_sgpr_argument(&args, const_array(ctx->v4i32, 16), &ctx->ring_offsets); /* address of rings */
+       }
+
+       switch (stage) {
        case MESA_SHADER_COMPUTE:
+               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
                if (ctx->shader_info->info.cs.grid_components_used)
                        add_user_sgpr_argument(&args, LLVMVectorType(ctx->i32, ctx->shader_info->info.cs.grid_components_used), &ctx->num_work_groups); /* grid size */
                add_sgpr_argument(&args, LLVMVectorType(ctx->i32, 3), &ctx->workgroup_ids);
@@ -664,14 +762,8 @@ static void create_function(struct nir_to_llvm_context *ctx)
                add_vgpr_argument(&args, LLVMVectorType(ctx->i32, 3), &ctx->local_invocation_ids);
                break;
        case MESA_SHADER_VERTEX:
-               if (!ctx->is_gs_copy_shader) {
-                       if (ctx->shader_info->info.vs.has_vertex_buffers)
-                               add_user_sgpr_argument(&args, const_array(ctx->v4i32, 16), &ctx->vertex_buffers); /* vertex buffers */
-                       add_user_sgpr_argument(&args, ctx->i32, &ctx->abi.base_vertex); // base vertex
-                       add_user_sgpr_argument(&args, ctx->i32, &ctx->abi.start_instance);// start instance
-                       if (ctx->shader_info->info.vs.needs_draw_id)
-                               add_user_sgpr_argument(&args, ctx->i32, &ctx->abi.draw_id); // draw id
-               }
+               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+               radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
                if (ctx->shader_info->info.needs_multiview_view_index || (!ctx->options->key.vs.as_es && !ctx->options->key.vs.as_ls && ctx->options->key.has_multiview_view_index))
                        add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
                if (ctx->options->key.vs.as_es)
@@ -686,18 +778,49 @@ static void create_function(struct nir_to_llvm_context *ctx)
                }
                break;
        case MESA_SHADER_TESS_CTRL:
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_offchip_layout); // tcs offchip layout
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_out_offsets); // tcs out offsets
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_out_layout); // tcs out layout
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_in_layout); // tcs in layout
-               if (ctx->shader_info->info.needs_multiview_view_index)
-                       add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
-               add_sgpr_argument(&args, ctx->i32, &ctx->oc_lds); // param oc lds
-               add_sgpr_argument(&args, ctx->i32, &ctx->tess_factor_offset); // tess factor offset
-               add_vgpr_argument(&args, ctx->i32, &ctx->tcs_patch_id); // patch id
-               add_vgpr_argument(&args, ctx->i32, &ctx->tcs_rel_ids); // rel ids;
+               if (has_previous_stage) {
+                       // First 6 system regs
+                       add_sgpr_argument(&args, ctx->i32, &ctx->oc_lds); // param oc lds
+                       add_sgpr_argument(&args, ctx->i32, &ctx->merged_wave_info); // merged wave info
+                       add_sgpr_argument(&args, ctx->i32, &ctx->tess_factor_offset); // tess factor offset
+
+                       add_sgpr_argument(&args, ctx->i32, NULL); // scratch offset
+                       add_sgpr_argument(&args, ctx->i32, NULL); // unknown
+                       add_sgpr_argument(&args, ctx->i32, NULL); // unknown
+
+                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+                       radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->ls_out_layout); // ls out layout
+
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_offchip_layout); // tcs offchip layout
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_out_offsets); // tcs out offsets
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_out_layout); // tcs out layout
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_in_layout); // tcs in layout
+                       if (ctx->shader_info->info.needs_multiview_view_index)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
+
+                       add_vgpr_argument(&args, ctx->i32, &ctx->tcs_patch_id); // patch id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->tcs_rel_ids); // rel ids;
+                       add_vgpr_argument(&args, ctx->i32, &ctx->abi.vertex_id); // vertex id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->rel_auto_id); // rel auto id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->vs_prim_id); // vs prim id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->abi.instance_id); // instance id
+               } else {
+                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_offchip_layout); // tcs offchip layout
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_out_offsets); // tcs out offsets
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_out_layout); // tcs out layout
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_in_layout); // tcs in layout
+                       if (ctx->shader_info->info.needs_multiview_view_index)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
+                       add_sgpr_argument(&args, ctx->i32, &ctx->oc_lds); // param oc lds
+                       add_sgpr_argument(&args, ctx->i32, &ctx->tess_factor_offset); // tess factor offset
+                       add_vgpr_argument(&args, ctx->i32, &ctx->tcs_patch_id); // patch id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->tcs_rel_ids); // rel ids;
+               }
                break;
        case MESA_SHADER_TESS_EVAL:
+               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
                add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_offchip_layout); // tcs offchip layout
                if (ctx->shader_info->info.needs_multiview_view_index || (!ctx->options->key.tes.as_es && ctx->options->key.has_multiview_view_index))
                        add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
@@ -715,22 +838,64 @@ static void create_function(struct nir_to_llvm_context *ctx)
                add_vgpr_argument(&args, ctx->i32, &ctx->tes_patch_id); // tes patch id
                break;
        case MESA_SHADER_GEOMETRY:
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_ring_stride); // gsvs stride
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_num_entries); // gsvs num entires
-               if (ctx->shader_info->info.needs_multiview_view_index)
-                       add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
-               add_sgpr_argument(&args, ctx->i32, &ctx->gs2vs_offset); // gs2vs offset
-               add_sgpr_argument(&args, ctx->i32, &ctx->gs_wave_id); // wave id
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[0]); // vtx0
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[1]); // vtx1
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_prim_id); // prim id
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[2]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[3]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[4]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[5]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_invocation_id);
+               if (has_previous_stage) {
+                       // First 6 system regs
+                       add_sgpr_argument(&args, ctx->i32, &ctx->gs2vs_offset); // tess factor offset
+                       add_sgpr_argument(&args, ctx->i32, &ctx->merged_wave_info); // merged wave info
+                       add_sgpr_argument(&args, ctx->i32, &ctx->oc_lds); // param oc lds
+
+                       add_sgpr_argument(&args, ctx->i32, NULL); // scratch offset
+                       add_sgpr_argument(&args, ctx->i32, NULL); // unknown
+                       add_sgpr_argument(&args, ctx->i32, NULL); // unknown
+
+                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+                       if (previous_stage == MESA_SHADER_TESS_EVAL)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_offchip_layout); // tcs offchip layout
+                       else
+                               radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_ring_stride); // gsvs stride
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_num_entries); // gsvs num entires
+                       if (ctx->shader_info->info.needs_multiview_view_index)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
+
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[0]); // vtx01
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[2]); // vtx23
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_prim_id); // prim id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_invocation_id);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[4]);
+
+                       if (previous_stage == MESA_SHADER_VERTEX) {
+                               add_vgpr_argument(&args, ctx->i32, &ctx->abi.vertex_id); // vertex id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->rel_auto_id); // rel auto id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->vs_prim_id); // vs prim id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->abi.instance_id); // instance id
+                       } else {
+                               add_vgpr_argument(&args, ctx->f32, &ctx->tes_u); // tes_u
+                               add_vgpr_argument(&args, ctx->f32, &ctx->tes_v); // tes_v
+                               add_vgpr_argument(&args, ctx->i32, &ctx->tes_rel_patch_id); // tes rel patch id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->tes_patch_id); // tes patch id
+                       }
+               } else {
+                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+                       radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_ring_stride); // gsvs stride
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_num_entries); // gsvs num entires
+                       if (ctx->shader_info->info.needs_multiview_view_index)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
+                       add_sgpr_argument(&args, ctx->i32, &ctx->gs2vs_offset); // gs2vs offset
+                       add_sgpr_argument(&args, ctx->i32, &ctx->gs_wave_id); // wave id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[0]); // vtx0
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[1]); // vtx1
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_prim_id); // prim id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[2]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[3]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[4]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[5]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_invocation_id);
+               }
                break;
        case MESA_SHADER_FRAGMENT:
+               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
                if (ctx->shader_info->info.ps.needs_sample_positions)
                        add_user_sgpr_argument(&args, ctx->i32, &ctx->sample_pos_offset); /* sample position offset */
                add_sgpr_argument(&args, ctx->i32, &ctx->prim_mask); /* prim mask */
@@ -759,14 +924,12 @@ static void create_function(struct nir_to_llvm_context *ctx)
            ctx->context, ctx->module, ctx->builder, NULL, 0, &args,
            ctx->max_workgroup_size,
            ctx->options->unsafe_math);
-       set_llvm_calling_convention(ctx->main_function, ctx->stage);
+       set_llvm_calling_convention(ctx->main_function, stage);
 
 
        ctx->shader_info->num_input_vgprs = 0;
-       ctx->shader_info->num_input_sgprs = ctx->shader_info->num_user_sgprs =
-         ctx->options->supports_spill ? 2 : 0;
+       ctx->shader_info->num_input_sgprs = ctx->options->supports_spill ? 2 : 0;
 
-       ctx->shader_info->num_user_sgprs += args.num_user_sgprs_used;
        ctx->shader_info->num_input_sgprs += args.num_sgprs_used;
 
        if (ctx->stage != MESA_SHADER_FRAGMENT)
@@ -786,50 +949,22 @@ static void create_function(struct nir_to_llvm_context *ctx)
                                                             const_array(ctx->v4i32, 16), "");
                }
        }
+       
+       /* For merged shaders the user SGPRs start at 8, with 8 system SGPRs in front (including
+        * the rw_buffers at s0/s1. With user SGPR0 = s8, lets restart the count from 0 */
+       if (has_previous_stage)
+               user_sgpr_idx = 0;
 
-       if (!user_sgpr_info.indirect_all_descriptor_sets) {
-               for (unsigned i = 0; i < num_sets; ++i) {
-                       if (ctx->options->layout->set[i].layout->shader_stages & (1 << ctx->stage)) {
-                               set_userdata_location(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], &user_sgpr_idx, 2);
-                       } else
-                               ctx->descriptor_sets[i] = NULL;
-               }
-       } else {
-               uint32_t desc_sgpr_idx = user_sgpr_idx;
-               set_userdata_location_shader(ctx, AC_UD_INDIRECT_DESCRIPTOR_SETS, &user_sgpr_idx, 2);
-
-               for (unsigned i = 0; i < num_sets; ++i) {
-                       if (ctx->options->layout->set[i].layout->shader_stages & (1 << ctx->stage)) {
-                               set_userdata_location_indirect(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], desc_sgpr_idx, 2, i * 8);
-                               ctx->descriptor_sets[i] = ac_build_load_to_sgpr(&ctx->ac, desc_sets, LLVMConstInt(ctx->i32, i, false));
-
-                       } else
-                               ctx->descriptor_sets[i] = NULL;
-               }
-               ctx->shader_info->need_indirect_descriptor_sets = true;
-       }
-
-       if (ctx->shader_info->info.needs_push_constants) {
-               set_userdata_location_shader(ctx, AC_UD_PUSH_CONSTANTS, &user_sgpr_idx, 2);
-       }
+       radv_define_common_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, desc_sets, &user_sgpr_idx);
 
-       switch (ctx->stage) {
+       switch (stage) {
        case MESA_SHADER_COMPUTE:
                if (ctx->shader_info->info.cs.grid_components_used) {
                        set_userdata_location_shader(ctx, AC_UD_CS_GRID_SIZE, &user_sgpr_idx, ctx->shader_info->info.cs.grid_components_used);
                }
                break;
        case MESA_SHADER_VERTEX:
-               if (!ctx->is_gs_copy_shader) {
-                       if (ctx->shader_info->info.vs.has_vertex_buffers) {
-                               set_userdata_location_shader(ctx, AC_UD_VS_VERTEX_BUFFERS, &user_sgpr_idx, 2);
-                       }
-                       unsigned vs_num = 2;
-                       if (ctx->shader_info->info.vs.needs_draw_id)
-                               vs_num++;
-
-                       set_userdata_location_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE, &user_sgpr_idx, vs_num);
-               }
+               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
                if (ctx->view_index)
                        set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                if (ctx->options->key.vs.as_ls) {
@@ -839,6 +974,9 @@ static void create_function(struct nir_to_llvm_context *ctx)
                        declare_tess_lds(ctx);
                break;
        case MESA_SHADER_TESS_CTRL:
+               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
+               if (has_previous_stage)
+                       set_userdata_location_shader(ctx, AC_UD_VS_LS_TCS_IN_LAYOUT, &user_sgpr_idx, 1);
                set_userdata_location_shader(ctx, AC_UD_TCS_OFFCHIP_LAYOUT, &user_sgpr_idx, 4);
                if (ctx->view_index)
                        set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
@@ -850,9 +988,17 @@ static void create_function(struct nir_to_llvm_context *ctx)
                        set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                break;
        case MESA_SHADER_GEOMETRY:
+               if (has_previous_stage) {
+                       if (previous_stage == MESA_SHADER_VERTEX)
+                               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
+                       else
+                               set_userdata_location_shader(ctx, AC_UD_TES_OFFCHIP_LAYOUT, &user_sgpr_idx, 1);
+               }
                set_userdata_location_shader(ctx, AC_UD_GS_VS_RING_STRIDE_ENTRIES, &user_sgpr_idx, 2);
                if (ctx->view_index)
                        set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+               if (has_previous_stage)
+                       declare_tess_lds(ctx);
                break;
        case MESA_SHADER_FRAGMENT:
                if (ctx->shader_info->info.ps.needs_sample_positions) {
@@ -862,6 +1008,8 @@ static void create_function(struct nir_to_llvm_context *ctx)
        default:
                unreachable("Shader stage not implemented");
        }
+
+       ctx->shader_info->num_user_sgprs = user_sgpr_idx;
 }
 
 static void setup_types(struct nir_to_llvm_context *ctx)
@@ -2779,19 +2927,19 @@ store_tcs_output(struct nir_to_llvm_context *ctx,
        buf_addr = get_tcs_tes_buffer_address_params(ctx, param, const_index, is_compact,
                                                     vertex_index, indir_index);
 
+       bool is_tess_factor = false;
+       if (instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
+           instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER)
+               is_tess_factor = true;
+
        unsigned base = is_compact ? const_index : 0;
        for (unsigned chan = 0; chan < 8; chan++) {
-               bool is_tess_factor = false;
                if (!(writemask & (1 << chan)))
                        continue;
                LLVMValueRef value = llvm_extract_elem(&ctx->ac, src, chan);
 
                lds_store(ctx, dw_addr, value);
 
-               if (instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
-                   instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER)
-                       is_tess_factor = true;
-
                if (!is_tess_factor && writemask != 0xF)
                        ac_build_buffer_store_dword(&ctx->ac, ctx->hs_ring_tess_offchip, value, 1,
                                                    buf_addr, ctx->oc_lds,
@@ -2860,21 +3008,27 @@ load_gs_input(struct nir_to_llvm_context *ctx,
 
        param = shader_io_get_unique_index(instr->variables[0]->var->data.location);
        for (unsigned i = 0; i < instr->num_components; i++) {
-
-               args[0] = ctx->esgs_ring;
-               args[1] = vtx_offset;
-               args[2] = LLVMConstInt(ctx->i32, (param * 4 + i + const_index) * 256, false);
-               args[3] = ctx->i32zero;
-               args[4] = ctx->i32one; /* OFFEN */
-               args[5] = ctx->i32zero; /* IDXEN */
-               args[6] = ctx->i32one; /* GLC */
-               args[7] = ctx->i32zero; /* SLC */
-               args[8] = ctx->i32zero; /* TFE */
-
-               value[i] = ac_build_intrinsic(&ctx->ac, "llvm.SI.buffer.load.dword.i32.i32",
-                                             ctx->i32, args, 9,
-                                             AC_FUNC_ATTR_READONLY |
-                                             AC_FUNC_ATTR_LEGACY);
+               if (ctx->ac.chip_class >= GFX9) {
+                       LLVMValueRef dw_addr = ctx->gs_vtx_offset[vtx_offset_param];
+                       dw_addr = LLVMBuildAdd(ctx->ac.builder, dw_addr,
+                                              LLVMConstInt(ctx->ac.i32, param * 4 + i, 0), "");
+                       value[i] = lds_load(ctx, dw_addr);
+               } else {
+                       args[0] = ctx->esgs_ring;
+                       args[1] = vtx_offset;
+                       args[2] = LLVMConstInt(ctx->i32, (param * 4 + i + const_index) * 256, false);
+                       args[3] = ctx->i32zero;
+                       args[4] = ctx->i32one; /* OFFEN */
+                       args[5] = ctx->i32zero; /* IDXEN */
+                       args[6] = ctx->i32one; /* GLC */
+                       args[7] = ctx->i32zero; /* SLC */
+                       args[8] = ctx->i32zero; /* TFE */
+
+                       value[i] = ac_build_intrinsic(&ctx->ac, "llvm.SI.buffer.load.dword.i32.i32",
+                                                     ctx->i32, args, 9,
+                                                     AC_FUNC_ATTR_READONLY |
+                                                     AC_FUNC_ATTR_LEGACY);
+               }
        }
        result = ac_build_gather_values(&ctx->ac, value, instr->num_components);
 
@@ -5151,7 +5305,9 @@ static LLVMValueRef si_build_alloca_undef(struct ac_llvm_context *ac,
 
 static void
 scan_shader_output_decl(struct nir_to_llvm_context *ctx,
-                       struct nir_variable *variable)
+                       struct nir_variable *variable,
+                       struct nir_shader *shader,
+                       gl_shader_stage stage)
 {
        int idx = variable->data.location + variable->data.index;
        unsigned attrib_count = glsl_count_attribute_slots(variable->type, false);
@@ -5160,22 +5316,23 @@ scan_shader_output_decl(struct nir_to_llvm_context *ctx,
        variable->data.driver_location = idx * 4;
 
        /* tess ctrl has it's own load/store paths for outputs */
-       if (ctx->stage == MESA_SHADER_TESS_CTRL)
+       if (stage == MESA_SHADER_TESS_CTRL)
                return;
 
        mask_attribs = ((1ull << attrib_count) - 1) << idx;
-       if (ctx->stage == MESA_SHADER_VERTEX ||
-           ctx->stage == MESA_SHADER_TESS_EVAL ||
-           ctx->stage == MESA_SHADER_GEOMETRY) {
+       if (stage == MESA_SHADER_VERTEX ||
+           stage == MESA_SHADER_TESS_EVAL ||
+           stage == MESA_SHADER_GEOMETRY) {
                if (idx == VARYING_SLOT_CLIP_DIST0) {
-                       int length = ctx->num_output_clips + ctx->num_output_culls;
-                       if (ctx->stage == MESA_SHADER_VERTEX) {
-                               ctx->shader_info->vs.outinfo.clip_dist_mask = (1 << ctx->num_output_clips) - 1;
-                               ctx->shader_info->vs.outinfo.cull_dist_mask = (1 << ctx->num_output_culls) - 1;
+                       int length = shader->info.clip_distance_array_size +
+                                    shader->info.cull_distance_array_size;
+                       if (stage == MESA_SHADER_VERTEX) {
+                               ctx->shader_info->vs.outinfo.clip_dist_mask = (1 << shader->info.clip_distance_array_size) - 1;
+                               ctx->shader_info->vs.outinfo.cull_dist_mask = (1 << shader->info.cull_distance_array_size) - 1;
                        }
-                       if (ctx->stage == MESA_SHADER_TESS_EVAL) {
-                               ctx->shader_info->tes.outinfo.clip_dist_mask = (1 << ctx->num_output_clips) - 1;
-                               ctx->shader_info->tes.outinfo.cull_dist_mask = (1 << ctx->num_output_culls) - 1;
+                       if (stage == MESA_SHADER_TESS_EVAL) {
+                               ctx->shader_info->tes.outinfo.clip_dist_mask = (1 << shader->info.clip_distance_array_size) - 1;
+                               ctx->shader_info->tes.outinfo.cull_dist_mask = (1 << shader->info.cull_distance_array_size) - 1;
                        }
 
                        if (length > 4)
@@ -5703,8 +5860,9 @@ handle_es_outputs_post(struct nir_to_llvm_context *ctx,
 {
        int j;
        uint64_t max_output_written = 0;
+       LLVMValueRef lds_base = NULL;
+
        for (unsigned i = 0; i < RADEON_LLVM_MAX_OUTPUTS; ++i) {
-               LLVMValueRef *out_ptr = &ctx->nir->outputs[i * 4];
                int param_index;
                int length = 4;
 
@@ -5717,20 +5875,60 @@ handle_es_outputs_post(struct nir_to_llvm_context *ctx,
                param_index = shader_io_get_unique_index(i);
 
                max_output_written = MAX2(param_index + (length > 4), max_output_written);
+       }
+
+       outinfo->esgs_itemsize = (max_output_written + 1) * 16;
 
+       if (ctx->ac.chip_class  >= GFX9) {
+               unsigned itemsize_dw = outinfo->esgs_itemsize / 4;
+               LLVMValueRef vertex_idx = ac_get_thread_id(&ctx->ac);
+               LLVMValueRef wave_idx = ac_build_bfe(&ctx->ac, ctx->merged_wave_info,
+                                                    LLVMConstInt(ctx->ac.i32, 24, false),
+                                                    LLVMConstInt(ctx->ac.i32, 4, false), false);
+               vertex_idx = LLVMBuildOr(ctx->ac.builder, vertex_idx,
+                                        LLVMBuildMul(ctx->ac.builder, wave_idx,
+                                                     LLVMConstInt(ctx->i32, 64, false), ""), "");
+               lds_base = LLVMBuildMul(ctx->ac.builder, vertex_idx,
+                                       LLVMConstInt(ctx->i32, itemsize_dw, 0), "");
+       }
+
+       for (unsigned i = 0; i < RADEON_LLVM_MAX_OUTPUTS; ++i) {
+               LLVMValueRef dw_addr;
+               LLVMValueRef *out_ptr = &ctx->nir->outputs[i * 4];
+               int param_index;
+               int length = 4;
+
+               if (!(ctx->output_mask & (1ull << i)))
+                       continue;
+
+               if (i == VARYING_SLOT_CLIP_DIST0)
+                       length = ctx->num_output_clips + ctx->num_output_culls;
+
+               param_index = shader_io_get_unique_index(i);
+
+               if (lds_base) {
+                       dw_addr = LLVMBuildAdd(ctx->builder, lds_base,
+                                              LLVMConstInt(ctx->i32, param_index * 4, false),
+                                              "");
+               }
                for (j = 0; j < length; j++) {
                        LLVMValueRef out_val = LLVMBuildLoad(ctx->builder, out_ptr[j], "");
                        out_val = LLVMBuildBitCast(ctx->builder, out_val, ctx->i32, "");
 
-                       ac_build_buffer_store_dword(&ctx->ac,
-                                              ctx->esgs_ring,
-                                              out_val, 1,
-                                              NULL, ctx->es2gs_offset,
-                                              (4 * param_index + j) * 4,
-                                              1, 1, true, true);
+                       if (ctx->ac.chip_class  >= GFX9) {
+                               lds_store(ctx, dw_addr,
+                                         LLVMBuildLoad(ctx->builder, out_ptr[j], ""));
+                               dw_addr = LLVMBuildAdd(ctx->builder, dw_addr, ctx->i32one, "");
+                       } else {
+                               ac_build_buffer_store_dword(&ctx->ac,
+                                                           ctx->esgs_ring,
+                                                           out_val, 1,
+                                                           NULL, ctx->es2gs_offset,
+                                                           (4 * param_index + j) * 4,
+                                                           1, 1, true, true);
+                       }
                }
        }
-       outinfo->esgs_itemsize = (max_output_written + 1) * 16;
 }
 
 static void
@@ -5942,26 +6140,31 @@ write_tess_factors(struct nir_to_llvm_context *ctx)
        tf_base = ctx->tess_factor_offset;
        byteoffset = LLVMBuildMul(ctx->builder, rel_patch_id,
                                  LLVMConstInt(ctx->i32, 4 * stride, false), "");
+       unsigned tf_offset = 0;
 
-       ac_nir_build_if(&inner_if_ctx, ctx,
-                   LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                 rel_patch_id, ctx->i32zero, ""));
+       if (ctx->options->chip_class <= VI) {
+               ac_nir_build_if(&inner_if_ctx, ctx,
+                               LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                                             rel_patch_id, ctx->i32zero, ""));
 
-       /* Store the dynamic HS control word. */
-       ac_build_buffer_store_dword(&ctx->ac, buffer,
-                                   LLVMConstInt(ctx->i32, 0x80000000, false),
-                                   1, ctx->i32zero, tf_base,
-                                   0, 1, 0, true, false);
-       ac_nir_build_endif(&inner_if_ctx);
+               /* Store the dynamic HS control word. */
+               ac_build_buffer_store_dword(&ctx->ac, buffer,
+                                           LLVMConstInt(ctx->i32, 0x80000000, false),
+                                           1, ctx->i32zero, tf_base,
+                                           0, 1, 0, true, false);
+               tf_offset += 4;
+
+               ac_nir_build_endif(&inner_if_ctx);
+       }
 
        /* Store the tessellation factors. */
        ac_build_buffer_store_dword(&ctx->ac, buffer, vec0,
                                    MIN2(stride, 4), byteoffset, tf_base,
-                                   4, 1, 0, true, false);
+                                   tf_offset, 1, 0, true, false);
        if (vec1)
                ac_build_buffer_store_dword(&ctx->ac, buffer, vec1,
                                            stride - 4, byteoffset, tf_base,
-                                           20, 1, 0, true, false);
+                                           16 + tf_offset, 1, 0, true, false);
 
        //TODO store to offchip for TES to read - only if TES reads them
        if (1) {
@@ -6250,7 +6453,7 @@ static unsigned
 ac_nir_get_max_workgroup_size(enum chip_class chip_class,
                              const struct nir_shader *nir)
 {
-       switch (nir->stage) {
+       switch (nir->info.stage) {
        case MESA_SHADER_TESS_CTRL:
                return chip_class >= CIK ? 128 : 64;
        case MESA_SHADER_GEOMETRY:
@@ -6267,6 +6470,33 @@ ac_nir_get_max_workgroup_size(enum chip_class chip_class,
        return max_workgroup_size;
 }
 
+/* Fixup the HW not emitting the TCS regs if there are no HS threads. */
+static void ac_nir_fixup_ls_hs_input_vgprs(struct nir_to_llvm_context *ctx)
+{
+       LLVMValueRef count = ac_build_bfe(&ctx->ac, ctx->merged_wave_info,
+                                         LLVMConstInt(ctx->ac.i32, 8, false),
+                                         LLVMConstInt(ctx->ac.i32, 8, false), false);
+       LLVMValueRef hs_empty = LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, count,
+                                             LLVMConstInt(ctx->ac.i32, 0, false), "");
+       ctx->abi.instance_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->rel_auto_id, ctx->abi.instance_id, "");
+       ctx->vs_prim_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->abi.vertex_id, ctx->vs_prim_id, "");
+       ctx->rel_auto_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->tcs_rel_ids, ctx->rel_auto_id, "");
+       ctx->abi.vertex_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->tcs_patch_id, ctx->abi.vertex_id, "");
+}
+
+static void prepare_gs_input_vgprs(struct nir_to_llvm_context *ctx)
+{
+       for(int i = 5; i >= 0; --i) {
+               ctx->gs_vtx_offset[i] = ac_build_bfe(&ctx->ac, ctx->gs_vtx_offset[i & ~1],
+                                                    LLVMConstInt(ctx->ac.i32, (i & 1) * 16, false),
+                                                    LLVMConstInt(ctx->ac.i32, 16, false), false);
+       }
+
+       ctx->gs_wave_id = ac_build_bfe(&ctx->ac, ctx->merged_wave_info,
+                                      LLVMConstInt(ctx->ac.i32, 16, false),
+                                      LLVMConstInt(ctx->ac.i32, 8, false), false);
+}
+
 void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
                      struct nir_shader *nir, struct nir_to_llvm_context *nctx)
 {
@@ -6280,7 +6510,7 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
        if (nctx)
                nctx->nir = &ctx;
 
-       ctx.stage = nir->stage;
+       ctx.stage = nir->info.stage;
 
        ctx.main_function = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx.ac.builder));
 
@@ -6298,7 +6528,7 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
 
        setup_locals(&ctx, func);
 
-       if (nir->stage == MESA_SHADER_COMPUTE)
+       if (nir->info.stage == MESA_SHADER_COMPUTE)
                setup_shared(&ctx, nir);
 
        visit_cf_list(&ctx, &func->impl->body);
@@ -6318,7 +6548,8 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
 
 static
 LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
-                                       struct nir_shader *nir,
+                                       struct nir_shader *const *shaders,
+                                       int shader_count,
                                        struct ac_shader_variant_info *shader_info,
                                        const struct ac_nir_compiler_options *options)
 {
@@ -6331,11 +6562,6 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
 
        ac_llvm_context_init(&ctx.ac, ctx.context, options->chip_class);
        ctx.ac.module = ctx.module;
-
-       memset(shader_info, 0, sizeof(*shader_info));
-
-       ac_nir_shader_info_pass(nir, options, &shader_info->info);
-
        LLVMSetTarget(ctx.module, options->supports_spill ? "amdgcn-mesa-mesa3d" : "amdgcn--");
 
        LLVMTargetDataRef data_layout = LLVMCreateTargetDataLayout(tm);
@@ -6345,72 +6571,118 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
        LLVMDisposeMessage(data_layout_str);
 
        setup_types(&ctx);
-
        ctx.builder = LLVMCreateBuilderInContext(ctx.context);
        ctx.ac.builder = ctx.builder;
-       ctx.stage = nir->stage;
-       ctx.max_workgroup_size = ac_nir_get_max_workgroup_size(ctx.options->chip_class, nir);
+
+       memset(shader_info, 0, sizeof(*shader_info));
+
+       for(int i = 0; i < shader_count; ++i)
+               ac_nir_shader_info_pass(shaders[i], options, &shader_info->info);
 
        for (i = 0; i < AC_UD_MAX_SETS; i++)
                shader_info->user_sgprs_locs.descriptor_sets[i].sgpr_idx = -1;
        for (i = 0; i < AC_UD_MAX_UD; i++)
                shader_info->user_sgprs_locs.shader_data[i].sgpr_idx = -1;
 
-       create_function(&ctx);
+       ctx.max_workgroup_size = ac_nir_get_max_workgroup_size(ctx.options->chip_class, shaders[0]);
 
-       if (nir->stage == MESA_SHADER_GEOMETRY) {
-               ctx.gs_next_vertex = ac_build_alloca(&ctx.ac, ctx.i32, "gs_next_vertex");
+       create_function(&ctx, shaders[shader_count - 1]->info.stage, shader_count >= 2,
+                       shader_count >= 2 ? shaders[shader_count - 2]->info.stage  : MESA_SHADER_VERTEX);
 
-               ctx.gs_max_out_vertices = nir->info.gs.vertices_out;
-       } else if (nir->stage == MESA_SHADER_TESS_EVAL) {
-               ctx.tes_primitive_mode = nir->info.tess.primitive_mode;
-       } else if (nir->stage == MESA_SHADER_VERTEX) {
-               if (shader_info->info.vs.needs_instance_id) {
-                       ctx.shader_info->vs.vgpr_comp_cnt =
-                               MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+       ctx.abi.inputs = &ctx.inputs[0];
+       ctx.abi.emit_outputs = handle_shader_outputs_post;
+       ctx.abi.load_ssbo = radv_load_ssbo;
+       ctx.abi.load_sampler_desc = radv_get_sampler_desc;
+
+       if (shader_count >= 2)
+               ac_init_exec_full_mask(&ctx.ac);
+
+       if (ctx.ac.chip_class == GFX9 &&
+           shaders[shader_count - 1]->info.stage == MESA_SHADER_TESS_CTRL)
+               ac_nir_fixup_ls_hs_input_vgprs(&ctx);
+
+       for(int i = 0; i < shader_count; ++i) {
+               ctx.stage = shaders[i]->info.stage;
+               ctx.output_mask = 0;
+               ctx.tess_outputs_written = 0;
+               ctx.num_output_clips = shaders[i]->info.clip_distance_array_size;
+               ctx.num_output_culls = shaders[i]->info.cull_distance_array_size;
+
+               if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY) {
+                       ctx.gs_next_vertex = ac_build_alloca(&ctx.ac, ctx.i32, "gs_next_vertex");
+
+                       ctx.gs_max_out_vertices = shaders[i]->info.gs.vertices_out;
+               } else if (shaders[i]->info.stage == MESA_SHADER_TESS_EVAL) {
+                       ctx.tes_primitive_mode = shaders[i]->info.tess.primitive_mode;
+               } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX) {
+                       if (shader_info->info.vs.needs_instance_id) {
+                               ctx.shader_info->vs.vgpr_comp_cnt =
+                                       MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+                       }
+               } else if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT) {
+                       shader_info->fs.can_discard = shaders[i]->info.fs.uses_discard;
                }
-       } else if (nir->stage == MESA_SHADER_FRAGMENT) {
-               shader_info->fs.can_discard = nir->info.fs.uses_discard;
-       }
 
-       ac_setup_rings(&ctx);
+               if (i)
+                       emit_barrier(&ctx);
 
-       ctx.num_output_clips = nir->info.clip_distance_array_size;
-       ctx.num_output_culls = nir->info.cull_distance_array_size;
+               ac_setup_rings(&ctx);
 
-       if (nir->stage == MESA_SHADER_FRAGMENT)
-               handle_fs_inputs(&ctx, nir);
-       else if(nir->stage == MESA_SHADER_VERTEX)
-               handle_vs_inputs(&ctx, nir);
+               LLVMBasicBlockRef merge_block;
+               if (shader_count >= 2) {
+                       LLVMValueRef fn = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx.ac.builder));
+                       LLVMBasicBlockRef then_block = LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
+                       merge_block = LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
 
-       ctx.abi.inputs = &ctx.inputs[0];
-       ctx.abi.emit_outputs = handle_shader_outputs_post;
-       ctx.abi.load_ssbo = radv_load_ssbo;
-       ctx.abi.load_sampler_desc = radv_get_sampler_desc;
+                       LLVMValueRef count = ac_build_bfe(&ctx.ac, ctx.merged_wave_info,
+                                                         LLVMConstInt(ctx.ac.i32, 8 * i, false),
+                                                         LLVMConstInt(ctx.ac.i32, 8, false), false);
+                       LLVMValueRef thread_id = ac_get_thread_id(&ctx.ac);
+                       LLVMValueRef cond = LLVMBuildICmp(ctx.ac.builder, LLVMIntULT,
+                                                         thread_id, count, "");
+                       LLVMBuildCondBr(ctx.ac.builder, cond, then_block, merge_block);
 
-       nir_foreach_variable(variable, &nir->outputs)
-               scan_shader_output_decl(&ctx, variable);
+                       LLVMPositionBuilderAtEnd(ctx.ac.builder, then_block);
+               }
 
-       ac_nir_translate(&ctx.ac, &ctx.abi, nir, &ctx);
+               if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT)
+                       handle_fs_inputs(&ctx, shaders[i]);
+               else if(shaders[i]->info.stage == MESA_SHADER_VERTEX)
+                       handle_vs_inputs(&ctx, shaders[i]);
+               else if(shader_count >= 2 && shaders[i]->info.stage == MESA_SHADER_GEOMETRY)
+                       prepare_gs_input_vgprs(&ctx);
 
-       LLVMBuildRetVoid(ctx.builder);
+               nir_foreach_variable(variable, &shaders[i]->outputs)
+                       scan_shader_output_decl(&ctx, variable, shaders[i], shaders[i]->info.stage);
 
-       ac_llvm_finalize_module(&ctx);
+               ac_nir_translate(&ctx.ac, &ctx.abi, shaders[i], &ctx);
 
-       ac_nir_eliminate_const_vs_outputs(&ctx);
+               if (shader_count >= 2) {
+                       LLVMBuildBr(ctx.ac.builder, merge_block);
+                       LLVMPositionBuilderAtEnd(ctx.ac.builder, merge_block);
+               }
 
-       if (nir->stage == MESA_SHADER_GEOMETRY) {
-               unsigned addclip = ctx.num_output_clips + ctx.num_output_culls > 4;
-               shader_info->gs.gsvs_vertex_size = (util_bitcount64(ctx.output_mask) + addclip) * 16;
-               shader_info->gs.max_gsvs_emit_size = shader_info->gs.gsvs_vertex_size *
-                       nir->info.gs.vertices_out;
-       } else if (nir->stage == MESA_SHADER_TESS_CTRL) {
-               shader_info->tcs.outputs_written = ctx.tess_outputs_written;
-               shader_info->tcs.patch_outputs_written = ctx.tess_patch_outputs_written;
-       } else if (nir->stage == MESA_SHADER_VERTEX && ctx.options->key.vs.as_ls) {
-               shader_info->vs.outputs_written = ctx.tess_outputs_written;
+               if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY) {
+                       unsigned addclip = shaders[i]->info.clip_distance_array_size +
+                                       shaders[i]->info.cull_distance_array_size > 4;
+                       shader_info->gs.gsvs_vertex_size = (util_bitcount64(ctx.output_mask) + addclip) * 16;
+                       shader_info->gs.max_gsvs_emit_size = shader_info->gs.gsvs_vertex_size *
+                               shaders[i]->info.gs.vertices_out;
+               } else if (shaders[i]->info.stage == MESA_SHADER_TESS_CTRL) {
+                       shader_info->tcs.outputs_written = ctx.tess_outputs_written;
+                       shader_info->tcs.patch_outputs_written = ctx.tess_patch_outputs_written;
+               } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX && ctx.options->key.vs.as_ls) {
+                       shader_info->vs.outputs_written = ctx.tess_outputs_written;
+               }
        }
 
+       LLVMBuildRetVoid(ctx.builder);
+
+       ac_llvm_finalize_module(&ctx);
+
+       if (shader_count == 1)
+               ac_nir_eliminate_const_vs_outputs(&ctx);
+
        return ctx.module;
 }
 
@@ -6540,53 +6812,61 @@ static void ac_compile_llvm_module(LLVMTargetMachineRef tm,
                                 shader_info->num_input_sgprs + 3);
 }
 
+static void
+ac_fill_shader_info(struct ac_shader_variant_info *shader_info, struct nir_shader *nir, const struct ac_nir_compiler_options *options)
+{
+        switch (nir->info.stage) {
+        case MESA_SHADER_COMPUTE:
+                for (int i = 0; i < 3; ++i)
+                        shader_info->cs.block_size[i] = nir->info.cs.local_size[i];
+                break;
+        case MESA_SHADER_FRAGMENT:
+                shader_info->fs.early_fragment_test = nir->info.fs.early_fragment_tests;
+                break;
+        case MESA_SHADER_GEOMETRY:
+                shader_info->gs.vertices_in = nir->info.gs.vertices_in;
+                shader_info->gs.vertices_out = nir->info.gs.vertices_out;
+                shader_info->gs.output_prim = nir->info.gs.output_primitive;
+                shader_info->gs.invocations = nir->info.gs.invocations;
+                break;
+        case MESA_SHADER_TESS_EVAL:
+                shader_info->tes.primitive_mode = nir->info.tess.primitive_mode;
+                shader_info->tes.spacing = nir->info.tess.spacing;
+                shader_info->tes.ccw = nir->info.tess.ccw;
+                shader_info->tes.point_mode = nir->info.tess.point_mode;
+                shader_info->tes.as_es = options->key.tes.as_es;
+                break;
+        case MESA_SHADER_TESS_CTRL:
+                shader_info->tcs.tcs_vertices_out = nir->info.tess.tcs_vertices_out;
+                break;
+        case MESA_SHADER_VERTEX:
+                shader_info->vs.as_es = options->key.vs.as_es;
+                shader_info->vs.as_ls = options->key.vs.as_ls;
+                /* in LS mode we need at least 1, invocation id needs 3, handled elsewhere */
+                if (options->key.vs.as_ls)
+                        shader_info->vs.vgpr_comp_cnt = MAX2(1, shader_info->vs.vgpr_comp_cnt);
+                break;
+        default:
+                break;
+        }
+}
+
 void ac_compile_nir_shader(LLVMTargetMachineRef tm,
                            struct ac_shader_binary *binary,
                            struct ac_shader_config *config,
                            struct ac_shader_variant_info *shader_info,
-                           struct nir_shader *nir,
+                           struct nir_shader *const *nir,
+                           int nir_count,
                            const struct ac_nir_compiler_options *options,
                           bool dump_shader)
 {
 
-       LLVMModuleRef llvm_module = ac_translate_nir_to_llvm(tm, nir, shader_info,
+       LLVMModuleRef llvm_module = ac_translate_nir_to_llvm(tm, nir, nir_count, shader_info,
                                                             options);
 
-       ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, nir->stage, dump_shader, options->supports_spill);
-       switch (nir->stage) {
-       case MESA_SHADER_COMPUTE:
-               for (int i = 0; i < 3; ++i)
-                       shader_info->cs.block_size[i] = nir->info.cs.local_size[i];
-               break;
-       case MESA_SHADER_FRAGMENT:
-               shader_info->fs.early_fragment_test = nir->info.fs.early_fragment_tests;
-               break;
-       case MESA_SHADER_GEOMETRY:
-               shader_info->gs.vertices_in = nir->info.gs.vertices_in;
-               shader_info->gs.vertices_out = nir->info.gs.vertices_out;
-               shader_info->gs.output_prim = nir->info.gs.output_primitive;
-               shader_info->gs.invocations = nir->info.gs.invocations;
-               break;
-       case MESA_SHADER_TESS_EVAL:
-               shader_info->tes.primitive_mode = nir->info.tess.primitive_mode;
-               shader_info->tes.spacing = nir->info.tess.spacing;
-               shader_info->tes.ccw = nir->info.tess.ccw;
-               shader_info->tes.point_mode = nir->info.tess.point_mode;
-               shader_info->tes.as_es = options->key.tes.as_es;
-               break;
-       case MESA_SHADER_TESS_CTRL:
-               shader_info->tcs.tcs_vertices_out = nir->info.tess.tcs_vertices_out;
-               break;
-       case MESA_SHADER_VERTEX:
-               shader_info->vs.as_es = options->key.vs.as_es;
-               shader_info->vs.as_ls = options->key.vs.as_ls;
-               /* in LS mode we need at least 1, invocation id needs 3, handled elsewhere */
-               if (options->key.vs.as_ls)
-                       shader_info->vs.vgpr_comp_cnt = MAX2(1, shader_info->vs.vgpr_comp_cnt);
-               break;
-       default:
-               break;
-       }
+       ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, nir[0]->info.stage, dump_shader, options->supports_spill);
+       for (int i = 0; i < nir_count; ++i)
+               ac_fill_shader_info(shader_info, nir[i], options);
 }
 
 static void
@@ -6663,7 +6943,7 @@ void ac_create_gs_copy_shader(LLVMTargetMachineRef tm,
        ctx.ac.builder = ctx.builder;
        ctx.stage = MESA_SHADER_VERTEX;
 
-       create_function(&ctx);
+       create_function(&ctx, MESA_SHADER_VERTEX, false, MESA_SHADER_VERTEX);
 
        ctx.gs_max_out_vertices = geom_shader->info.gs.vertices_out;
        ac_setup_rings(&ctx);
@@ -6679,7 +6959,7 @@ void ac_create_gs_copy_shader(LLVMTargetMachineRef tm,
        ctx.nir = &nir_ctx;
 
        nir_foreach_variable(variable, &geom_shader->outputs) {
-               scan_shader_output_decl(&ctx, variable);
+               scan_shader_output_decl(&ctx, variable, geom_shader, MESA_SHADER_VERTEX);
                handle_shader_output_decl(&nir_ctx, geom_shader, variable);
        }