nir: Get rid of nir_shader::stage
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index c6c56f30b8167dc57a66c67252cbc995f63a9d77..61ffe91eafd001d7d3e56ceb1c746f4dcc6260b1 100644 (file)
@@ -838,22 +838,61 @@ static void create_function(struct nir_to_llvm_context *ctx,
                add_vgpr_argument(&args, ctx->i32, &ctx->tes_patch_id); // tes patch id
                break;
        case MESA_SHADER_GEOMETRY:
-               radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
-               radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_ring_stride); // gsvs stride
-               add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_num_entries); // gsvs num entires
-               if (ctx->shader_info->info.needs_multiview_view_index)
-                       add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
-               add_sgpr_argument(&args, ctx->i32, &ctx->gs2vs_offset); // gs2vs offset
-               add_sgpr_argument(&args, ctx->i32, &ctx->gs_wave_id); // wave id
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[0]); // vtx0
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[1]); // vtx1
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_prim_id); // prim id
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[2]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[3]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[4]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[5]);
-               add_vgpr_argument(&args, ctx->i32, &ctx->gs_invocation_id);
+               if (has_previous_stage) {
+                       // First 6 system regs
+                       add_sgpr_argument(&args, ctx->i32, &ctx->gs2vs_offset); // tess factor offset
+                       add_sgpr_argument(&args, ctx->i32, &ctx->merged_wave_info); // merged wave info
+                       add_sgpr_argument(&args, ctx->i32, &ctx->oc_lds); // param oc lds
+
+                       add_sgpr_argument(&args, ctx->i32, NULL); // scratch offset
+                       add_sgpr_argument(&args, ctx->i32, NULL); // unknown
+                       add_sgpr_argument(&args, ctx->i32, NULL); // unknown
+
+                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+                       if (previous_stage == MESA_SHADER_TESS_EVAL)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->tcs_offchip_layout); // tcs offchip layout
+                       else
+                               radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_ring_stride); // gsvs stride
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_num_entries); // gsvs num entires
+                       if (ctx->shader_info->info.needs_multiview_view_index)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
+
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[0]); // vtx01
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[2]); // vtx23
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_prim_id); // prim id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_invocation_id);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[4]);
+
+                       if (previous_stage == MESA_SHADER_VERTEX) {
+                               add_vgpr_argument(&args, ctx->i32, &ctx->abi.vertex_id); // vertex id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->rel_auto_id); // rel auto id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->vs_prim_id); // vs prim id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->abi.instance_id); // instance id
+                       } else {
+                               add_vgpr_argument(&args, ctx->f32, &ctx->tes_u); // tes_u
+                               add_vgpr_argument(&args, ctx->f32, &ctx->tes_v); // tes_v
+                               add_vgpr_argument(&args, ctx->i32, &ctx->tes_rel_patch_id); // tes rel patch id
+                               add_vgpr_argument(&args, ctx->i32, &ctx->tes_patch_id); // tes patch id
+                       }
+               } else {
+                       radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
+                       radv_define_vs_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &args);
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_ring_stride); // gsvs stride
+                       add_user_sgpr_argument(&args, ctx->i32, &ctx->gsvs_num_entries); // gsvs num entires
+                       if (ctx->shader_info->info.needs_multiview_view_index)
+                               add_user_sgpr_argument(&args, ctx->i32, &ctx->view_index);
+                       add_sgpr_argument(&args, ctx->i32, &ctx->gs2vs_offset); // gs2vs offset
+                       add_sgpr_argument(&args, ctx->i32, &ctx->gs_wave_id); // wave id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[0]); // vtx0
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[1]); // vtx1
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_prim_id); // prim id
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[2]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[3]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[4]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_vtx_offset[5]);
+                       add_vgpr_argument(&args, ctx->i32, &ctx->gs_invocation_id);
+               }
                break;
        case MESA_SHADER_FRAGMENT:
                radv_define_common_user_sgprs_phase1(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_info, &args, &desc_sets);
