amd/llvm: Refactor ac_build_scan.
[mesa.git] / src / amd / llvm / ac_llvm_build.c
index 641587051b8452a6da06c5f449bb7c3bfd7ec4bf..47c27893fe50d2af2d98cd9498eeaf16cdcbe4d0 100644 (file)
@@ -432,11 +432,19 @@ ac_build_optimization_barrier(struct ac_llvm_context *ctx,
        } else {
                LLVMTypeRef ftype = LLVMFunctionType(ctx->i32, &ctx->i32, 1, false);
                LLVMValueRef inlineasm = LLVMConstInlineAsm(ftype, code, "=v,0", true, false);
+               LLVMTypeRef type = LLVMTypeOf(*pvgpr);
+               unsigned bitsize = ac_get_elem_bits(ctx, type);
                LLVMValueRef vgpr = *pvgpr;
-               LLVMTypeRef vgpr_type = LLVMTypeOf(vgpr);
-               unsigned vgpr_size = ac_get_type_size(vgpr_type);
+               LLVMTypeRef vgpr_type;
+               unsigned vgpr_size;
                LLVMValueRef vgpr0;
 
+               if (bitsize < 32)
+                       vgpr = LLVMBuildZExt(ctx->builder, vgpr, ctx->i32, "");
+
+               vgpr_type = LLVMTypeOf(vgpr);
+               vgpr_size = ac_get_type_size(vgpr_type);
+
                assert(vgpr_size % 4 == 0);
 
                vgpr = LLVMBuildBitCast(builder, vgpr, LLVMVectorType(ctx->i32, vgpr_size / 4), "");
@@ -445,6 +453,9 @@ ac_build_optimization_barrier(struct ac_llvm_context *ctx,
                vgpr = LLVMBuildInsertElement(builder, vgpr, vgpr0, ctx->i32_0, "");
                vgpr = LLVMBuildBitCast(builder, vgpr, vgpr_type, "");
 
+               if (bitsize < 32)
+                       vgpr = LLVMBuildTrunc(builder, vgpr, type, "");
+
                *pvgpr = vgpr;
        }
 }
@@ -1226,8 +1237,7 @@ ac_build_buffer_store_dword(struct ac_llvm_context *ctx,
                            LLVMValueRef voffset,
                            LLVMValueRef soffset,
                            unsigned inst_offset,
-                           unsigned cache_policy,
-                           bool swizzle_enable_hint)
+                           unsigned cache_policy)
 {
        /* Split 3 channel stores, because only LLVM 9+ support 3-channel
         * intrinsics. */
@@ -1241,12 +1251,10 @@ ac_build_buffer_store_dword(struct ac_llvm_context *ctx,
                v01 = ac_build_gather_values(ctx, v, 2);
 
                ac_build_buffer_store_dword(ctx, rsrc, v01, 2, voffset,
-                                           soffset, inst_offset, cache_policy,
-                                           swizzle_enable_hint);
+                                           soffset, inst_offset, cache_policy);
                ac_build_buffer_store_dword(ctx, rsrc, v[2], 1, voffset,
                                            soffset, inst_offset + 8,
-                                           cache_policy,
-                                           swizzle_enable_hint);
+                                           cache_policy);
                return;
        }
 
