radeonsi: move more LLVM functions into si_shader_llvm.c
authorMarek Olšák <marek.olsak@amd.com>
Wed, 15 Jan 2020 23:41:06 +0000 (18:41 -0500)
committerMarge Bot <eric+marge@anholt.net>
Thu, 23 Jan 2020 19:10:21 +0000 (19:10 +0000)
Reviewed-by: Timothy Arceri <tarceri@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/merge_requests/3421>

src/gallium/drivers/radeonsi/si_shader.c
src/gallium/drivers/radeonsi/si_shader_internal.h
src/gallium/drivers/radeonsi/si_shader_llvm.c

index 0fc616b4865baa81c0de8330b4628713f88c0da0..b3739533a9c096b5b780af079ecc4a21f8e49378 100644 (file)
@@ -44,7 +44,7 @@ static const char scratch_rsrc_dword1_symbol[] =
 static void si_dump_shader_key(const struct si_shader *shader, FILE *f);
 
 /** Whether the shader runs as a combination of multiple API shaders */
-static bool is_multi_part_shader(struct si_shader_context *ctx)
+bool si_is_multi_part_shader(struct si_shader_context *ctx)
 {
        if (ctx->screen->info.chip_class <= GFX8)
                return false;
@@ -58,7 +58,7 @@ static bool is_multi_part_shader(struct si_shader_context *ctx)
 /** Whether the shader runs on a merged HW stage (LSHS or ESGS) */
 bool si_is_merged_shader(struct si_shader_context *ctx)
 {
-       return ctx->shader->key.as_ngg || is_multi_part_shader(ctx);
+       return ctx->shader->key.as_ngg || si_is_multi_part_shader(ctx);
 }
 
 /**
@@ -145,105 +145,6 @@ unsigned si_shader_io_get_unique_index(unsigned semantic_name, unsigned index,
        }
 }
 
-/**
- * Get the value of a shader input parameter and extract a bitfield.
- */
-static LLVMValueRef unpack_llvm_param(struct si_shader_context *ctx,
-                                     LLVMValueRef value, unsigned rshift,
-                                     unsigned bitwidth)
-{
-       if (LLVMGetTypeKind(LLVMTypeOf(value)) == LLVMFloatTypeKind)
-               value = ac_to_integer(&ctx->ac, value);
-
-       if (rshift)
-               value = LLVMBuildLShr(ctx->ac.builder, value,
-                                     LLVMConstInt(ctx->ac.i32, rshift, 0), "");
-
-       if (rshift + bitwidth < 32) {
-               unsigned mask = (1 << bitwidth) - 1;
-               value = LLVMBuildAnd(ctx->ac.builder, value,
-                                    LLVMConstInt(ctx->ac.i32, mask, 0), "");
-       }
-
-       return value;
-}
-
-LLVMValueRef si_unpack_param(struct si_shader_context *ctx,
-                            struct ac_arg param, unsigned rshift,
-                            unsigned bitwidth)
-{
-       LLVMValueRef value = ac_get_arg(&ctx->ac, param);
-
-       return unpack_llvm_param(ctx, value, rshift, bitwidth);
-}
-
-LLVMValueRef si_get_primitive_id(struct si_shader_context *ctx,
-                                unsigned swizzle)
-{
-       if (swizzle > 0)
-               return ctx->ac.i32_0;
-
-       switch (ctx->type) {
-       case PIPE_SHADER_VERTEX:
-               return ac_get_arg(&ctx->ac, ctx->vs_prim_id);
-       case PIPE_SHADER_TESS_CTRL:
-               return ac_get_arg(&ctx->ac, ctx->args.tcs_patch_id);
-       case PIPE_SHADER_TESS_EVAL:
-               return ac_get_arg(&ctx->ac, ctx->args.tes_patch_id);
-       case PIPE_SHADER_GEOMETRY:
-               return ac_get_arg(&ctx->ac, ctx->args.gs_prim_id);
-       default:
-               assert(0);
-               return ctx->ac.i32_0;
-       }
-}
-
-static LLVMValueRef get_block_size(struct ac_shader_abi *abi)
-{
-       struct si_shader_context *ctx = si_shader_context_from_abi(abi);
-
-       LLVMValueRef values[3];
-       LLVMValueRef result;
-       unsigned i;
-       unsigned *properties = ctx->shader->selector->info.properties;
-
-       if (properties[TGSI_PROPERTY_CS_FIXED_BLOCK_WIDTH] != 0) {
-               unsigned sizes[3] = {
-                       properties[TGSI_PROPERTY_CS_FIXED_BLOCK_WIDTH],
-                       properties[TGSI_PROPERTY_CS_FIXED_BLOCK_HEIGHT],
-                       properties[TGSI_PROPERTY_CS_FIXED_BLOCK_DEPTH]
-               };
-
-               for (i = 0; i < 3; ++i)
-                       values[i] = LLVMConstInt(ctx->ac.i32, sizes[i], 0);
-
-               result = ac_build_gather_values(&ctx->ac, values, 3);
-       } else {
-               result = ac_get_arg(&ctx->ac, ctx->block_size);
-       }
-
-       return result;
-}
-
-void si_declare_compute_memory(struct si_shader_context *ctx)
-{
-       struct si_shader_selector *sel = ctx->shader->selector;
-       unsigned lds_size = sel->info.properties[TGSI_PROPERTY_CS_LOCAL_SIZE];
-
-       LLVMTypeRef i8p = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS);
-       LLVMValueRef var;
-
-       assert(!ctx->ac.lds);
-
-       var = LLVMAddGlobalInAddressSpace(ctx->ac.module,
-                                         LLVMArrayType(ctx->ac.i8, lds_size),
-                                         "compute_lds",
-                                         AC_ADDR_SPACE_LDS);
-       LLVMSetAlignment(var, 64 * 1024);
-
-       ctx->ac.lds = LLVMBuildBitCast(ctx->ac.builder, var, i8p, "");
-}
-
 static void si_dump_streamout(struct pipe_stream_output_info *so)
 {
        unsigned i;
@@ -291,7 +192,7 @@ static void declare_streamout_params(struct si_shader_context *ctx,
        }
 }
 