@@ -949,10 +988,17 @@ static void create_function(struct nir_to_llvm_context *ctx,
                        set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
                break;
        case MESA_SHADER_GEOMETRY:
-               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
+               if (has_previous_stage) {
+                       if (previous_stage == MESA_SHADER_VERTEX)
+                               radv_define_vs_user_sgprs_phase2(ctx, stage, has_previous_stage, previous_stage, &user_sgpr_idx);
+                       else
+                               set_userdata_location_shader(ctx, AC_UD_TES_OFFCHIP_LAYOUT, &user_sgpr_idx, 1);
+               }
                set_userdata_location_shader(ctx, AC_UD_GS_VS_RING_STRIDE_ENTRIES, &user_sgpr_idx, 2);
                if (ctx->view_index)
                        set_userdata_location_shader(ctx, AC_UD_VIEW_INDEX, &user_sgpr_idx, 1);
+               if (has_previous_stage)
+                       declare_tess_lds(ctx);
                break;
        case MESA_SHADER_FRAGMENT:
                if (ctx->shader_info->info.ps.needs_sample_positions) {
@@ -2881,19 +2927,19 @@ store_tcs_output(struct nir_to_llvm_context *ctx,
        buf_addr = get_tcs_tes_buffer_address_params(ctx, param, const_index, is_compact,
                                                     vertex_index, indir_index);
 
+       bool is_tess_factor = false;
+       if (instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
+           instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER)
+               is_tess_factor = true;
+
        unsigned base = is_compact ? const_index : 0;
        for (unsigned chan = 0; chan < 8; chan++) {
-               bool is_tess_factor = false;
                if (!(writemask & (1 << chan)))
                        continue;
                LLVMValueRef value = llvm_extract_elem(&ctx->ac, src, chan);
 
                lds_store(ctx, dw_addr, value);
 
-               if (instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
-                   instr->variables[0]->var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER)
-                       is_tess_factor = true;
-
                if (!is_tess_factor && writemask != 0xF)
                        ac_build_buffer_store_dword(&ctx->ac, ctx->hs_ring_tess_offchip, value, 1,
                                                    buf_addr, ctx->oc_lds,
@@ -2962,21 +3008,27 @@ load_gs_input(struct nir_to_llvm_context *ctx,
 
        param = shader_io_get_unique_index(instr->variables[0]->var->data.location);
        for (unsigned i = 0; i < instr->num_components; i++) {
-
-               args[0] = ctx->esgs_ring;
-               args[1] = vtx_offset;
-               args[2] = LLVMConstInt(ctx->i32, (param * 4 + i + const_index) * 256, false);
-               args[3] = ctx->i32zero;
-               args[4] = ctx->i32one; /* OFFEN */
-               args[5] = ctx->i32zero; /* IDXEN */
-               args[6] = ctx->i32one; /* GLC */
-               args[7] = ctx->i32zero; /* SLC */
-               args[8] = ctx->i32zero; /* TFE */
-
-               value[i] = ac_build_intrinsic(&ctx->ac, "llvm.SI.buffer.load.dword.i32.i32",
-                                             ctx->i32, args, 9,
-                                             AC_FUNC_ATTR_READONLY |
-                                             AC_FUNC_ATTR_LEGACY);
+               if (ctx->ac.chip_class >= GFX9) {
+                       LLVMValueRef dw_addr = ctx->gs_vtx_offset[vtx_offset_param];
+                       dw_addr = LLVMBuildAdd(ctx->ac.builder, dw_addr,
+                                              LLVMConstInt(ctx->ac.i32, param * 4 + i, 0), "");
+                       value[i] = lds_load(ctx, dw_addr);
+               } else {
+                       args[0] = ctx->esgs_ring;
+                       args[1] = vtx_offset;
+                       args[2] = LLVMConstInt(ctx->i32, (param * 4 + i + const_index) * 256, false);
+                       args[3] = ctx->i32zero;
+                       args[4] = ctx->i32one; /* OFFEN */
+                       args[5] = ctx->i32zero; /* IDXEN */
+                       args[6] = ctx->i32one; /* GLC */
+                       args[7] = ctx->i32zero; /* SLC */
+                       args[8] = ctx->i32zero; /* TFE */
+
+                       value[i] = ac_build_intrinsic(&ctx->ac, "llvm.SI.buffer.load.dword.i32.i32",
+                                                     ctx->i32, args, 9,
+                                                     AC_FUNC_ATTR_READONLY |
+                                                     AC_FUNC_ATTR_LEGACY);
+               }
        }
        result = ac_build_gather_values(&ctx->ac, value, instr->num_components);
 
@@ -5808,8 +5860,9 @@ handle_es_outputs_post(struct nir_to_llvm_context *ctx,
 {
        int j;
        uint64_t max_output_written = 0;
+       LLVMValueRef lds_base = NULL;
+
        for (unsigned i = 0; i < RADEON_LLVM_MAX_OUTPUTS; ++i) {
-               LLVMValueRef *out_ptr = &ctx->nir->outputs[i * 4];
                int param_index;
                int length = 4;
 
@@ -5822,20 +5875,60 @@ handle_es_outputs_post(struct nir_to_llvm_context *ctx,
                param_index = shader_io_get_unique_index(i);
 
                max_output_written = MAX2(param_index + (length > 4), max_output_written);
+       }
+
+       outinfo->esgs_itemsize = (max_output_written + 1) * 16;
+
+       if (ctx->ac.chip_class  >= GFX9) {
+               unsigned itemsize_dw = outinfo->esgs_itemsize / 4;
+               LLVMValueRef vertex_idx = ac_get_thread_id(&ctx->ac);
+               LLVMValueRef wave_idx = ac_build_bfe(&ctx->ac, ctx->merged_wave_info,
+                                                    LLVMConstInt(ctx->ac.i32, 24, false),
+                                                    LLVMConstInt(ctx->ac.i32, 4, false), false);
+               vertex_idx = LLVMBuildOr(ctx->ac.builder, vertex_idx,
+                                        LLVMBuildMul(ctx->ac.builder, wave_idx,
+                                                     LLVMConstInt(ctx->i32, 64, false), ""), "");
+               lds_base = LLVMBuildMul(ctx->ac.builder, vertex_idx,
+                                       LLVMConstInt(ctx->i32, itemsize_dw, 0), "");
+       }
+
+       for (unsigned i = 0; i < RADEON_LLVM_MAX_OUTPUTS; ++i) {
+               LLVMValueRef dw_addr;
+               LLVMValueRef *out_ptr = &ctx->nir->outputs[i * 4];
+               int param_index;
+               int length = 4;
+
+               if (!(ctx->output_mask & (1ull << i)))
+                       continue;
+
+               if (i == VARYING_SLOT_CLIP_DIST0)
+                       length = ctx->num_output_clips + ctx->num_output_culls;
 
+               param_index = shader_io_get_unique_index(i);
+
+               if (lds_base) {
+                       dw_addr = LLVMBuildAdd(ctx->builder, lds_base,
+                                              LLVMConstInt(ctx->i32, param_index * 4, false),
+                                              "");
+               }
                for (j = 0; j < length; j++) {
                        LLVMValueRef out_val = LLVMBuildLoad(ctx->builder, out_ptr[j], "");
                        out_val = LLVMBuildBitCast(ctx->builder, out_val, ctx->i32, "");
 
-                       ac_build_buffer_store_dword(&ctx->ac,
-                                              ctx->esgs_ring,
-                                              out_val, 1,
-                                              NULL, ctx->es2gs_offset,
-                                              (4 * param_index + j) * 4,
-                                              1, 1, true, true);
+                       if (ctx->ac.chip_class  >= GFX9) {
+                               lds_store(ctx, dw_addr,
+                                         LLVMBuildLoad(ctx->builder, out_ptr[j], ""));
+                               dw_addr = LLVMBuildAdd(ctx->builder, dw_addr, ctx->i32one, "");
+                       } else {
+                               ac_build_buffer_store_dword(&ctx->ac,
+                                                           ctx->esgs_ring,
+                                                           out_val, 1,
+                                                           NULL, ctx->es2gs_offset,
+                                                           (4 * param_index + j) * 4,
+                                                           1, 1, true, true);
+                       }
                }
        }
-       outinfo->esgs_itemsize = (max_output_written + 1) * 16;
 }
 
 static void
@@ -6047,26 +6140,31 @@ write_tess_factors(struct nir_to_llvm_context *ctx)
        tf_base = ctx->tess_factor_offset;
        byteoffset = LLVMBuildMul(ctx->builder, rel_patch_id,
                                  LLVMConstInt(ctx->i32, 4 * stride, false), "");
+       unsigned tf_offset = 0;
 
-       ac_nir_build_if(&inner_if_ctx, ctx,
-                   LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                 rel_patch_id, ctx->i32zero, ""));
+       if (ctx->options->chip_class <= VI) {
+               ac_nir_build_if(&inner_if_ctx, ctx,
+                               LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                                             rel_patch_id, ctx->i32zero, ""));
 
-       /* Store the dynamic HS control word. */
-       ac_build_buffer_store_dword(&ctx->ac, buffer,
-                                   LLVMConstInt(ctx->i32, 0x80000000, false),
-                                   1, ctx->i32zero, tf_base,
-                                   0, 1, 0, true, false);
-       ac_nir_build_endif(&inner_if_ctx);
+               /* Store the dynamic HS control word. */
+               ac_build_buffer_store_dword(&ctx->ac, buffer,
+                                           LLVMConstInt(ctx->i32, 0x80000000, false),
+                                           1, ctx->i32zero, tf_base,
+                                           0, 1, 0, true, false);
+               tf_offset += 4;
+
+               ac_nir_build_endif(&inner_if_ctx);
+       }
 
        /* Store the tessellation factors. */
        ac_build_buffer_store_dword(&ctx->ac, buffer, vec0,
                                    MIN2(stride, 4), byteoffset, tf_base,
-                                   4, 1, 0, true, false);
+                                   tf_offset, 1, 0, true, false);
        if (vec1)
                ac_build_buffer_store_dword(&ctx->ac, buffer, vec1,
                                            stride - 4, byteoffset, tf_base,
-                                           20, 1, 0, true, false);
+                                           16 + tf_offset, 1, 0, true, false);
 
        //TODO store to offchip for TES to read - only if TES reads them
        if (1) {
@@ -6355,7 +6453,7 @@ static unsigned
 ac_nir_get_max_workgroup_size(enum chip_class chip_class,
                              const struct nir_shader *nir)
 {
-       switch (nir->stage) {
+       switch (nir->info.stage) {
        case MESA_SHADER_TESS_CTRL:
                return chip_class >= CIK ? 128 : 64;
        case MESA_SHADER_GEOMETRY:
@@ -6372,6 +6470,33 @@ ac_nir_get_max_workgroup_size(enum chip_class chip_class,
        return max_workgroup_size;
 }
 
+/* Fixup the HW not emitting the TCS regs if there are no HS threads. */
+static void ac_nir_fixup_ls_hs_input_vgprs(struct nir_to_llvm_context *ctx)
+{
+       LLVMValueRef count = ac_build_bfe(&ctx->ac, ctx->merged_wave_info,
+                                         LLVMConstInt(ctx->ac.i32, 8, false),
+                                         LLVMConstInt(ctx->ac.i32, 8, false), false);
+       LLVMValueRef hs_empty = LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, count,
+                                             LLVMConstInt(ctx->ac.i32, 0, false), "");
+       ctx->abi.instance_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->rel_auto_id, ctx->abi.instance_id, "");
+       ctx->vs_prim_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->abi.vertex_id, ctx->vs_prim_id, "");
+       ctx->rel_auto_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->tcs_rel_ids, ctx->rel_auto_id, "");
+       ctx->abi.vertex_id = LLVMBuildSelect(ctx->ac.builder, hs_empty, ctx->tcs_patch_id, ctx->abi.vertex_id, "");
+}
+
+static void prepare_gs_input_vgprs(struct nir_to_llvm_context *ctx)
+{
+       for(int i = 5; i >= 0; --i) {
+               ctx->gs_vtx_offset[i] = ac_build_bfe(&ctx->ac, ctx->gs_vtx_offset[i & ~1],
+                                                    LLVMConstInt(ctx->ac.i32, (i & 1) * 16, false),
+                                                    LLVMConstInt(ctx->ac.i32, 16, false), false);
+       }
+
+       ctx->gs_wave_id = ac_build_bfe(&ctx->ac, ctx->merged_wave_info,
+                                      LLVMConstInt(ctx->ac.i32, 16, false),
+                                      LLVMConstInt(ctx->ac.i32, 8, false), false);
+}
+
 void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
                      struct nir_shader *nir, struct nir_to_llvm_context *nctx)
 {
@@ -6385,7 +6510,7 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
        if (nctx)
                nctx->nir = &ctx;
 
-       ctx.stage = nir->stage;
+       ctx.stage = nir->info.stage;
 
        ctx.main_function = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx.ac.builder));
 
@@ -6403,7 +6528,7 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
 
        setup_locals(&ctx, func);
 
-       if (nir->stage == MESA_SHADER_COMPUTE)
+       if (nir->info.stage == MESA_SHADER_COMPUTE)
                setup_shared(&ctx, nir);
 
        visit_cf_list(&ctx, &func->impl->body);
@@ -6423,7 +6548,8 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
 
 static
 LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
-                                       struct nir_shader *nir,
+                                       struct nir_shader *const *shaders,
+                                       int shader_count,
                                        struct ac_shader_variant_info *shader_info,
                                        const struct ac_nir_compiler_options *options)
 {
@@ -6436,11 +6562,6 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
 
        ac_llvm_context_init(&ctx.ac, ctx.context, options->chip_class);
        ctx.ac.module = ctx.module;
-
-       memset(shader_info, 0, sizeof(*shader_info));
-
-       ac_nir_shader_info_pass(nir, options, &shader_info->info);
-
        LLVMSetTarget(ctx.module, options->supports_spill ? "amdgcn-mesa-mesa3d" : "amdgcn--");
 
        LLVMTargetDataRef data_layout = LLVMCreateTargetDataLayout(tm);
@@ -6450,72 +6571,118 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
        LLVMDisposeMessage(data_layout_str);
 
        setup_types(&ctx);
-
        ctx.builder = LLVMCreateBuilderInContext(ctx.context);
        ctx.ac.builder = ctx.builder;
-       ctx.stage = nir->stage;
-       ctx.max_workgroup_size = ac_nir_get_max_workgroup_size(ctx.options->chip_class, nir);
+
+       memset(shader_info, 0, sizeof(*shader_info));
+
+       for(int i = 0; i < shader_count; ++i)
+               ac_nir_shader_info_pass(shaders[i], options, &shader_info->info);
 
        for (i = 0; i < AC_UD_MAX_SETS; i++)
                shader_info->user_sgprs_locs.descriptor_sets[i].sgpr_idx = -1;
        for (i = 0; i < AC_UD_MAX_UD; i++)
                shader_info->user_sgprs_locs.shader_data[i].sgpr_idx = -1;
 
-       create_function(&ctx, nir->stage, false, MESA_SHADER_VERTEX);
+       ctx.max_workgroup_size = ac_nir_get_max_workgroup_size(ctx.options->chip_class, shaders[0]);
 
-       if (nir->stage == MESA_SHADER_GEOMETRY) {
-               ctx.gs_next_vertex = ac_build_alloca(&ctx.ac, ctx.i32, "gs_next_vertex");
+       create_function(&ctx, shaders[shader_count - 1]->info.stage, shader_count >= 2,
+                       shader_count >= 2 ? shaders[shader_count - 2]->info.stage  : MESA_SHADER_VERTEX);
+
+       ctx.abi.inputs = &ctx.inputs[0];
+       ctx.abi.emit_outputs = handle_shader_outputs_post;
+       ctx.abi.load_ssbo = radv_load_ssbo;
+       ctx.abi.load_sampler_desc = radv_get_sampler_desc;
 
-               ctx.gs_max_out_vertices = nir->info.gs.vertices_out;
-       } else if (nir->stage == MESA_SHADER_TESS_EVAL) {
-               ctx.tes_primitive_mode = nir->info.tess.primitive_mode;
-       } else if (nir->stage == MESA_SHADER_VERTEX) {
-               if (shader_info->info.vs.needs_instance_id) {
-                       ctx.shader_info->vs.vgpr_comp_cnt =
-                               MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+       if (shader_count >= 2)
+               ac_init_exec_full_mask(&ctx.ac);
+
+       if (ctx.ac.chip_class == GFX9 &&
+           shaders[shader_count - 1]->info.stage == MESA_SHADER_TESS_CTRL)
+               ac_nir_fixup_ls_hs_input_vgprs(&ctx);
+
+       for(int i = 0; i < shader_count; ++i) {
+               ctx.stage = shaders[i]->info.stage;
+               ctx.output_mask = 0;
+               ctx.tess_outputs_written = 0;
+               ctx.num_output_clips = shaders[i]->info.clip_distance_array_size;
+               ctx.num_output_culls = shaders[i]->info.cull_distance_array_size;
+
+               if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY) {
+                       ctx.gs_next_vertex = ac_build_alloca(&ctx.ac, ctx.i32, "gs_next_vertex");
+
+                       ctx.gs_max_out_vertices = shaders[i]->info.gs.vertices_out;
+               } else if (shaders[i]->info.stage == MESA_SHADER_TESS_EVAL) {
+                       ctx.tes_primitive_mode = shaders[i]->info.tess.primitive_mode;
+               } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX) {
+                       if (shader_info->info.vs.needs_instance_id) {
+                               ctx.shader_info->vs.vgpr_comp_cnt =
+                                       MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+                       }
+               } else if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT) {
+                       shader_info->fs.can_discard = shaders[i]->info.fs.uses_discard;
                }
-       } else if (nir->stage == MESA_SHADER_FRAGMENT) {
-               shader_info->fs.can_discard = nir->info.fs.uses_discard;
-       }
 
-       ac_setup_rings(&ctx);
+               if (i)
+                       emit_barrier(&ctx);
 
-       ctx.num_output_clips = nir->info.clip_distance_array_size;
-       ctx.num_output_culls = nir->info.cull_distance_array_size;
+               ac_setup_rings(&ctx);
 
-       if (nir->stage == MESA_SHADER_FRAGMENT)
-               handle_fs_inputs(&ctx, nir);
-       else if(nir->stage == MESA_SHADER_VERTEX)
-               handle_vs_inputs(&ctx, nir);
+               LLVMBasicBlockRef merge_block;
+               if (shader_count >= 2) {
+                       LLVMValueRef fn = LLVMGetBasicBlockParent(LLVMGetInsertBlock(ctx.ac.builder));
+                       LLVMBasicBlockRef then_block = LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
+                       merge_block = LLVMAppendBasicBlockInContext(ctx.ac.context, fn, "");
 
-       ctx.abi.inputs = &ctx.inputs[0];
-       ctx.abi.emit_outputs = handle_shader_outputs_post;
-       ctx.abi.load_ssbo = radv_load_ssbo;
-       ctx.abi.load_sampler_desc = radv_get_sampler_desc;
+                       LLVMValueRef count = ac_build_bfe(&ctx.ac, ctx.merged_wave_info,
+                                                         LLVMConstInt(ctx.ac.i32, 8 * i, false),
+                                                         LLVMConstInt(ctx.ac.i32, 8, false), false);
+                       LLVMValueRef thread_id = ac_get_thread_id(&ctx.ac);
+                       LLVMValueRef cond = LLVMBuildICmp(ctx.ac.builder, LLVMIntULT,
+                                                         thread_id, count, "");
+                       LLVMBuildCondBr(ctx.ac.builder, cond, then_block, merge_block);
 
-       nir_foreach_variable(variable, &nir->outputs)
-               scan_shader_output_decl(&ctx, variable, nir, nir->stage);
+                       LLVMPositionBuilderAtEnd(ctx.ac.builder, then_block);
+               }
 
-       ac_nir_translate(&ctx.ac, &ctx.abi, nir, &ctx);
+               if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT)
+                       handle_fs_inputs(&ctx, shaders[i]);
+               else if(shaders[i]->info.stage == MESA_SHADER_VERTEX)
+                       handle_vs_inputs(&ctx, shaders[i]);
+               else if(shader_count >= 2 && shaders[i]->info.stage == MESA_SHADER_GEOMETRY)
+                       prepare_gs_input_vgprs(&ctx);
 