@@ -1254,7 +1262,7 @@ ac_build_buffer_store_dword(struct ac_llvm_context *ctx,
         * (voffset is swizzled, but soffset isn't swizzled).
         * llvm.amdgcn.buffer.store doesn't have a separate soffset parameter.
         */
-       if (!swizzle_enable_hint) {
+       if (!(cache_policy & ac_swizzled)) {
                LLVMValueRef offset = soffset;
 
                if (inst_offset)
@@ -3760,6 +3768,11 @@ static LLVMValueRef
 _ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
                     bool exchange_rows, bool bound_ctrl)
 {
+       LLVMTypeRef type = LLVMTypeOf(src);
+       LLVMValueRef result;
+
+       src = LLVMBuildZExt(ctx->builder, src, ctx->i32, "");
+
        LLVMValueRef args[6] = {
                src,
                src,
@@ -3768,10 +3781,13 @@ _ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel
                ctx->i1true, /* fi */
                bound_ctrl ? ctx->i1true : ctx->i1false,
        };
-       return ac_build_intrinsic(ctx, exchange_rows ? "llvm.amdgcn.permlanex16"
-                                                    : "llvm.amdgcn.permlane16",
-                                 ctx->i32, args, 6,
-                                 AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
+
+       result = ac_build_intrinsic(ctx, exchange_rows ? "llvm.amdgcn.permlanex16"
+                                                      : "llvm.amdgcn.permlane16",
+                                   ctx->i32, args, 6,
+                                   AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
+
+       return LLVMBuildTrunc(ctx->builder, result, type, "");
 }
 
 static LLVMValueRef
@@ -3782,10 +3798,7 @@ ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
        src = ac_to_integer(ctx, src);
        unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src));
        LLVMValueRef ret;
-       if (bits == 32) {
-               ret = _ac_build_permlane16(ctx, src, sel, exchange_rows,
-                                          bound_ctrl);
-       } else {
+       if (bits > 32) {
                assert(bits % 32 == 0);
                LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32);
                LLVMValueRef src_vector =
@@ -3804,6 +3817,9 @@ ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
                                                     LLVMConstInt(ctx->i32, i,
                                                                  0), "");
                }
+       } else {
+               ret = _ac_build_permlane16(ctx, src, sel, exchange_rows,
+                                          bound_ctrl);
        }
        return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
 }
@@ -3864,12 +3880,27 @@ ac_build_ds_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src, unsigned mask
 static LLVMValueRef
 ac_build_wwm(struct ac_llvm_context *ctx, LLVMValueRef src)
 {
+       LLVMTypeRef src_type = LLVMTypeOf(src);
+       unsigned bitsize = ac_get_elem_bits(ctx, src_type);
        char name[32], type[8];
+       LLVMValueRef ret;
+
+       src = ac_to_integer(ctx, src);
+
+       if (bitsize < 32)
+               src = LLVMBuildZExt(ctx->builder, src, ctx->i32, "");
+
        ac_build_type_name_for_intr(LLVMTypeOf(src), type, sizeof(type));
        snprintf(name, sizeof(name), "llvm.amdgcn.wwm.%s", type);
-       return ac_build_intrinsic(ctx, name, LLVMTypeOf(src),
-                                 (LLVMValueRef []) { src }, 1,
-                                 AC_FUNC_ATTR_READNONE);
+       ret = ac_build_intrinsic(ctx, name, LLVMTypeOf(src),
+                                (LLVMValueRef []) { src }, 1,
+                                AC_FUNC_ATTR_READNONE);
+
+       if (bitsize < 32)
+               ret = LLVMBuildTrunc(ctx->builder, ret,
+                                    ac_to_integer_type(ctx, src_type), "");
+
+       return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
 }
 
 static LLVMValueRef
