From a43a29ccd5824b9b6dcb45baa3ccb3f8c5ea72e1 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 27 May 2021 21:04:01 -0700 Subject: [PATCH] add sqrt_fast_f16/f32/f64 --- src/algorithms/roots.rs | 155 +++++++++++++++++++++++++++++++++++++--- src/f16.rs | 12 ++++ src/prim.rs | 7 +- src/traits.rs | 8 +++ 4 files changed, 172 insertions(+), 10 deletions(-) diff --git a/src/algorithms/roots.rs b/src/algorithms/roots.rs index 0004476..22cecc4 100644 --- a/src/algorithms/roots.rs +++ b/src/algorithms/roots.rs @@ -1,6 +1,6 @@ use crate::{ prim::{PrimFloat, PrimUInt}, - traits::{Context, ConvertTo, Float, Make, UInt}, + traits::{Context, ConvertTo, Float, Make, Select, UInt}, }; pub fn initial_rsqrt_approximation< @@ -22,9 +22,8 @@ pub fn initial_rsqrt_approximation< v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to())) } -/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method -/// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic -pub fn sqrt_rsqrt_kernel< +/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm +pub fn sqrt_rsqrt_kernel_fast< Ctx: Context, VecF: Float + Make, VecU: UInt + Make, @@ -39,7 +38,6 @@ pub fn sqrt_rsqrt_kernel< let y = initial_rsqrt_approximation(ctx, v); let mut x = v * y; let one_half: VecF = ctx.make(0.5.to()); - let three_halves: VecF = ctx.make((3.0 / 2.0).to()); let mut neg_h = y * -one_half; for _ in 0..iteration_count { let r = x.mul_add_fast(neg_h, one_half); @@ -48,11 +46,72 @@ pub fn sqrt_rsqrt_kernel< } let sqrt = x; let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2)); - // do one step of Newton's method - let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt)); (sqrt, rsqrt) } +/// calculate `sqrt(v)`, error inherited from `kernel_fn`. Calls `kernel_fn` with inputs in the range `0.5 <= x < 2.0`. +pub fn sqrt_impl< + Ctx: Context, + VecF: Float + Make, + VecU: UInt + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, + KernelFn: FnOnce(Ctx, VecF) -> VecF, +>( + ctx: Ctx, + v: VecF, + kernel_fn: KernelFn, +) -> VecF { + let is_normal_case = v.gt(ctx.make(0.0.to())) & v.is_finite(); + let is_zero_or_positive = v.ge(ctx.make(0.0.to())); + let exceptional_retval = is_zero_or_positive.select(v, VecF::nan(ctx)); + let need_subnormal_scale = v.is_zero_or_subnormal(); + let subnormal_result_scale_exponent: PrimU = (PrimF::MANTISSA_FIELD_WIDTH + 1.to()) / 2.to(); + let subnormal_input_scale_exponent = subnormal_result_scale_exponent * 2.to(); + let subnormal_result_scale_prim = + PrimF::cvt_from(1) / PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_result_scale_exponent); + let subnormal_input_scale_prim = + PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_input_scale_exponent); + let subnormal_result_scale: VecF = + need_subnormal_scale.select(ctx.make(subnormal_result_scale_prim), ctx.make(1.0.to())); + let subnormal_input_scale: VecF = + need_subnormal_scale.select(ctx.make(subnormal_input_scale_prim), ctx.make(1.0.to())); + let v = v * subnormal_input_scale; + let exponent_field = v.extract_exponent_field(); + let normal_result_scale_exponent_field_offset: PrimU = + PrimF::EXPONENT_BIAS_UNSIGNED - (PrimF::EXPONENT_BIAS_UNSIGNED >> 1.to()); + let shifted_exponent_field = exponent_field >> ctx.make(1.to()); + let normal_result_scale_exponent_field = + shifted_exponent_field + ctx.make(normal_result_scale_exponent_field_offset); + let normal_result_scale = ctx + .make::(1.to()) + .with_exponent_field(normal_result_scale_exponent_field); + let v = v.with_exponent_field( + (exponent_field & ctx.make(1.to())) + | ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED & !PrimU::cvt_from(1)), + ); + let normal_result = kernel_fn(ctx, v) * (normal_result_scale * subnormal_result_scale); + is_normal_case.select(normal_result, exceptional_retval) +} + +/// computes `sqrt(x)` +/// has an error of up to 2ULP +pub fn sqrt_fast_f16(ctx: Ctx, v: Ctx::VecF16) -> Ctx::VecF16 { + sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 3).0) +} + +/// computes `sqrt(x)` +/// has an error of up to 3ULP +pub fn sqrt_fast_f32(ctx: Ctx, v: Ctx::VecF32) -> Ctx::VecF32 { + sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 4).0) +} + +/// computes `sqrt(x)` +/// has an error of up to 2ULP +pub fn sqrt_fast_f64(ctx: Ctx, v: Ctx::VecF64) -> Ctx::VecF64 { + sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 5).0) +} + #[cfg(test)] mod tests { use super::*; @@ -70,16 +129,40 @@ mod tests { } } + fn approx_same(a: F, b: F, max_ulp: F::BitsType) -> bool { + if a.is_finite() && b.is_finite() { + let a = a.to_bits(); + let b = b.to_bits(); + let ulp = if a < b { b - a } else { a - b }; + ulp <= max_ulp + } else { + a == b || (a.is_nan() && b.is_nan()) + } + } + struct DisplayValueAndBits(F); impl fmt::Display for DisplayValueAndBits { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} {:#X}", self.0, self.0.to_bits()) + write!(f, "{:e} {:#X}", self.0, self.0.to_bits()) + } + } + + fn reference_sqrt_f64(v: f64) -> f64 { + use az::Cast; + use rug::Float; + if v.is_nan() || v < 0.0 { + f64::NAN + } else if v == 0.0 || !v.is_finite() { + v + } else { + let precision = 100; + Float::with_val(precision, v).sqrt().cast() } } #[test] - #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic + #[cfg(not(test))] // TODO: broken for now -- create sqrt_rsqrt_kernel variant using float-float arithmetic #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") @@ -103,4 +186,58 @@ mod tests { ); } } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_sqrt_fast_f16() { + for bits in 0..=u16::MAX { + let v = F16::from_bits(bits); + let expected: F16 = f64::from(v).sqrt().to(); + let result = sqrt_fast_f16(Scalar, Value(v)).0; + assert!( + approx_same(expected, result, 2), + "case failed: v={v}, expected={expected}, result={result}", + v = DisplayValueAndBits(v), + expected = DisplayValueAndBits(expected), + result = DisplayValueAndBits(result), + ); + } + } + + #[test] + #[cfg(feature = "full_tests")] + fn test_sqrt_fast_f32() { + for bits in (0..=u32::MAX).step_by(1 << 8) { + let v = f32::from_bits(bits); + let expected: f32 = f64::from(v).sqrt().to(); + let result = sqrt_fast_f32(Scalar, Value(v)).0; + assert!( + approx_same(expected, result, 3), + "case failed: v={v}, expected={expected}, result={result}", + v = DisplayValueAndBits(v), + expected = DisplayValueAndBits(expected), + result = DisplayValueAndBits(result), + ); + } + } + + #[test] + #[cfg(feature = "full_tests")] + fn test_sqrt_fast_f64() { + for bits in (0..=u64::MAX).step_by(1 << 40) { + let v = f64::from_bits(bits); + let expected: f64 = reference_sqrt_f64(v); + let result = sqrt_fast_f64(Scalar, Value(v)).0; + assert!( + approx_same(expected, result, 2), + "case failed: v={v}, expected={expected}, result={result}", + v = DisplayValueAndBits(v), + expected = DisplayValueAndBits(expected), + result = DisplayValueAndBits(result), + ); + } + } } diff --git a/src/f16.rs b/src/f16.rs index b5d84d5..9d81f9e 100644 --- a/src/f16.rs +++ b/src/f16.rs @@ -47,6 +47,18 @@ impl fmt::Display for F16 { } } +impl fmt::LowerExp for F16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f16_impl!(self.0.fmt(f), [f]) + } +} + +impl fmt::UpperExp for F16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f16_impl!(self.0.fmt(f), [f]) + } +} + impl fmt::Debug for F16 { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f16_impl!(self.0.fmt(f), [f]) diff --git a/src/prim.rs b/src/prim.rs index 7ba23e5..39bfa87 100644 --- a/src/prim.rs +++ b/src/prim.rs @@ -135,7 +135,12 @@ impl_int!(u32, i32); impl_int!(u64, i64); pub trait PrimFloat: - PrimBase + ops::Neg + ConvertFrom + ConvertFrom + PrimBase + + ops::Neg + + ConvertFrom + + ConvertFrom + + fmt::LowerExp + + fmt::UpperExp { type BitsType: PrimUInt + ConvertFrom; type SignedBitsType: PrimSInt + ConvertFrom; diff --git a/src/traits.rs b/src/traits.rs index e923e1a..67daa22 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -242,6 +242,14 @@ pub trait Float: (sign_field << sign_shift) | (exponent_field << exponent_shift) | mantissa_field, ) } + fn with_exponent_field(self, exponent_field: Self::BitsType) -> Self { + let exponent_shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); + let not_exponent_mask = self.ctx().make(!Self::PrimFloat::EXPONENT_FIELD_MASK); + Self::from_bits((self.to_bits() & not_exponent_mask) | (exponent_field << exponent_shift)) + } + fn with_exponent_unbiased(self, exponent: Self::SignedBitsType) -> Self { + self.with_exponent_field(Self::add_exponent_bias(exponent)) + } fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType { Self::SignedBitsType::cvt_from(exponent_field) - exponent_field -- 2.30.2