From f2d096a09d964631f8273f2c8ca0d1fa0ddfbdfc Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 20 May 2021 20:29:53 -0700 Subject: [PATCH] working on sqrt_rsqrt_kernel --- src/algorithms/roots.rs | 57 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/src/algorithms/roots.rs b/src/algorithms/roots.rs index 4318bc9..0004476 100644 --- a/src/algorithms/roots.rs +++ b/src/algorithms/roots.rs @@ -22,7 +22,8 @@ pub fn initial_rsqrt_approximation< v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to())) } -/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm +/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method +/// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic pub fn sqrt_rsqrt_kernel< Ctx: Context, VecF: Float + Make, @@ -38,18 +39,68 @@ pub fn sqrt_rsqrt_kernel< let y = initial_rsqrt_approximation(ctx, v); let mut x = v * y; let one_half: VecF = ctx.make(0.5.to()); + let three_halves: VecF = ctx.make((3.0 / 2.0).to()); let mut neg_h = y * -one_half; for _ in 0..iteration_count { let r = x.mul_add_fast(neg_h, one_half); x = x.mul_add_fast(r, x); neg_h = neg_h.mul_add_fast(r, neg_h); } - (x, neg_h * ctx.make(PrimF::cvt_from(-2))) + let sqrt = x; + let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2)); + // do one step of Newton's method + let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt)); + (sqrt, rsqrt) } #[cfg(test)] mod tests { use super::*; + use crate::{ + f16::F16, + scalar::{Scalar, Value}, + }; + use std::fmt; - // TODO: add tests for `sqrt_rsqrt_kernel` + fn same(a: F, b: F) -> bool { + if a.is_finite() && b.is_finite() { + a.to_bits() == b.to_bits() + } else { + a == b || (a.is_nan() && b.is_nan()) + } + } + + struct DisplayValueAndBits(F); + + impl fmt::Display for DisplayValueAndBits { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} {:#X}", self.0, self.0.to_bits()) + } + } + + #[test] + #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_sqrt_rsqrt_kernel_f16() { + let start = F16::to_bits(0.5.to()); + let end = F16::to_bits(2.0.to()); + for bits in start..=end { + let v = F16::from_bits(bits); + let expected_sqrt: F16 = f64::from(v).sqrt().to(); + let expected_rsqrt: F16 = (1.0 / f64::from(v).sqrt()).to(); + let (Value(result_sqrt), Value(result_rsqrt)) = sqrt_rsqrt_kernel(Scalar, Value(v), 3); + assert!( + same(expected_sqrt, result_sqrt) && same(expected_rsqrt, result_rsqrt), + "case failed: v={v}, expected_sqrt={expected_sqrt}, result_sqrt={result_sqrt}, expected_rsqrt={expected_rsqrt}, result_rsqrt={result_rsqrt}", + v = DisplayValueAndBits(v), + expected_sqrt = DisplayValueAndBits(expected_sqrt), + result_sqrt = DisplayValueAndBits(result_sqrt), + expected_rsqrt = DisplayValueAndBits(expected_rsqrt), + result_rsqrt = DisplayValueAndBits(result_rsqrt), + ); + } + } } -- 2.30.2