@@ -3979,6 +4010,7 @@ static LLVMValueRef
 ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs, nir_op op)
 {
        bool _64bit = ac_get_type_size(LLVMTypeOf(lhs)) == 8;
+       bool _32bit = ac_get_type_size(LLVMTypeOf(lhs)) == 4;
        switch (op) {
        case nir_op_iadd: return LLVMBuildAdd(ctx->builder, lhs, rhs, "");
        case nir_op_fadd: return LLVMBuildFAdd(ctx->builder, lhs, rhs, "");
@@ -3991,8 +4023,8 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
                                        LLVMBuildICmp(ctx->builder, LLVMIntULT, lhs, rhs, ""),
                                        lhs, rhs, "");
        case nir_op_fmin: return ac_build_intrinsic(ctx,
-                                       _64bit ? "llvm.minnum.f64" : "llvm.minnum.f32",
-                                       _64bit ? ctx->f64 : ctx->f32,
+                                       _64bit ? "llvm.minnum.f64" : _32bit ? "llvm.minnum.f32" : "llvm.minnum.f16",
+                                       _64bit ? ctx->f64 : _32bit ? ctx->f32 : ctx->f16,
                                        (LLVMValueRef[]){lhs, rhs}, 2, AC_FUNC_ATTR_READNONE);
        case nir_op_imax: return LLVMBuildSelect(ctx->builder,
                                        LLVMBuildICmp(ctx->builder, LLVMIntSGT, lhs, rhs, ""),
@@ -4001,8 +4033,8 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
                                        LLVMBuildICmp(ctx->builder, LLVMIntUGT, lhs, rhs, ""),
                                        lhs, rhs, "");
        case nir_op_fmax: return ac_build_intrinsic(ctx,
-                                       _64bit ? "llvm.maxnum.f64" : "llvm.maxnum.f32",
-                                       _64bit ? ctx->f64 : ctx->f32,
+                                       _64bit ? "llvm.maxnum.f64" : _32bit ? "llvm.maxnum.f32" : "llvm.maxnum.f16",
+                                       _64bit ? ctx->f64 : _32bit ? ctx->f32 : ctx->f16,
                                        (LLVMValueRef[]){lhs, rhs}, 2, AC_FUNC_ATTR_READNONE);
        case nir_op_iand: return LLVMBuildAnd(ctx->builder, lhs, rhs, "");
        case nir_op_ior: return LLVMBuildOr(ctx->builder, lhs, rhs, "");
@@ -4012,11 +4044,80 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
        }
 }
 
+/**
+ * \param src The value to shift.
+ * \param identity The value to use the first lane.
+ * \param maxprefix specifies that the result only needs to be correct for a
+ *     prefix of this many threads
+ * \return src, shifted 1 lane up, and identity shifted into lane 0.
+ */
+static LLVMValueRef
+ac_wavefront_shift_right_1(struct ac_llvm_context *ctx, LLVMValueRef src,
+                           LLVMValueRef identity, unsigned maxprefix)
+{
+       if (ctx->chip_class >= GFX10) {
+               /* wavefront shift_right by 1 on GFX10 (emulate dpp_wf_sr1) */
+               LLVMValueRef active, tmp1, tmp2;
+               LLVMValueRef tid = ac_get_thread_id(ctx);
+
+               tmp1 = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
+
+               tmp2 = ac_build_permlane16(ctx, src, (uint64_t)~0, true, false);
+
+               if (maxprefix > 32) {
+                       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid,
+                                              LLVMConstInt(ctx->i32, 32, false), "");
+
+                       tmp2 = LLVMBuildSelect(ctx->builder, active,
+                                              ac_build_readlane(ctx, src,
+                                                                LLVMConstInt(ctx->i32, 31, false)),
+                                              tmp2, "");
+
+                       active = LLVMBuildOr(ctx->builder, active,
+                                            LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                                                          LLVMBuildAnd(ctx->builder, tid,
+                                                                       LLVMConstInt(ctx->i32, 0x1f, false), ""),
+                                                          LLVMConstInt(ctx->i32, 0x10, false), ""), "");
+                       return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+               } else if (maxprefix > 16) {
+                       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid,
+                                              LLVMConstInt(ctx->i32, 16, false), "");
+
+                       return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+               }
+       } else if (ctx->chip_class >= GFX8) {
+               return ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
+       }
+
+       /* wavefront shift_right by 1 on SI/CI */
+       LLVMValueRef active, tmp1, tmp2;
+       LLVMValueRef tid = ac_get_thread_id(ctx);
+       tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x4, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x8, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x10, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
+       return LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
+}
+
 /**
  * \param maxprefix specifies that the result only needs to be correct for a
  *     prefix of this many threads
- *
- * TODO: add inclusive and excluse scan functions for GFX6.
  */
 static LLVMValueRef
 ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