-static unsigned si_get_max_workgroup_size(const struct si_shader *shader)
+unsigned si_get_max_workgroup_size(const struct si_shader *shader)
 {
        switch (shader->selector->type) {
        case PIPE_SHADER_VERTEX:
@@ -1531,7 +1432,7 @@ static bool si_build_main_function(struct si_shader_context *ctx,
                si_llvm_init_ps_callbacks(ctx);
                break;
        case PIPE_SHADER_COMPUTE:
-               ctx->abi.load_local_group_size = get_block_size;
+               ctx->abi.load_local_group_size = si_llvm_get_block_size;
                break;
        default:
                assert(!"Unsupported shader type");
@@ -1776,287 +1677,6 @@ static void si_get_vs_prolog_key(const struct si_shader_info *info,
                shader_out->info.uses_instanceid = true;
 }
 
-/**
- * Given a list of shader part functions, build a wrapper function that
- * runs them in sequence to form a monolithic shader.
- */
-void si_build_wrapper_function(struct si_shader_context *ctx, LLVMValueRef *parts,
-                              unsigned num_parts, unsigned main_part,
-                              unsigned next_shader_first_part)
-{
-       LLVMBuilderRef builder = ctx->ac.builder;
-       /* PS epilog has one arg per color component; gfx9 merged shader
-        * prologs need to forward 40 SGPRs.
-        */
-       LLVMValueRef initial[AC_MAX_ARGS], out[AC_MAX_ARGS];
-       LLVMTypeRef function_type;
-       unsigned num_first_params;
-       unsigned num_out, initial_num_out;
-       ASSERTED unsigned num_out_sgpr; /* used in debug checks */
-       ASSERTED unsigned initial_num_out_sgpr; /* used in debug checks */
-       unsigned num_sgprs, num_vgprs;
-       unsigned gprs;
-
-       memset(&ctx->args, 0, sizeof(ctx->args));
-
-       for (unsigned i = 0; i < num_parts; ++i) {
-               ac_add_function_attr(ctx->ac.context, parts[i], -1,
-                                    AC_FUNC_ATTR_ALWAYSINLINE);
-               LLVMSetLinkage(parts[i], LLVMPrivateLinkage);
-       }
-
-       /* The parameters of the wrapper function correspond to those of the
-        * first part in terms of SGPRs and VGPRs, but we use the types of the
-        * main part to get the right types. This is relevant for the
-        * dereferenceable attribute on descriptor table pointers.
-        */
-       num_sgprs = 0;
-       num_vgprs = 0;
-
-       function_type = LLVMGetElementType(LLVMTypeOf(parts[0]));
-       num_first_params = LLVMCountParamTypes(function_type);
-
-       for (unsigned i = 0; i < num_first_params; ++i) {
-               LLVMValueRef param = LLVMGetParam(parts[0], i);
-
-               if (ac_is_sgpr_param(param)) {
-                       assert(num_vgprs == 0);
-                       num_sgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
-               } else {
-                       num_vgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
-               }
-       }
-
-       gprs = 0;
-       while (gprs < num_sgprs + num_vgprs) {
-               LLVMValueRef param = LLVMGetParam(parts[main_part], ctx->args.arg_count);
-               LLVMTypeRef type = LLVMTypeOf(param);
-               unsigned size = ac_get_type_size(type) / 4;
-
-               /* This is going to get casted anyways, so we don't have to
-                * have the exact same type. But we do have to preserve the
-                * pointer-ness so that LLVM knows about it.
-                */
-               enum ac_arg_type arg_type = AC_ARG_INT;
-               if (LLVMGetTypeKind(type) == LLVMPointerTypeKind) {
-                       type = LLVMGetElementType(type);
-
-                       if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) {
-                               if (LLVMGetVectorSize(type) == 4)
-                                       arg_type = AC_ARG_CONST_DESC_PTR;
-                               else if (LLVMGetVectorSize(type) == 8)
-                                       arg_type = AC_ARG_CONST_IMAGE_PTR;
-                               else
-                                       assert(0);
-                       } else if (type == ctx->ac.f32) {
-                               arg_type = AC_ARG_CONST_FLOAT_PTR;
-                       } else {
-                               assert(0);
-                       }
-               }
-
-               ac_add_arg(&ctx->args, gprs < num_sgprs ? AC_ARG_SGPR : AC_ARG_VGPR,
-                          size, arg_type, NULL);
-
-               assert(ac_is_sgpr_param(param) == (gprs < num_sgprs));
-               assert(gprs + size <= num_sgprs + num_vgprs &&
-                      (gprs >= num_sgprs || gprs + size <= num_sgprs));
-
-               gprs += size;
-       }
-
-       /* Prepare the return type. */
-       unsigned num_returns = 0;
-       LLVMTypeRef returns[AC_MAX_ARGS], last_func_type, return_type;
-
-       last_func_type = LLVMGetElementType(LLVMTypeOf(parts[num_parts - 1]));
-       return_type = LLVMGetReturnType(last_func_type);
-
-       switch (LLVMGetTypeKind(return_type)) {
-       case LLVMStructTypeKind:
-               num_returns = LLVMCountStructElementTypes(return_type);
-               assert(num_returns <= ARRAY_SIZE(returns));
-               LLVMGetStructElementTypes(return_type, returns);
-               break;
-       case LLVMVoidTypeKind:
-               break;
-       default:
-               unreachable("unexpected type");
-       }
-
-       si_llvm_create_func(ctx, "wrapper", returns, num_returns,
-                           si_get_max_workgroup_size(ctx->shader));
-
-       if (si_is_merged_shader(ctx))
-               ac_init_exec_full_mask(&ctx->ac);
-
-       /* Record the arguments of the function as if they were an output of
-        * a previous part.
-        */
-       num_out = 0;
-       num_out_sgpr = 0;
-
-       for (unsigned i = 0; i < ctx->args.arg_count; ++i) {
-               LLVMValueRef param = LLVMGetParam(ctx->main_fn, i);
-               LLVMTypeRef param_type = LLVMTypeOf(param);
-               LLVMTypeRef out_type = ctx->args.args[i].file == AC_ARG_SGPR ? ctx->ac.i32 : ctx->ac.f32;
-               unsigned size = ac_get_type_size(param_type) / 4;
-
-               if (size == 1) {
-                       if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
-                               param = LLVMBuildPtrToInt(builder, param, ctx->ac.i32, "");
-                               param_type = ctx->ac.i32;
-                       }
-
-                       if (param_type != out_type)
-                               param = LLVMBuildBitCast(builder, param, out_type, "");
-                       out[num_out++] = param;
-               } else {
-                       LLVMTypeRef vector_type = LLVMVectorType(out_type, size);
-
-                       if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
-                               param = LLVMBuildPtrToInt(builder, param, ctx->ac.i64, "");
-                               param_type = ctx->ac.i64;
-                       }
-
-                       if (param_type != vector_type)
-                               param = LLVMBuildBitCast(builder, param, vector_type, "");
-
-                       for (unsigned j = 0; j < size; ++j)
-                               out[num_out++] = LLVMBuildExtractElement(
-                                       builder, param, LLVMConstInt(ctx->ac.i32, j, 0), "");
-               }
-
-               if (ctx->args.args[i].file == AC_ARG_SGPR)
-                       num_out_sgpr = num_out;
-       }
-
-       memcpy(initial, out, sizeof(out));
-       initial_num_out = num_out;
-       initial_num_out_sgpr = num_out_sgpr;
-
-       /* Now chain the parts. */
-       LLVMValueRef ret = NULL;
-       for (unsigned part = 0; part < num_parts; ++part) {
-               LLVMValueRef in[AC_MAX_ARGS];
-               LLVMTypeRef ret_type;
-               unsigned out_idx = 0;
-               unsigned num_params = LLVMCountParams(parts[part]);
-
-               /* Merged shaders are executed conditionally depending
-                * on the number of enabled threads passed in the input SGPRs. */
-               if (is_multi_part_shader(ctx) && part == 0) {
-                       LLVMValueRef ena, count = initial[3];
-
-                       count = LLVMBuildAnd(builder, count,
-                                            LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
-                       ena = LLVMBuildICmp(builder, LLVMIntULT,
-                                           ac_get_thread_id(&ctx->ac), count, "");
-                       ac_build_ifcc(&ctx->ac, ena, 6506);
-               }
-
-               /* Derive arguments for the next part from outputs of the
-                * previous one.
-                */
-               for (unsigned param_idx = 0; param_idx < num_params; ++param_idx) {
-                       LLVMValueRef param;
-                       LLVMTypeRef param_type;
-                       bool is_sgpr;
-                       unsigned param_size;
-                       LLVMValueRef arg = NULL;
-
-                       param = LLVMGetParam(parts[part], param_idx);
-                       param_type = LLVMTypeOf(param);
-                       param_size = ac_get_type_size(param_type) / 4;
-                       is_sgpr = ac_is_sgpr_param(param);
-
-                       if (is_sgpr) {
-                               ac_add_function_attr(ctx->ac.context, parts[part],
-                                                    param_idx + 1, AC_FUNC_ATTR_INREG);
-                       } else if (out_idx < num_out_sgpr) {
-                               /* Skip returned SGPRs the current part doesn't
-                                * declare on the input. */
-                               out_idx = num_out_sgpr;
-                       }
-
-                       assert(out_idx + param_size <= (is_sgpr ? num_out_sgpr : num_out));
-
-                       if (param_size == 1)
-                               arg = out[out_idx];
-                       else
-                               arg = ac_build_gather_values(&ctx->ac, &out[out_idx], param_size);
-
-                       if (LLVMTypeOf(arg) != param_type) {
-                               if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
-                                       if (LLVMGetPointerAddressSpace(param_type) ==
-                                           AC_ADDR_SPACE_CONST_32BIT) {
-                                               arg = LLVMBuildBitCast(builder, arg, ctx->ac.i32, "");
-                                               arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
-                                       } else {
-                                               arg = LLVMBuildBitCast(builder, arg, ctx->ac.i64, "");
-                                               arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
-                                       }
-                               } else {
-                                       arg = LLVMBuildBitCast(builder, arg, param_type, "");
-                               }
-                       }
-
-                       in[param_idx] = arg;
-                       out_idx += param_size;
-               }
-
-               ret = ac_build_call(&ctx->ac, parts[part], in, num_params);
-
-               if (is_multi_part_shader(ctx) &&
-                   part + 1 == next_shader_first_part) {
-                       ac_build_endif(&ctx->ac, 6506);
-
-                       /* The second half of the merged shader should use
-                        * the inputs from the toplevel (wrapper) function,
-                        * not the return value from the last call.
-                        *
-                        * That's because the last call was executed condi-
-                        * tionally, so we can't consume it in the main
-                        * block.
-                        */
-                       memcpy(out, initial, sizeof(initial));
-                       num_out = initial_num_out;
-                       num_out_sgpr = initial_num_out_sgpr;
-                       continue;
-               }
-
-               /* Extract the returned GPRs. */
-               ret_type = LLVMTypeOf(ret);
-               num_out = 0;
-               num_out_sgpr = 0;
-
-               if (LLVMGetTypeKind(ret_type) != LLVMVoidTypeKind) {
-                       assert(LLVMGetTypeKind(ret_type) == LLVMStructTypeKind);
-
-                       unsigned ret_size = LLVMCountStructElementTypes(ret_type);
-
-                       for (unsigned i = 0; i < ret_size; ++i) {
-                               LLVMValueRef val =
-                                       LLVMBuildExtractValue(builder, ret, i, "");
-
-                               assert(num_out < ARRAY_SIZE(out));
-                               out[num_out++] = val;
-
-                               if (LLVMTypeOf(val) == ctx->ac.i32) {
-                                       assert(num_out_sgpr + 1 == num_out);
-                                       num_out_sgpr = num_out;
-                               }
-                       }
-               }
-       }
-
-       /* Return the value from the last part. */
-       if (LLVMGetTypeKind(LLVMTypeOf(ret)) == LLVMVoidTypeKind)
-               LLVMBuildRetVoid(builder);
-       else
-               LLVMBuildRet(builder, ret);
-}
-
 static bool si_should_optimize_less(struct ac_llvm_compiler *compiler,
                                    struct si_shader_selector *sel)
 {
index d9b84ab13024a2caa0bed9cfcd8f0c19b42fe7a1..542466ee2056c5232022dd3fc584c5b094eb4e86 100644 (file)
@@ -196,23 +196,14 @@ si_shader_context_from_abi(struct ac_shader_abi *abi)
        return container_of(abi, ctx, abi);
 }
 
