refactor to easily allow algorithms generic over f16/32/64
[vector-math.git] / src / algorithms / trig_pi.rs
1 use crate::{
2 f16::F16,
3 prim::PrimFloat,prim::PrimSInt,prim::PrimUInt,
4 traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Make, Select},
5 };
6
7 mod consts {
8 #![allow(clippy::excessive_precision)]
9
10 /// coefficients of taylor series for `sin(pi * x)` centered at `0`
11 /// generated using:
12 /// ```maxima,text
13 /// fpprec:50$
14 /// sinpi: bfloat(taylor(sin(%pi*x),x,0,19))$
15 /// for i: 1 step 2 thru 19 do
16 /// printf(true, "pub(crate) const SINPI_KERNEL_TAYLOR_~d: f64 = ~a;~%", i, ssubst("e", "b", string(coeff(sinpi, x, i))))$
17 /// ```
18 pub(crate) const SINPI_KERNEL_TAYLOR_1: f64 =
19 3.1415926535897932384626433832795028841971693993751e0;
20 pub(crate) const SINPI_KERNEL_TAYLOR_3: f64 =
21 -5.1677127800499700292460525111835658670375480943142e0;
22 pub(crate) const SINPI_KERNEL_TAYLOR_5: f64 =
23 2.550164039877345443856177583695296720669172555234e0;
24 pub(crate) const SINPI_KERNEL_TAYLOR_7: f64 =
25 -5.9926452932079207688773938354604004601536358636814e-1;
26 pub(crate) const SINPI_KERNEL_TAYLOR_9: f64 =
27 8.2145886611128228798802365523698344807837460797753e-2;
28 pub(crate) const SINPI_KERNEL_TAYLOR_11: f64 =
29 -7.370430945714350777259089957290781501211638236021e-3;
30 pub(crate) const SINPI_KERNEL_TAYLOR_13: f64 =
31 4.6630280576761256442062891447027174382819981361599e-4;
32 pub(crate) const SINPI_KERNEL_TAYLOR_15: f64 =
33 -2.1915353447830215827384652057094188859248708765956e-5;
34 pub(crate) const SINPI_KERNEL_TAYLOR_17: f64 =
35 7.9520540014755127847832068624575890327682459384282e-7;
36 pub(crate) const SINPI_KERNEL_TAYLOR_19: f64 =
37 -2.2948428997269873110203872385571587856074785581088e-8;
38
39 /// coefficients of taylor series for `cos(pi * x)` centered at `0`
40 /// generated using:
41 /// ```maxima,text
42 /// fpprec:50$
43 /// cospi: bfloat(taylor(cos(%pi*x),x,0,18))$
44 /// for i: 0 step 2 thru 18 do
45 /// printf(true, "pub(crate) const COSPI_KERNEL_TAYLOR_~d: f64 = ~a;~%", i, ssubst("e", "b", string(coeff(cospi, x, i))))$
46 /// ```
47 pub(crate) const COSPI_KERNEL_TAYLOR_0: f64 = 1.0e0;
48 pub(crate) const COSPI_KERNEL_TAYLOR_2: f64 =
49 -4.9348022005446793094172454999380755676568497036204e0;
50 pub(crate) const COSPI_KERNEL_TAYLOR_4: f64 =
51 4.0587121264167682181850138620293796354053160696952e0;
52 pub(crate) const COSPI_KERNEL_TAYLOR_6: f64 =
53 -1.3352627688545894958753047828505831928711354556681e0;
54 pub(crate) const COSPI_KERNEL_TAYLOR_8: f64 =
55 2.3533063035889320454187935277546542154506893530856e-1;
56 pub(crate) const COSPI_KERNEL_TAYLOR_10: f64 =
57 -2.5806891390014060012598294252898849657186441048147e-2;
58 pub(crate) const COSPI_KERNEL_TAYLOR_12: f64 =
59 1.9295743094039230479033455636859576401684718150003e-3;
60 pub(crate) const COSPI_KERNEL_TAYLOR_14: f64 =
61 -1.0463810492484570711801672835223932761029733149091e-4;
62 pub(crate) const COSPI_KERNEL_TAYLOR_16: f64 =
63 4.3030695870329470072978237149669233008960901556009e-6;
64 pub(crate) const COSPI_KERNEL_TAYLOR_18: f64 =
65 -1.387895246221377211446808750399309343777037849978e-7;
66 }
67
68 /// computes `sin(pi * x)` for `-0.25 <= x <= 0.25`
69 /// not guaranteed to give correct sign for zero result
70 /// has an error of up to 2ULP
71 pub fn sin_pi_kernel_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
72 let x_sq = x * x;
73 let mut v: Ctx::VecF16 = ctx.make(consts::SINPI_KERNEL_TAYLOR_5.to());
74 v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_3.to()));
75 v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_1.to()));
76 v * x
77 }
78
79 /// computes `cos(pi * x)` for `-0.25 <= x <= 0.25`
80 /// has an error of up to 2ULP
81 pub fn cos_pi_kernel_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
82 let x_sq = x * x;
83 let mut v: Ctx::VecF16 = ctx.make(consts::COSPI_KERNEL_TAYLOR_4.to());
84 v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_2.to()));
85 v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to()))
86 }
87
88 /// computes `(sin(pi * x), cos(pi * x))`
89 /// not guaranteed to give correct sign for zero results
90 /// inherits error from `sin_pi_kernel` and `cos_pi_kernel`
91 pub fn sin_cos_pi_impl<
92 Ctx: Context,
93 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
94 PrimF: PrimFloat<BitsType = PrimU>,
95 PrimU: PrimUInt,
96 SinPiKernel: FnOnce(Ctx, VecF) -> VecF,
97 CosPiKernel: FnOnce(Ctx, VecF) -> VecF,
98 >(
99 ctx: Ctx,
100 x: VecF,
101 sin_pi_kernel: SinPiKernel,
102 cos_pi_kernel: CosPiKernel,
103 ) -> (VecF, VecF) {
104 let two_f: VecF = ctx.make(2.0.to());
105 let one_half: VecF = ctx.make(0.5.to());
106 let max_contiguous_integer: VecF =
107 ctx.make((PrimU::cvt_from(1) << (PrimF::MANTISSA_FIELD_WIDTH + 1.to())).to());
108 // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer
109 let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range
110 let is_finite = x.is_finite();
111 let nan: VecF = ctx.make(f32::NAN.to());
112 let zero_f: VecF = ctx.make(0.to());
113 let one_f: VecF = ctx.make(1.to());
114 let zero_i: VecF::SignedBitsType = ctx.make(0.to());
115 let one_i: VecF::SignedBitsType = ctx.make(1.to());
116 let two_i: VecF::SignedBitsType = ctx.make(2.to());
117 let out_of_range_sin = is_finite.select(zero_f, nan);
118 let out_of_range_cos = is_finite.select(one_f, nan);
119 let xi = (x * two_f).round();
120 let xk = x - xi * one_half;
121 let sk = sin_pi_kernel(ctx, xk);
122 let ck = cos_pi_kernel(ctx, xk);
123 let xi = VecF::SignedBitsType::cvt_from(xi);
124 let bit_0_clear = (xi & one_i).eq(zero_i);
125 let st = bit_0_clear.select(sk, ck);
126 let ct = bit_0_clear.select(ck, sk);
127 let s = (xi & two_i).eq(zero_i).select(st, -st);
128 let c = ((xi + one_i) & two_i).eq(zero_i).select(ct, -ct);
129 (
130 in_range.select(s, out_of_range_sin),
131 in_range.select(c, out_of_range_cos),
132 )
133 }
134
135 /// computes `(sin(pi * x), cos(pi * x))`
136 /// not guaranteed to give correct sign for zero results
137 /// has an error of up to 2ULP
138 pub fn sin_cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) {
139 sin_cos_pi_impl(ctx, x, sin_pi_kernel_f16, cos_pi_kernel_f16)
140 }
141
142 /// computes `sin(pi * x)`
143 /// not guaranteed to give correct sign for zero results
144 /// has an error of up to 2ULP
145 pub fn sin_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
146 sin_cos_pi_f16(ctx, x).0
147 }
148
149 /// computes `cos(pi * x)`
150 /// not guaranteed to give correct sign for zero results
151 /// has an error of up to 2ULP
152 pub fn cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
153 sin_cos_pi_f16(ctx, x).1
154 }
155
156 #[cfg(test)]
157 mod tests {
158 use super::*;
159 use crate::{
160 f16::F16,
161 scalar::{Scalar, Value},
162 };
163 use std::f64;
164
165 struct CheckUlpCallbackArg<F, I> {
166 distance_in_ulp: I,
167 x: F,
168 expected: F,
169 result: F,
170 }
171
172 #[track_caller]
173 fn check_ulp_f16(
174 x: F16,
175 is_ok: impl Fn(CheckUlpCallbackArg<F16, u32>) -> bool,
176 fn_f16: impl Fn(F16) -> F16,
177 fn_f64: impl Fn(f64) -> f64,
178 ) {
179 let x_f64: f64 = x.to();
180 let expected_f64 = fn_f64(x_f64);
181 let expected: F16 = expected_f64.to();
182 let result = fn_f16(x);
183 if result == expected {
184 return;
185 }
186 if result.is_nan() && expected.is_nan() {
187 return;
188 }
189 let distance_in_ulp = (expected.to_bits() as i32 - result.to_bits() as i32).unsigned_abs();
190 if !result.is_nan()
191 && !expected.is_nan()
192 && is_ok(CheckUlpCallbackArg {
193 distance_in_ulp,
194 x,
195 expected,
196 result,
197 })
198 {
199 return;
200 }
201 panic!(
202 "error is too big: \
203 x = {x:?} {x_bits:#X}, \
204 result = {result:?} {result_bits:#X}, \
205 expected = {expected:?} {expected_bits:#X}, \
206 distance_in_ulp = {distance_in_ulp}",
207 x = x,
208 x_bits = x.to_bits(),
209 result = result,
210 result_bits = result.to_bits(),
211 expected = expected,
212 expected_bits = expected.to_bits(),
213 distance_in_ulp = distance_in_ulp,
214 );
215 }
216
217 #[test]
218 #[cfg_attr(
219 not(feature = "f16"),
220 should_panic(expected = "f16 feature is not enabled")
221 )]
222 fn test_sin_pi_kernel_f16() {
223 let check = |x| {
224 check_ulp_f16(
225 x,
226 |arg| arg.distance_in_ulp <= if arg.expected == 0.to() { 0 } else { 2 },
227 |x| sin_pi_kernel_f16(Scalar, Value(x)).0,
228 |x| (f64::consts::PI * x).sin(),
229 )
230 };
231 let quarter = F16::to_bits(0.25f32.to());
232 for bits in (0..=quarter).rev() {
233 check(F16::from_bits(bits));
234 check(-F16::from_bits(bits));
235 }
236 }
237
238 #[test]
239 #[cfg_attr(
240 not(feature = "f16"),
241 should_panic(expected = "f16 feature is not enabled")
242 )]
243 fn test_cos_pi_kernel_f16() {
244 let check = |x| {
245 check_ulp_f16(
246 x,
247 |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.to(),
248 |x| cos_pi_kernel_f16(Scalar, Value(x)).0,
249 |x| (f64::consts::PI * x).cos(),
250 )
251 };
252 let quarter = F16::to_bits(0.25f32.to());
253 for bits in (0..=quarter).rev() {
254 check(F16::from_bits(bits));
255 check(-F16::from_bits(bits));
256 }
257 }
258
259 fn sin_cos_pi_check_ulp_callback_f16(arg: CheckUlpCallbackArg<F16, u32>) -> bool {
260 if f32::cvt_from(arg.x) % 0.5 == 0.0 {
261 arg.distance_in_ulp == 0
262 } else {
263 arg.distance_in_ulp <= 2 && arg.result.abs() <= 1.to()
264 }
265 }
266
267 #[test]
268 #[cfg_attr(
269 not(feature = "f16"),
270 should_panic(expected = "f16 feature is not enabled")
271 )]
272 fn test_sin_pi_f16() {
273 for bits in 0..=u16::MAX {
274 check_ulp_f16(
275 F16::from_bits(bits),
276 sin_cos_pi_check_ulp_callback_f16,
277 |x| sin_pi_f16(Scalar, Value(x)).0,
278 |x| (f64::consts::PI * x).sin(),
279 );
280 }
281 }
282
283 #[test]
284 #[cfg_attr(
285 not(feature = "f16"),
286 should_panic(expected = "f16 feature is not enabled")
287 )]
288 fn test_cos_pi_f16() {
289 for bits in 0..=u16::MAX {
290 check_ulp_f16(
291 F16::from_bits(bits),
292 sin_cos_pi_check_ulp_callback_f16,
293 |x| cos_pi_f16(Scalar, Value(x)).0,
294 |x| (f64::consts::PI * x).cos(),
295 );
296 }
297 }
298 }