From d79f43bed2398cbc4f6b75b8e55ee317289599a1 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 10 May 2021 00:28:02 -0700 Subject: [PATCH] implement sin_cos_pi_f32 --- src/algorithms/trig_pi.rs | 233 +++++++++++++++++++++++++++++++++++++- src/prim.rs | 4 + 2 files changed, 233 insertions(+), 4 deletions(-) diff --git a/src/algorithms/trig_pi.rs b/src/algorithms/trig_pi.rs index a9b275c..fca45d9 100644 --- a/src/algorithms/trig_pi.rs +++ b/src/algorithms/trig_pi.rs @@ -85,6 +85,30 @@ pub fn cos_pi_kernel_f16(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to())) } +/// computes `sin(pi * x)` for `-0.25 <= x <= 0.25` +/// not guaranteed to give correct sign for zero result +/// has an error of up to 2ULP +pub fn sin_pi_kernel_f32(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 { + let x_sq = x * x; + let mut v: Ctx::VecF32 = ctx.make(consts::SINPI_KERNEL_TAYLOR_9.to()); + v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_7.to())); + v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_5.to())); + v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_3.to())); + v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_1.to())); + v * x +} + +/// computes `cos(pi * x)` for `-0.25 <= x <= 0.25` +/// has an error of up to 2ULP +pub fn cos_pi_kernel_f32(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 { + let x_sq = x * x; + let mut v: Ctx::VecF32 = ctx.make(consts::COSPI_KERNEL_TAYLOR_8.to()); + v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_6.to())); + v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_4.to())); + v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_2.to())); + v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to())) +} + /// computes `(sin(pi * x), cos(pi * x))` /// not guaranteed to give correct sign for zero results /// inherits error from `sin_pi_kernel` and `cos_pi_kernel` @@ -153,6 +177,27 @@ pub fn cos_pi_f16(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 { sin_cos_pi_f16(ctx, x).1 } +/// computes `(sin(pi * x), cos(pi * x))` +/// not guaranteed to give correct sign for zero results +/// has an error of up to 2ULP +pub fn sin_cos_pi_f32(ctx: Ctx, x: Ctx::VecF32) -> (Ctx::VecF32, Ctx::VecF32) { + sin_cos_pi_impl(ctx, x, sin_pi_kernel_f32, cos_pi_kernel_f32) +} + +/// computes `sin(pi * x)` +/// not guaranteed to give correct sign for zero results +/// has an error of up to 2ULP +pub fn sin_pi_f32(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 { + sin_cos_pi_f32(ctx, x).0 +} + +/// computes `cos(pi * x)` +/// not guaranteed to give correct sign for zero results +/// has an error of up to 2ULP +pub fn cos_pi_f32(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 { + sin_cos_pi_f32(ctx, x).1 +} + #[cfg(test)] mod tests { use super::*; @@ -258,8 +303,44 @@ mod tests { } } - fn sin_cos_pi_check_ulp_callback_f16(arg: CheckUlpCallbackArg) -> bool { - if f32::cvt_from(arg.x) % 0.5 == 0.0 { + #[test] + #[cfg(feature = "full_tests")] + fn test_sin_pi_kernel_f32() { + let check = |x| { + check_ulp( + x, + |arg| arg.distance_in_ulp <= if arg.expected == 0.to() { 0 } else { 2 }, + |x| sin_pi_kernel_f32(Scalar, Value(x)).0, + |x| (f64::consts::PI * x).sin(), + ) + }; + let quarter = 0.25f32.to_bits(); + for bits in (0..=quarter).rev() { + check(f32::from_bits(bits)); + check(-f32::from_bits(bits)); + } + } + + #[test] + #[cfg(feature = "full_tests")] + fn test_cos_pi_kernel_f32() { + let check = |x| { + check_ulp( + x, + |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.to(), + |x| cos_pi_kernel_f32(Scalar, Value(x)).0, + |x| (f64::consts::PI * x).cos(), + ) + }; + let quarter = 0.25f32.to_bits(); + for bits in (0..=quarter).rev() { + check(f32::from_bits(bits)); + check(-f32::from_bits(bits)); + } + } + + fn sin_cos_pi_check_ulp_callback(arg: CheckUlpCallbackArg) -> bool { + if arg.x % 0.5.to() == 0.0.to() { arg.distance_in_ulp == 0 } else { arg.distance_in_ulp <= 2 && arg.result.abs() <= 1.to() @@ -275,7 +356,7 @@ mod tests { for bits in 0..=u16::MAX { check_ulp( F16::from_bits(bits), - sin_cos_pi_check_ulp_callback_f16, + sin_cos_pi_check_ulp_callback, |x| sin_pi_f16(Scalar, Value(x)).0, |x| (f64::consts::PI * x).sin(), ); @@ -291,10 +372,154 @@ mod tests { for bits in 0..=u16::MAX { check_ulp( F16::from_bits(bits), - sin_cos_pi_check_ulp_callback_f16, + sin_cos_pi_check_ulp_callback, |x| cos_pi_f16(Scalar, Value(x)).0, |x| (f64::consts::PI * x).cos(), ); } } + + fn reference_sin_cos_pi_f32(mut v: f64) -> (f64, f64) { + if !v.is_finite() { + return (f64::NAN, f64::NAN); + } + v %= 2.0; + if v >= 1.0 { + v -= 2.0; + } else if v <= -1.0 { + v += 2.0; + } + v *= 2.0; + let part = v.round() as i32; + v -= part as f64; + v *= f64::consts::PI / 2.0; + let (sin, cos) = v.sin_cos(); + match part { + 0 => (sin, cos), + 1 => (cos, -sin), + 2 => (-sin, -cos), + -2 => (-sin, -cos), + -1 => (-cos, sin), + _ => panic!("not implemented: part={}", part), + } + } + + #[test] + fn test_reference_sin_cos_pi_f32() { + fn approx_same(a: f32, b: f32) -> bool { + if a.is_finite() && b.is_finite() { + (a - b).abs() < 1e-6 + } else { + a == b || (a.is_nan() && b.is_nan()) + } + } + #[track_caller] + fn case(x: f32, expected_sin: f32, expected_cos: f32) { + let (ref_sin, ref_cos) = reference_sin_cos_pi_f32(x as f64); + assert!( + approx_same(ref_sin as f32, expected_sin) + && approx_same(ref_cos as f32, expected_cos), + "case failed: x={x}, expected_sin={expected_sin}, expected_cos={expected_cos}, ref_sin={ref_sin}, ref_cos={ref_cos}", + x=x, + expected_sin=expected_sin, + expected_cos=expected_cos, + ref_sin=ref_sin, + ref_cos=ref_cos, + ); + } + case(f32::NAN, f32::NAN, f32::NAN); + case(f32::INFINITY, f32::NAN, f32::NAN); + case(-f32::INFINITY, f32::NAN, f32::NAN); + case(-4.0, 0.0, 1.0); + case(-3.875, 0.3826834323650906, 0.9238795325112864); + case(-3.75, 0.7071067811865475, 0.7071067811865475); + case(-3.625, 0.9238795325112867, 0.3826834323650898); + case(-3.5, 1.0, 0.0); + case(-3.375, 0.9238795325112864, -0.3826834323650905); + case(-3.25, 0.7071067811865475, -0.7071067811865475); + case(-3.125, 0.3826834323650898, -0.9238795325112867); + case(-3.0, 0.0, -1.0); + case(-2.875, -0.3826834323650905, -0.9238795325112864); + case(-2.75, -0.7071067811865475, -0.7071067811865475); + case(-2.625, -0.9238795325112867, -0.3826834323650899); + case(-2.5, -1.0, 0.0); + case(-2.375, -0.9238795325112865, 0.3826834323650904); + case(-2.25, -0.7071067811865475, 0.7071067811865475); + case(-2.125, -0.3826834323650899, 0.9238795325112867); + case(-2.0, 0.0, 1.0); + case(-1.875, 0.3826834323650904, 0.9238795325112865); + case(-1.75, 0.7071067811865475, 0.7071067811865475); + case(-1.625, 0.9238795325112866, 0.38268343236509); + case(-1.5, 1.0, 0.0); + case(-1.375, 0.9238795325112865, -0.3826834323650903); + case(-1.25, 0.7071067811865475, -0.7071067811865475); + case(-1.125, 0.3826834323650896, -0.9238795325112869); + case(-1.0, 0.0, -1.0); + case(-0.875, -0.3826834323650899, -0.9238795325112867); + case(-0.75, -0.7071067811865475, -0.7071067811865475); + case(-0.625, -0.9238795325112867, -0.3826834323650897); + case(-0.5, -1.0, 0.0); + case(-0.375, -0.9238795325112867, 0.3826834323650898); + case(-0.25, -0.7071067811865475, 0.7071067811865475); + case(-0.125, -0.3826834323650898, 0.9238795325112867); + case(0.0, 0.0, 1.0); + case(0.125, 0.3826834323650898, 0.9238795325112867); + case(0.25, 0.7071067811865475, 0.7071067811865475); + case(0.375, 0.9238795325112867, 0.3826834323650898); + case(0.5, 1.0, 0.0); + case(0.625, 0.9238795325112867, -0.3826834323650897); + case(0.75, 0.7071067811865475, -0.7071067811865475); + case(0.875, 0.3826834323650899, -0.9238795325112867); + case(1.0, 0.0, -1.0); + case(1.125, -0.3826834323650896, -0.9238795325112869); + case(1.25, -0.7071067811865475, -0.7071067811865475); + case(1.375, -0.9238795325112865, -0.3826834323650903); + case(1.5, -1.0, 0.0); + case(1.625, -0.9238795325112866, 0.38268343236509); + case(1.75, -0.7071067811865475, 0.7071067811865475); + case(1.875, -0.3826834323650904, 0.9238795325112865); + case(2.0, 0.0, 1.0); + case(2.125, 0.3826834323650899, 0.9238795325112867); + case(2.25, 0.7071067811865475, 0.7071067811865475); + case(2.375, 0.9238795325112865, 0.3826834323650904); + case(2.5, 1.0, 0.0); + case(2.625, 0.9238795325112867, -0.3826834323650899); + case(2.75, 0.7071067811865475, -0.7071067811865475); + case(2.875, 0.3826834323650905, -0.9238795325112864); + case(3.0, 0.0, -1.0); + case(3.125, -0.3826834323650898, -0.9238795325112867); + case(3.25, -0.7071067811865475, -0.7071067811865475); + case(3.375, -0.9238795325112864, -0.3826834323650905); + case(3.5, -1.0, 0.0); + case(3.625, -0.9238795325112867, 0.3826834323650898); + case(3.75, -0.7071067811865475, 0.7071067811865475); + case(3.875, -0.3826834323650906, 0.9238795325112864); + case(4.0, 0.0, 1.0); + } + + #[test] + #[cfg(feature = "full_tests")] + fn test_sin_pi_f32() { + for bits in 0..=u32::MAX { + check_ulp( + f32::from_bits(bits), + sin_cos_pi_check_ulp_callback, + |x| sin_pi_f32(Scalar, Value(x)).0, + |x| reference_sin_cos_pi_f32(x).0, + ); + } + } + + #[test] + #[cfg(feature = "full_tests")] + fn test_cos_pi_f32() { + for bits in 0..=u32::MAX { + check_ulp( + f32::from_bits(bits), + sin_cos_pi_check_ulp_callback, + |x| cos_pi_f32(Scalar, Value(x)).0, + |x| reference_sin_cos_pi_f32(x).1, + ); + } + } } diff --git a/src/prim.rs b/src/prim.rs index 08ede9e..184e5fc 100644 --- a/src/prim.rs +++ b/src/prim.rs @@ -139,6 +139,7 @@ pub trait PrimFloat: fn is_nan(self) -> bool; fn from_bits(bits: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; + fn abs(self) -> Self; } macro_rules! impl_float { @@ -185,6 +186,9 @@ macro_rules! impl_float { fn to_bits(self) -> Self::BitsType { self.to_bits() } + fn abs(self) -> Self { + $float::abs(self) + } } }; } -- 2.30.2