use crate::{ prim::{PrimFloat, PrimUInt}, traits::{Context, ConvertTo, Float, Make, UInt}, }; pub fn initial_rsqrt_approximation< Ctx: Context, VecF: Float + Make, VecU: UInt + Make, PrimF: PrimFloat, PrimU: PrimUInt, >( ctx: Ctx, v: VecF, ) -> VecF { // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation // where `CONST` is optimized for use for Goldschmidt's algorithm. // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root // but using different constants. const FACTOR: f64 = -0.5; const TERM: f64 = 1.6; v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to())) } /// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method /// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic pub fn sqrt_rsqrt_kernel< Ctx: Context, VecF: Float + Make, VecU: UInt + Make, PrimF: PrimFloat, PrimU: PrimUInt, >( ctx: Ctx, v: VecF, iteration_count: usize, ) -> (VecF, VecF) { // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm let y = initial_rsqrt_approximation(ctx, v); let mut x = v * y; let one_half: VecF = ctx.make(0.5.to()); let three_halves: VecF = ctx.make((3.0 / 2.0).to()); let mut neg_h = y * -one_half; for _ in 0..iteration_count { let r = x.mul_add_fast(neg_h, one_half); x = x.mul_add_fast(r, x); neg_h = neg_h.mul_add_fast(r, neg_h); } let sqrt = x; let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2)); // do one step of Newton's method let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt)); (sqrt, rsqrt) } #[cfg(test)] mod tests { use super::*; use crate::{ f16::F16, scalar::{Scalar, Value}, }; use std::fmt; fn same(a: F, b: F) -> bool { if a.is_finite() && b.is_finite() { a.to_bits() == b.to_bits() } else { a == b || (a.is_nan() && b.is_nan()) } } struct DisplayValueAndBits(F); impl fmt::Display for DisplayValueAndBits { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{} {:#X}", self.0, self.0.to_bits()) } } #[test] #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_sqrt_rsqrt_kernel_f16() { let start = F16::to_bits(0.5.to()); let end = F16::to_bits(2.0.to()); for bits in start..=end { let v = F16::from_bits(bits); let expected_sqrt: F16 = f64::from(v).sqrt().to(); let expected_rsqrt: F16 = (1.0 / f64::from(v).sqrt()).to(); let (Value(result_sqrt), Value(result_rsqrt)) = sqrt_rsqrt_kernel(Scalar, Value(v), 3); assert!( same(expected_sqrt, result_sqrt) && same(expected_rsqrt, result_rsqrt), "case failed: v={v}, expected_sqrt={expected_sqrt}, result_sqrt={result_sqrt}, expected_rsqrt={expected_rsqrt}, result_rsqrt={result_rsqrt}", v = DisplayValueAndBits(v), expected_sqrt = DisplayValueAndBits(expected_sqrt), result_sqrt = DisplayValueAndBits(result_sqrt), expected_rsqrt = DisplayValueAndBits(expected_rsqrt), result_rsqrt = DisplayValueAndBits(result_rsqrt), ); } } }