ac/nir: use more types from ac_llvm_context
authorMarek Olšák <marek.olsak@amd.com>
Mon, 4 May 2020 13:27:49 +0000 (09:27 -0400)
committerMarek Olšák <marek.olsak@amd.com>
Tue, 2 Jun 2020 20:29:25 +0000 (16:29 -0400)
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5003>

src/amd/llvm/ac_llvm_build.c
src/amd/llvm/ac_llvm_build.h
src/amd/llvm/ac_nir_to_llvm.c
src/gallium/drivers/radeonsi/si_shader_llvm_ps.c

index ef4c95e5673a02c1dfc87f13ede5960091c79afe..bafc884362ca2426b5b2ea2160d7600771258bbe 100644 (file)
@@ -89,6 +89,9 @@ ac_llvm_context_init(struct ac_llvm_context *ctx,
        ctx->f32 = LLVMFloatTypeInContext(ctx->context);
        ctx->f64 = LLVMDoubleTypeInContext(ctx->context);
        ctx->v2i16 = LLVMVectorType(ctx->i16, 2);
+       ctx->v4i16 = LLVMVectorType(ctx->i16, 4);
+       ctx->v2f16 = LLVMVectorType(ctx->f16, 2);
+       ctx->v4f16 = LLVMVectorType(ctx->f16, 4);
        ctx->v2i32 = LLVMVectorType(ctx->i32, 2);
        ctx->v3i32 = LLVMVectorType(ctx->i32, 3);
        ctx->v4i32 = LLVMVectorType(ctx->i32, 4);
@@ -2249,13 +2252,10 @@ void ac_build_export(struct ac_llvm_context *ctx, struct ac_export_args *a)
        args[1] = LLVMConstInt(ctx->i32, a->enabled_channels, 0);
 
        if (a->compr) {
-               LLVMTypeRef i16 = LLVMInt16TypeInContext(ctx->context);
-               LLVMTypeRef v2i16 = LLVMVectorType(i16, 2);
-
                args[2] = LLVMBuildBitCast(ctx->builder, a->out[0],
-                               v2i16, "");
+                               ctx->v2i16, "");
                args[3] = LLVMBuildBitCast(ctx->builder, a->out[1],
-                               v2i16, "");
+                               ctx->v2i16, "");
                args[4] = LLVMConstInt(ctx->i1, a->done, 0);
                args[5] = LLVMConstInt(ctx->i1, a->valid_mask, 0);
 
@@ -2540,10 +2540,7 @@ LLVMValueRef ac_build_image_get_sample_count(struct ac_llvm_context *ctx,
 LLVMValueRef ac_build_cvt_pkrtz_f16(struct ac_llvm_context *ctx,
                                    LLVMValueRef args[2])
 {
-       LLVMTypeRef v2f16 =
-               LLVMVectorType(LLVMHalfTypeInContext(ctx->context), 2);
-
-       return ac_build_intrinsic(ctx, "llvm.amdgcn.cvt.pkrtz", v2f16,
+       return ac_build_intrinsic(ctx, "llvm.amdgcn.cvt.pkrtz", ctx->v2f16,
                                  args, 2, AC_FUNC_ATTR_READNONE);
 }
 
@@ -3712,9 +3709,7 @@ ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask)
                                          (LLVMValueRef []) { mask, ctx->i32_0 },
                                          2, AC_FUNC_ATTR_READNONE);
        }
