add sqrt_fast_f16/f32/f64
[vector-math.git] / src / algorithms / roots.rs
index 00044760b16ffe794cf3c3ddbd5499fe1b6849cc..22cecc4ded8a61ae227bac97193c88fbffeb3524 100644 (file)
@@ -1,6 +1,6 @@
 use crate::{
     prim::{PrimFloat, PrimUInt},
-    traits::{Context, ConvertTo, Float, Make, UInt},
+    traits::{Context, ConvertTo, Float, Make, Select, UInt},
 };
 
 pub fn initial_rsqrt_approximation<
@@ -22,9 +22,8 @@ pub fn initial_rsqrt_approximation<
     v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
 }
 
-/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method
-/// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
-pub fn sqrt_rsqrt_kernel<
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+pub fn sqrt_rsqrt_kernel_fast<
     Ctx: Context,
     VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
     VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
@@ -39,7 +38,6 @@ pub fn sqrt_rsqrt_kernel<
     let y = initial_rsqrt_approximation(ctx, v);
     let mut x = v * y;
     let one_half: VecF = ctx.make(0.5.to());
-    let three_halves: VecF = ctx.make((3.0 / 2.0).to());
     let mut neg_h = y * -one_half;
     for _ in 0..iteration_count {
         let r = x.mul_add_fast(neg_h, one_half);
@@ -48,11 +46,72 @@ pub fn sqrt_rsqrt_kernel<
     }
     let sqrt = x;
     let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2));
-    // do one step of Newton's method
-    let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt));
     (sqrt, rsqrt)
 }
 
+/// calculate `sqrt(v)`, error inherited from `kernel_fn`. Calls `kernel_fn` with inputs in the range `0.5 <= x < 2.0`.
+pub fn sqrt_impl<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+    KernelFn: FnOnce(Ctx, VecF) -> VecF,
+>(
+    ctx: Ctx,
+    v: VecF,
+    kernel_fn: KernelFn,
+) -> VecF {
+    let is_normal_case = v.gt(ctx.make(0.0.to())) & v.is_finite();
+    let is_zero_or_positive = v.ge(ctx.make(0.0.to()));
+    let exceptional_retval = is_zero_or_positive.select(v, VecF::nan(ctx));
+    let need_subnormal_scale = v.is_zero_or_subnormal();
+    let subnormal_result_scale_exponent: PrimU = (PrimF::MANTISSA_FIELD_WIDTH + 1.to()) / 2.to();
+    let subnormal_input_scale_exponent = subnormal_result_scale_exponent * 2.to();
+    let subnormal_result_scale_prim =
+        PrimF::cvt_from(1) / PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_result_scale_exponent);
+    let subnormal_input_scale_prim =
+        PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_input_scale_exponent);
+    let subnormal_result_scale: VecF =
+        need_subnormal_scale.select(ctx.make(subnormal_result_scale_prim), ctx.make(1.0.to()));
+    let subnormal_input_scale: VecF =
+        need_subnormal_scale.select(ctx.make(subnormal_input_scale_prim), ctx.make(1.0.to()));
+    let v = v * subnormal_input_scale;
+    let exponent_field = v.extract_exponent_field();
+    let normal_result_scale_exponent_field_offset: PrimU =
+        PrimF::EXPONENT_BIAS_UNSIGNED - (PrimF::EXPONENT_BIAS_UNSIGNED >> 1.to());
+    let shifted_exponent_field = exponent_field >> ctx.make(1.to());
+    let normal_result_scale_exponent_field =
+        shifted_exponent_field + ctx.make(normal_result_scale_exponent_field_offset);
+    let normal_result_scale = ctx
+        .make::<VecF>(1.to())
+        .with_exponent_field(normal_result_scale_exponent_field);
+    let v = v.with_exponent_field(
+        (exponent_field & ctx.make(1.to()))
+            | ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED & !PrimU::cvt_from(1)),
+    );
+    let normal_result = kernel_fn(ctx, v) * (normal_result_scale * subnormal_result_scale);
+    is_normal_case.select(normal_result, exceptional_retval)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 2ULP
+pub fn sqrt_fast_f16<Ctx: Context>(ctx: Ctx, v: Ctx::VecF16) -> Ctx::VecF16 {
+    sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 3).0)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 3ULP
+pub fn sqrt_fast_f32<Ctx: Context>(ctx: Ctx, v: Ctx::VecF32) -> Ctx::VecF32 {
+    sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 4).0)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 2ULP
+pub fn sqrt_fast_f64<Ctx: Context>(ctx: Ctx, v: Ctx::VecF64) -> Ctx::VecF64 {
+    sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 5).0)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -70,16 +129,40 @@ mod tests {
         }
     }
 
