v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
}
-/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method
+/// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
pub fn sqrt_rsqrt_kernel<
Ctx: Context,
VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
let y = initial_rsqrt_approximation(ctx, v);
let mut x = v * y;
let one_half: VecF = ctx.make(0.5.to());
+ let three_halves: VecF = ctx.make((3.0 / 2.0).to());
let mut neg_h = y * -one_half;
for _ in 0..iteration_count {
let r = x.mul_add_fast(neg_h, one_half);
x = x.mul_add_fast(r, x);
neg_h = neg_h.mul_add_fast(r, neg_h);
}
- (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
+ let sqrt = x;
+ let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2));
+ // do one step of Newton's method
+ let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt));
+ (sqrt, rsqrt)
}
#[cfg(test)]
mod tests {
use super::*;
+ use crate::{
+ f16::F16,
+ scalar::{Scalar, Value},
+ };
+ use std::fmt;
- // TODO: add tests for `sqrt_rsqrt_kernel`
+ fn same<F: PrimFloat>(a: F, b: F) -> bool {
+ if a.is_finite() && b.is_finite() {
+ a.to_bits() == b.to_bits()
+ } else {
+ a == b || (a.is_nan() && b.is_nan())
+ }
+ }
+
+ struct DisplayValueAndBits<F: PrimFloat>(F);
+
+ impl<F: PrimFloat> fmt::Display for DisplayValueAndBits<F> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{} {:#X}", self.0, self.0.to_bits())
+ }
+ }
+
+ #[test]
+ #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_sqrt_rsqrt_kernel_f16() {
+ let start = F16::to_bits(0.5.to());
+ let end = F16::to_bits(2.0.to());
+ for bits in start..=end {
+ let v = F16::from_bits(bits);
+ let expected_sqrt: F16 = f64::from(v).sqrt().to();
+ let expected_rsqrt: F16 = (1.0 / f64::from(v).sqrt()).to();
+ let (Value(result_sqrt), Value(result_rsqrt)) = sqrt_rsqrt_kernel(Scalar, Value(v), 3);
+ assert!(
+ same(expected_sqrt, result_sqrt) && same(expected_rsqrt, result_rsqrt),
+ "case failed: v={v}, expected_sqrt={expected_sqrt}, result_sqrt={result_sqrt}, expected_rsqrt={expected_rsqrt}, result_rsqrt={result_rsqrt}",
+ v = DisplayValueAndBits(v),
+ expected_sqrt = DisplayValueAndBits(expected_sqrt),
+ result_sqrt = DisplayValueAndBits(result_sqrt),
+ expected_rsqrt = DisplayValueAndBits(expected_rsqrt),
+ result_rsqrt = DisplayValueAndBits(result_rsqrt),
+ );
+ }
+ }
}