use crate::{
prim::{PrimFloat, PrimUInt},
- traits::{Context, ConvertTo, Float, Make, UInt},
+ traits::{Context, ConvertTo, Float, Make, Select, UInt},
};
pub fn initial_rsqrt_approximation<
v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
}
-/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method
-/// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
-pub fn sqrt_rsqrt_kernel<
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+pub fn sqrt_rsqrt_kernel_fast<
Ctx: Context,
VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
let y = initial_rsqrt_approximation(ctx, v);
let mut x = v * y;
let one_half: VecF = ctx.make(0.5.to());
- let three_halves: VecF = ctx.make((3.0 / 2.0).to());
let mut neg_h = y * -one_half;
for _ in 0..iteration_count {
let r = x.mul_add_fast(neg_h, one_half);
}
let sqrt = x;
let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2));
- // do one step of Newton's method
- let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt));
(sqrt, rsqrt)
}
+/// calculate `sqrt(v)`, error inherited from `kernel_fn`. Calls `kernel_fn` with inputs in the range `0.5 <= x < 2.0`.
+pub fn sqrt_impl<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+ KernelFn: FnOnce(Ctx, VecF) -> VecF,
+>(
+ ctx: Ctx,
+ v: VecF,
+ kernel_fn: KernelFn,
+) -> VecF {
+ let is_normal_case = v.gt(ctx.make(0.0.to())) & v.is_finite();
+ let is_zero_or_positive = v.ge(ctx.make(0.0.to()));
+ let exceptional_retval = is_zero_or_positive.select(v, VecF::nan(ctx));
+ let need_subnormal_scale = v.is_zero_or_subnormal();
+ let subnormal_result_scale_exponent: PrimU = (PrimF::MANTISSA_FIELD_WIDTH + 1.to()) / 2.to();
+ let subnormal_input_scale_exponent = subnormal_result_scale_exponent * 2.to();
+ let subnormal_result_scale_prim =
+ PrimF::cvt_from(1) / PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_result_scale_exponent);
+ let subnormal_input_scale_prim =
+ PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_input_scale_exponent);
+ let subnormal_result_scale: VecF =
+ need_subnormal_scale.select(ctx.make(subnormal_result_scale_prim), ctx.make(1.0.to()));
+ let subnormal_input_scale: VecF =
+ need_subnormal_scale.select(ctx.make(subnormal_input_scale_prim), ctx.make(1.0.to()));
+ let v = v * subnormal_input_scale;
+ let exponent_field = v.extract_exponent_field();
+ let normal_result_scale_exponent_field_offset: PrimU =
+ PrimF::EXPONENT_BIAS_UNSIGNED - (PrimF::EXPONENT_BIAS_UNSIGNED >> 1.to());
+ let shifted_exponent_field = exponent_field >> ctx.make(1.to());
+ let normal_result_scale_exponent_field =
+ shifted_exponent_field + ctx.make(normal_result_scale_exponent_field_offset);
+ let normal_result_scale = ctx
+ .make::<VecF>(1.to())
+ .with_exponent_field(normal_result_scale_exponent_field);
+ let v = v.with_exponent_field(
+ (exponent_field & ctx.make(1.to()))
+ | ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED & !PrimU::cvt_from(1)),
+ );
+ let normal_result = kernel_fn(ctx, v) * (normal_result_scale * subnormal_result_scale);
+ is_normal_case.select(normal_result, exceptional_retval)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 2ULP
+pub fn sqrt_fast_f16<Ctx: Context>(ctx: Ctx, v: Ctx::VecF16) -> Ctx::VecF16 {
+ sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 3).0)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 3ULP
+pub fn sqrt_fast_f32<Ctx: Context>(ctx: Ctx, v: Ctx::VecF32) -> Ctx::VecF32 {
+ sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 4).0)
+}
+
+/// computes `sqrt(x)`
+/// has an error of up to 2ULP
+pub fn sqrt_fast_f64<Ctx: Context>(ctx: Ctx, v: Ctx::VecF64) -> Ctx::VecF64 {
+ sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 5).0)
+}
+
#[cfg(test)]
mod tests {
use super::*;
}
}
+ fn approx_same<F: PrimFloat>(a: F, b: F, max_ulp: F::BitsType) -> bool {
+ if a.is_finite() && b.is_finite() {
+ let a = a.to_bits();
+ let b = b.to_bits();
+ let ulp = if a < b { b - a } else { a - b };
+ ulp <= max_ulp
+ } else {
+ a == b || (a.is_nan() && b.is_nan())
+ }
+ }
+
struct DisplayValueAndBits<F: PrimFloat>(F);
impl<F: PrimFloat> fmt::Display for DisplayValueAndBits<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "{} {:#X}", self.0, self.0.to_bits())
+ write!(f, "{:e} {:#X}", self.0, self.0.to_bits())
+ }
+ }
+
+ fn reference_sqrt_f64(v: f64) -> f64 {
+ use az::Cast;
+ use rug::Float;
+ if v.is_nan() || v < 0.0 {
+ f64::NAN
+ } else if v == 0.0 || !v.is_finite() {
+ v
+ } else {
+ let precision = 100;
+ Float::with_val(precision, v).sqrt().cast()
}
}
#[test]
- #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
+ #[cfg(not(test))] // TODO: broken for now -- create sqrt_rsqrt_kernel variant using float-float arithmetic
#[cfg_attr(
not(feature = "f16"),
should_panic(expected = "f16 feature is not enabled")
);
}
}
+
+ #[test]
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_sqrt_fast_f16() {
+ for bits in 0..=u16::MAX {
+ let v = F16::from_bits(bits);
+ let expected: F16 = f64::from(v).sqrt().to();
+ let result = sqrt_fast_f16(Scalar, Value(v)).0;
+ assert!(
+ approx_same(expected, result, 2),
+ "case failed: v={v}, expected={expected}, result={result}",
+ v = DisplayValueAndBits(v),
+ expected = DisplayValueAndBits(expected),
+ result = DisplayValueAndBits(result),
+ );
+ }
+ }
+
+ #[test]
+ #[cfg(feature = "full_tests")]
+ fn test_sqrt_fast_f32() {
+ for bits in (0..=u32::MAX).step_by(1 << 8) {
+ let v = f32::from_bits(bits);
+ let expected: f32 = f64::from(v).sqrt().to();
+ let result = sqrt_fast_f32(Scalar, Value(v)).0;
+ assert!(
+ approx_same(expected, result, 3),
+ "case failed: v={v}, expected={expected}, result={result}",
+ v = DisplayValueAndBits(v),
+ expected = DisplayValueAndBits(expected),
+ result = DisplayValueAndBits(result),
+ );
+ }
+ }
+
+ #[test]
+ #[cfg(feature = "full_tests")]
+ fn test_sqrt_fast_f64() {
+ for bits in (0..=u64::MAX).step_by(1 << 40) {
+ let v = f64::from_bits(bits);
+ let expected: f64 = reference_sqrt_f64(v);
+ let result = sqrt_fast_f64(Scalar, Value(v)).0;
+ assert!(
+ approx_same(expected, result, 2),
+ "case failed: v={v}, expected={expected}, result={result}",
+ v = DisplayValueAndBits(v),
+ expected = DisplayValueAndBits(expected),
+ result = DisplayValueAndBits(result),
+ );
+ }
+ }
}