@@ -4024,13 +4125,54 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
 {
        LLVMValueRef result, tmp;
 
-       if (ctx->chip_class >= GFX10) {
-               result = inclusive ? src : identity;
-       } else {
-               if (!inclusive)
-                       src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
-               result = src;
+       if (!inclusive)
+               src = ac_wavefront_shift_right_1(ctx, src, identity, maxprefix);
+
+       result = src;
+
+       if (ctx->chip_class <= GFX7) {
+               assert(maxprefix == 64);
+               LLVMValueRef tid = ac_get_thread_id(ctx);
+               LLVMValueRef active;
+               tmp = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x1e, 0x00, 0x00));
+               active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+                                      LLVMBuildAnd(ctx->builder, tid, ctx->i32_1, ""),
+                                      ctx->i32_0, "");
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+               result = ac_build_alu_op(ctx, result, tmp, op);
+               tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1c, 0x01, 0x00));
+               active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 2, 0), ""),
+                                      ctx->i32_0, "");
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+               result = ac_build_alu_op(ctx, result, tmp, op);
+               tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x18, 0x03, 0x00));
+               active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 4, 0), ""),
+                                      ctx->i32_0, "");
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+               result = ac_build_alu_op(ctx, result, tmp, op);
+               tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x10, 0x07, 0x00));
+               active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 8, 0), ""),
+                                      ctx->i32_0, "");
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+               result = ac_build_alu_op(ctx, result, tmp, op);
+               tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x00, 0x0f, 0x00));
+               active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 16, 0), ""),
+                                      ctx->i32_0, "");
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+               result = ac_build_alu_op(ctx, result, tmp, op);
+               tmp = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, 0));
+               active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 32, 0), ""),
+                                      ctx->i32_0, "");
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+               result = ac_build_alu_op(ctx, result, tmp, op);
+               return result;
        }
+
        if (maxprefix <= 1)
                return result;
        tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
@@ -4055,33 +4197,31 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
                return result;
 
        if (ctx->chip_class >= GFX10) {
-               /* dpp_row_bcast{15,31} are not supported on gfx10. */
-               LLVMBuilderRef builder = ctx->builder;
                LLVMValueRef tid = ac_get_thread_id(ctx);
-               LLVMValueRef cc;
-               /* TODO-GFX10: Can we get better code-gen by putting this into
-                * a branch so that LLVM generates EXEC mask manipulations? */
-               if (inclusive)
-                       tmp = result;
-               else
-                       tmp = ac_build_alu_op(ctx, result, src, op);
-               tmp = ac_build_permlane16(ctx, tmp, ~(uint64_t)0, true, false);
-               tmp = ac_build_alu_op(ctx, result, tmp, op);
-               cc = LLVMBuildAnd(builder, tid, LLVMConstInt(ctx->i32, 16, false), "");
-               cc = LLVMBuildICmp(builder, LLVMIntNE, cc, ctx->i32_0, "");
-               result = LLVMBuildSelect(builder, cc, tmp, result, "");
+               LLVMValueRef active;
+
+               tmp = ac_build_permlane16(ctx, result, ~(uint64_t)0, true, false);
+
+               active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
+                                      LLVMBuildAnd(ctx->builder, tid,
+                                                   LLVMConstInt(ctx->i32, 16, false), ""),
+                                      ctx->i32_0, "");
+
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+
+               result = ac_build_alu_op(ctx, result, tmp, op);
+
                if (maxprefix <= 32)
                        return result;
 
-               if (inclusive)
-                       tmp = result;
-               else
-                       tmp = ac_build_alu_op(ctx, result, src, op);
-               tmp = ac_build_readlane(ctx, tmp, LLVMConstInt(ctx->i32, 31, false));
-               tmp = ac_build_alu_op(ctx, result, tmp, op);
-               cc = LLVMBuildICmp(builder, LLVMIntUGE, tid,
-                                  LLVMConstInt(ctx->i32, 32, false), "");
-               result = LLVMBuildSelect(builder, cc, tmp, result, "");
+               tmp = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, false));
+
+               active = LLVMBuildICmp(ctx->builder, LLVMIntUGE, tid,
+                                      LLVMConstInt(ctx->i32, 32, false), "");
+
+               tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
+
+               result = ac_build_alu_op(ctx, result, tmp, op);
                return result;
        }
 
