Revert "spirv: Use a simpler and more correct implementaiton of tanh()"
authorKristian H. Kristensen <hoegsberg@google.com>
Thu, 27 Feb 2020 19:38:53 +0000 (11:38 -0800)
committerMarge Bot <eric+marge@anholt.net>
Thu, 5 Mar 2020 15:23:31 +0000 (15:23 +0000)
This reverts commit da1c49171d0df185545cfbbd600e287f7c6160fa.

The reduced formula has precision problems on fp16 around 0.  Bring
back the old formula, but make sure to keep the clamping.

Tested-by: Marge Bot <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4054>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4054>

src/compiler/spirv/vtn_glsl450.c

index a074b6f3fb48cbb40376fb5a8f4d627bd21e6325..571f06cbed2a59ffdbd0f85ab2e6037f6e341df3 100644 (file)
@@ -458,25 +458,24 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
       return;
 
    case GLSLstd450Tanh: {
-      /* tanh(x) := (0.5 * (e^x - e^(-x))) / (0.5 * (e^x + e^(-x)))
+      /* tanh(x) := (e^x - e^(-x)) / (e^x + e^(-x))
        *
-       * With a little algebra this reduces to (e^2x - 1) / (e^2x + 1)
+       * We clamp x to [-10, +10] to avoid precision problems.  When x > 10,
+       * e^x dominates the sum, e^(-x) is lost and tanh(x) is 1.0 for 32 bit
+       * floating point.
        *
-       * We clamp x to (-inf, +10] to avoid precision problems.  When x > 10,
-       * e^2x is so much larger than 1.0 that 1.0 gets flushed to zero in the
-       * computation e^2x +/- 1 so it can be ignored.
-       *
-       * For 16-bit precision we clamp x to (-inf, +4.2] since the maximum
-       * representable number is only 65,504 and e^(2*6) exceeds that. Also,
-       * if x > 4.2, tanh(x) will return 1.0 in fp16.
+       * For 16-bit precision this we clamp x to [-4.2, +4.2].
        */
       const uint32_t bit_size = src[0]->bit_size;
       const double clamped_x = bit_size > 16 ? 10.0 : 4.2;
-      nir_ssa_def *x = nir_fmin(nb, src[0],
-                                    nir_imm_floatN_t(nb, clamped_x, bit_size));
-      nir_ssa_def *exp2x = build_exp(nb, nir_fmul_imm(nb, x, 2.0));
-      val->ssa->def = nir_fdiv(nb, nir_fadd_imm(nb, exp2x, -1.0),
-                                   nir_fadd_imm(nb, exp2x, 1.0));
+      nir_ssa_def *x = nir_fclamp(nb, src[0],
+                                  nir_imm_floatN_t(nb, -clamped_x, bit_size),
+                                  nir_imm_floatN_t(nb, clamped_x, bit_size));
+      val->ssa->def =
+         nir_fdiv(nb, nir_fsub(nb, build_exp(nb, x),
+                               build_exp(nb, nir_fneg(nb, x))),
+                  nir_fadd(nb, build_exp(nb, x),
+                           build_exp(nb, nir_fneg(nb, x))));
       return;
    }