working on sqrt_rsqrt_kernel
[vector-math.git] / src / algorithms / roots.rs
index 4318bc9e786c823a487b3d5b5c42391e11c71f42..00044760b16ffe794cf3c3ddbd5499fe1b6849cc 100644 (file)
@@ -22,7 +22,8 @@ pub fn initial_rsqrt_approximation<
     v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
 }
 
-/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method
+/// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
 pub fn sqrt_rsqrt_kernel<
     Ctx: Context,
     VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
@@ -38,18 +39,68 @@ pub fn sqrt_rsqrt_kernel<
     let y = initial_rsqrt_approximation(ctx, v);
     let mut x = v * y;
     let one_half: VecF = ctx.make(0.5.to());
+    let three_halves: VecF = ctx.make((3.0 / 2.0).to());
     let mut neg_h = y * -one_half;
     for _ in 0..iteration_count {
         let r = x.mul_add_fast(neg_h, one_half);
         x = x.mul_add_fast(r, x);
         neg_h = neg_h.mul_add_fast(r, neg_h);
     }
-    (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
+    let sqrt = x;
+    let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2));
+    // do one step of Newton's method
+    let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt));
+    (sqrt, rsqrt)
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::{
+        f16::F16,
+        scalar::{Scalar, Value},
+    };
+    use std::fmt;
 
-    // TODO: add tests for `sqrt_rsqrt_kernel`
+    fn same<F: PrimFloat>(a: F, b: F) -> bool {
+        if a.is_finite() && b.is_finite() {
+            a.to_bits() == b.to_bits()
+        } else {
+            a == b || (a.is_nan() && b.is_nan())
+        }
+    }
+
+    struct DisplayValueAndBits<F: PrimFloat>(F);
+
+    impl<F: PrimFloat> fmt::Display for DisplayValueAndBits<F> {
+        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+            write!(f, "{} {:#X}", self.0, self.0.to_bits())
+        }
+    }
+
+    #[test]
+    #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_sqrt_rsqrt_kernel_f16() {
+        let start = F16::to_bits(0.5.to());
+        let end = F16::to_bits(2.0.to());
+        for bits in start..=end {
+            let v = F16::from_bits(bits);
+            let expected_sqrt: F16 = f64::from(v).sqrt().to();
+            let expected_rsqrt: F16 = (1.0 / f64::from(v).sqrt()).to();
+            let (Value(result_sqrt), Value(result_rsqrt)) = sqrt_rsqrt_kernel(Scalar, Value(v), 3);
+            assert!(
+                same(expected_sqrt, result_sqrt) && same(expected_rsqrt, result_rsqrt),
+                "case failed: v={v}, expected_sqrt={expected_sqrt}, result_sqrt={result_sqrt}, expected_rsqrt={expected_rsqrt}, result_rsqrt={result_rsqrt}",
+                v = DisplayValueAndBits(v),
+                expected_sqrt = DisplayValueAndBits(expected_sqrt),
+                result_sqrt = DisplayValueAndBits(result_sqrt),
+                expected_rsqrt = DisplayValueAndBits(expected_rsqrt),
+                result_rsqrt = DisplayValueAndBits(result_rsqrt),
+            );
+        }
+    }
 }