@@ -4460,7 +4600,7 @@ ac_build_canonicalize(struct ac_llvm_context *ctx, LLVMValueRef src0,
        } else if (bitsize == 32) {
                intr = "llvm.canonicalize.f32";
                type = ctx->f32;
-       } else if (bitsize == 64) {
+       } else {
                intr = "llvm.canonicalize.f64";
                type = ctx->f64;
        }
@@ -4579,3 +4719,82 @@ ac_export_mrt_z(struct ac_llvm_context *ctx, LLVMValueRef depth,
        args->enabled_channels = mask;
 }
 
+static LLVMTypeRef
+arg_llvm_type(enum ac_arg_type type, unsigned size, struct ac_llvm_context *ctx)
+{
+       if (type == AC_ARG_FLOAT) {
+               return size == 1 ? ctx->f32 : LLVMVectorType(ctx->f32, size);
+       } else if (type == AC_ARG_INT) {
+               return size == 1 ? ctx->i32 : LLVMVectorType(ctx->i32, size);
+       } else {
+               LLVMTypeRef ptr_type;
+               switch (type) {
+               case AC_ARG_CONST_PTR:
+                       ptr_type = ctx->i8;
+                       break;
+               case AC_ARG_CONST_FLOAT_PTR:
+                       ptr_type = ctx->f32;
+                       break;
+               case AC_ARG_CONST_PTR_PTR:
+                       ptr_type = ac_array_in_const32_addr_space(ctx->i8);
+                       break;
+               case AC_ARG_CONST_DESC_PTR:
+                       ptr_type = ctx->v4i32;
+                       break;
+               case AC_ARG_CONST_IMAGE_PTR:
+                       ptr_type = ctx->v8i32;
+                       break;
+               default:
+                       unreachable("unknown arg type");
+               }
+               if (size == 1) {
+                       return ac_array_in_const32_addr_space(ptr_type);
+               } else {
+                       assert(size == 2);
+                       return ac_array_in_const_addr_space(ptr_type);
+               }
+       }
+}
+
+LLVMValueRef
+ac_build_main(const struct ac_shader_args *args,
+             struct ac_llvm_context *ctx,
+             enum ac_llvm_calling_convention convention,
+             const char *name, LLVMTypeRef ret_type,
+             LLVMModuleRef module)
+{
+       LLVMTypeRef arg_types[AC_MAX_ARGS];
+
+       for (unsigned i = 0; i < args->arg_count; i++) {
+               arg_types[i] = arg_llvm_type(args->args[i].type,
+                                            args->args[i].size, ctx);
+       }
+
+       LLVMTypeRef main_function_type =
+               LLVMFunctionType(ret_type, arg_types, args->arg_count, 0);
+
+       LLVMValueRef main_function =
+           LLVMAddFunction(module, name, main_function_type);
+       LLVMBasicBlockRef main_function_body =
+           LLVMAppendBasicBlockInContext(ctx->context, main_function, "main_body");
+       LLVMPositionBuilderAtEnd(ctx->builder, main_function_body);
+
+       LLVMSetFunctionCallConv(main_function, convention);
+       for (unsigned i = 0; i < args->arg_count; ++i) {
+               LLVMValueRef P = LLVMGetParam(main_function, i);
+
+               if (args->args[i].file != AC_ARG_SGPR)
+                       continue;
+
+               ac_add_function_attr(ctx->context, main_function, i + 1, AC_FUNC_ATTR_INREG);
+
+               if (LLVMGetTypeKind(LLVMTypeOf(P)) == LLVMPointerTypeKind) {
+                       ac_add_function_attr(ctx->context, main_function, i + 1, AC_FUNC_ATTR_NOALIAS);
+                       ac_add_attr_dereferenceable(P, UINT64_MAX);
+               }
+       }
+
+       ctx->main_function = main_function;
+       return main_function;
+}
+