ac/nir: remove type and num_channels args from ac_build_buffer_store_common
[mesa.git] / src / amd / llvm / ac_llvm_build.c
index ebcb91bd4b3374c02d1bb60e19773d15c4822cdd..5483b3146c01c284eb7153f4cb67e5f1720a319b 100644 (file)
@@ -89,6 +89,9 @@ ac_llvm_context_init(struct ac_llvm_context *ctx,
        ctx->f32 = LLVMFloatTypeInContext(ctx->context);
        ctx->f64 = LLVMDoubleTypeInContext(ctx->context);
        ctx->v2i16 = LLVMVectorType(ctx->i16, 2);
+       ctx->v4i16 = LLVMVectorType(ctx->i16, 4);
+       ctx->v2f16 = LLVMVectorType(ctx->f16, 2);
+       ctx->v4f16 = LLVMVectorType(ctx->f16, 4);
        ctx->v2i32 = LLVMVectorType(ctx->i32, 2);
        ctx->v3i32 = LLVMVectorType(ctx->i32, 3);
        ctx->v4i32 = LLVMVectorType(ctx->i32, 4);
@@ -144,7 +147,7 @@ int
 ac_get_llvm_num_components(LLVMValueRef value)
 {
        LLVMTypeRef type = LLVMTypeOf(value);
-       unsigned num_components = LLVMGetTypeKind(type) == LLVMFixedVectorTypeKind
+       unsigned num_components = LLVMGetTypeKind(type) == LLVMVectorTypeKind
                                      ? LLVMGetVectorSize(type)
                                      : 1;
        return num_components;
@@ -155,7 +158,7 @@ ac_llvm_extract_elem(struct ac_llvm_context *ac,
                     LLVMValueRef value,
                     int index)
 {
-       if (LLVMGetTypeKind(LLVMTypeOf(value)) != LLVMFixedVectorTypeKind) {
+       if (LLVMGetTypeKind(LLVMTypeOf(value)) != LLVMVectorTypeKind) {
                assert(index == 0);
                return value;
        }
@@ -167,7 +170,7 @@ ac_llvm_extract_elem(struct ac_llvm_context *ac,
 int
 ac_get_elem_bits(struct ac_llvm_context *ctx, LLVMTypeRef type)
 {
-       if (LLVMGetTypeKind(type) == LLVMFixedVectorTypeKind)
+       if (LLVMGetTypeKind(type) == LLVMVectorTypeKind)
                type = LLVMGetElementType(type);
 
        if (LLVMGetTypeKind(type) == LLVMIntegerTypeKind)
@@ -206,7 +209,7 @@ ac_get_type_size(LLVMTypeRef type)
                if (LLVMGetPointerAddressSpace(type) == AC_ADDR_SPACE_CONST_32BIT)
                        return 4;
                return 8;
-       case LLVMFixedVectorTypeKind:
+       case LLVMVectorTypeKind:
                return LLVMGetVectorSize(type) *
                       ac_get_type_size(LLVMGetElementType(type));
        case LLVMArrayTypeKind:
@@ -235,7 +238,7 @@ static LLVMTypeRef to_integer_type_scalar(struct ac_llvm_context *ctx, LLVMTypeR
 LLVMTypeRef
 ac_to_integer_type(struct ac_llvm_context *ctx, LLVMTypeRef t)
 {
-       if (LLVMGetTypeKind(t) == LLVMFixedVectorTypeKind) {
+       if (LLVMGetTypeKind(t) == LLVMVectorTypeKind) {
                LLVMTypeRef elem_type = LLVMGetElementType(t);
                return LLVMVectorType(to_integer_type_scalar(ctx, elem_type),
                                      LLVMGetVectorSize(t));
@@ -290,7 +293,7 @@ static LLVMTypeRef to_float_type_scalar(struct ac_llvm_context *ctx, LLVMTypeRef
 LLVMTypeRef
 ac_to_float_type(struct ac_llvm_context *ctx, LLVMTypeRef t)
 {
-       if (LLVMGetTypeKind(t) == LLVMFixedVectorTypeKind) {
+       if (LLVMGetTypeKind(t) == LLVMVectorTypeKind) {
                LLVMTypeRef elem_type = LLVMGetElementType(t);
                return LLVMVectorType(to_float_type_scalar(ctx, elem_type),
                                      LLVMGetVectorSize(t));
@@ -352,7 +355,7 @@ void ac_build_type_name_for_intr(LLVMTypeRef type, char *buf, unsigned bufsize)
 
        assert(bufsize >= 8);
 
-       if (LLVMGetTypeKind(type) == LLVMFixedVectorTypeKind) {
+       if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) {
                int ret = snprintf(buf, bufsize, "v%u",
                                        LLVMGetVectorSize(type));
                if (ret < 0) {
@@ -457,11 +460,10 @@ ac_build_optimization_barrier(struct ac_llvm_context *ctx,
 }
 
 LLVMValueRef
-ac_build_shader_clock(struct ac_llvm_context *ctx)
+ac_build_shader_clock(struct ac_llvm_context *ctx, nir_scope scope)
 {
-       const char *intr = LLVM_VERSION_MAJOR >= 9 && ctx->chip_class >= GFX8 ?
-                               "llvm.amdgcn.s.memrealtime" : "llvm.readcyclecounter";
-       LLVMValueRef tmp = ac_build_intrinsic(ctx, intr, ctx->i64, NULL, 0, 0);
+       const char *name = scope == NIR_SCOPE_DEVICE ? "llvm.amdgcn.s.memrealtime" : "llvm.amdgcn.s.memtime";
+       LLVMValueRef tmp = ac_build_intrinsic(ctx, name, ctx->i64, NULL, 0, 0);
        return LLVMBuildBitCast(ctx->builder, tmp, ctx->v2i32, "");
 }
 
@@ -627,7 +629,7 @@ ac_build_expand(struct ac_llvm_context *ctx,
        LLVMTypeRef elemtype;
        LLVMValueRef chan[dst_channels];
 
-       if (LLVMGetTypeKind(LLVMTypeOf(value)) == LLVMFixedVectorTypeKind) {
+       if (LLVMGetTypeKind(LLVMTypeOf(value)) == LLVMVectorTypeKind) {
                unsigned vec_size = LLVMGetVectorSize(LLVMTypeOf(value));
 
                if (src_channels == dst_channels && vec_size == dst_channels)
@@ -1181,8 +1183,6 @@ ac_build_buffer_store_common(struct ac_llvm_context *ctx,
                             LLVMValueRef vindex,
                             LLVMValueRef voffset,
                             LLVMValueRef soffset,
-                            unsigned num_channels,
-                            LLVMTypeRef return_channel_type,
                             unsigned cache_policy,
                             bool use_format,
                             bool structurized)
@@ -1196,12 +1196,10 @@ ac_build_buffer_store_common(struct ac_llvm_context *ctx,
        args[idx++] = voffset ? voffset : ctx->i32_0;
        args[idx++] = soffset ? soffset : ctx->i32_0;
        args[idx++] = LLVMConstInt(ctx->i32, cache_policy, 0);
-       unsigned func = !ac_has_vec3_support(ctx->chip_class, use_format) && num_channels == 3 ? 4 : num_channels;
        const char *indexing_kind = structurized ? "struct" : "raw";
        char name[256], type_name[8];
 
-       LLVMTypeRef type = func > 1 ? LLVMVectorType(return_channel_type, func) : return_channel_type;
-       ac_build_type_name_for_intr(type, type_name, sizeof(type_name));
+       ac_build_type_name_for_intr(LLVMTypeOf(data), type_name, sizeof(type_name));
 
        if (use_format) {
                snprintf(name, sizeof(name), "llvm.amdgcn.%s.buffer.store.format.%s",
@@ -1221,13 +1219,10 @@ ac_build_buffer_store_format(struct ac_llvm_context *ctx,
                             LLVMValueRef data,
                             LLVMValueRef vindex,
                             LLVMValueRef voffset,
-                            unsigned num_channels,
                             unsigned cache_policy)
 {
-       ac_build_buffer_store_common(ctx, rsrc, data, vindex,
-                                    voffset, NULL, num_channels,
-                                    ctx->f32, cache_policy,
-                                    true, true);
+       ac_build_buffer_store_common(ctx, rsrc, data, vindex, voffset, NULL,
+                                    cache_policy, true, true);
 }
 
 /* TBUFFER_STORE_FORMAT_{X,XY,XYZ,XYZW} <- the suffix is selected by num_channels=1..4.
@@ -1276,7 +1271,6 @@ ac_build_buffer_store_dword(struct ac_llvm_context *ctx,
 
                ac_build_buffer_store_common(ctx, rsrc, ac_to_float(ctx, vdata),
                                             ctx->i32_0, voffset, offset,
-                                            num_channels, ctx->f32,
                                             cache_policy, false, false);
                return;
        }
@@ -1934,8 +1928,7 @@ ac_build_tbuffer_store_short(struct ac_llvm_context *ctx,
        if (LLVM_VERSION_MAJOR >= 9) {
                /* LLVM 9+ supports i8/i16 with struct/raw intrinsics. */
                ac_build_buffer_store_common(ctx, rsrc, vdata, NULL,
-                                            voffset, soffset, 1,
-                                            ctx->i16, cache_policy,
+                                            voffset, soffset, cache_policy,
                                             false, false);
        } else {
                unsigned dfmt = V_008F0C_BUF_DATA_FORMAT_16;
@@ -1961,8 +1954,7 @@ ac_build_tbuffer_store_byte(struct ac_llvm_context *ctx,
        if (LLVM_VERSION_MAJOR >= 9) {
                /* LLVM 9+ supports i8/i16 with struct/raw intrinsics. */
                ac_build_buffer_store_common(ctx, rsrc, vdata, NULL,
-                                            voffset, soffset, 1,
-                                            ctx->i8, cache_policy,
+                                            voffset, soffset, cache_policy,
                                             false, false);
        } else {
                unsigned dfmt = V_008F0C_BUF_DATA_FORMAT_8;
@@ -2190,8 +2182,10 @@ ac_build_umsb(struct ac_llvm_context *ctx,
 LLVMValueRef ac_build_fmin(struct ac_llvm_context *ctx, LLVMValueRef a,
                           LLVMValueRef b)
 {
-       char name[64];
-       snprintf(name, sizeof(name), "llvm.minnum.f%d", ac_get_elem_bits(ctx, LLVMTypeOf(a)));
+       char name[64], type[64];
+
+       ac_build_type_name_for_intr(LLVMTypeOf(a), type, sizeof(type));
+       snprintf(name, sizeof(name), "llvm.minnum.%s", type);
        LLVMValueRef args[2] = {a, b};
        return ac_build_intrinsic(ctx, name, LLVMTypeOf(a), args, 2,
                                  AC_FUNC_ATTR_READNONE);
@@ -2200,8 +2194,10 @@ LLVMValueRef ac_build_fmin(struct ac_llvm_context *ctx, LLVMValueRef a,
 LLVMValueRef ac_build_fmax(struct ac_llvm_context *ctx, LLVMValueRef a,
                           LLVMValueRef b)
 {
-       char name[64];
-       snprintf(name, sizeof(name), "llvm.maxnum.f%d", ac_get_elem_bits(ctx, LLVMTypeOf(a)));
+       char name[64], type[64];
+
+       ac_build_type_name_for_intr(LLVMTypeOf(a), type, sizeof(type));
+       snprintf(name, sizeof(name), "llvm.maxnum.%s", type);
        LLVMValueRef args[2] = {a, b};
        return ac_build_intrinsic(ctx, name, LLVMTypeOf(a), args, 2,
                                  AC_FUNC_ATTR_READNONE);
@@ -2250,13 +2246,10 @@ void ac_build_export(struct ac_llvm_context *ctx, struct ac_export_args *a)
        args[1] = LLVMConstInt(ctx->i32, a->enabled_channels, 0);
 
        if (a->compr) {
-               LLVMTypeRef i16 = LLVMInt16TypeInContext(ctx->context);
-               LLVMTypeRef v2i16 = LLVMVectorType(i16, 2);
-
                args[2] = LLVMBuildBitCast(ctx->builder, a->out[0],
-                               v2i16, "");
+                               ctx->v2i16, "");
                args[3] = LLVMBuildBitCast(ctx->builder, a->out[1],
-                               v2i16, "");
+                               ctx->v2i16, "");
                args[4] = LLVMConstInt(ctx->i1, a->done, 0);
                args[5] = LLVMConstInt(ctx->i1, a->valid_mask, 0);
 
@@ -2541,10 +2534,7 @@ LLVMValueRef ac_build_image_get_sample_count(struct ac_llvm_context *ctx,
 LLVMValueRef ac_build_cvt_pkrtz_f16(struct ac_llvm_context *ctx,
                                    LLVMValueRef args[2])
 {
-       LLVMTypeRef v2f16 =
-               LLVMVectorType(LLVMHalfTypeInContext(ctx->context), 2);
-
-       return ac_build_intrinsic(ctx, "llvm.amdgcn.cvt.pkrtz", v2f16,
+       return ac_build_intrinsic(ctx, "llvm.amdgcn.cvt.pkrtz", ctx->v2f16,
                                  args, 2, AC_FUNC_ATTR_READNONE);
 }
 
@@ -3713,9 +3703,7 @@ ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask)
                                          (LLVMValueRef []) { mask, ctx->i32_0 },
                                          2, AC_FUNC_ATTR_READNONE);
        }
-       LLVMValueRef mask_vec = LLVMBuildBitCast(ctx->builder, mask,
-                                                LLVMVectorType(ctx->i32, 2),
-                                                "");
+       LLVMValueRef mask_vec = LLVMBuildBitCast(ctx->builder, mask, ctx->v2i32, "");
        LLVMValueRef mask_lo = LLVMBuildExtractElement(ctx->builder, mask_vec,
                                                       ctx->i32_0, "");
        LLVMValueRef mask_hi = LLVMBuildExtractElement(ctx->builder, mask_vec,