From 29577b21230a588b048b8e445fdf0dfabc695373 Mon Sep 17 00:00:00 2001 From: Bas Nieuwenhuizen Date: Thu, 5 Jan 2017 01:09:12 +0100 Subject: [PATCH] radv/ac: Add core Float64 support. Signed-off-by: Bas Nieuwenhuizen Reviewed-by: Dave Airlie --- src/amd/common/ac_nir_to_llvm.c | 173 ++++++++++++++++++++++++-------- 1 file changed, 129 insertions(+), 44 deletions(-) diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c index c622c006cec..be315851904 100644 --- a/src/amd/common/ac_nir_to_llvm.c +++ b/src/amd/common/ac_nir_to_llvm.c @@ -119,6 +119,7 @@ struct nir_to_llvm_context { LLVMTypeRef v3i32; LLVMTypeRef v4i32; LLVMTypeRef v8i32; + LLVMTypeRef f64; LLVMTypeRef f32; LLVMTypeRef f16; LLVMTypeRef v2f32; @@ -313,34 +314,78 @@ static LLVMValueRef get_shared_memory_ptr(struct nir_to_llvm_context *ctx, return ptr; } +static LLVMTypeRef to_integer_type_scalar(struct nir_to_llvm_context *ctx, LLVMTypeRef t) +{ + if (t == ctx->f16 || t == ctx->i16) + return ctx->i16; + else if (t == ctx->f32 || t == ctx->i32) + return ctx->i32; + else if (t == ctx->f64 || t == ctx->i64) + return ctx->i64; + else + unreachable("Unhandled integer size"); +} + +static LLVMTypeRef to_integer_type(struct nir_to_llvm_context *ctx, LLVMTypeRef t) +{ + if (LLVMGetTypeKind(t) == LLVMVectorTypeKind) { + LLVMTypeRef elem_type = LLVMGetElementType(t); + return LLVMVectorType(to_integer_type_scalar(ctx, elem_type), + LLVMGetVectorSize(t)); + } + return to_integer_type_scalar(ctx, t); +} + static LLVMValueRef to_integer(struct nir_to_llvm_context *ctx, LLVMValueRef v) { LLVMTypeRef type = LLVMTypeOf(v); - if (type == ctx->f32) { - return LLVMBuildBitCast(ctx->builder, v, ctx->i32, ""); - } else if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) { - LLVMTypeRef elem_type = LLVMGetElementType(type); - if (elem_type == ctx->f32) { - LLVMTypeRef nt = LLVMVectorType(ctx->i32, LLVMGetVectorSize(type)); - return LLVMBuildBitCast(ctx->builder, v, nt, ""); - } + return LLVMBuildBitCast(ctx->builder, v, to_integer_type(ctx, type), ""); +} + +static LLVMTypeRef to_float_type_scalar(struct nir_to_llvm_context *ctx, LLVMTypeRef t) +{ + if (t == ctx->i16 || t == ctx->f16) + return ctx->f16; + else if (t == ctx->i32 || t == ctx->f32) + return ctx->f32; + else if (t == ctx->i64 || t == ctx->f64) + return ctx->f64; + else + unreachable("Unhandled float size"); +} + +static LLVMTypeRef to_float_type(struct nir_to_llvm_context *ctx, LLVMTypeRef t) +{ + if (LLVMGetTypeKind(t) == LLVMVectorTypeKind) { + LLVMTypeRef elem_type = LLVMGetElementType(t); + return LLVMVectorType(to_float_type_scalar(ctx, elem_type), + LLVMGetVectorSize(t)); } - return v; + return to_float_type_scalar(ctx, t); } static LLVMValueRef to_float(struct nir_to_llvm_context *ctx, LLVMValueRef v) { LLVMTypeRef type = LLVMTypeOf(v); - if (type == ctx->i32) { - return LLVMBuildBitCast(ctx->builder, v, ctx->f32, ""); - } else if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) { - LLVMTypeRef elem_type = LLVMGetElementType(type); - if (elem_type == ctx->i32) { - LLVMTypeRef nt = LLVMVectorType(ctx->f32, LLVMGetVectorSize(type)); - return LLVMBuildBitCast(ctx->builder, v, nt, ""); - } - } - return v; + return LLVMBuildBitCast(ctx->builder, v, to_float_type(ctx, type), ""); +} + +static int get_elem_bits(struct nir_to_llvm_context *ctx, LLVMTypeRef type) +{ + if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) + type = LLVMGetElementType(type); + + if (LLVMGetTypeKind(type) == LLVMIntegerTypeKind) + return LLVMGetIntTypeWidth(type); + + if (type == ctx->f16) + return 16; + if (type == ctx->f32) + return 32; + if (type == ctx->f64) + return 64; + + unreachable("Unhandled type kind in get_elem_bits"); } static LLVMValueRef unpack_param(struct nir_to_llvm_context *ctx, @@ -710,6 +755,7 @@ static void setup_types(struct nir_to_llvm_context *ctx) ctx->v8i32 = LLVMVectorType(ctx->i32, 8); ctx->f32 = LLVMFloatTypeInContext(ctx->context); ctx->f16 = LLVMHalfTypeInContext(ctx->context); + ctx->f64 = LLVMDoubleTypeInContext(ctx->context); ctx->v2f32 = LLVMVectorType(ctx->f32, 2); ctx->v4f32 = LLVMVectorType(ctx->f32, 4); ctx->v16i8 = LLVMVectorType(ctx->i8, 16); @@ -894,35 +940,47 @@ static LLVMValueRef emit_float_cmp(struct nir_to_llvm_context *ctx, static LLVMValueRef emit_intrin_1f_param(struct nir_to_llvm_context *ctx, const char *intrin, + LLVMTypeRef result_type, LLVMValueRef src0) { + char name[64]; LLVMValueRef params[] = { to_float(ctx, src0), }; - return ac_emit_llvm_intrinsic(&ctx->ac, intrin, ctx->f32, params, 1, AC_FUNC_ATTR_READNONE); + + sprintf(name, "%s.f%d", intrin, get_elem_bits(ctx, result_type)); + return ac_emit_llvm_intrinsic(&ctx->ac, name, result_type, params, 1, AC_FUNC_ATTR_READNONE); } static LLVMValueRef emit_intrin_2f_param(struct nir_to_llvm_context *ctx, const char *intrin, + LLVMTypeRef result_type, LLVMValueRef src0, LLVMValueRef src1) { + char name[64]; LLVMValueRef params[] = { to_float(ctx, src0), to_float(ctx, src1), }; - return ac_emit_llvm_intrinsic(&ctx->ac, intrin, ctx->f32, params, 2, AC_FUNC_ATTR_READNONE); + + sprintf(name, "%s.f%d", intrin, get_elem_bits(ctx, result_type)); + return ac_emit_llvm_intrinsic(&ctx->ac, name, result_type, params, 2, AC_FUNC_ATTR_READNONE); } static LLVMValueRef emit_intrin_3f_param(struct nir_to_llvm_context *ctx, const char *intrin, + LLVMTypeRef result_type, LLVMValueRef src0, LLVMValueRef src1, LLVMValueRef src2) { + char name[64]; LLVMValueRef params[] = { to_float(ctx, src0), to_float(ctx, src1), to_float(ctx, src2), }; - return ac_emit_llvm_intrinsic(&ctx->ac, intrin, ctx->f32, params, 3, AC_FUNC_ATTR_READNONE); + + sprintf(name, "%s.f%d", intrin, get_elem_bits(ctx, result_type)); + return ac_emit_llvm_intrinsic(&ctx->ac, name, result_type, params, 3, AC_FUNC_ATTR_READNONE); } static LLVMValueRef emit_bcsel(struct nir_to_llvm_context *ctx, @@ -1345,6 +1403,7 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr) LLVMValueRef src[4], result = NULL; unsigned num_components = instr->dest.dest.ssa.num_components; unsigned src_components; + LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.dest.ssa); assert(nir_op_infos[instr->op].num_inputs <= ARRAY_SIZE(src)); switch (instr->op) { @@ -1410,7 +1469,8 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr) src[0] = to_float(ctx, src[0]); src[1] = to_float(ctx, src[1]); result = ac_emit_fdiv(&ctx->ac, src[0], src[1]); - result = emit_intrin_1f_param(ctx, "llvm.floor.f32", result); + result = emit_intrin_1f_param(ctx, "llvm.floor", + to_float_type(ctx, def_type), result); result = LLVMBuildFMul(ctx->builder, src[1] , result, ""); result = LLVMBuildFSub(ctx->builder, src[0], result, ""); break; @@ -1491,7 +1551,8 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr) result = emit_float_cmp(ctx, LLVMRealUGE, src[0], src[1]); break; case nir_op_fabs: - result = emit_intrin_1f_param(ctx, "llvm.fabs.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.fabs", + to_float_type(ctx, def_type), src[0]); break; case nir_op_iabs: result = emit_iabs(ctx, src[0]); @@ -1516,50 +1577,64 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr) result = emit_fsign(ctx, src[0]); break; case nir_op_ffloor: - result = emit_intrin_1f_param(ctx, "llvm.floor.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.floor", + to_float_type(ctx, def_type), src[0]); break; case nir_op_ftrunc: - result = emit_intrin_1f_param(ctx, "llvm.trunc.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.trunc", + to_float_type(ctx, def_type), src[0]); break; case nir_op_fceil: - result = emit_intrin_1f_param(ctx, "llvm.ceil.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.ceil", + to_float_type(ctx, def_type), src[0]); break; case nir_op_fround_even: - result = emit_intrin_1f_param(ctx, "llvm.rint.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.rint", + to_float_type(ctx, def_type),src[0]); break; case nir_op_ffract: result = emit_ffract(ctx, src[0]); break; case nir_op_fsin: - result = emit_intrin_1f_param(ctx, "llvm.sin.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.sin", + to_float_type(ctx, def_type), src[0]); break; case nir_op_fcos: - result = emit_intrin_1f_param(ctx, "llvm.cos.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.cos", + to_float_type(ctx, def_type), src[0]); break; case nir_op_fsqrt: - result = emit_intrin_1f_param(ctx, "llvm.sqrt.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.sqrt", + to_float_type(ctx, def_type), src[0]); break; case nir_op_fexp2: - result = emit_intrin_1f_param(ctx, "llvm.exp2.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.exp2", + to_float_type(ctx, def_type), src[0]); break; case nir_op_flog2: - result = emit_intrin_1f_param(ctx, "llvm.log2.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.log2", + to_float_type(ctx, def_type), src[0]); break; case nir_op_frsq: - result = emit_intrin_1f_param(ctx, "llvm.sqrt.f32", src[0]); + result = emit_intrin_1f_param(ctx, "llvm.sqrt", + to_float_type(ctx, def_type), src[0]); result = ac_emit_fdiv(&ctx->ac, ctx->f32one, result); break; case nir_op_fpow: - result = emit_intrin_2f_param(ctx, "llvm.pow.f32", src[0], src[1]); + result = emit_intrin_2f_param(ctx, "llvm.pow", + to_float_type(ctx, def_type), src[0], src[1]); break; case nir_op_fmax: - result = emit_intrin_2f_param(ctx, "llvm.maxnum.f32", src[0], src[1]); + result = emit_intrin_2f_param(ctx, "llvm.maxnum", + to_float_type(ctx, def_type), src[0], src[1]); break; case nir_op_fmin: - result = emit_intrin_2f_param(ctx, "llvm.minnum.f32", src[0], src[1]); + result = emit_intrin_2f_param(ctx, "llvm.minnum", + to_float_type(ctx, def_type), src[0], src[1]); break; case nir_op_ffma: - result = emit_intrin_3f_param(ctx, "llvm.fma.f32", src[0], src[1], src[2]); + result = emit_intrin_3f_param(ctx, "llvm.fma", + to_float_type(ctx, def_type), src[0], src[1], src[2]); break; case nir_op_ibitfield_extract: result = emit_bitfield_extract(ctx, "llvm.AMDGPU.bfe.i32", src); @@ -1583,19 +1658,29 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr) src[i] = to_integer(ctx, src[i]); result = ac_build_gather_values(&ctx->ac, src, num_components); break; + case nir_op_d2i: case nir_op_f2i: src[0] = to_float(ctx, src[0]); - result = LLVMBuildFPToSI(ctx->builder, src[0], ctx->i32, ""); + result = LLVMBuildFPToSI(ctx->builder, src[0], def_type, ""); break; + case nir_op_d2u: case nir_op_f2u: src[0] = to_float(ctx, src[0]); - result = LLVMBuildFPToUI(ctx->builder, src[0], ctx->i32, ""); + result = LLVMBuildFPToUI(ctx->builder, src[0], def_type, ""); break; + case nir_op_i2d: case nir_op_i2f: - result = LLVMBuildSIToFP(ctx->builder, src[0], ctx->f32, ""); + result = LLVMBuildSIToFP(ctx->builder, src[0], to_float_type(ctx, def_type), ""); break; + case nir_op_u2d: case nir_op_u2f: - result = LLVMBuildUIToFP(ctx->builder, src[0], ctx->f32, ""); + result = LLVMBuildUIToFP(ctx->builder, src[0], to_float_type(ctx, def_type), ""); + break; + case nir_op_f2d: + result = LLVMBuildFPExt(ctx->builder, src[0], to_float_type(ctx, def_type), ""); + break; + case nir_op_d2f: + result = LLVMBuildFPTrunc(ctx->builder, src[0], to_float_type(ctx, def_type), ""); break; case nir_op_bcsel: result = emit_bcsel(ctx, src[0], src[1], src[2]); @@ -4249,8 +4334,8 @@ static LLVMValueRef emit_float_saturate(struct nir_to_llvm_context *ctx, LLVMValueRef v, float lo, float hi) { v = to_float(ctx, v); - v = emit_intrin_2f_param(ctx, "llvm.maxnum.f32", v, LLVMConstReal(ctx->f32, lo)); - return emit_intrin_2f_param(ctx, "llvm.minnum.f32", v, LLVMConstReal(ctx->f32, hi)); + v = emit_intrin_2f_param(ctx, "llvm.maxnum.f32", ctx->f32, v, LLVMConstReal(ctx->f32, lo)); + return emit_intrin_2f_param(ctx, "llvm.minnum.f32", ctx->f32, v, LLVMConstReal(ctx->f32, hi)); } -- 2.30.2