+bool si_is_multi_part_shader(struct si_shader_context *ctx);
 bool si_is_merged_shader(struct si_shader_context *ctx);
-void si_declare_compute_memory(struct si_shader_context *ctx);
-LLVMValueRef si_get_primitive_id(struct si_shader_context *ctx,
-                                unsigned swizzle);
 void si_add_arg_checked(struct ac_shader_args *args,
                        enum ac_arg_regfile file,
                        unsigned registers, enum ac_arg_type type,
                        struct ac_arg *arg,
                        unsigned idx);
-bool si_nir_build_llvm(struct si_shader_context *ctx, struct nir_shader *nir);
-
-LLVMValueRef si_unpack_param(struct si_shader_context *ctx,
-                            struct ac_arg param, unsigned rshift,
-                            unsigned bitwidth);
-void si_build_wrapper_function(struct si_shader_context *ctx, LLVMValueRef *parts,
-                              unsigned num_parts, unsigned main_part,
-                              unsigned next_shader_first_part);
+unsigned si_get_max_workgroup_size(const struct si_shader *shader);
 bool si_need_ps_prolog(const union si_shader_part_key *key);
 void si_get_ps_prolog_key(struct si_shader *shader,
                          union si_shader_part_key *key,
@@ -276,6 +267,17 @@ void si_llvm_emit_barrier(struct si_shader_context *ctx);
 void si_llvm_declare_esgs_ring(struct si_shader_context *ctx);
 void si_init_exec_from_input(struct si_shader_context *ctx, struct ac_arg param,
                             unsigned bitoffset);
+LLVMValueRef si_unpack_param(struct si_shader_context *ctx,
+                            struct ac_arg param, unsigned rshift,
+                            unsigned bitwidth);
+LLVMValueRef si_get_primitive_id(struct si_shader_context *ctx,
+                                unsigned swizzle);
+LLVMValueRef si_llvm_get_block_size(struct ac_shader_abi *abi);
+void si_llvm_declare_compute_memory(struct si_shader_context *ctx);
+bool si_nir_build_llvm(struct si_shader_context *ctx, struct nir_shader *nir);
+void si_build_wrapper_function(struct si_shader_context *ctx, LLVMValueRef *parts,
+                              unsigned num_parts, unsigned main_part,
+                              unsigned next_shader_first_part);
 
 /* si_shader_llvm_gs.c */
 LLVMValueRef si_is_es_thread(struct si_shader_context *ctx);
index d7336ea6d87eb5a2e89838c80fcb030356eab880..4ddcbccfac0a5767bc2042881c24e647eaeac286 100644 (file)
@@ -365,6 +365,105 @@ void si_init_exec_from_input(struct si_shader_context *ctx, struct ac_arg param,
                           ctx->ac.voidt, args, 2, AC_FUNC_ATTR_CONVERGENT);
 }
 