-       LLVMBuildRetVoid(ctx.builder);
+               nir_foreach_variable(variable, &shaders[i]->outputs)
+                       scan_shader_output_decl(&ctx, variable, shaders[i], shaders[i]->info.stage);
 
-       ac_llvm_finalize_module(&ctx);
+               ac_nir_translate(&ctx.ac, &ctx.abi, shaders[i], &ctx);
 
-       ac_nir_eliminate_const_vs_outputs(&ctx);
+               if (shader_count >= 2) {
+                       LLVMBuildBr(ctx.ac.builder, merge_block);
+                       LLVMPositionBuilderAtEnd(ctx.ac.builder, merge_block);
+               }
 
-       if (nir->stage == MESA_SHADER_GEOMETRY) {
-               unsigned addclip = ctx.num_output_clips + ctx.num_output_culls > 4;
-               shader_info->gs.gsvs_vertex_size = (util_bitcount64(ctx.output_mask) + addclip) * 16;
-               shader_info->gs.max_gsvs_emit_size = shader_info->gs.gsvs_vertex_size *
-                       nir->info.gs.vertices_out;
-       } else if (nir->stage == MESA_SHADER_TESS_CTRL) {
-               shader_info->tcs.outputs_written = ctx.tess_outputs_written;
-               shader_info->tcs.patch_outputs_written = ctx.tess_patch_outputs_written;
-       } else if (nir->stage == MESA_SHADER_VERTEX && ctx.options->key.vs.as_ls) {
-               shader_info->vs.outputs_written = ctx.tess_outputs_written;
+               if (shaders[i]->info.stage == MESA_SHADER_GEOMETRY) {
+                       unsigned addclip = shaders[i]->info.clip_distance_array_size +
+                                       shaders[i]->info.cull_distance_array_size > 4;
+                       shader_info->gs.gsvs_vertex_size = (util_bitcount64(ctx.output_mask) + addclip) * 16;
+                       shader_info->gs.max_gsvs_emit_size = shader_info->gs.gsvs_vertex_size *
+                               shaders[i]->info.gs.vertices_out;
+               } else if (shaders[i]->info.stage == MESA_SHADER_TESS_CTRL) {
+                       shader_info->tcs.outputs_written = ctx.tess_outputs_written;
+                       shader_info->tcs.patch_outputs_written = ctx.tess_patch_outputs_written;
+               } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX && ctx.options->key.vs.as_ls) {
+                       shader_info->vs.outputs_written = ctx.tess_outputs_written;
+               }
        }
 
