radeonsi: get rid of get_interp_param
[mesa.git] / src / gallium / drivers / radeonsi / si_shader.c
index 81c361e7ea7b4cedd73339d04f1056b24f618d09..f3b94d7ff76fc3b8243251695a52cd3b32700012 100644 (file)
@@ -40,6 +40,7 @@
 #include "tgsi/tgsi_util.h"
 #include "tgsi/tgsi_dump.h"
 
+#include "ac_llvm_util.h"
 #include "si_shader_internal.h"
 #include "si_pipe.h"
 #include "sid.h"
@@ -70,6 +71,17 @@ static void si_llvm_emit_barrier(const struct lp_build_tgsi_action *action,
 static void si_dump_shader_key(unsigned shader, union si_shader_key *key,
                               FILE *f);
 
+static void si_build_vs_prolog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key);
+static void si_build_vs_epilog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key);
+static void si_build_tcs_epilog_function(struct si_shader_context *ctx,
+                                        union si_shader_part_key *key);
+static void si_build_ps_prolog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key);
+static void si_build_ps_epilog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key);
+
 /* Ideally pass the sample mask input to the PS epilog as v13, which
  * is its usual location, so that the shader doesn't have to add v_mov.
  */
@@ -1171,45 +1183,6 @@ static int lookup_interp_param_index(unsigned interpolate, unsigned location)
        }
 }
 
-/* This shouldn't be used by explicit INTERP opcodes. */
-static unsigned select_interp_param(struct si_shader_context *ctx,
-                                   unsigned param)
-{
-       if (!ctx->no_prolog)
-               return param;
-
-       if (ctx->shader->key.ps.prolog.force_persp_sample_interp) {
-               switch (param) {
-               case SI_PARAM_PERSP_CENTROID:
-               case SI_PARAM_PERSP_CENTER:
-                       return SI_PARAM_PERSP_SAMPLE;
-               }
-       }
-       if (ctx->shader->key.ps.prolog.force_linear_sample_interp) {
-               switch (param) {
-               case SI_PARAM_LINEAR_CENTROID:
-               case SI_PARAM_LINEAR_CENTER:
-                       return SI_PARAM_LINEAR_SAMPLE;
-               }
-       }
-       if (ctx->shader->key.ps.prolog.force_persp_center_interp) {
-               switch (param) {
-               case SI_PARAM_PERSP_CENTROID:
-               case SI_PARAM_PERSP_SAMPLE:
-                       return SI_PARAM_PERSP_CENTER;
-               }
-       }
-       if (ctx->shader->key.ps.prolog.force_linear_center_interp) {
-               switch (param) {
-               case SI_PARAM_LINEAR_CENTROID:
-               case SI_PARAM_LINEAR_SAMPLE:
-                       return SI_PARAM_LINEAR_CENTER;
-               }
-       }
-
-       return param;
-}
-
 /**
  * Interpolate a fragment shader input.
  *
@@ -1327,56 +1300,6 @@ static void interp_fs_input(struct si_shader_context *ctx,
        }
 }
 
-/* LLVMGetParam with bc_optimize resolved. */
-static LLVMValueRef get_interp_param(struct si_shader_context *ctx,
-                                    int interp_param_idx)
-{
-       LLVMBuilderRef builder = ctx->gallivm.builder;
-       LLVMValueRef main_fn = ctx->main_fn;
-       LLVMValueRef param = NULL;
-
-       /* Handle PRIM_MASK[31] (bc_optimize). */
-       if (ctx->no_prolog &&
-           ((ctx->shader->key.ps.prolog.bc_optimize_for_persp &&
-             interp_param_idx == SI_PARAM_PERSP_CENTROID) ||
-            (ctx->shader->key.ps.prolog.bc_optimize_for_linear &&
-             interp_param_idx == SI_PARAM_LINEAR_CENTROID))) {
-               /* The shader should do: if (PRIM_MASK[31]) CENTROID = CENTER;
-                * The hw doesn't compute CENTROID if the whole wave only
-                * contains fully-covered quads.
-                */
-               LLVMValueRef bc_optimize =
-                       LLVMGetParam(main_fn, SI_PARAM_PRIM_MASK);
-               bc_optimize = LLVMBuildLShr(builder,
-                                           bc_optimize,
-                                           LLVMConstInt(ctx->i32, 31, 0), "");
-               bc_optimize = LLVMBuildTrunc(builder, bc_optimize, ctx->i1, "");
-
-               if (ctx->shader->key.ps.prolog.bc_optimize_for_persp &&
-                   interp_param_idx == SI_PARAM_PERSP_CENTROID) {
-                       param = LLVMBuildSelect(builder, bc_optimize,
-                                               LLVMGetParam(main_fn,
-                                                            SI_PARAM_PERSP_CENTER),
-                                               LLVMGetParam(main_fn,
-                                                            SI_PARAM_PERSP_CENTROID),
-                                               "");
-               }
-               if (ctx->shader->key.ps.prolog.bc_optimize_for_linear &&
-                   interp_param_idx == SI_PARAM_LINEAR_CENTROID) {
-                       param = LLVMBuildSelect(builder, bc_optimize,
-                                               LLVMGetParam(main_fn,
-                                                            SI_PARAM_LINEAR_CENTER),
-                                               LLVMGetParam(main_fn,
-                                                            SI_PARAM_LINEAR_CENTROID),
-                                               "");
-               }
-       }
-
-       if (!param)
-               param = LLVMGetParam(main_fn, interp_param_idx);
-       return param;
-}
-
 static void declare_input_fs(
        struct si_shader_context *radeon_bld,
        unsigned input_index,
@@ -1412,9 +1335,7 @@ static void declare_input_fs(
        if (interp_param_idx == -1)
                return;
        else if (interp_param_idx) {
-               interp_param_idx = select_interp_param(ctx,
-                                                      interp_param_idx);
-               interp_param = get_interp_param(ctx, interp_param_idx);
+               interp_param = LLVMGetParam(ctx->main_fn, interp_param_idx);
        }
 
        if (decl->Semantic.Name == TGSI_SEMANTIC_COLOR &&
@@ -2473,6 +2394,10 @@ handle_semantic:
        }
 }
 
+/**
+ * Forward all outputs from the vertex shader to the TES. This is only used
+ * for the fixed function TCS.
+ */
 static void si_copy_tcs_inputs(struct lp_build_tgsi_context *bld_base)
 {
        struct si_shader_context *ctx = si_shader_context(bld_base);
@@ -2627,6 +2552,8 @@ static void si_llvm_emit_tcs_epilogue(struct lp_build_tgsi_context *bld_base)
        struct si_shader_context *ctx = si_shader_context(bld_base);
        LLVMValueRef rel_patch_id, invocation_id, tf_lds_offset;
 
+       si_copy_tcs_inputs(bld_base);
+
        rel_patch_id = get_rel_patch_id(ctx);
        invocation_id = unpack_param(ctx, SI_PARAM_REL_IDS, 8, 5);
        tf_lds_offset = get_tcs_out_current_patch_data_offset(ctx);
@@ -2669,7 +2596,6 @@ static void si_llvm_emit_tcs_epilogue(struct lp_build_tgsi_context *bld_base)
                return;
        }
 
-       si_copy_tcs_inputs(bld_base);
        si_write_tess_factors(bld_base, rel_patch_id, invocation_id, tf_lds_offset);
 }
 
@@ -5155,7 +5081,7 @@ static void build_interp_intrinsic(const struct lp_build_tgsi_action *action,
        if (interp_param_idx == -1)
                return;
        else if (interp_param_idx)
-               interp_param = get_interp_param(ctx, interp_param_idx);
+               interp_param = LLVMGetParam(ctx->main_fn, interp_param_idx);
        else
                interp_param = NULL;
 
@@ -5491,6 +5417,7 @@ static void create_function(struct si_shader_context *ctx)
        LLVMTypeRef returns[16+32*4];
        unsigned i, last_sgpr, num_params, num_return_sgprs;
        unsigned num_returns = 0;
+       unsigned num_prolog_vgprs = 0;
 
        v3i32 = LLVMVectorType(ctx->i32, 3);
 
@@ -5541,6 +5468,8 @@ static void create_function(struct si_shader_context *ctx)
 
                        for (i = 0; i < shader->selector->info.num_inputs; i++)
                                params[num_params++] = ctx->i32;
+
+                       num_prolog_vgprs += shader->selector->info.num_inputs;
                }
 
                if (!ctx->no_epilog &&
@@ -5640,6 +5569,7 @@ static void create_function(struct si_shader_context *ctx)
                params[SI_PARAM_POS_Z_FLOAT] = ctx->f32;
                params[SI_PARAM_POS_W_FLOAT] = ctx->f32;
                params[SI_PARAM_FRONT_FACE] = ctx->i32;
+               shader->info.face_vgpr_index = 20;
                params[SI_PARAM_ANCILLARY] = ctx->i32;
                params[SI_PARAM_SAMPLE_COVERAGE] = ctx->f32;
                params[SI_PARAM_POS_FIXED_PT] = ctx->i32;
@@ -5654,6 +5584,8 @@ static void create_function(struct si_shader_context *ctx)
                                assert(num_params + num_color_elements <= ARRAY_SIZE(params));
                                for (i = 0; i < num_color_elements; i++)
                                        params[num_params++] = ctx->f32;
+
+                               num_prolog_vgprs += num_color_elements;
                        }
                }
 