+    fn approx_same<F: PrimFloat>(a: F, b: F, max_ulp: F::BitsType) -> bool {
+        if a.is_finite() && b.is_finite() {
+            let a = a.to_bits();
+            let b = b.to_bits();
+            let ulp = if a < b { b - a } else { a - b };
+            ulp <= max_ulp
+        } else {
+            a == b || (a.is_nan() && b.is_nan())
+        }
+    }
+
     struct DisplayValueAndBits<F: PrimFloat>(F);
 
     impl<F: PrimFloat> fmt::Display for DisplayValueAndBits<F> {
         fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-            write!(f, "{} {:#X}", self.0, self.0.to_bits())
+            write!(f, "{:e} {:#X}", self.0, self.0.to_bits())
+        }
+    }
+
+    fn reference_sqrt_f64(v: f64) -> f64 {
+        use az::Cast;
+        use rug::Float;
+        if v.is_nan() || v < 0.0 {
+            f64::NAN
+        } else if v == 0.0 || !v.is_finite() {
+            v
+        } else {
+            let precision = 100;
+            Float::with_val(precision, v).sqrt().cast()
         }
     }
 
     #[test]
-    #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
+    #[cfg(not(test))] // TODO: broken for now -- create sqrt_rsqrt_kernel variant using float-float arithmetic
     #[cfg_attr(
         not(feature = "f16"),
         should_panic(expected = "f16 feature is not enabled")
@@ -103,4 +186,58 @@ mod tests {
             );
         }
     }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_sqrt_fast_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected: F16 = f64::from(v).sqrt().to();
+            let result = sqrt_fast_f16(Scalar, Value(v)).0;
+            assert!(
+                approx_same(expected, result, 2),
+                "case failed: v={v}, expected={expected}, result={result}",
+                v = DisplayValueAndBits(v),
+                expected = DisplayValueAndBits(expected),
+                result = DisplayValueAndBits(result),
+            );
+        }
+    }
+
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_sqrt_fast_f32() {
+        for bits in (0..=u32::MAX).step_by(1 << 8) {
+            let v = f32::from_bits(bits);
+            let expected: f32 = f64::from(v).sqrt().to();
+            let result = sqrt_fast_f32(Scalar, Value(v)).0;
+            assert!(
+                approx_same(expected, result, 3),
+                "case failed: v={v}, expected={expected}, result={result}",
+                v = DisplayValueAndBits(v),
+                expected = DisplayValueAndBits(expected),
+                result = DisplayValueAndBits(result),
+            );
+        }
+    }
+
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_sqrt_fast_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 40) {
+            let v = f64::from_bits(bits);
+            let expected: f64 = reference_sqrt_f64(v);
+            let result = sqrt_fast_f64(Scalar, Value(v)).0;
+            assert!(
+                approx_same(expected, result, 2),
+                "case failed: v={v}, expected={expected}, result={result}",
+                v = DisplayValueAndBits(v),
+                expected = DisplayValueAndBits(expected),
+                result = DisplayValueAndBits(result),
+            );
+        }
+    }
 }