+/**
+ * Get the value of a shader input parameter and extract a bitfield.
+ */
+static LLVMValueRef unpack_llvm_param(struct si_shader_context *ctx,
+                                     LLVMValueRef value, unsigned rshift,
+                                     unsigned bitwidth)
+{
+       if (LLVMGetTypeKind(LLVMTypeOf(value)) == LLVMFloatTypeKind)
+               value = ac_to_integer(&ctx->ac, value);
+
+       if (rshift)
+               value = LLVMBuildLShr(ctx->ac.builder, value,
+                                     LLVMConstInt(ctx->ac.i32, rshift, 0), "");
+
+       if (rshift + bitwidth < 32) {
+               unsigned mask = (1 << bitwidth) - 1;
+               value = LLVMBuildAnd(ctx->ac.builder, value,
+                                    LLVMConstInt(ctx->ac.i32, mask, 0), "");
+       }
+
+       return value;
+}
+
+LLVMValueRef si_unpack_param(struct si_shader_context *ctx,
+                            struct ac_arg param, unsigned rshift,
+                            unsigned bitwidth)
+{
+       LLVMValueRef value = ac_get_arg(&ctx->ac, param);
+
+       return unpack_llvm_param(ctx, value, rshift, bitwidth);
+}
+
+LLVMValueRef si_get_primitive_id(struct si_shader_context *ctx,
+                                unsigned swizzle)
+{
+       if (swizzle > 0)
+               return ctx->ac.i32_0;
+
+       switch (ctx->type) {
+       case PIPE_SHADER_VERTEX:
+               return ac_get_arg(&ctx->ac, ctx->vs_prim_id);
+       case PIPE_SHADER_TESS_CTRL:
+               return ac_get_arg(&ctx->ac, ctx->args.tcs_patch_id);
+       case PIPE_SHADER_TESS_EVAL:
+               return ac_get_arg(&ctx->ac, ctx->args.tes_patch_id);
+       case PIPE_SHADER_GEOMETRY:
+               return ac_get_arg(&ctx->ac, ctx->args.gs_prim_id);
+       default:
+               assert(0);
+               return ctx->ac.i32_0;
+       }
+}
+
+LLVMValueRef si_llvm_get_block_size(struct ac_shader_abi *abi)
+{
+       struct si_shader_context *ctx = si_shader_context_from_abi(abi);
+
+       LLVMValueRef values[3];
+       LLVMValueRef result;
+       unsigned i;
+       unsigned *properties = ctx->shader->selector->info.properties;
+
+       if (properties[TGSI_PROPERTY_CS_FIXED_BLOCK_WIDTH] != 0) {
+               unsigned sizes[3] = {
+                       properties[TGSI_PROPERTY_CS_FIXED_BLOCK_WIDTH],
+                       properties[TGSI_PROPERTY_CS_FIXED_BLOCK_HEIGHT],
+                       properties[TGSI_PROPERTY_CS_FIXED_BLOCK_DEPTH]
+               };
+
+               for (i = 0; i < 3; ++i)
+                       values[i] = LLVMConstInt(ctx->ac.i32, sizes[i], 0);
+
+               result = ac_build_gather_values(&ctx->ac, values, 3);
+       } else {
+               result = ac_get_arg(&ctx->ac, ctx->block_size);
+       }
+
+       return result;
+}
+
+void si_llvm_declare_compute_memory(struct si_shader_context *ctx)
+{
+       struct si_shader_selector *sel = ctx->shader->selector;
+       unsigned lds_size = sel->info.properties[TGSI_PROPERTY_CS_LOCAL_SIZE];
+
+       LLVMTypeRef i8p = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS);
+       LLVMValueRef var;
+
+       assert(!ctx->ac.lds);
+
+       var = LLVMAddGlobalInAddressSpace(ctx->ac.module,
+                                         LLVMArrayType(ctx->ac.i8, lds_size),
+                                         "compute_lds",
+                                         AC_ADDR_SPACE_LDS);
+       LLVMSetAlignment(var, 64 * 1024);
+
+       ctx->ac.lds = LLVMBuildBitCast(ctx->ac.builder, var, i8p, "");
+}
+
 bool si_nir_build_llvm(struct si_shader_context *ctx, struct nir_shader *nir)
 {
        if (nir->info.stage == MESA_SHADER_VERTEX) {
@@ -417,9 +516,290 @@ bool si_nir_build_llvm(struct si_shader_context *ctx, struct nir_shader *nir)
 
        if (ctx->shader->selector->info.properties[TGSI_PROPERTY_CS_LOCAL_SIZE]) {
                assert(gl_shader_stage_is_compute(nir->info.stage));
-               si_declare_compute_memory(ctx);
+               si_llvm_declare_compute_memory(ctx);
        }
        ac_nir_translate(&ctx->ac, &ctx->abi, &ctx->args, nir);
 
        return true;
 }
+
+/**
+ * Given a list of shader part functions, build a wrapper function that
+ * runs them in sequence to form a monolithic shader.
+ */
+void si_build_wrapper_function(struct si_shader_context *ctx, LLVMValueRef *parts,
+                              unsigned num_parts, unsigned main_part,
+                              unsigned next_shader_first_part)
+{
+       LLVMBuilderRef builder = ctx->ac.builder;
+       /* PS epilog has one arg per color component; gfx9 merged shader
+        * prologs need to forward 40 SGPRs.
+        */
+       LLVMValueRef initial[AC_MAX_ARGS], out[AC_MAX_ARGS];
+       LLVMTypeRef function_type;
+       unsigned num_first_params;
+       unsigned num_out, initial_num_out;
+       ASSERTED unsigned num_out_sgpr; /* used in debug checks */
+       ASSERTED unsigned initial_num_out_sgpr; /* used in debug checks */
+       unsigned num_sgprs, num_vgprs;
+       unsigned gprs;
+
+       memset(&ctx->args, 0, sizeof(ctx->args));
+
+       for (unsigned i = 0; i < num_parts; ++i) {
+               ac_add_function_attr(ctx->ac.context, parts[i], -1,
+                                    AC_FUNC_ATTR_ALWAYSINLINE);
+               LLVMSetLinkage(parts[i], LLVMPrivateLinkage);
+       }
+
+       /* The parameters of the wrapper function correspond to those of the
+        * first part in terms of SGPRs and VGPRs, but we use the types of the
+        * main part to get the right types. This is relevant for the
+        * dereferenceable attribute on descriptor table pointers.
+        */
+       num_sgprs = 0;
+       num_vgprs = 0;
+
+       function_type = LLVMGetElementType(LLVMTypeOf(parts[0]));
+       num_first_params = LLVMCountParamTypes(function_type);
+
+       for (unsigned i = 0; i < num_first_params; ++i) {
+               LLVMValueRef param = LLVMGetParam(parts[0], i);
+
+               if (ac_is_sgpr_param(param)) {
+                       assert(num_vgprs == 0);
+                       num_sgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
+               } else {
+                       num_vgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
+               }
+       }
+
+       gprs = 0;
+       while (gprs < num_sgprs + num_vgprs) {
+               LLVMValueRef param = LLVMGetParam(parts[main_part], ctx->args.arg_count);
+               LLVMTypeRef type = LLVMTypeOf(param);
+               unsigned size = ac_get_type_size(type) / 4;
+
+               /* This is going to get casted anyways, so we don't have to
+                * have the exact same type. But we do have to preserve the
+                * pointer-ness so that LLVM knows about it.
+                */
+               enum ac_arg_type arg_type = AC_ARG_INT;
+               if (LLVMGetTypeKind(type) == LLVMPointerTypeKind) {
+                       type = LLVMGetElementType(type);
+
+                       if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) {
+                               if (LLVMGetVectorSize(type) == 4)
+                                       arg_type = AC_ARG_CONST_DESC_PTR;
+                               else if (LLVMGetVectorSize(type) == 8)
+                                       arg_type = AC_ARG_CONST_IMAGE_PTR;
+                               else
+                                       assert(0);
+                       } else if (type == ctx->ac.f32) {
+                               arg_type = AC_ARG_CONST_FLOAT_PTR;
+                       } else {
+                               assert(0);
+                       }
+               }
+
+               ac_add_arg(&ctx->args, gprs < num_sgprs ? AC_ARG_SGPR : AC_ARG_VGPR,
+                          size, arg_type, NULL);
+
+               assert(ac_is_sgpr_param(param) == (gprs < num_sgprs));
+               assert(gprs + size <= num_sgprs + num_vgprs &&
+                      (gprs >= num_sgprs || gprs + size <= num_sgprs));
+
+               gprs += size;
+       }
+
+       /* Prepare the return type. */
+       unsigned num_returns = 0;
+       LLVMTypeRef returns[AC_MAX_ARGS], last_func_type, return_type;
+
+       last_func_type = LLVMGetElementType(LLVMTypeOf(parts[num_parts - 1]));
+       return_type = LLVMGetReturnType(last_func_type);
+
+       switch (LLVMGetTypeKind(return_type)) {
+       case LLVMStructTypeKind:
+               num_returns = LLVMCountStructElementTypes(return_type);
+               assert(num_returns <= ARRAY_SIZE(returns));
+               LLVMGetStructElementTypes(return_type, returns);
+               break;
+       case LLVMVoidTypeKind:
+               break;
+       default:
+               unreachable("unexpected type");
+       }
+
+       si_llvm_create_func(ctx, "wrapper", returns, num_returns,
+                           si_get_max_workgroup_size(ctx->shader));
+
+       if (si_is_merged_shader(ctx))
+               ac_init_exec_full_mask(&ctx->ac);
+
+       /* Record the arguments of the function as if they were an output of
+        * a previous part.
+        */
+       num_out = 0;
+       num_out_sgpr = 0;
+
+       for (unsigned i = 0; i < ctx->args.arg_count; ++i) {
+               LLVMValueRef param = LLVMGetParam(ctx->main_fn, i);
+               LLVMTypeRef param_type = LLVMTypeOf(param);
+               LLVMTypeRef out_type = ctx->args.args[i].file == AC_ARG_SGPR ? ctx->ac.i32 : ctx->ac.f32;
+               unsigned size = ac_get_type_size(param_type) / 4;
+
+               if (size == 1) {
+                       if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
+                               param = LLVMBuildPtrToInt(builder, param, ctx->ac.i32, "");
+                               param_type = ctx->ac.i32;
+                       }
+
+                       if (param_type != out_type)
+                               param = LLVMBuildBitCast(builder, param, out_type, "");
+                       out[num_out++] = param;
+               } else {
+                       LLVMTypeRef vector_type = LLVMVectorType(out_type, size);
+
+                       if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
+                               param = LLVMBuildPtrToInt(builder, param, ctx->ac.i64, "");
+                               param_type = ctx->ac.i64;
+                       }
+
+                       if (param_type != vector_type)
+                               param = LLVMBuildBitCast(builder, param, vector_type, "");
+
+                       for (unsigned j = 0; j < size; ++j)
+                               out[num_out++] = LLVMBuildExtractElement(
+                                       builder, param, LLVMConstInt(ctx->ac.i32, j, 0), "");
+               }
+
+               if (ctx->args.args[i].file == AC_ARG_SGPR)
+                       num_out_sgpr = num_out;
+       }
+
+       memcpy(initial, out, sizeof(out));
+       initial_num_out = num_out;
+       initial_num_out_sgpr = num_out_sgpr;
+
+       /* Now chain the parts. */
+       LLVMValueRef ret = NULL;
+       for (unsigned part = 0; part < num_parts; ++part) {
+               LLVMValueRef in[AC_MAX_ARGS];
+               LLVMTypeRef ret_type;
+               unsigned out_idx = 0;
+               unsigned num_params = LLVMCountParams(parts[part]);
+
+               /* Merged shaders are executed conditionally depending
+                * on the number of enabled threads passed in the input SGPRs. */
+               if (si_is_multi_part_shader(ctx) && part == 0) {
+                       LLVMValueRef ena, count = initial[3];
+
+                       count = LLVMBuildAnd(builder, count,
+                                            LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
+                       ena = LLVMBuildICmp(builder, LLVMIntULT,
+                                           ac_get_thread_id(&ctx->ac), count, "");
+                       ac_build_ifcc(&ctx->ac, ena, 6506);
+               }
+
+               /* Derive arguments for the next part from outputs of the
+                * previous one.
+                */
+               for (unsigned param_idx = 0; param_idx < num_params; ++param_idx) {
+                       LLVMValueRef param;
+                       LLVMTypeRef param_type;
+                       bool is_sgpr;
+                       unsigned param_size;
+                       LLVMValueRef arg = NULL;
+
+                       param = LLVMGetParam(parts[part], param_idx);
+                       param_type = LLVMTypeOf(param);
+                       param_size = ac_get_type_size(param_type) / 4;
+                       is_sgpr = ac_is_sgpr_param(param);
+
+                       if (is_sgpr) {
+                               ac_add_function_attr(ctx->ac.context, parts[part],
+                                                    param_idx + 1, AC_FUNC_ATTR_INREG);
+                       } else if (out_idx < num_out_sgpr) {
+                               /* Skip returned SGPRs the current part doesn't
+                                * declare on the input. */
+                               out_idx = num_out_sgpr;
+                       }
+
+                       assert(out_idx + param_size <= (is_sgpr ? num_out_sgpr : num_out));
+
+                       if (param_size == 1)
+                               arg = out[out_idx];
+                       else
+                               arg = ac_build_gather_values(&ctx->ac, &out[out_idx], param_size);
+
+                       if (LLVMTypeOf(arg) != param_type) {
+                               if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
+                                       if (LLVMGetPointerAddressSpace(param_type) ==
+                                           AC_ADDR_SPACE_CONST_32BIT) {
+                                               arg = LLVMBuildBitCast(builder, arg, ctx->ac.i32, "");
+                                               arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
+                                       } else {
+                                               arg = LLVMBuildBitCast(builder, arg, ctx->ac.i64, "");
+                                               arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
+                                       }
+                               } else {
+                                       arg = LLVMBuildBitCast(builder, arg, param_type, "");
+                               }
+                       }
+
+                       in[param_idx] = arg;
+                       out_idx += param_size;
+               }
+
+               ret = ac_build_call(&ctx->ac, parts[part], in, num_params);
+
+               if (si_is_multi_part_shader(ctx) &&
+                   part + 1 == next_shader_first_part) {
+                       ac_build_endif(&ctx->ac, 6506);
+
+                       /* The second half of the merged shader should use
+                        * the inputs from the toplevel (wrapper) function,
+                        * not the return value from the last call.
+                        *
+                        * That's because the last call was executed condi-
+                        * tionally, so we can't consume it in the main
+                        * block.
+                        */
+                       memcpy(out, initial, sizeof(initial));
+                       num_out = initial_num_out;
+                       num_out_sgpr = initial_num_out_sgpr;
+                       continue;
+               }
+
+               /* Extract the returned GPRs. */
+               ret_type = LLVMTypeOf(ret);
+               num_out = 0;
+               num_out_sgpr = 0;
+
+               if (LLVMGetTypeKind(ret_type) != LLVMVoidTypeKind) {
+                       assert(LLVMGetTypeKind(ret_type) == LLVMStructTypeKind);
+
+                       unsigned ret_size = LLVMCountStructElementTypes(ret_type);
+
+                       for (unsigned i = 0; i < ret_size; ++i) {
+                               LLVMValueRef val =
+                                       LLVMBuildExtractValue(builder, ret, i, "");
+
+                               assert(num_out < ARRAY_SIZE(out));
+                               out[num_out++] = val;
+
+                               if (LLVMTypeOf(val) == ctx->ac.i32) {
+                                       assert(num_out_sgpr + 1 == num_out);
+                                       num_out_sgpr = num_out;
+                               }
+                       }
+               }
+       }
+
+       /* Return the value from the last part. */
+       if (LLVMGetTypeKind(LLVMTypeOf(ret)) == LLVMVoidTypeKind)
+               LLVMBuildRetVoid(builder);
+       else
+               LLVMBuildRet(builder, ret);
+}