add sqrt_fast_f16/f32/f64
[vector-math.git] / src / algorithms / roots.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, Select, UInt},
4 };
5
6 pub fn initial_rsqrt_approximation<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
9 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
10 PrimF: PrimFloat<BitsType = PrimU>,
11 PrimU: PrimUInt,
12 >(
13 ctx: Ctx,
14 v: VecF,
15 ) -> VecF {
16 // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
17 // where `CONST` is optimized for use for Goldschmidt's algorithm.
18 // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
19 // but using different constants.
20 const FACTOR: f64 = -0.5;
21 const TERM: f64 = 1.6;
22 v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
23 }
24
25 /// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
26 pub fn sqrt_rsqrt_kernel_fast<
27 Ctx: Context,
28 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
29 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
30 PrimF: PrimFloat<BitsType = PrimU>,
31 PrimU: PrimUInt,
32 >(
33 ctx: Ctx,
34 v: VecF,
35 iteration_count: usize,
36 ) -> (VecF, VecF) {
37 // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
38 let y = initial_rsqrt_approximation(ctx, v);
39 let mut x = v * y;
40 let one_half: VecF = ctx.make(0.5.to());
41 let mut neg_h = y * -one_half;
42 for _ in 0..iteration_count {
43 let r = x.mul_add_fast(neg_h, one_half);
44 x = x.mul_add_fast(r, x);
45 neg_h = neg_h.mul_add_fast(r, neg_h);
46 }
47 let sqrt = x;
48 let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2));
49 (sqrt, rsqrt)
50 }
51
52 /// calculate `sqrt(v)`, error inherited from `kernel_fn`. Calls `kernel_fn` with inputs in the range `0.5 <= x < 2.0`.
53 pub fn sqrt_impl<
54 Ctx: Context,
55 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
56 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
57 PrimF: PrimFloat<BitsType = PrimU>,
58 PrimU: PrimUInt,
59 KernelFn: FnOnce(Ctx, VecF) -> VecF,
60 >(
61 ctx: Ctx,
62 v: VecF,
63 kernel_fn: KernelFn,
64 ) -> VecF {
65 let is_normal_case = v.gt(ctx.make(0.0.to())) & v.is_finite();
66 let is_zero_or_positive = v.ge(ctx.make(0.0.to()));
67 let exceptional_retval = is_zero_or_positive.select(v, VecF::nan(ctx));
68 let need_subnormal_scale = v.is_zero_or_subnormal();
69 let subnormal_result_scale_exponent: PrimU = (PrimF::MANTISSA_FIELD_WIDTH + 1.to()) / 2.to();
70 let subnormal_input_scale_exponent = subnormal_result_scale_exponent * 2.to();
71 let subnormal_result_scale_prim =
72 PrimF::cvt_from(1) / PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_result_scale_exponent);
73 let subnormal_input_scale_prim =
74 PrimF::cvt_from(PrimU::cvt_from(1) << subnormal_input_scale_exponent);
75 let subnormal_result_scale: VecF =
76 need_subnormal_scale.select(ctx.make(subnormal_result_scale_prim), ctx.make(1.0.to()));
77 let subnormal_input_scale: VecF =
78 need_subnormal_scale.select(ctx.make(subnormal_input_scale_prim), ctx.make(1.0.to()));
79 let v = v * subnormal_input_scale;
80 let exponent_field = v.extract_exponent_field();
81 let normal_result_scale_exponent_field_offset: PrimU =
82 PrimF::EXPONENT_BIAS_UNSIGNED - (PrimF::EXPONENT_BIAS_UNSIGNED >> 1.to());
83 let shifted_exponent_field = exponent_field >> ctx.make(1.to());
84 let normal_result_scale_exponent_field =
85 shifted_exponent_field + ctx.make(normal_result_scale_exponent_field_offset);
86 let normal_result_scale = ctx
87 .make::<VecF>(1.to())
88 .with_exponent_field(normal_result_scale_exponent_field);
89 let v = v.with_exponent_field(
90 (exponent_field & ctx.make(1.to()))
91 | ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED & !PrimU::cvt_from(1)),
92 );
93 let normal_result = kernel_fn(ctx, v) * (normal_result_scale * subnormal_result_scale);
94 is_normal_case.select(normal_result, exceptional_retval)
95 }
96
97 /// computes `sqrt(x)`
98 /// has an error of up to 2ULP
99 pub fn sqrt_fast_f16<Ctx: Context>(ctx: Ctx, v: Ctx::VecF16) -> Ctx::VecF16 {
100 sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 3).0)
101 }
102
103 /// computes `sqrt(x)`
104 /// has an error of up to 3ULP
105 pub fn sqrt_fast_f32<Ctx: Context>(ctx: Ctx, v: Ctx::VecF32) -> Ctx::VecF32 {
106 sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 4).0)
107 }
108
109 /// computes `sqrt(x)`
110 /// has an error of up to 2ULP
111 pub fn sqrt_fast_f64<Ctx: Context>(ctx: Ctx, v: Ctx::VecF64) -> Ctx::VecF64 {
112 sqrt_impl(ctx, v, |ctx, v| sqrt_rsqrt_kernel_fast(ctx, v, 5).0)
113 }
114
115 #[cfg(test)]
116 mod tests {
117 use super::*;
118 use crate::{
119 f16::F16,
120 scalar::{Scalar, Value},
121 };
122 use std::fmt;
123
124 fn same<F: PrimFloat>(a: F, b: F) -> bool {
125 if a.is_finite() && b.is_finite() {
126 a.to_bits() == b.to_bits()
127 } else {
128 a == b || (a.is_nan() && b.is_nan())
129 }
130 }
131
132 fn approx_same<F: PrimFloat>(a: F, b: F, max_ulp: F::BitsType) -> bool {
133 if a.is_finite() && b.is_finite() {
134 let a = a.to_bits();
135 let b = b.to_bits();
136 let ulp = if a < b { b - a } else { a - b };
137 ulp <= max_ulp
138 } else {
139 a == b || (a.is_nan() && b.is_nan())
140 }
141 }
142
143 struct DisplayValueAndBits<F: PrimFloat>(F);
144
145 impl<F: PrimFloat> fmt::Display for DisplayValueAndBits<F> {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 write!(f, "{:e} {:#X}", self.0, self.0.to_bits())
148 }
149 }
150
151 fn reference_sqrt_f64(v: f64) -> f64 {
152 use az::Cast;
153 use rug::Float;
154 if v.is_nan() || v < 0.0 {
155 f64::NAN
156 } else if v == 0.0 || !v.is_finite() {
157 v
158 } else {
159 let precision = 100;
160 Float::with_val(precision, v).sqrt().cast()
161 }
162 }
163
164 #[test]
165 #[cfg(not(test))] // TODO: broken for now -- create sqrt_rsqrt_kernel variant using float-float arithmetic
166 #[cfg_attr(
167 not(feature = "f16"),
168 should_panic(expected = "f16 feature is not enabled")
169 )]
170 fn test_sqrt_rsqrt_kernel_f16() {
171 let start = F16::to_bits(0.5.to());
172 let end = F16::to_bits(2.0.to());
173 for bits in start..=end {
174 let v = F16::from_bits(bits);
175 let expected_sqrt: F16 = f64::from(v).sqrt().to();
176 let expected_rsqrt: F16 = (1.0 / f64::from(v).sqrt()).to();
177 let (Value(result_sqrt), Value(result_rsqrt)) = sqrt_rsqrt_kernel(Scalar, Value(v), 3);
178 assert!(
179 same(expected_sqrt, result_sqrt) && same(expected_rsqrt, result_rsqrt),
180 "case failed: v={v}, expected_sqrt={expected_sqrt}, result_sqrt={result_sqrt}, expected_rsqrt={expected_rsqrt}, result_rsqrt={result_rsqrt}",
181 v = DisplayValueAndBits(v),
182 expected_sqrt = DisplayValueAndBits(expected_sqrt),
183 result_sqrt = DisplayValueAndBits(result_sqrt),
184 expected_rsqrt = DisplayValueAndBits(expected_rsqrt),
185 result_rsqrt = DisplayValueAndBits(result_rsqrt),
186 );
187 }
188 }
189
190 #[test]
191 #[cfg_attr(
192 not(feature = "f16"),
193 should_panic(expected = "f16 feature is not enabled")
194 )]
195 fn test_sqrt_fast_f16() {
196 for bits in 0..=u16::MAX {
197 let v = F16::from_bits(bits);
198 let expected: F16 = f64::from(v).sqrt().to();
199 let result = sqrt_fast_f16(Scalar, Value(v)).0;
200 assert!(
201 approx_same(expected, result, 2),
202 "case failed: v={v}, expected={expected}, result={result}",
203 v = DisplayValueAndBits(v),
204 expected = DisplayValueAndBits(expected),
205 result = DisplayValueAndBits(result),
206 );
207 }
208 }
209
210 #[test]
211 #[cfg(feature = "full_tests")]
212 fn test_sqrt_fast_f32() {
213 for bits in (0..=u32::MAX).step_by(1 << 8) {
214 let v = f32::from_bits(bits);
215 let expected: f32 = f64::from(v).sqrt().to();
216 let result = sqrt_fast_f32(Scalar, Value(v)).0;
217 assert!(
218 approx_same(expected, result, 3),
219 "case failed: v={v}, expected={expected}, result={result}",
220 v = DisplayValueAndBits(v),
221 expected = DisplayValueAndBits(expected),
222 result = DisplayValueAndBits(result),
223 );
224 }
225 }
226
227 #[test]
228 #[cfg(feature = "full_tests")]
229 fn test_sqrt_fast_f64() {
230 for bits in (0..=u64::MAX).step_by(1 << 40) {
231 let v = f64::from_bits(bits);
232 let expected: f64 = reference_sqrt_f64(v);
233 let result = sqrt_fast_f64(Scalar, Value(v)).0;
234 assert!(
235 approx_same(expected, result, 2),
236 "case failed: v={v}, expected={expected}, result={result}",
237 v = DisplayValueAndBits(v),
238 expected = DisplayValueAndBits(expected),
239 result = DisplayValueAndBits(result),
240 );
241 }
242 }
243 }