v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to()))
}
+/// computes `sin(pi * x)` for `-0.25 <= x <= 0.25`
+/// not guaranteed to give correct sign for zero result
+/// has an error of up to 2ULP
+pub fn sin_pi_kernel_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+ let x_sq = x * x;
+ let mut v: Ctx::VecF32 = ctx.make(consts::SINPI_KERNEL_TAYLOR_9.to());
+ v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_7.to()));
+ v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_5.to()));
+ v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_3.to()));
+ v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_1.to()));
+ v * x
+}
+
+/// computes `cos(pi * x)` for `-0.25 <= x <= 0.25`
+/// has an error of up to 2ULP
+pub fn cos_pi_kernel_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+ let x_sq = x * x;
+ let mut v: Ctx::VecF32 = ctx.make(consts::COSPI_KERNEL_TAYLOR_8.to());
+ v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_6.to()));
+ v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_4.to()));
+ v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_2.to()));
+ v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to()))
+}
+
/// computes `(sin(pi * x), cos(pi * x))`
/// not guaranteed to give correct sign for zero results
/// inherits error from `sin_pi_kernel` and `cos_pi_kernel`
sin_cos_pi_f16(ctx, x).1
}
+/// computes `(sin(pi * x), cos(pi * x))`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn sin_cos_pi_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> (Ctx::VecF32, Ctx::VecF32) {
+ sin_cos_pi_impl(ctx, x, sin_pi_kernel_f32, cos_pi_kernel_f32)
+}
+
+/// computes `sin(pi * x)`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn sin_pi_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+ sin_cos_pi_f32(ctx, x).0
+}
+
+/// computes `cos(pi * x)`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn cos_pi_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
+ sin_cos_pi_f32(ctx, x).1
+}
+
#[cfg(test)]
mod tests {
use super::*;
}
}
- fn sin_cos_pi_check_ulp_callback_f16(arg: CheckUlpCallbackArg<F16, u64>) -> bool {
- if f32::cvt_from(arg.x) % 0.5 == 0.0 {
+ #[test]
+ #[cfg(feature = "full_tests")]
+ fn test_sin_pi_kernel_f32() {
+ let check = |x| {
+ check_ulp(
+ x,
+ |arg| arg.distance_in_ulp <= if arg.expected == 0.to() { 0 } else { 2 },
+ |x| sin_pi_kernel_f32(Scalar, Value(x)).0,
+ |x| (f64::consts::PI * x).sin(),
+ )
+ };
+ let quarter = 0.25f32.to_bits();
+ for bits in (0..=quarter).rev() {
+ check(f32::from_bits(bits));
+ check(-f32::from_bits(bits));
+ }
+ }
+
+ #[test]
+ #[cfg(feature = "full_tests")]
+ fn test_cos_pi_kernel_f32() {
+ let check = |x| {
+ check_ulp(
+ x,
+ |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.to(),
+ |x| cos_pi_kernel_f32(Scalar, Value(x)).0,
+ |x| (f64::consts::PI * x).cos(),
+ )
+ };
+ let quarter = 0.25f32.to_bits();
+ for bits in (0..=quarter).rev() {
+ check(f32::from_bits(bits));
+ check(-f32::from_bits(bits));
+ }
+ }
+
+ fn sin_cos_pi_check_ulp_callback<F: PrimFloat>(arg: CheckUlpCallbackArg<F, u64>) -> bool {
+ if arg.x % 0.5.to() == 0.0.to() {
arg.distance_in_ulp == 0
} else {
arg.distance_in_ulp <= 2 && arg.result.abs() <= 1.to()
for bits in 0..=u16::MAX {
check_ulp(
F16::from_bits(bits),
- sin_cos_pi_check_ulp_callback_f16,
+ sin_cos_pi_check_ulp_callback,
|x| sin_pi_f16(Scalar, Value(x)).0,
|x| (f64::consts::PI * x).sin(),
);
for bits in 0..=u16::MAX {
check_ulp(
F16::from_bits(bits),
- sin_cos_pi_check_ulp_callback_f16,
+ sin_cos_pi_check_ulp_callback,
|x| cos_pi_f16(Scalar, Value(x)).0,
|x| (f64::consts::PI * x).cos(),
);
}
}
+
+ fn reference_sin_cos_pi_f32(mut v: f64) -> (f64, f64) {
+ if !v.is_finite() {
+ return (f64::NAN, f64::NAN);
+ }
+ v %= 2.0;
+ if v >= 1.0 {
+ v -= 2.0;
+ } else if v <= -1.0 {
+ v += 2.0;
+ }
+ v *= 2.0;
+ let part = v.round() as i32;
+ v -= part as f64;
+ v *= f64::consts::PI / 2.0;
+ let (sin, cos) = v.sin_cos();
+ match part {
+ 0 => (sin, cos),
+ 1 => (cos, -sin),
+ 2 => (-sin, -cos),
+ -2 => (-sin, -cos),
+ -1 => (-cos, sin),
+ _ => panic!("not implemented: part={}", part),
+ }
+ }
+
+ #[test]
+ fn test_reference_sin_cos_pi_f32() {
+ fn approx_same(a: f32, b: f32) -> bool {
+ if a.is_finite() && b.is_finite() {
+ (a - b).abs() < 1e-6
+ } else {
+ a == b || (a.is_nan() && b.is_nan())
+ }
+ }
+ #[track_caller]
+ fn case(x: f32, expected_sin: f32, expected_cos: f32) {
+ let (ref_sin, ref_cos) = reference_sin_cos_pi_f32(x as f64);
+ assert!(
+ approx_same(ref_sin as f32, expected_sin)
+ && approx_same(ref_cos as f32, expected_cos),
+ "case failed: x={x}, expected_sin={expected_sin}, expected_cos={expected_cos}, ref_sin={ref_sin}, ref_cos={ref_cos}",
+ x=x,
+ expected_sin=expected_sin,
+ expected_cos=expected_cos,
+ ref_sin=ref_sin,
+ ref_cos=ref_cos,
+ );
+ }
+ case(f32::NAN, f32::NAN, f32::NAN);
+ case(f32::INFINITY, f32::NAN, f32::NAN);
+ case(-f32::INFINITY, f32::NAN, f32::NAN);
+ case(-4.0, 0.0, 1.0);
+ case(-3.875, 0.3826834323650906, 0.9238795325112864);
+ case(-3.75, 0.7071067811865475, 0.7071067811865475);
+ case(-3.625, 0.9238795325112867, 0.3826834323650898);
+ case(-3.5, 1.0, 0.0);
+ case(-3.375, 0.9238795325112864, -0.3826834323650905);
+ case(-3.25, 0.7071067811865475, -0.7071067811865475);
+ case(-3.125, 0.3826834323650898, -0.9238795325112867);
+ case(-3.0, 0.0, -1.0);
+ case(-2.875, -0.3826834323650905, -0.9238795325112864);
+ case(-2.75, -0.7071067811865475, -0.7071067811865475);
+ case(-2.625, -0.9238795325112867, -0.3826834323650899);
+ case(-2.5, -1.0, 0.0);
+ case(-2.375, -0.9238795325112865, 0.3826834323650904);
+ case(-2.25, -0.7071067811865475, 0.7071067811865475);
+ case(-2.125, -0.3826834323650899, 0.9238795325112867);
+ case(-2.0, 0.0, 1.0);
+ case(-1.875, 0.3826834323650904, 0.9238795325112865);
+ case(-1.75, 0.7071067811865475, 0.7071067811865475);
+ case(-1.625, 0.9238795325112866, 0.38268343236509);
+ case(-1.5, 1.0, 0.0);
+ case(-1.375, 0.9238795325112865, -0.3826834323650903);
+ case(-1.25, 0.7071067811865475, -0.7071067811865475);
+ case(-1.125, 0.3826834323650896, -0.9238795325112869);
+ case(-1.0, 0.0, -1.0);
+ case(-0.875, -0.3826834323650899, -0.9238795325112867);
+ case(-0.75, -0.7071067811865475, -0.7071067811865475);
+ case(-0.625, -0.9238795325112867, -0.3826834323650897);
+ case(-0.5, -1.0, 0.0);
+ case(-0.375, -0.9238795325112867, 0.3826834323650898);
+ case(-0.25, -0.7071067811865475, 0.7071067811865475);
+ case(-0.125, -0.3826834323650898, 0.9238795325112867);
+ case(0.0, 0.0, 1.0);
+ case(0.125, 0.3826834323650898, 0.9238795325112867);
+ case(0.25, 0.7071067811865475, 0.7071067811865475);
+ case(0.375, 0.9238795325112867, 0.3826834323650898);
+ case(0.5, 1.0, 0.0);
+ case(0.625, 0.9238795325112867, -0.3826834323650897);
+ case(0.75, 0.7071067811865475, -0.7071067811865475);
+ case(0.875, 0.3826834323650899, -0.9238795325112867);
+ case(1.0, 0.0, -1.0);
+ case(1.125, -0.3826834323650896, -0.9238795325112869);
+ case(1.25, -0.7071067811865475, -0.7071067811865475);
+ case(1.375, -0.9238795325112865, -0.3826834323650903);
+ case(1.5, -1.0, 0.0);
+ case(1.625, -0.9238795325112866, 0.38268343236509);
+ case(1.75, -0.7071067811865475, 0.7071067811865475);
+ case(1.875, -0.3826834323650904, 0.9238795325112865);
+ case(2.0, 0.0, 1.0);
+ case(2.125, 0.3826834323650899, 0.9238795325112867);
+ case(2.25, 0.7071067811865475, 0.7071067811865475);
+ case(2.375, 0.9238795325112865, 0.3826834323650904);
+ case(2.5, 1.0, 0.0);
+ case(2.625, 0.9238795325112867, -0.3826834323650899);
+ case(2.75, 0.7071067811865475, -0.7071067811865475);
+ case(2.875, 0.3826834323650905, -0.9238795325112864);
+ case(3.0, 0.0, -1.0);
+ case(3.125, -0.3826834323650898, -0.9238795325112867);
+ case(3.25, -0.7071067811865475, -0.7071067811865475);
+ case(3.375, -0.9238795325112864, -0.3826834323650905);
+ case(3.5, -1.0, 0.0);
+ case(3.625, -0.9238795325112867, 0.3826834323650898);
+ case(3.75, -0.7071067811865475, 0.7071067811865475);
+ case(3.875, -0.3826834323650906, 0.9238795325112864);
+ case(4.0, 0.0, 1.0);
+ }
+
+ #[test]
+ #[cfg(feature = "full_tests")]
+ fn test_sin_pi_f32() {
+ for bits in 0..=u32::MAX {
+ check_ulp(
+ f32::from_bits(bits),
+ sin_cos_pi_check_ulp_callback,
+ |x| sin_pi_f32(Scalar, Value(x)).0,
+ |x| reference_sin_cos_pi_f32(x).0,
+ );
+ }
+ }
+
+ #[test]
+ #[cfg(feature = "full_tests")]
+ fn test_cos_pi_f32() {
+ for bits in 0..=u32::MAX {
+ check_ulp(
+ f32::from_bits(bits),
+ sin_cos_pi_check_ulp_callback,
+ |x| cos_pi_f32(Scalar, Value(x)).0,
+ |x| reference_sin_cos_pi_f32(x).1,
+ );
+ }
+ }
}