@@ -5736,12 +5668,11 @@ static void create_function(struct si_shader_context *ctx)
        for (i = 0; i <= last_sgpr; ++i)
                shader->info.num_input_sgprs += llvm_get_type_size(params[i]) / 4;
 
-       /* Unused fragment shader inputs are eliminated by the compiler,
-        * so we don't know yet how many there will be.
-        */
-       if (ctx->type != PIPE_SHADER_FRAGMENT)
-               for (; i < num_params; ++i)
-                       shader->info.num_input_vgprs += llvm_get_type_size(params[i]) / 4;
+       for (; i < num_params; ++i)
+               shader->info.num_input_vgprs += llvm_get_type_size(params[i]) / 4;
+
+       assert(shader->info.num_input_vgprs >= num_prolog_vgprs);
+       shader->info.num_input_vgprs -= num_prolog_vgprs;
 
        if (!ctx->screen->has_ds_bpermute &&
            bld_base->info &&
@@ -6758,6 +6689,386 @@ static bool si_compile_tgsi_main(struct si_shader_context *ctx,
        return true;
 }
 
+/**
+ * Compute the VS prolog key, which contains all the information needed to
+ * build the VS prolog function, and set shader->info bits where needed.
+ */
+static void si_get_vs_prolog_key(struct si_shader *shader,
+                                union si_shader_part_key *key)
+{
+       struct tgsi_shader_info *info = &shader->selector->info;
+
+       memset(key, 0, sizeof(*key));
+       key->vs_prolog.states = shader->key.vs.prolog;
+       key->vs_prolog.num_input_sgprs = shader->info.num_input_sgprs;
+       key->vs_prolog.last_input = MAX2(1, info->num_inputs) - 1;
+
+       /* Set the instanceID flag. */
+       for (unsigned i = 0; i < info->num_inputs; i++)
+               if (key->vs_prolog.states.instance_divisors[i])
+                       shader->info.uses_instanceid = true;
+}
+
+/**
+ * Compute the VS epilog key, which contains all the information needed to
+ * build the VS epilog function, and set the PrimitiveID output offset.
+ */
+static void si_get_vs_epilog_key(struct si_shader *shader,
+                                struct si_vs_epilog_bits *states,
+                                union si_shader_part_key *key)
+{
+       memset(key, 0, sizeof(*key));
+       key->vs_epilog.states = *states;
+
+       /* Set up the PrimitiveID output. */
+       if (shader->key.vs.epilog.export_prim_id) {
+               unsigned index = shader->selector->info.num_outputs;
+               unsigned offset = shader->info.nr_param_exports++;
+
+               key->vs_epilog.prim_id_param_offset = offset;
+               assert(index < ARRAY_SIZE(shader->info.vs_output_param_offset));
+               shader->info.vs_output_param_offset[index] = offset;
+       }
+}
+
+/**
+ * Compute the PS prolog key, which contains all the information needed to
+ * build the PS prolog function, and set related bits in shader->config.
+ */
+static void si_get_ps_prolog_key(struct si_shader *shader,
+                                union si_shader_part_key *key,
+                                bool separate_prolog)
+{
+       struct tgsi_shader_info *info = &shader->selector->info;
+
+       memset(key, 0, sizeof(*key));
+       key->ps_prolog.states = shader->key.ps.prolog;
+       key->ps_prolog.colors_read = info->colors_read;
+       key->ps_prolog.num_input_sgprs = shader->info.num_input_sgprs;
+       key->ps_prolog.num_input_vgprs = shader->info.num_input_vgprs;
+       key->ps_prolog.wqm = info->uses_derivatives &&
+               (key->ps_prolog.colors_read ||
+                key->ps_prolog.states.force_persp_sample_interp ||
+                key->ps_prolog.states.force_linear_sample_interp ||
+                key->ps_prolog.states.force_persp_center_interp ||
+                key->ps_prolog.states.force_linear_center_interp ||
+                key->ps_prolog.states.bc_optimize_for_persp ||
+                key->ps_prolog.states.bc_optimize_for_linear);
+
+       if (info->colors_read) {
+               unsigned *color = shader->selector->color_attr_index;
+
+               if (shader->key.ps.prolog.color_two_side) {
+                       /* BCOLORs are stored after the last input. */
+                       key->ps_prolog.num_interp_inputs = info->num_inputs;
+                       key->ps_prolog.face_vgpr_index = shader->info.face_vgpr_index;
+                       shader->config.spi_ps_input_ena |= S_0286CC_FRONT_FACE_ENA(1);
+               }
+
+               for (unsigned i = 0; i < 2; i++) {
+                       unsigned interp = info->input_interpolate[color[i]];
+                       unsigned location = info->input_interpolate_loc[color[i]];
+
+                       if (!(info->colors_read & (0xf << i*4)))
+                               continue;
+
+                       key->ps_prolog.color_attr_index[i] = color[i];
+
+                       if (shader->key.ps.prolog.flatshade_colors &&
+                           interp == TGSI_INTERPOLATE_COLOR)
+                               interp = TGSI_INTERPOLATE_CONSTANT;
+
+                       switch (interp) {
+                       case TGSI_INTERPOLATE_CONSTANT:
+                               key->ps_prolog.color_interp_vgpr_index[i] = -1;
+                               break;
+                       case TGSI_INTERPOLATE_PERSPECTIVE:
+                       case TGSI_INTERPOLATE_COLOR:
+                               /* Force the interpolation location for colors here. */
+                               if (shader->key.ps.prolog.force_persp_sample_interp)
+                                       location = TGSI_INTERPOLATE_LOC_SAMPLE;
+                               if (shader->key.ps.prolog.force_persp_center_interp)
+                                       location = TGSI_INTERPOLATE_LOC_CENTER;
+
+                               switch (location) {
+                               case TGSI_INTERPOLATE_LOC_SAMPLE:
+                                       key->ps_prolog.color_interp_vgpr_index[i] = 0;
+                                       shader->config.spi_ps_input_ena |=
+                                               S_0286CC_PERSP_SAMPLE_ENA(1);
+                                       break;
+                               case TGSI_INTERPOLATE_LOC_CENTER:
+                                       key->ps_prolog.color_interp_vgpr_index[i] = 2;
+                                       shader->config.spi_ps_input_ena |=
+                                               S_0286CC_PERSP_CENTER_ENA(1);
+                                       break;
+                               case TGSI_INTERPOLATE_LOC_CENTROID:
+                                       key->ps_prolog.color_interp_vgpr_index[i] = 4;
+                                       shader->config.spi_ps_input_ena |=
+                                               S_0286CC_PERSP_CENTROID_ENA(1);
+                                       break;
+                               default:
+                                       assert(0);
+                               }
+                               break;
+                       case TGSI_INTERPOLATE_LINEAR:
+                               /* Force the interpolation location for colors here. */
+                               if (shader->key.ps.prolog.force_linear_sample_interp)
+                                       location = TGSI_INTERPOLATE_LOC_SAMPLE;
+                               if (shader->key.ps.prolog.force_linear_center_interp)
+                                       location = TGSI_INTERPOLATE_LOC_CENTER;
+
+                               /* The VGPR assignment for non-monolithic shaders
+                                * works because InitialPSInputAddr is set on the
+                                * main shader and PERSP_PULL_MODEL is never used.
+                                */
+                               switch (location) {
+                               case TGSI_INTERPOLATE_LOC_SAMPLE:
+                                       key->ps_prolog.color_interp_vgpr_index[i] =
+                                               separate_prolog ? 6 : 9;
+                                       shader->config.spi_ps_input_ena |=
+                                               S_0286CC_LINEAR_SAMPLE_ENA(1);
+                                       break;
+                               case TGSI_INTERPOLATE_LOC_CENTER:
+                                       key->ps_prolog.color_interp_vgpr_index[i] =
+                                               separate_prolog ? 8 : 11;
+                                       shader->config.spi_ps_input_ena |=
+                                               S_0286CC_LINEAR_CENTER_ENA(1);
+                                       break;
+                               case TGSI_INTERPOLATE_LOC_CENTROID:
+                                       key->ps_prolog.color_interp_vgpr_index[i] =
+                                               separate_prolog ? 10 : 13;
+                                       shader->config.spi_ps_input_ena |=
+                                               S_0286CC_LINEAR_CENTROID_ENA(1);
+                                       break;
+                               default:
+                                       assert(0);
+                               }
+                               break;
+                       default:
+                               assert(0);
+                       }
+               }
+       }
+}
+
+/**
+ * Check whether a PS prolog is required based on the key.
+ */
+static bool si_need_ps_prolog(const union si_shader_part_key *key)
+{
+       return key->ps_prolog.colors_read ||
+              key->ps_prolog.states.force_persp_sample_interp ||
+              key->ps_prolog.states.force_linear_sample_interp ||
+              key->ps_prolog.states.force_persp_center_interp ||
+              key->ps_prolog.states.force_linear_center_interp ||
+              key->ps_prolog.states.bc_optimize_for_persp ||
+              key->ps_prolog.states.bc_optimize_for_linear ||
+              key->ps_prolog.states.poly_stipple;
+}
+
+/**
+ * Compute the PS epilog key, which contains all the information needed to
+ * build the PS epilog function.
+ */
+static void si_get_ps_epilog_key(struct si_shader *shader,
+                                union si_shader_part_key *key)
+{
+       struct tgsi_shader_info *info = &shader->selector->info;
+       memset(key, 0, sizeof(*key));
+       key->ps_epilog.colors_written = info->colors_written;
+       key->ps_epilog.writes_z = info->writes_z;
+       key->ps_epilog.writes_stencil = info->writes_stencil;
+       key->ps_epilog.writes_samplemask = info->writes_samplemask;
+       key->ps_epilog.states = shader->key.ps.epilog;
+}
+
+/**
+ * Given a list of shader part functions, build a wrapper function that
+ * runs them in sequence to form a monolithic shader.
+ */
+static void si_build_wrapper_function(struct si_shader_context *ctx,
+                                     LLVMValueRef *parts,
+                                     unsigned num_parts,
+                                     unsigned main_part)
+{
+       struct gallivm_state *gallivm = &ctx->gallivm;
+       LLVMBuilderRef builder = ctx->gallivm.builder;
+       /* PS epilog has one arg per color component */
+       LLVMTypeRef param_types[48];
+       LLVMValueRef out[48];
+       LLVMTypeRef function_type;
+       unsigned num_params;
+       unsigned num_out_sgpr, num_out;
+       unsigned num_sgprs, num_vgprs;
+       unsigned last_sgpr_param;
+       unsigned gprs;
+
+       for (unsigned i = 0; i < num_parts; ++i) {
+               LLVMAddFunctionAttr(parts[i], LLVMAlwaysInlineAttribute);
+               LLVMSetLinkage(parts[i], LLVMPrivateLinkage);
+       }
+
+       /* The parameters of the wrapper function correspond to those of the
+        * first part in terms of SGPRs and VGPRs, but we use the types of the
+        * main part to get the right types. This is relevant for the
+        * dereferenceable attribute on descriptor table pointers.
+        */
+       num_sgprs = 0;
+       num_vgprs = 0;
+
+       function_type = LLVMGetElementType(LLVMTypeOf(parts[0]));
+       num_params = LLVMCountParamTypes(function_type);
+
+       for (unsigned i = 0; i < num_params; ++i) {
+               LLVMValueRef param = LLVMGetParam(parts[0], i);
+
+               if (ac_is_sgpr_param(param)) {
+                       assert(num_vgprs == 0);
+                       num_sgprs += llvm_get_type_size(LLVMTypeOf(param)) / 4;
+               } else {
+                       num_vgprs += llvm_get_type_size(LLVMTypeOf(param)) / 4;
+               }
+       }
+       assert(num_vgprs + num_sgprs <= ARRAY_SIZE(param_types));
+
+       num_params = 0;
+       last_sgpr_param = 0;
+       gprs = 0;
+       while (gprs < num_sgprs + num_vgprs) {
+               LLVMValueRef param = LLVMGetParam(parts[main_part], num_params);
+               unsigned size;
+
+               param_types[num_params] = LLVMTypeOf(param);
+               if (gprs < num_sgprs)
+                       last_sgpr_param = num_params;
+               size = llvm_get_type_size(param_types[num_params]) / 4;
+               num_params++;
+
+               assert(ac_is_sgpr_param(param) == (gprs < num_sgprs));
+               assert(gprs + size <= num_sgprs + num_vgprs &&
+                      (gprs >= num_sgprs || gprs + size <= num_sgprs));
+
+               gprs += size;
+       }
+
+       si_create_function(ctx, "wrapper", NULL, 0, param_types, num_params, last_sgpr_param);
+
+       /* Record the arguments of the function as if they were an output of
+        * a previous part.
+        */
+       num_out = 0;
+       num_out_sgpr = 0;
+
+       for (unsigned i = 0; i < num_params; ++i) {
+               LLVMValueRef param = LLVMGetParam(ctx->main_fn, i);
+               LLVMTypeRef param_type = LLVMTypeOf(param);
+               LLVMTypeRef out_type = i <= last_sgpr_param ? ctx->i32 : ctx->f32;
+               unsigned size = llvm_get_type_size(param_type) / 4;
+
+               if (size == 1) {
+                       if (param_type != out_type)
+                               param = LLVMBuildBitCast(builder, param, out_type, "");
+                       out[num_out++] = param;
+               } else {
+                       LLVMTypeRef vector_type = LLVMVectorType(out_type, size);
+
+                       if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
+                               param = LLVMBuildPtrToInt(builder, param, ctx->i64, "");
+                               param_type = ctx->i64;
+                       }
+
+                       if (param_type != vector_type)
+                               param = LLVMBuildBitCast(builder, param, vector_type, "");
+
+                       for (unsigned j = 0; j < size; ++j)
+                               out[num_out++] = LLVMBuildExtractElement(
+                                       builder, param, LLVMConstInt(ctx->i32, j, 0), "");
+               }
+
+               if (i <= last_sgpr_param)
+                       num_out_sgpr = num_out;
+       }
+
+       /* Now chain the parts. */
+       for (unsigned part = 0; part < num_parts; ++part) {
+               LLVMValueRef in[48];
+               LLVMValueRef ret;
+               LLVMTypeRef ret_type;
+               unsigned out_idx = 0;
+
+               num_params = LLVMCountParams(parts[part]);
+               assert(num_params <= ARRAY_SIZE(param_types));
+
+               /* Derive arguments for the next part from outputs of the
+                * previous one.
+                */
+               for (unsigned param_idx = 0; param_idx < num_params; ++param_idx) {
+                       LLVMValueRef param;
+                       LLVMTypeRef param_type;
+                       bool is_sgpr;
+                       unsigned param_size;
+                       LLVMValueRef arg = NULL;
+
+                       param = LLVMGetParam(parts[part], param_idx);
+                       param_type = LLVMTypeOf(param);
+                       param_size = llvm_get_type_size(param_type) / 4;
+                       is_sgpr = ac_is_sgpr_param(param);
+
+                       if (is_sgpr) {
+                               LLVMRemoveAttribute(param, LLVMByValAttribute);
+                               LLVMAddAttribute(param, LLVMInRegAttribute);
+                       }
+
+                       assert(out_idx + param_size <= (is_sgpr ? num_out_sgpr : num_out));
+                       assert(is_sgpr || out_idx >= num_out_sgpr);
+
+                       if (param_size == 1)
+                               arg = out[out_idx];
+                       else
+                               arg = lp_build_gather_values(gallivm, &out[out_idx], param_size);
+
+                       if (LLVMTypeOf(arg) != param_type) {
+                               if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
+                                       arg = LLVMBuildBitCast(builder, arg, ctx->i64, "");
+                                       arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
+                               } else {
+                                       arg = LLVMBuildBitCast(builder, arg, param_type, "");
+                               }
+                       }
+
+                       in[param_idx] = arg;
+                       out_idx += param_size;
+               }
+
+               ret = LLVMBuildCall(builder, parts[part], in, num_params, "");
+               ret_type = LLVMTypeOf(ret);
+
+               /* Extract the returned GPRs. */
+               num_out = 0;
+               num_out_sgpr = 0;
+
+               if (LLVMGetTypeKind(ret_type) != LLVMVoidTypeKind) {
+                       assert(LLVMGetTypeKind(ret_type) == LLVMStructTypeKind);
+
+                       unsigned ret_size = LLVMCountStructElementTypes(ret_type);
+
+                       for (unsigned i = 0; i < ret_size; ++i) {
+                               LLVMValueRef val =
+                                       LLVMBuildExtractValue(builder, ret, i, "");
+
+                               out[num_out++] = val;
+
+                               if (LLVMTypeOf(val) == ctx->i32) {
+                                       assert(num_out_sgpr + 1 == num_out);
+                                       num_out_sgpr = num_out;
+                               }
+                       }
+               }
+       }
+
+       LLVMBuildRetVoid(builder);
+}
+
 int si_compile_tgsi_shader(struct si_screen *sscreen,
                           LLVMTargetMachineRef tm,
                           struct si_shader *shader,
@@ -6783,6 +7094,14 @@ int si_compile_tgsi_shader(struct si_screen *sscreen,
        ctx.no_epilog = is_monolithic;
        ctx.separate_prolog = !is_monolithic;
 
+       if (ctx.type == PIPE_SHADER_VERTEX ||
+           ctx.type == PIPE_SHADER_TESS_CTRL ||
+           ctx.type == PIPE_SHADER_TESS_EVAL ||
+           ctx.type == PIPE_SHADER_FRAGMENT) {
+               ctx.no_prolog = false;
+               ctx.no_epilog = false;
+       }
+
        memset(shader->info.vs_output_param_offset, 0xff,
               sizeof(shader->info.vs_output_param_offset));
 
@@ -6796,6 +7115,79 @@ int si_compile_tgsi_shader(struct si_screen *sscreen,
                return -1;
        }
 
+       if (is_monolithic && ctx.type == PIPE_SHADER_VERTEX) {
+               LLVMValueRef parts[3];
+               bool need_prolog;
+               bool need_epilog;
+
+               need_prolog = sel->info.num_inputs;
+               need_epilog = !shader->key.vs.as_es && !shader->key.vs.as_ls;
+
+               parts[need_prolog ? 1 : 0] = ctx.main_fn;
+
+               if (need_prolog) {
+                       union si_shader_part_key prolog_key;
+                       si_get_vs_prolog_key(shader, &prolog_key);
+                       si_build_vs_prolog_function(&ctx, &prolog_key);
+                       parts[0] = ctx.main_fn;
+               }
+
+               if (need_epilog) {
+                       union si_shader_part_key epilog_key;
+                       si_get_vs_epilog_key(shader, &shader->key.vs.epilog, &epilog_key);
+                       si_build_vs_epilog_function(&ctx, &epilog_key);
+                       parts[need_prolog ? 2 : 1] = ctx.main_fn;
+               }
+
+               si_build_wrapper_function(&ctx, parts, 1 + need_prolog + need_epilog,
+                                         need_prolog ? 1 : 0);
+       } else if (is_monolithic && ctx.type == PIPE_SHADER_TESS_CTRL) {
+               LLVMValueRef parts[2];
+               union si_shader_part_key epilog_key;
+
+               parts[0] = ctx.main_fn;
+
+               memset(&epilog_key, 0, sizeof(epilog_key));
+               epilog_key.tcs_epilog.states = shader->key.tcs.epilog;
+               si_build_tcs_epilog_function(&ctx, &epilog_key);
+               parts[1] = ctx.main_fn;
+
+               si_build_wrapper_function(&ctx, parts, 2, 0);
+       } else if (is_monolithic && ctx.type == PIPE_SHADER_TESS_EVAL &&
+                  !shader->key.tes.as_es) {
+               LLVMValueRef parts[2];
+               union si_shader_part_key epilog_key;
+
+               parts[0] = ctx.main_fn;
+
+               si_get_vs_epilog_key(shader, &shader->key.tes.epilog, &epilog_key);
+               si_build_vs_epilog_function(&ctx, &epilog_key);
+               parts[1] = ctx.main_fn;
+
+               si_build_wrapper_function(&ctx, parts, 2, 0);
+       } else if (is_monolithic && ctx.type == PIPE_SHADER_FRAGMENT) {
+               LLVMValueRef parts[3];
+               union si_shader_part_key prolog_key;
+               union si_shader_part_key epilog_key;
+               bool need_prolog;
+
+               si_get_ps_prolog_key(shader, &prolog_key, false);
+               need_prolog = si_need_ps_prolog(&prolog_key);
+
+               parts[need_prolog ? 1 : 0] = ctx.main_fn;
+
+               if (need_prolog) {
+                       si_build_ps_prolog_function(&ctx, &prolog_key);
+                       parts[0] = ctx.main_fn;
+               }
+
+               si_get_ps_epilog_key(shader, &epilog_key);
+               si_build_ps_epilog_function(&ctx, &epilog_key);
+               parts[need_prolog ? 2 : 1] = ctx.main_fn;
+
+               si_build_wrapper_function(&ctx, parts, need_prolog ? 3 : 2, need_prolog ? 1 : 0);
+       }
+
        mod = bld_base->base.gallivm->module;
 
        /* Dump LLVM IR before any optimization passes */
@@ -6969,11 +7361,11 @@ si_get_shader_part(struct si_screen *sscreen,
 }
 
 /**
- * Create a vertex shader prolog.
+ * Build the vertex shader prolog function.
  *
  * The inputs are the same as VS (a lot of SGPRs and 4 VGPR system values).
  * All inputs are returned unmodified. The vertex load indices are
- * stored after them, which will used by the API VS for fetching inputs.
+ * stored after them, which will be used by the API VS for fetching inputs.
  *
  * For example, the expected outputs for instance_divisors[] = {0, 1, 2} are:
  *   input_v0,
@@ -6984,24 +7376,16 @@ si_get_shader_part(struct si_screen *sscreen,
  *   (InstanceID + StartInstance),
  *   (InstanceID / 2 + StartInstance)
  */
-static bool si_compile_vs_prolog(struct si_screen *sscreen,
-                                LLVMTargetMachineRef tm,
-                                struct pipe_debug_callback *debug,
-                                struct si_shader_part *out)
+static void si_build_vs_prolog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key)
 {
-       union si_shader_part_key *key = &out->key;
-       struct si_shader shader = {};
-       struct si_shader_context ctx;
-       struct gallivm_state *gallivm = &ctx.gallivm;
+       struct gallivm_state *gallivm = &ctx->gallivm;
        LLVMTypeRef *params, *returns;
        LLVMValueRef ret, func;
        int last_sgpr, num_params, num_returns, i;
-       bool status = true;
 
-       si_init_shader_ctx(&ctx, sscreen, &shader, tm);
-       ctx.type = PIPE_SHADER_VERTEX;
-       ctx.param_vertex_id = key->vs_prolog.num_input_sgprs;
-       ctx.param_instance_id = key->vs_prolog.num_input_sgprs + 3;
+       ctx->param_vertex_id = key->vs_prolog.num_input_sgprs;
+       ctx->param_instance_id = key->vs_prolog.num_input_sgprs + 3;
 
        /* 4 preloaded VGPRs + vertex load indices as prolog outputs */
        params = alloca((key->vs_prolog.num_input_sgprs + 4) *
@@ -7015,37 +7399,37 @@ static bool si_compile_vs_prolog(struct si_screen *sscreen,
        /* Declare input and output SGPRs. */
        num_params = 0;
        for (i = 0; i < key->vs_prolog.num_input_sgprs; i++) {
-               params[num_params++] = ctx.i32;
-               returns[num_returns++] = ctx.i32;
+               params[num_params++] = ctx->i32;
+               returns[num_returns++] = ctx->i32;
        }
        last_sgpr = num_params - 1;
 
        /* 4 preloaded VGPRs (outputs must be floats) */
        for (i = 0; i < 4; i++) {
-               params[num_params++] = ctx.i32;
-               returns[num_returns++] = ctx.f32;
+               params[num_params++] = ctx->i32;
+               returns[num_returns++] = ctx->f32;
        }
 
        /* Vertex load indices. */
        for (i = 0; i <= key->vs_prolog.last_input; i++)
-               returns[num_returns++] = ctx.f32;
+               returns[num_returns++] = ctx->f32;
 
        /* Create the function. */
-       si_create_function(&ctx, "vs_prolog", returns, num_returns, params,
+       si_create_function(ctx, "vs_prolog", returns, num_returns, params,
                           num_params, last_sgpr);
-       func = ctx.main_fn;
+       func = ctx->main_fn;
 
        /* Copy inputs to outputs. This should be no-op, as the registers match,
         * but it will prevent the compiler from overwriting them unintentionally.
         */
-       ret = ctx.return_value;
+       ret = ctx->return_value;
        for (i = 0; i < key->vs_prolog.num_input_sgprs; i++) {
                LLVMValueRef p = LLVMGetParam(func, i);
                ret = LLVMBuildInsertValue(gallivm->builder, ret, p, i, "");
        }
        for (i = num_params - 4; i < num_params; i++) {
                LLVMValueRef p = LLVMGetParam(func, i);
-               p = LLVMBuildBitCast(gallivm->builder, p, ctx.f32, "");
+               p = LLVMBuildBitCast(gallivm->builder, p, ctx->f32, "");
                ret = LLVMBuildInsertValue(gallivm->builder, ret, p, i, "");
        }
 
@@ -7056,25 +7440,46 @@ static bool si_compile_vs_prolog(struct si_screen *sscreen,
 
                if (divisor) {
                        /* InstanceID / Divisor + StartInstance */
-                       index = get_instance_index_for_fetch(&ctx,
+                       index = get_instance_index_for_fetch(ctx,
                                                             SI_SGPR_START_INSTANCE,
                                                             divisor);
                } else {
                        /* VertexID + BaseVertex */
                        index = LLVMBuildAdd(gallivm->builder,
-                                            LLVMGetParam(func, ctx.param_vertex_id),
+                                            LLVMGetParam(func, ctx->param_vertex_id),
                                             LLVMGetParam(func, SI_SGPR_BASE_VERTEX), "");
                }
 
-               index = LLVMBuildBitCast(gallivm->builder, index, ctx.f32, "");
+               index = LLVMBuildBitCast(gallivm->builder, index, ctx->f32, "");
                ret = LLVMBuildInsertValue(gallivm->builder, ret, index,
                                           num_params++, "");
        }
 
-       /* Compile. */
-       si_llvm_build_ret(&ctx, ret);
-       si_llvm_finalize_module(&ctx,
-               r600_extra_shader_checks(&sscreen->b, PIPE_SHADER_VERTEX));
+       si_llvm_build_ret(ctx, ret);
+}
+
+/**
+ * Create a vertex shader prolog.
+ */
+static bool si_compile_vs_prolog(struct si_screen *sscreen,
+                                LLVMTargetMachineRef tm,
+                                struct pipe_debug_callback *debug,
+                                struct si_shader_part *out)
+{
+       union si_shader_part_key *key = &out->key;
+       struct si_shader shader = {};
+       struct si_shader_context ctx;
+       struct gallivm_state *gallivm = &ctx.gallivm;
+       bool status = true;
+
+       si_init_shader_ctx(&ctx, sscreen, &shader, tm);
+       ctx.type = PIPE_SHADER_VERTEX;
+
+       si_build_vs_prolog_function(&ctx, key);
+
+       /* Compile. */
+       si_llvm_finalize_module(&ctx,
+               r600_extra_shader_checks(&sscreen->b, PIPE_SHADER_VERTEX));
 
        if (si_compile_llvm(sscreen, &out->binary, &out->config, tm,
                            gallivm->module, debug, ctx.type,
@@ -7086,7 +7491,7 @@ static bool si_compile_vs_prolog(struct si_screen *sscreen,
 }
 
 /**
- * Compile the vertex shader epilog. This is also used by the tessellation
+ * Build the vertex shader epilog function. This is also used by the tessellation
  * evaluation shader compiled as VS.
  *
  * The input is PrimitiveID.
@@ -7094,21 +7499,13 @@ static bool si_compile_vs_prolog(struct si_screen *sscreen,
  * If PrimitiveID is required by the pixel shader, export it.
  * Otherwise, do nothing.
  */
-static bool si_compile_vs_epilog(struct si_screen *sscreen,
-                                LLVMTargetMachineRef tm,
-                                struct pipe_debug_callback *debug,
-                                struct si_shader_part *out)
+static void si_build_vs_epilog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key)
 {
-       union si_shader_part_key *key = &out->key;
-       struct si_shader_context ctx;
-       struct gallivm_state *gallivm = &ctx.gallivm;
-       struct lp_build_tgsi_context *bld_base = &ctx.soa.bld_base;
+       struct gallivm_state *gallivm = &ctx->gallivm;
+       struct lp_build_tgsi_context *bld_base = &ctx->soa.bld_base;
        LLVMTypeRef params[5];
        int num_params, i;
-       bool status = true;
-
-       si_init_shader_ctx(&ctx, sscreen, NULL, tm);
-       ctx.type = PIPE_SHADER_VERTEX;
 
        /* Declare input VGPRs. */
        num_params = key->vs_epilog.states.export_prim_id ?
@@ -7116,10 +7513,10 @@ static bool si_compile_vs_epilog(struct si_screen *sscreen,
        assert(num_params <= ARRAY_SIZE(params));
 
        for (i = 0; i < num_params; i++)
-               params[i] = ctx.f32;
+               params[i] = ctx->f32;
 
        /* Create the function. */
-       si_create_function(&ctx, "vs_epilog", NULL, 0, params, num_params, -1);
+       si_create_function(ctx, "vs_epilog", NULL, 0, params, num_params, -1);
 
        /* Emit exports. */
        if (key->vs_epilog.states.export_prim_id) {
@@ -7133,7 +7530,7 @@ static bool si_compile_vs_epilog(struct si_screen *sscreen,
                args[3] = lp_build_const_int32(base->gallivm, V_008DFC_SQ_EXP_PARAM +
                                               key->vs_epilog.prim_id_param_offset);
                args[4] = uint->zero; /* COMPR flag (0 = 32-bit export) */
-               args[5] = LLVMGetParam(ctx.main_fn,
+               args[5] = LLVMGetParam(ctx->main_fn,
                                       VS_EPILOG_PRIMID_LOC); /* X */
                args[6] = base->undef; /* Y */
                args[7] = base->undef; /* Z */
@@ -7144,8 +7541,29 @@ static bool si_compile_vs_epilog(struct si_screen *sscreen,
                                   args, 9, 0);
        }
 
-       /* Compile. */
        LLVMBuildRetVoid(gallivm->builder);
+}
+
+/**
+ * Compile the vertex shader epilog. This is also used by the tessellation
+ * evaluation shader compiled as VS.
+ */
+static bool si_compile_vs_epilog(struct si_screen *sscreen,
+                                LLVMTargetMachineRef tm,
+                                struct pipe_debug_callback *debug,
+                                struct si_shader_part *out)
+{
+       union si_shader_part_key *key = &out->key;
+       struct si_shader_context ctx;
+       struct gallivm_state *gallivm = &ctx.gallivm;
+       bool status = true;
+
+       si_init_shader_ctx(&ctx, sscreen, NULL, tm);
+       ctx.type = PIPE_SHADER_VERTEX;
+
+       si_build_vs_epilog_function(&ctx, key);
+
+       /* Compile. */
        si_llvm_finalize_module(&ctx,
                r600_extra_shader_checks(&sscreen->b, PIPE_SHADER_VERTEX));
 
@@ -7169,18 +7587,7 @@ static bool si_get_vs_epilog(struct si_screen *sscreen,
 {
        union si_shader_part_key epilog_key;
 
-       memset(&epilog_key, 0, sizeof(epilog_key));
-       epilog_key.vs_epilog.states = *states;
-
-       /* Set up the PrimitiveID output. */
-       if (shader->key.vs.epilog.export_prim_id) {
-               unsigned index = shader->selector->info.num_outputs;
-               unsigned offset = shader->info.nr_param_exports++;
-
-               epilog_key.vs_epilog.prim_id_param_offset = offset;
-               assert(index < ARRAY_SIZE(shader->info.vs_output_param_offset));
-               shader->info.vs_output_param_offset[index] = offset;
-       }
+       si_get_vs_epilog_key(shader, states, &epilog_key);
 
        shader->epilog = si_get_shader_part(sscreen, &sscreen->vs_epilogs,
                                            &epilog_key, tm, debug,
@@ -7198,13 +7605,9 @@ static bool si_shader_select_vs_parts(struct si_screen *sscreen,
 {
        struct tgsi_shader_info *info = &shader->selector->info;
        union si_shader_part_key prolog_key;
-       unsigned i;
 
        /* Get the prolog. */
-       memset(&prolog_key, 0, sizeof(prolog_key));
-       prolog_key.vs_prolog.states = shader->key.vs.prolog;
-       prolog_key.vs_prolog.num_input_sgprs = shader->info.num_input_sgprs;
-       prolog_key.vs_prolog.last_input = MAX2(1, info->num_inputs) - 1;
+       si_get_vs_prolog_key(shader, &prolog_key);
 
        /* The prolog is a no-op if there are no inputs. */
        if (info->num_inputs) {
@@ -7222,11 +7625,6 @@ static bool si_shader_select_vs_parts(struct si_screen *sscreen,
                              &shader->key.vs.epilog))
                return false;
 
-       /* Set the instanceID flag. */
-       for (i = 0; i < info->num_inputs; i++)
-               if (prolog_key.vs_prolog.states.instance_divisors[i])
-                       shader->info.uses_instanceid = true;
-
        return true;
 }
 
@@ -7246,6 +7644,51 @@ static bool si_shader_select_tes_parts(struct si_screen *sscreen,
                                &shader->key.tes.epilog);
 }
 
+/**
+ * Compile the TCS epilog function. This writes tesselation factors to memory
+ * based on the output primitive type of the tesselator (determined by TES).
+ */
+static void si_build_tcs_epilog_function(struct si_shader_context *ctx,
+                                        union si_shader_part_key *key)
+{
+       struct gallivm_state *gallivm = &ctx->gallivm;
+       struct lp_build_tgsi_context *bld_base = &ctx->soa.bld_base;
+       LLVMTypeRef params[16];
+       LLVMValueRef func;
+       int last_sgpr, num_params;
+
+       /* Declare inputs. Only RW_BUFFERS and TESS_FACTOR_OFFSET are used. */
+       params[SI_PARAM_RW_BUFFERS] = const_array(ctx->v16i8, SI_NUM_RW_BUFFERS);
+       params[SI_PARAM_CONST_BUFFERS] = ctx->i64;
+       params[SI_PARAM_SAMPLERS] = ctx->i64;
+       params[SI_PARAM_IMAGES] = ctx->i64;
+       params[SI_PARAM_SHADER_BUFFERS] = ctx->i64;
+       params[SI_PARAM_TCS_OFFCHIP_LAYOUT] = ctx->i32;
+       params[SI_PARAM_TCS_OUT_OFFSETS] = ctx->i32;
+       params[SI_PARAM_TCS_OUT_LAYOUT] = ctx->i32;
+       params[SI_PARAM_TCS_IN_LAYOUT] = ctx->i32;
+       params[ctx->param_oc_lds = SI_PARAM_TCS_OC_LDS] = ctx->i32;
+       params[SI_PARAM_TESS_FACTOR_OFFSET] = ctx->i32;
+       last_sgpr = SI_PARAM_TESS_FACTOR_OFFSET;
+       num_params = last_sgpr + 1;
+
+       params[num_params++] = ctx->i32; /* patch index within the wave (REL_PATCH_ID) */
+       params[num_params++] = ctx->i32; /* invocation ID within the patch */
+       params[num_params++] = ctx->i32; /* LDS offset where tess factors should be loaded from */
+
+       /* Create the function. */
+       si_create_function(ctx, "tcs_epilog", NULL, 0, params, num_params, last_sgpr);
+       declare_tess_lds(ctx);
+       func = ctx->main_fn;
+
+       si_write_tess_factors(bld_base,
+                             LLVMGetParam(func, last_sgpr + 1),
+                             LLVMGetParam(func, last_sgpr + 2),
+                             LLVMGetParam(func, last_sgpr + 3));
+
+       LLVMBuildRetVoid(gallivm->builder);
+}
+
 /**
  * Compile the TCS epilog. This writes tesselation factors to memory based on
  * the output primitive type of the tesselator (determined by TES).
@@ -7259,47 +7702,15 @@ static bool si_compile_tcs_epilog(struct si_screen *sscreen,
        struct si_shader shader = {};
        struct si_shader_context ctx;
        struct gallivm_state *gallivm = &ctx.gallivm;
-       struct lp_build_tgsi_context *bld_base = &ctx.soa.bld_base;
-       LLVMTypeRef params[16];
-       LLVMValueRef func;
-       int last_sgpr, num_params;
        bool status = true;
 
        si_init_shader_ctx(&ctx, sscreen, &shader, tm);
        ctx.type = PIPE_SHADER_TESS_CTRL;
        shader.key.tcs.epilog = key->tcs_epilog.states;
 
-       /* Declare inputs. Only RW_BUFFERS and TESS_FACTOR_OFFSET are used. */
-       params[SI_PARAM_RW_BUFFERS] = const_array(ctx.v16i8, SI_NUM_RW_BUFFERS);
-       params[SI_PARAM_CONST_BUFFERS] = ctx.i64;
-       params[SI_PARAM_SAMPLERS] = ctx.i64;
-       params[SI_PARAM_IMAGES] = ctx.i64;
-       params[SI_PARAM_SHADER_BUFFERS] = ctx.i64;
-       params[SI_PARAM_TCS_OFFCHIP_LAYOUT] = ctx.i32;
-       params[SI_PARAM_TCS_OUT_OFFSETS] = ctx.i32;
-       params[SI_PARAM_TCS_OUT_LAYOUT] = ctx.i32;
-       params[SI_PARAM_TCS_IN_LAYOUT] = ctx.i32;
-       params[ctx.param_oc_lds = SI_PARAM_TCS_OC_LDS] = ctx.i32;
-       params[SI_PARAM_TESS_FACTOR_OFFSET] = ctx.i32;
-       last_sgpr = SI_PARAM_TESS_FACTOR_OFFSET;
-       num_params = last_sgpr + 1;
-
-       params[num_params++] = ctx.i32; /* patch index within the wave (REL_PATCH_ID) */
-       params[num_params++] = ctx.i32; /* invocation ID within the patch */
-       params[num_params++] = ctx.i32; /* LDS offset where tess factors should be loaded from */
-
-       /* Create the function. */
-       si_create_function(&ctx, "tcs_epilog", NULL, 0, params, num_params, last_sgpr);
-       declare_tess_lds(&ctx);
-       func = ctx.main_fn;
-
-       si_write_tess_factors(bld_base,
-                             LLVMGetParam(func, last_sgpr + 1),
-                             LLVMGetParam(func, last_sgpr + 2),
-                             LLVMGetParam(func, last_sgpr + 3));
+       si_build_tcs_epilog_function(&ctx, key);
 
        /* Compile. */
-       LLVMBuildRetVoid(gallivm->builder);
        si_llvm_finalize_module(&ctx,
                r600_extra_shader_checks(&sscreen->b, PIPE_SHADER_TESS_CTRL));
 
@@ -7333,7 +7744,7 @@ static bool si_shader_select_tcs_parts(struct si_screen *sscreen,
 }
 
 /**
- * Compile the pixel shader prolog. This handles:
+ * Build the pixel shader prolog function. This handles:
  * - two-side color selection and interpolation
  * - overriding interpolation parameters for the API PS
  * - polygon stippling
@@ -7342,23 +7753,15 @@ static bool si_shader_select_tcs_parts(struct si_screen *sscreen,
  * overriden by other states. (e.g. per-sample interpolation)
  * Interpolated colors are stored after the preloaded VGPRs.
  */
-static bool si_compile_ps_prolog(struct si_screen *sscreen,
-                                LLVMTargetMachineRef tm,
-                                struct pipe_debug_callback *debug,
-                                struct si_shader_part *out)
+static void si_build_ps_prolog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key)
 {
-       union si_shader_part_key *key = &out->key;
-       struct si_shader shader = {};
-       struct si_shader_context ctx;
-       struct gallivm_state *gallivm = &ctx.gallivm;
+       struct gallivm_state *gallivm = &ctx->gallivm;
        LLVMTypeRef *params;
        LLVMValueRef ret, func;
        int last_sgpr, num_params, num_returns, i, num_color_channels;
-       bool status = true;
 
-       si_init_shader_ctx(&ctx, sscreen, &shader, tm);
-       ctx.type = PIPE_SHADER_FRAGMENT;
-       shader.key.ps.prolog = key->ps_prolog.states;
+       assert(si_need_ps_prolog(key));
 
        /* Number of inputs + 8 color elements. */
        params = alloca((key->ps_prolog.num_input_sgprs +
@@ -7368,27 +7771,27 @@ static bool si_compile_ps_prolog(struct si_screen *sscreen,
        /* Declare inputs. */
        num_params = 0;
        for (i = 0; i < key->ps_prolog.num_input_sgprs; i++)
-               params[num_params++] = ctx.i32;
+               params[num_params++] = ctx->i32;
        last_sgpr = num_params - 1;
 
        for (i = 0; i < key->ps_prolog.num_input_vgprs; i++)
-               params[num_params++] = ctx.f32;
+               params[num_params++] = ctx->f32;
 
        /* Declare outputs (same as inputs + add colors if needed) */
        num_returns = num_params;
        num_color_channels = util_bitcount(key->ps_prolog.colors_read);
        for (i = 0; i < num_color_channels; i++)
-               params[num_returns++] = ctx.f32;
+               params[num_returns++] = ctx->f32;
 
        /* Create the function. */
-       si_create_function(&ctx, "ps_prolog", params, num_returns, params,
+       si_create_function(ctx, "ps_prolog", params, num_returns, params,
                           num_params, last_sgpr);
-       func = ctx.main_fn;
+       func = ctx->main_fn;
 
        /* Copy inputs to outputs. This should be no-op, as the registers match,
         * but it will prevent the compiler from overwriting them unintentionally.
         */
-       ret = ctx.return_value;
+       ret = ctx->return_value;
        for (i = 0; i < num_params; i++) {
                LLVMValueRef p = LLVMGetParam(func, i);
                ret = LLVMBuildInsertValue(gallivm->builder, ret, p, i, "");
@@ -7405,11 +7808,11 @@ static bool si_compile_ps_prolog(struct si_screen *sscreen,
                ptr[0] = LLVMGetParam(func, SI_SGPR_RW_BUFFERS);
                ptr[1] = LLVMGetParam(func, SI_SGPR_RW_BUFFERS_HI);
                list = lp_build_gather_values(gallivm, ptr, 2);
-               list = LLVMBuildBitCast(gallivm->builder, list, ctx.i64, "");
+               list = LLVMBuildBitCast(gallivm->builder, list, ctx->i64, "");
                list = LLVMBuildIntToPtr(gallivm->builder, list,
-                                         const_array(ctx.v16i8, SI_NUM_RW_BUFFERS), "");
+                                         const_array(ctx->v16i8, SI_NUM_RW_BUFFERS), "");
 
-               si_llvm_emit_polygon_stipple(&ctx, list, pos);
+               si_llvm_emit_polygon_stipple(ctx, list, pos);
        }
 
        if (key->ps_prolog.states.bc_optimize_for_persp ||
@@ -7425,9 +7828,9 @@ static bool si_compile_ps_prolog(struct si_screen *sscreen,
                 */
                bc_optimize = LLVMGetParam(func, SI_PS_NUM_USER_SGPR);
                bc_optimize = LLVMBuildLShr(gallivm->builder, bc_optimize,
-                                           LLVMConstInt(ctx.i32, 31, 0), "");
+                                           LLVMConstInt(ctx->i32, 31, 0), "");
                bc_optimize = LLVMBuildTrunc(gallivm->builder, bc_optimize,
-                                            ctx.i1, "");
+                                            ctx->i1, "");
 
                if (key->ps_prolog.states.bc_optimize_for_persp) {
                        /* Read PERSP_CENTER. */
@@ -7552,7 +7955,7 @@ static bool si_compile_ps_prolog(struct si_screen *sscreen,
                                                          interp_vgpr + 1, "");
                        interp_ij = lp_build_gather_values(gallivm, interp, 2);
                        interp_ij = LLVMBuildBitCast(gallivm->builder, interp_ij,
-                                                    ctx.v2i32, "");
+                                                    ctx->v2i32, "");
                }
 
                /* Use the absolute location of the input. */
@@ -7560,10 +7963,10 @@ static bool si_compile_ps_prolog(struct si_screen *sscreen,
 
                if (key->ps_prolog.states.color_two_side) {
                        face = LLVMGetParam(func, face_vgpr);
-                       face = LLVMBuildBitCast(gallivm->builder, face, ctx.i32, "");
+                       face = LLVMBuildBitCast(gallivm->builder, face, ctx->i32, "");
                }
 
-               interp_fs_input(&ctx,
+               interp_fs_input(ctx,
                                key->ps_prolog.color_attr_index[i],
                                TGSI_SEMANTIC_COLOR, i,
                                key->ps_prolog.num_interp_inputs,
@@ -7583,8 +7986,30 @@ static bool si_compile_ps_prolog(struct si_screen *sscreen,
                                                   "amdgpu-ps-wqm-outputs", "");
        }
 
+       si_llvm_build_ret(ctx, ret);
+}
+
+/**
+ * Compile the pixel shader prolog.
+ */
+static bool si_compile_ps_prolog(struct si_screen *sscreen,
+                                LLVMTargetMachineRef tm,
+                                struct pipe_debug_callback *debug,
+                                struct si_shader_part *out)
+{
+       union si_shader_part_key *key = &out->key;
+       struct si_shader shader = {};
+       struct si_shader_context ctx;
+       struct gallivm_state *gallivm = &ctx.gallivm;
+       bool status = true;
+
+       si_init_shader_ctx(&ctx, sscreen, &shader, tm);
+       ctx.type = PIPE_SHADER_FRAGMENT;
+       shader.key.ps.prolog = key->ps_prolog.states;
+
+       si_build_ps_prolog_function(&ctx, key);
+
        /* Compile. */
-       si_llvm_build_ret(&ctx, ret);
        si_llvm_finalize_module(&ctx,
                r600_extra_shader_checks(&sscreen->b, PIPE_SHADER_FRAGMENT));
 
@@ -7598,36 +8023,26 @@ static bool si_compile_ps_prolog(struct si_screen *sscreen,
 }
 
 /**
- * Compile the pixel shader epilog. This handles everything that must be
+ * Build the pixel shader epilog function. This handles everything that must be
  * emulated for pixel shader exports. (alpha-test, format conversions, etc)
  */
-static bool si_compile_ps_epilog(struct si_screen *sscreen,
-                                LLVMTargetMachineRef tm,
-                                struct pipe_debug_callback *debug,
-                                struct si_shader_part *out)
+static void si_build_ps_epilog_function(struct si_shader_context *ctx,
+                                       union si_shader_part_key *key)
 {
-       union si_shader_part_key *key = &out->key;
-       struct si_shader shader = {};
-       struct si_shader_context ctx;
-       struct gallivm_state *gallivm = &ctx.gallivm;
-       struct lp_build_tgsi_context *bld_base = &ctx.soa.bld_base;
+       struct gallivm_state *gallivm = &ctx->gallivm;
+       struct lp_build_tgsi_context *bld_base = &ctx->soa.bld_base;
        LLVMTypeRef params[16+8*4+3];
        LLVMValueRef depth = NULL, stencil = NULL, samplemask = NULL;
        int last_sgpr, num_params, i;
-       bool status = true;
        struct si_ps_exports exp = {};
 
-       si_init_shader_ctx(&ctx, sscreen, &shader, tm);
-       ctx.type = PIPE_SHADER_FRAGMENT;
-       shader.key.ps.epilog = key->ps_epilog.states;
-
        /* Declare input SGPRs. */
-       params[SI_PARAM_RW_BUFFERS] = ctx.i64;
-       params[SI_PARAM_CONST_BUFFERS] = ctx.i64;
-       params[SI_PARAM_SAMPLERS] = ctx.i64;
-       params[SI_PARAM_IMAGES] = ctx.i64;
-       params[SI_PARAM_SHADER_BUFFERS] = ctx.i64;
-       params[SI_PARAM_ALPHA_REF] = ctx.f32;
+       params[SI_PARAM_RW_BUFFERS] = ctx->i64;
+       params[SI_PARAM_CONST_BUFFERS] = ctx->i64;
+       params[SI_PARAM_SAMPLERS] = ctx->i64;
+       params[SI_PARAM_IMAGES] = ctx->i64;
+       params[SI_PARAM_SHADER_BUFFERS] = ctx->i64;
+       params[SI_PARAM_ALPHA_REF] = ctx->f32;
        last_sgpr = SI_PARAM_ALPHA_REF;
 
        /* Declare input VGPRs. */
@@ -7643,12 +8058,12 @@ static bool si_compile_ps_epilog(struct si_screen *sscreen,
        assert(num_params <= ARRAY_SIZE(params));
 
        for (i = last_sgpr + 1; i < num_params; i++)
-               params[i] = ctx.f32;
+               params[i] = ctx->f32;
 
        /* Create the function. */
-       si_create_function(&ctx, "ps_epilog", NULL, 0, params, num_params, last_sgpr);
+       si_create_function(ctx, "ps_epilog", NULL, 0, params, num_params, last_sgpr);
        /* Disable elimination of unused inputs. */
-       si_llvm_add_attribute(ctx.main_fn,
+       si_llvm_add_attribute(ctx->main_fn,
                                  "InitialPSInputAddr", 0xffffff);
 
        /* Process colors. */
@@ -7681,7 +8096,7 @@ static bool si_compile_ps_epilog(struct si_screen *sscreen,
                int mrt = u_bit_scan(&colors_written);
 
                for (i = 0; i < 4; i++)
-                       color[i] = LLVMGetParam(ctx.main_fn, vgpr++);
+                       color[i] = LLVMGetParam(ctx->main_fn, vgpr++);
 
                si_export_mrt_color(bld_base, color, mrt,
                                    num_params - 1,
@@ -7690,11 +8105,11 @@ static bool si_compile_ps_epilog(struct si_screen *sscreen,
 
        /* Process depth, stencil, samplemask. */
        if (key->ps_epilog.writes_z)
-               depth = LLVMGetParam(ctx.main_fn, vgpr++);
+               depth = LLVMGetParam(ctx->main_fn, vgpr++);
        if (key->ps_epilog.writes_stencil)
-               stencil = LLVMGetParam(ctx.main_fn, vgpr++);
+               stencil = LLVMGetParam(ctx->main_fn, vgpr++);
        if (key->ps_epilog.writes_samplemask)
-               samplemask = LLVMGetParam(ctx.main_fn, vgpr++);
+               samplemask = LLVMGetParam(ctx->main_fn, vgpr++);
 
        if (depth || stencil || samplemask)
                si_export_mrt_z(bld_base, depth, stencil, samplemask, &exp);
@@ -7702,10 +8117,34 @@ static bool si_compile_ps_epilog(struct si_screen *sscreen,
                si_export_null(bld_base);
 
        if (exp.num)
-               si_emit_ps_exports(&ctx, &exp);
+               si_emit_ps_exports(ctx, &exp);
 
        /* Compile. */
        LLVMBuildRetVoid(gallivm->builder);
+}
+
+
+/**
+ * Compile the pixel shader epilog to a binary for concatenation.
+ */
+static bool si_compile_ps_epilog(struct si_screen *sscreen,
+                                LLVMTargetMachineRef tm,
+                                struct pipe_debug_callback *debug,
+                                struct si_shader_part *out)
+{
+       union si_shader_part_key *key = &out->key;
+       struct si_shader shader = {};
+       struct si_shader_context ctx;
+       struct gallivm_state *gallivm = &ctx.gallivm;
+       bool status = true;
+
+       si_init_shader_ctx(&ctx, sscreen, &shader, tm);
+       ctx.type = PIPE_SHADER_FRAGMENT;
+       shader.key.ps.epilog = key->ps_epilog.states;
+
+       si_build_ps_epilog_function(&ctx, key);
+
+       /* Compile. */
        si_llvm_finalize_module(&ctx,
                r600_extra_shader_checks(&sscreen->b, PIPE_SHADER_FRAGMENT));
 
@@ -7726,123 +8165,14 @@ static bool si_shader_select_ps_parts(struct si_screen *sscreen,
                                      struct si_shader *shader,
                                      struct pipe_debug_callback *debug)
 {
-       struct tgsi_shader_info *info = &shader->selector->info;
        union si_shader_part_key prolog_key;
        union si_shader_part_key epilog_key;
-       unsigned i;
 
        /* Get the prolog. */
-       memset(&prolog_key, 0, sizeof(prolog_key));
-       prolog_key.ps_prolog.states = shader->key.ps.prolog;
-       prolog_key.ps_prolog.colors_read = info->colors_read;
-       prolog_key.ps_prolog.num_input_sgprs = shader->info.num_input_sgprs;
-       prolog_key.ps_prolog.num_input_vgprs = shader->info.num_input_vgprs;
-       prolog_key.ps_prolog.wqm = info->uses_derivatives &&
-               (prolog_key.ps_prolog.colors_read ||
-                prolog_key.ps_prolog.states.force_persp_sample_interp ||
-                prolog_key.ps_prolog.states.force_linear_sample_interp ||
-                prolog_key.ps_prolog.states.force_persp_center_interp ||
-                prolog_key.ps_prolog.states.force_linear_center_interp ||
-                prolog_key.ps_prolog.states.bc_optimize_for_persp ||
-                prolog_key.ps_prolog.states.bc_optimize_for_linear);
-
-       if (info->colors_read) {
-               unsigned *color = shader->selector->color_attr_index;
-
-               if (shader->key.ps.prolog.color_two_side) {
-                       /* BCOLORs are stored after the last input. */
-                       prolog_key.ps_prolog.num_interp_inputs = info->num_inputs;
-                       prolog_key.ps_prolog.face_vgpr_index = shader->info.face_vgpr_index;
-                       shader->config.spi_ps_input_ena |= S_0286CC_FRONT_FACE_ENA(1);
-               }
-
-               for (i = 0; i < 2; i++) {
-                       unsigned interp = info->input_interpolate[color[i]];
-                       unsigned location = info->input_interpolate_loc[color[i]];
-
-                       if (!(info->colors_read & (0xf << i*4)))
-                               continue;
-
-                       prolog_key.ps_prolog.color_attr_index[i] = color[i];
-
-                       if (shader->key.ps.prolog.flatshade_colors &&
-                           interp == TGSI_INTERPOLATE_COLOR)
-                               interp = TGSI_INTERPOLATE_CONSTANT;
-
-                       switch (interp) {
-                       case TGSI_INTERPOLATE_CONSTANT:
-                               prolog_key.ps_prolog.color_interp_vgpr_index[i] = -1;
-                               break;
-                       case TGSI_INTERPOLATE_PERSPECTIVE:
-                       case TGSI_INTERPOLATE_COLOR:
-                               /* Force the interpolation location for colors here. */
-                               if (shader->key.ps.prolog.force_persp_sample_interp)
-                                       location = TGSI_INTERPOLATE_LOC_SAMPLE;
-                               if (shader->key.ps.prolog.force_persp_center_interp)
-                                       location = TGSI_INTERPOLATE_LOC_CENTER;
-
-                               switch (location) {
-                               case TGSI_INTERPOLATE_LOC_SAMPLE:
-                                       prolog_key.ps_prolog.color_interp_vgpr_index[i] = 0;
-                                       shader->config.spi_ps_input_ena |=
-                                               S_0286CC_PERSP_SAMPLE_ENA(1);
-                                       break;
-                               case TGSI_INTERPOLATE_LOC_CENTER:
-                                       prolog_key.ps_prolog.color_interp_vgpr_index[i] = 2;
-                                       shader->config.spi_ps_input_ena |=
-                                               S_0286CC_PERSP_CENTER_ENA(1);
-                                       break;
-                               case TGSI_INTERPOLATE_LOC_CENTROID:
-                                       prolog_key.ps_prolog.color_interp_vgpr_index[i] = 4;
-                                       shader->config.spi_ps_input_ena |=
-                                               S_0286CC_PERSP_CENTROID_ENA(1);
-                                       break;
-                               default:
-                                       assert(0);
-                               }
-                               break;
-                       case TGSI_INTERPOLATE_LINEAR:
-                               /* Force the interpolation location for colors here. */
-                               if (shader->key.ps.prolog.force_linear_sample_interp)
-                                       location = TGSI_INTERPOLATE_LOC_SAMPLE;
-                               if (shader->key.ps.prolog.force_linear_center_interp)
-                                       location = TGSI_INTERPOLATE_LOC_CENTER;
-
-                               switch (location) {
-                               case TGSI_INTERPOLATE_LOC_SAMPLE:
-                                       prolog_key.ps_prolog.color_interp_vgpr_index[i] = 6;
-                                       shader->config.spi_ps_input_ena |=
-                                               S_0286CC_LINEAR_SAMPLE_ENA(1);
-                                       break;
-                               case TGSI_INTERPOLATE_LOC_CENTER:
-                                       prolog_key.ps_prolog.color_interp_vgpr_index[i] = 8;
-                                       shader->config.spi_ps_input_ena |=
-                                               S_0286CC_LINEAR_CENTER_ENA(1);
-                                       break;
-                               case TGSI_INTERPOLATE_LOC_CENTROID:
-                                       prolog_key.ps_prolog.color_interp_vgpr_index[i] = 10;
-                                       shader->config.spi_ps_input_ena |=
-                                               S_0286CC_LINEAR_CENTROID_ENA(1);
-                                       break;
-                               default:
-                                       assert(0);
-                               }
-                               break;
-                       default:
-                               assert(0);
-                       }
-               }
-       }
+       si_get_ps_prolog_key(shader, &prolog_key, true);
 
        /* The prolog is a no-op if these aren't set. */
-       if (prolog_key.ps_prolog.colors_read ||
-           prolog_key.ps_prolog.states.force_persp_sample_interp ||
-           prolog_key.ps_prolog.states.force_linear_sample_interp ||
-           prolog_key.ps_prolog.states.force_persp_center_interp ||
-           prolog_key.ps_prolog.states.force_linear_center_interp ||
-           prolog_key.ps_prolog.states.bc_optimize_for_persp ||
-           prolog_key.ps_prolog.states.bc_optimize_for_linear ||
-           prolog_key.ps_prolog.states.poly_stipple) {
+       if (si_need_ps_prolog(&prolog_key)) {
                shader->prolog =
                        si_get_shader_part(sscreen, &sscreen->ps_prologs,
                                           &prolog_key, tm, debug,
@@ -7852,12 +8182,7 @@ static bool si_shader_select_ps_parts(struct si_screen *sscreen,
        }
 
        /* Get the epilog. */
-       memset(&epilog_key, 0, sizeof(epilog_key));
-       epilog_key.ps_epilog.colors_written = info->colors_written;
-       epilog_key.ps_epilog.writes_z = info->writes_z;
-       epilog_key.ps_epilog.writes_stencil = info->writes_stencil;
-       epilog_key.ps_epilog.writes_samplemask = info->writes_samplemask;
-       epilog_key.ps_epilog.states = shader->key.ps.epilog;
+       si_get_ps_epilog_key(shader, &epilog_key);
 
        shader->epilog =
                si_get_shader_part(sscreen, &sscreen->ps_epilogs,