+       LLVMBuildRetVoid(ctx.builder);
+
+       ac_llvm_finalize_module(&ctx);
+
+       if (shader_count == 1)
+               ac_nir_eliminate_const_vs_outputs(&ctx);
+
        return ctx.module;
 }
 
@@ -6648,7 +6815,7 @@ static void ac_compile_llvm_module(LLVMTargetMachineRef tm,
 static void
 ac_fill_shader_info(struct ac_shader_variant_info *shader_info, struct nir_shader *nir, const struct ac_nir_compiler_options *options)
 {
-        switch (nir->stage) {
+        switch (nir->info.stage) {
         case MESA_SHADER_COMPUTE:
                 for (int i = 0; i < 3; ++i)
                         shader_info->cs.block_size[i] = nir->info.cs.local_size[i];
@@ -6694,10 +6861,10 @@ void ac_compile_nir_shader(LLVMTargetMachineRef tm,
                           bool dump_shader)
 {
 
-       LLVMModuleRef llvm_module = ac_translate_nir_to_llvm(tm, nir[0], shader_info,
+       LLVMModuleRef llvm_module = ac_translate_nir_to_llvm(tm, nir, nir_count, shader_info,
                                                             options);
 
-       ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, nir[0]->stage, dump_shader, options->supports_spill);
+       ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, nir[0]->info.stage, dump_shader, options->supports_spill);
        for (int i = 0; i < nir_count; ++i)
                ac_fill_shader_info(shader_info, nir[i], options);
 }