ac: add 16bit conversion operations
authorDaniel Schürmann <daniel.schuermann@campus.tu-berlin.de>
Sat, 3 Feb 2018 13:37:26 +0000 (14:37 +0100)
committerBas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Mon, 23 Jul 2018 21:16:25 +0000 (23:16 +0200)
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
src/amd/common/ac_llvm_build.c
src/amd/common/ac_nir_to_llvm.c

index a77c29270d1960102c3465e8271bcd96b4d2a2d0..4078b005e540a046d5860560241d94accc87d1fc 100644 (file)
@@ -175,6 +175,8 @@ ac_get_type_size(LLVMTypeRef type)
        switch (kind) {
        case LLVMIntegerTypeKind:
                return LLVMGetIntTypeWidth(type) / 8;
+       case LLVMHalfTypeKind:
+               return 2;
        case LLVMFloatTypeKind:
                return 4;
        case LLVMDoubleTypeKind:
@@ -320,6 +322,9 @@ void ac_build_type_name_for_intr(LLVMTypeRef type, char *buf, unsigned bufsize)
        case LLVMIntegerTypeKind:
                snprintf(buf, bufsize, "i%d", LLVMGetIntTypeWidth(elem_type));
                break;
+       case LLVMHalfTypeKind:
+               snprintf(buf, bufsize, "f16");
+               break;
        case LLVMFloatTypeKind:
                snprintf(buf, bufsize, "f32");
                break;
@@ -1819,11 +1824,9 @@ LLVMValueRef ac_build_cvt_pkrtz_f16(struct ac_llvm_context *ctx,
 {
        LLVMTypeRef v2f16 =
                LLVMVectorType(LLVMHalfTypeInContext(ctx->context), 2);
-       LLVMValueRef res =
-               ac_build_intrinsic(ctx, "llvm.amdgcn.cvt.pkrtz",
-                                  v2f16, args, 2,
-                                  AC_FUNC_ATTR_READNONE);
-       return LLVMBuildBitCast(ctx->builder, res, ctx->i32, "");
+
+       return ac_build_intrinsic(ctx, "llvm.amdgcn.cvt.pkrtz", v2f16,
+                                 args, 2, AC_FUNC_ATTR_READNONE);
 }
 
 /* Upper 16 bits must be zero. */
index 10d1773850924971e3dc8f69cdca0743bde192ae..c55f3c50681238f938dd827be440b980c0cedb60 100644 (file)
@@ -478,7 +478,8 @@ static LLVMValueRef emit_pack_half_2x16(struct ac_llvm_context *ctx,
        comp[0] = LLVMBuildExtractElement(ctx->builder, src0, ctx->i32_0, "");
        comp[1] = LLVMBuildExtractElement(ctx->builder, src0, ctx->i32_1, "");
 
-       return ac_build_cvt_pkrtz_f16(ctx, comp);
+       return LLVMBuildBitCast(ctx->builder, ac_build_cvt_pkrtz_f16(ctx, comp),
+                               ctx->i32, "");
 }
 
 static LLVMValueRef emit_unpack_half_2x16(struct ac_llvm_context *ctx,
@@ -857,34 +858,47 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                        src[i] = ac_to_integer(&ctx->ac, src[i]);
                result = ac_build_gather_values(&ctx->ac, src, num_components);
                break;
+       case nir_op_f2i16:
        case nir_op_f2i32:
        case nir_op_f2i64:
                src[0] = ac_to_float(&ctx->ac, src[0]);
                result = LLVMBuildFPToSI(ctx->ac.builder, src[0], def_type, "");
                break;
+       case nir_op_f2u16:
        case nir_op_f2u32:
        case nir_op_f2u64:
                src[0] = ac_to_float(&ctx->ac, src[0]);
                result = LLVMBuildFPToUI(ctx->ac.builder, src[0], def_type, "");
                break;
+       case nir_op_i2f16:
        case nir_op_i2f32:
        case nir_op_i2f64:
                src[0] = ac_to_integer(&ctx->ac, src[0]);
                result = LLVMBuildSIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
+       case nir_op_u2f16:
        case nir_op_u2f32:
        case nir_op_u2f64:
                src[0] = ac_to_integer(&ctx->ac, src[0]);
                result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
-       case nir_op_f2f64:
+       case nir_op_f2f16_rtz:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+               LLVMValueRef param[2] = { src[0], ctx->ac.f32_0 };
+               result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
+               result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
                break;
+       case nir_op_f2f16_rtne:
+       case nir_op_f2f16_undef:
        case nir_op_f2f32:
+       case nir_op_f2f64:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+               if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
+                       result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+               else
+                       result = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
+       case nir_op_u2u16:
        case nir_op_u2u32:
        case nir_op_u2u64:
                src[0] = ac_to_integer(&ctx->ac, src[0]);
@@ -893,6 +907,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                else
                        result = LLVMBuildTrunc(ctx->ac.builder, src[0], def_type, "");
                break;
+       case nir_op_i2i16:
        case nir_op_i2i32:
        case nir_op_i2i64:
                src[0] = ac_to_integer(&ctx->ac, src[0]);
@@ -1098,6 +1113,10 @@ static void visit_load_const(struct ac_nir_context *ctx,
 
        for (unsigned i = 0; i < instr->def.num_components; ++i) {
                switch (instr->def.bit_size) {
+               case 16:
+                       values[i] = LLVMConstInt(element_type,
+                                                instr->value.u16[i], false);
+                       break;
                case 32:
                        values[i] = LLVMConstInt(element_type,
                                                 instr->value.u32[i], false);