-       LLVMValueRef mask_vec = LLVMBuildBitCast(ctx->builder, mask,
-                                                LLVMVectorType(ctx->i32, 2),
-                                                "");
+       LLVMValueRef mask_vec = LLVMBuildBitCast(ctx->builder, mask, ctx->v2i32, "");
        LLVMValueRef mask_lo = LLVMBuildExtractElement(ctx->builder, mask_vec,
                                                       ctx->i32_0, "");
        LLVMValueRef mask_hi = LLVMBuildExtractElement(ctx->builder, mask_vec,
index a4777f12f0783e0d13cfdb5d39e8e9f371cc6e96..edaedab4d97e708c691790577065afeadd7e1901 100644 (file)
@@ -79,6 +79,9 @@ struct ac_llvm_context {
        LLVMTypeRef f32;
        LLVMTypeRef f64;
        LLVMTypeRef v2i16;
+       LLVMTypeRef v4i16;
+       LLVMTypeRef v2f16;
+       LLVMTypeRef v4f16;
        LLVMTypeRef v2i32;
        LLVMTypeRef v3i32;
        LLVMTypeRef v4i32;
index 959998c0c5dbf007e01f4f76e827103d3401ab85..f150dca22cd62ba5f872ecaa1fe0604742ec416f 100644 (file)
@@ -1574,13 +1574,13 @@ static LLVMValueRef visit_load_push_constant(struct ac_nir_context *ctx,
 
        if (instr->dest.ssa.bit_size == 8) {
                unsigned load_dwords = instr->dest.ssa.num_components > 1 ? 2 : 1;
-               LLVMTypeRef vec_type = LLVMVectorType(LLVMInt8TypeInContext(ctx->ac.context), 4 * load_dwords);
+               LLVMTypeRef vec_type = LLVMVectorType(ctx->ac.i8, 4 * load_dwords);
                ptr = ac_cast_ptr(&ctx->ac, ptr, vec_type);
                LLVMValueRef res = LLVMBuildLoad(ctx->ac.builder, ptr, "");
 
                LLVMValueRef params[3];
                if (load_dwords > 1) {
-                       LLVMValueRef res_vec = LLVMBuildBitCast(ctx->ac.builder, res, LLVMVectorType(ctx->ac.i32, 2), "");
+                       LLVMValueRef res_vec = LLVMBuildBitCast(ctx->ac.builder, res, ctx->ac.v2i32, "");
                        params[0] = LLVMBuildExtractElement(ctx->ac.builder, res_vec, LLVMConstInt(ctx->ac.i32, 1, false), "");
                        params[1] = LLVMBuildExtractElement(ctx->ac.builder, res_vec, LLVMConstInt(ctx->ac.i32, 0, false), "");
                } else {
@@ -1593,11 +1593,11 @@ static LLVMValueRef visit_load_push_constant(struct ac_nir_context *ctx,
 
                res = LLVMBuildTrunc(ctx->ac.builder, res, LLVMIntTypeInContext(ctx->ac.context, instr->dest.ssa.num_components * 8), "");
                if (instr->dest.ssa.num_components > 1)
-                       res = LLVMBuildBitCast(ctx->ac.builder, res, LLVMVectorType(LLVMInt8TypeInContext(ctx->ac.context), instr->dest.ssa.num_components), "");
+                       res = LLVMBuildBitCast(ctx->ac.builder, res, LLVMVectorType(ctx->ac.i8, instr->dest.ssa.num_components), "");
                return res;
        } else if (instr->dest.ssa.bit_size == 16) {
                unsigned load_dwords = instr->dest.ssa.num_components / 2 + 1;
-               LLVMTypeRef vec_type = LLVMVectorType(LLVMInt16TypeInContext(ctx->ac.context), 2 * load_dwords);
+               LLVMTypeRef vec_type = LLVMVectorType(ctx->ac.i16, 2 * load_dwords);
                ptr = ac_cast_ptr(&ctx->ac, ptr, vec_type);
                LLVMValueRef res = LLVMBuildLoad(ctx->ac.builder, ptr, "");
                res = LLVMBuildBitCast(ctx->ac.builder, res, vec_type, "");
index 3ff696d294b38112e784860d6b8d4ec18bcc780e..d04b28d64cb893ce9a078644b3a3df6dc962cd62 100644 (file)
@@ -165,7 +165,7 @@ static void interp_fs_color(struct si_shader_context *ctx, unsigned input_index,
 
    if (interp) {
       interp_param =
-         LLVMBuildBitCast(ctx->ac.builder, interp_param, LLVMVectorType(ctx->ac.f32, 2), "");
+         LLVMBuildBitCast(ctx->ac.builder, interp_param, ctx->ac.v2f32, "");
 
       i = LLVMBuildExtractElement(ctx->ac.builder, interp_param, ctx->ac.i32_0, "");
       j = LLVMBuildExtractElement(ctx->ac.builder, interp_param, ctx->ac.i32_1, "");