From 7d3c34197a5d357473f8b3090dd24d6e0dfea2e4 Mon Sep 17 00:00:00 2001 From: Iago Toral Quiroga Date: Wed, 18 Apr 2018 10:14:11 +0200 Subject: [PATCH] compiler/spirv: implement 16-bit hyperbolic trigonometric functions v2: - use nir_fadd_imm and nir_fmul_imm helpers (Jason) v3: - since we need to define one for fsub use it for fdiv too (Jason) Reviewed-by: Jason Ekstrand --- src/compiler/spirv/vtn_glsl450.c | 44 +++++++++++++++++++------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c index 7984c7cc776..396ec641562 100644 --- a/src/compiler/spirv/vtn_glsl450.c +++ b/src/compiler/spirv/vtn_glsl450.c @@ -654,17 +654,17 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, case GLSLstd450Sinh: /* 0.5 * (e^x - e^(-x)) */ val->ssa->def = - nir_fmul(nb, nir_imm_float(nb, 0.5f), - nir_fsub(nb, build_exp(nb, src[0]), - build_exp(nb, nir_fneg(nb, src[0])))); + nir_fmul_imm(nb, nir_fsub(nb, build_exp(nb, src[0]), + build_exp(nb, nir_fneg(nb, src[0]))), + 0.5f); return; case GLSLstd450Cosh: /* 0.5 * (e^x + e^(-x)) */ val->ssa->def = - nir_fmul(nb, nir_imm_float(nb, 0.5f), - nir_fadd(nb, build_exp(nb, src[0]), - build_exp(nb, nir_fneg(nb, src[0])))); + nir_fmul_imm(nb, nir_fadd(nb, build_exp(nb, src[0]), + build_exp(nb, nir_fneg(nb, src[0]))), + 0.5f); return; case GLSLstd450Tanh: { @@ -675,30 +675,38 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, * We clamp x to (-inf, +10] to avoid precision problems. When x > 10, * e^2x is so much larger than 1.0 that 1.0 gets flushed to zero in the * computation e^2x +/- 1 so it can be ignored. + * + * For 16-bit precision we clamp x to (-inf, +4.2] since the maximum + * representable number is only 65,504 and e^(2*6) exceeds that. Also, + * if x > 4.2, tanh(x) will return 1.0 in fp16. */ - nir_ssa_def *x = nir_fmin(nb, src[0], nir_imm_float(nb, 10)); - nir_ssa_def *exp2x = build_exp(nb, nir_fmul(nb, x, nir_imm_float(nb, 2))); - val->ssa->def = nir_fdiv(nb, nir_fsub(nb, exp2x, nir_imm_float(nb, 1)), - nir_fadd(nb, exp2x, nir_imm_float(nb, 1))); + const uint32_t bit_size = src[0]->bit_size; + const double clamped_x = bit_size > 16 ? 10.0 : 4.2; + nir_ssa_def *x = nir_fmin(nb, src[0], + nir_imm_floatN_t(nb, clamped_x, bit_size)); + nir_ssa_def *exp2x = build_exp(nb, nir_fmul_imm(nb, x, 2.0)); + val->ssa->def = nir_fdiv(nb, nir_fadd_imm(nb, exp2x, -1.0), + nir_fadd_imm(nb, exp2x, 1.0)); return; } case GLSLstd450Asinh: val->ssa->def = nir_fmul(nb, nir_fsign(nb, src[0]), build_log(nb, nir_fadd(nb, nir_fabs(nb, src[0]), - nir_fsqrt(nb, nir_fadd(nb, nir_fmul(nb, src[0], src[0]), - nir_imm_float(nb, 1.0f)))))); + nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]), + 1.0f))))); return; case GLSLstd450Acosh: val->ssa->def = build_log(nb, nir_fadd(nb, src[0], - nir_fsqrt(nb, nir_fsub(nb, nir_fmul(nb, src[0], src[0]), - nir_imm_float(nb, 1.0f))))); + nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]), + -1.0f)))); return; case GLSLstd450Atanh: { - nir_ssa_def *one = nir_imm_float(nb, 1.0); - val->ssa->def = nir_fmul(nb, nir_imm_float(nb, 0.5f), - build_log(nb, nir_fdiv(nb, nir_fadd(nb, one, src[0]), - nir_fsub(nb, one, src[0])))); + nir_ssa_def *one = nir_imm_floatN_t(nb, 1.0, src[0]->bit_size); + val->ssa->def = + nir_fmul_imm(nb, build_log(nb, nir_fdiv(nb, nir_fadd(nb, src[0], one), + nir_fsub(nb, one, src[0]))), + 0.5f); return; } -- 2.30.2