add sin_pi_kernel_f64 and cos_pi_kernel_f64
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 12 May 2021 03:14:29 +0000 (20:14 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 12 May 2021 03:14:29 +0000 (20:14 -0700)
Cargo.toml
src/algorithms/trig_pi.rs

index bd3fbffade2f78bb18ac1a3b3dda109de5b89c45..a5b7fd4ef293b6c4b9ccce87626ee64696336178 100644 (file)
@@ -17,13 +17,13 @@ f16 = ["half"]
 fma = ["std"]
 std = []
 ir = ["std", "typed-arena"]
-stdsimd = [
-    "core_simd",
-    # for `f32::round` and similar
-    "std",
-]
+stdsimd = ["core_simd", "std"] # for `f32::round` and similar
 # enable slow tests
 full_tests = []
 
+[dev-dependencies]
+rug = "1.12.0"
+az = "1.1.1"
+
 [workspace]
 members = [".", "vector-math-proc-macro"]
index fca45d9d2998ea173e2d9eb8870d2ba822f01b7a..b2a870f5e4a5aa271d63ce78d56b42e9ff08b25f 100644 (file)
@@ -6,6 +6,7 @@ use crate::{
 
 mod consts {
     #![allow(clippy::excessive_precision)]
+    #![allow(dead_code)]
 
     /// coefficients of taylor series for `sin(pi * x)` centered at `0`
     /// generated using:
@@ -109,6 +110,37 @@ pub fn cos_pi_kernel_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32
     v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to()))
 }
 
+/// computes `sin(pi * x)` for `-0.25 <= x <= 0.25`
+/// not guaranteed to give correct sign for zero result
+/// has an error of up to 2ULP
+pub fn sin_pi_kernel_f64<Ctx: Context>(ctx: Ctx, x: Ctx::VecF64) -> Ctx::VecF64 {
+    let x_sq = x * x;
+    let mut v: Ctx::VecF64 = ctx.make(consts::SINPI_KERNEL_TAYLOR_15.to());
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_13.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_11.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_9.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_7.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_5.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_3.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::SINPI_KERNEL_TAYLOR_1.to()));
+    v * x
+}
+
+/// computes `cos(pi * x)` for `-0.25 <= x <= 0.25`
+/// has an error of up to 2ULP
+pub fn cos_pi_kernel_f64<Ctx: Context>(ctx: Ctx, x: Ctx::VecF64) -> Ctx::VecF64 {
+    let x_sq = x * x;
+    let mut v: Ctx::VecF64 = ctx.make(consts::COSPI_KERNEL_TAYLOR_16.to());
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_14.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_12.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_10.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_8.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_6.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_4.to()));
+    v = v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_2.to()));
+    v.mul_add_fast(x_sq, ctx.make(consts::COSPI_KERNEL_TAYLOR_0.to()))
+}
+
 /// computes `(sin(pi * x), cos(pi * x))`
 /// not guaranteed to give correct sign for zero results
 /// inherits error from `sin_pi_kernel` and `cos_pi_kernel`
@@ -309,7 +341,7 @@ mod tests {
         let check = |x| {
             check_ulp(
                 x,
-                |arg| arg.distance_in_ulp <= if arg.expected == 0.to() { 0 } else { 2 },
+                |arg| arg.distance_in_ulp <= if arg.expected == 0. { 0 } else { 2 },
                 |x| sin_pi_kernel_f32(Scalar, Value(x)).0,
                 |x| (f64::consts::PI * x).sin(),
             )
@@ -327,7 +359,7 @@ mod tests {
         let check = |x| {
             check_ulp(
                 x,
-                |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.to(),
+                |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.,
                 |x| cos_pi_kernel_f32(Scalar, Value(x)).0,
                 |x| (f64::consts::PI * x).cos(),
             )
@@ -339,6 +371,42 @@ mod tests {
         }
     }
 
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_sin_pi_kernel_f64() {
+        let check = |x| {
+            check_ulp(
+                x,
+                sin_cos_pi_check_ulp_callback,
+                |x| sin_pi_kernel_f64(Scalar, Value(x)).0,
+                |x| reference_sin_cos_pi_f64(x).0,
+            )
+        };
+        let quarter = 0.25f32.to_bits();
+        for bits in (0..=quarter).rev().step_by(1 << 5) {
+            check(f32::from_bits(bits) as f64);
+            check(-f32::from_bits(bits) as f64);
+        }
+    }
+
+    #[test]
+    #[cfg(feature = "full_tests")]
+    fn test_cos_pi_kernel_f64() {
+        let check = |x| {
+            check_ulp(
+                x,
+                sin_cos_pi_check_ulp_callback,
+                |x| cos_pi_kernel_f64(Scalar, Value(x)).0,
+                |x| reference_sin_cos_pi_f64(x).1,
+            )
+        };
+        let quarter = 0.25f32.to_bits();
+        for bits in (0..=quarter).rev().step_by(1 << 5) {
+            check(f32::from_bits(bits) as f64);
+            check(-f32::from_bits(bits) as f64);
+        }
+    }
+
     fn sin_cos_pi_check_ulp_callback<F: PrimFloat>(arg: CheckUlpCallbackArg<F, u64>) -> bool {
         if arg.x % 0.5.to() == 0.0.to() {
             arg.distance_in_ulp == 0
@@ -404,6 +472,305 @@ mod tests {
         }
     }
 
+    fn reference_sin_cos_pi_f64(mut v: f64) -> (f64, f64) {
+        use az::Cast;
+        use rug::{float::Constant, Float};
+        if !v.is_finite() {
+            return (f64::NAN, f64::NAN);
+        }
+        v %= 2.0;
+        if v >= 1.0 {
+            v -= 2.0;
+        } else if v <= -1.0 {
+            v += 2.0;
+        }
+        v *= 2.0;
+        let part = v.round() as i32;
+        v -= part as f64;
+        let precision = 100;
+        let mut v = Float::with_val(precision, v);
+        let pi = Float::with_val(precision, Constant::Pi);
+        let pi_2 = pi / 2;
+        v *= &pi_2;
+        let cos = pi_2; // just a temp var, value is ignored
+        let (sin, cos) = v.sin_cos(cos);
+        let sin: f64 = sin.cast();
+        let cos: f64 = cos.cast();
+        match part {
+            0 => (sin, cos),
+            1 => (cos, -sin),
+            2 => (-sin, -cos),
+            -2 => (-sin, -cos),
+            -1 => (-cos, sin),
+            _ => panic!("not implemented: part={}", part),
+        }
+    }
+
+    macro_rules! test_reference_sin_cos_pi_test_cases {
+        ($case:expr, $ty:ident) => {
+            $case($ty::NAN, $ty::NAN, $ty::NAN);
+            $case($ty::INFINITY, $ty::NAN, $ty::NAN);
+            $case(-$ty::INFINITY, $ty::NAN, $ty::NAN);
+            $case(-4., 0., 1.);
+            $case(
+                -3.875,
+                0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                -3.75,
+                0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -3.625,
+                0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(-3.5, 1., -0.);
+            $case(
+                -3.375,
+                0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                -3.25,
+                0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -3.125,
+                0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(-3., -0., -1.);
+            $case(
+                -2.875,
+                -0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                -2.75,
+                -0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -2.625,
+                -0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(-2.5, -1., 0.);
+            $case(
+                -2.375,
+                -0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                -2.25,
+                -0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -2.125,
+                -0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(-2., 0., 1.);
+            $case(
+                -1.875,
+                0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                -1.75,
+                0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -1.625,
+                0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(-1.5, 1., -0.);
+            $case(
+                -1.375,
+                0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                -1.25,
+                0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -1.125,
+                0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(-1., -0., -1.);
+            $case(
+                -0.875,
+                -0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                -0.75,
+                -0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -0.625,
+                -0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(-0.5, -1., 0.);
+            $case(
+                -0.375,
+                -0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                -0.25,
+                -0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                -0.125,
+                -0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(0., 0., 1.);
+            $case(
+                0.125,
+                0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                0.25,
+                0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                0.375,
+                0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(0.5, 1., 0.);
+            $case(
+                0.625,
+                0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                0.75,
+                0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                0.875,
+                0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(1., 0., -1.);
+            $case(
+                1.125,
+                -0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                1.25,
+                -0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                1.375,
+                -0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(1.5, -1., -0.);
+            $case(
+                1.625,
+                -0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                1.75,
+                -0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                1.875,
+                -0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(2., -0., 1.);
+            $case(
+                2.125,
+                0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                2.25,
+                0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                2.375,
+                0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(2.5, 1., 0.);
+            $case(
+                2.625,
+                0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                2.75,
+                0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                2.875,
+                0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(3., 0., -1.);
+            $case(
+                3.125,
+                -0.38268343236508977172845998403039886676134456248563,
+                -0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(
+                3.25,
+                -0.70710678118654752440084436210484903928483593768847,
+                -0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                3.375,
+                -0.92387953251128675612818318939678828682241662586364,
+                -0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(3.5, -1., -0.);
+            $case(
+                3.625,
+                -0.92387953251128675612818318939678828682241662586364,
+                0.38268343236508977172845998403039886676134456248563,
+            );
+            $case(
+                3.75,
+                -0.70710678118654752440084436210484903928483593768847,
+                0.70710678118654752440084436210484903928483593768847,
+            );
+            $case(
+                3.875,
+                -0.38268343236508977172845998403039886676134456248563,
+                0.92387953251128675612818318939678828682241662586364,
+            );
+            $case(4., -0., 1.);
+        };
+    }
+
     #[test]
     fn test_reference_sin_cos_pi_f32() {
         fn approx_same(a: f32, b: f32) -> bool {
@@ -427,74 +794,32 @@ mod tests {
                 ref_cos=ref_cos,
             );
         }
-        case(f32::NAN, f32::NAN, f32::NAN);
-        case(f32::INFINITY, f32::NAN, f32::NAN);
-        case(-f32::INFINITY, f32::NAN, f32::NAN);
-        case(-4.0, 0.0, 1.0);
-        case(-3.875, 0.3826834323650906, 0.9238795325112864);
-        case(-3.75, 0.7071067811865475, 0.7071067811865475);
-        case(-3.625, 0.9238795325112867, 0.3826834323650898);
-        case(-3.5, 1.0, 0.0);
-        case(-3.375, 0.9238795325112864, -0.3826834323650905);
-        case(-3.25, 0.7071067811865475, -0.7071067811865475);
-        case(-3.125, 0.3826834323650898, -0.9238795325112867);
-        case(-3.0, 0.0, -1.0);
-        case(-2.875, -0.3826834323650905, -0.9238795325112864);
-        case(-2.75, -0.7071067811865475, -0.7071067811865475);
-        case(-2.625, -0.9238795325112867, -0.3826834323650899);
-        case(-2.5, -1.0, 0.0);
-        case(-2.375, -0.9238795325112865, 0.3826834323650904);
-        case(-2.25, -0.7071067811865475, 0.7071067811865475);
-        case(-2.125, -0.3826834323650899, 0.9238795325112867);
-        case(-2.0, 0.0, 1.0);
-        case(-1.875, 0.3826834323650904, 0.9238795325112865);
-        case(-1.75, 0.7071067811865475, 0.7071067811865475);
-        case(-1.625, 0.9238795325112866, 0.38268343236509);
-        case(-1.5, 1.0, 0.0);
-        case(-1.375, 0.9238795325112865, -0.3826834323650903);
-        case(-1.25, 0.7071067811865475, -0.7071067811865475);
-        case(-1.125, 0.3826834323650896, -0.9238795325112869);
-        case(-1.0, 0.0, -1.0);
-        case(-0.875, -0.3826834323650899, -0.9238795325112867);
-        case(-0.75, -0.7071067811865475, -0.7071067811865475);
-        case(-0.625, -0.9238795325112867, -0.3826834323650897);
-        case(-0.5, -1.0, 0.0);
-        case(-0.375, -0.9238795325112867, 0.3826834323650898);
-        case(-0.25, -0.7071067811865475, 0.7071067811865475);
-        case(-0.125, -0.3826834323650898, 0.9238795325112867);
-        case(0.0, 0.0, 1.0);
-        case(0.125, 0.3826834323650898, 0.9238795325112867);
-        case(0.25, 0.7071067811865475, 0.7071067811865475);
-        case(0.375, 0.9238795325112867, 0.3826834323650898);
-        case(0.5, 1.0, 0.0);
-        case(0.625, 0.9238795325112867, -0.3826834323650897);
-        case(0.75, 0.7071067811865475, -0.7071067811865475);
-        case(0.875, 0.3826834323650899, -0.9238795325112867);
-        case(1.0, 0.0, -1.0);
-        case(1.125, -0.3826834323650896, -0.9238795325112869);
-        case(1.25, -0.7071067811865475, -0.7071067811865475);
-        case(1.375, -0.9238795325112865, -0.3826834323650903);
-        case(1.5, -1.0, 0.0);
-        case(1.625, -0.9238795325112866, 0.38268343236509);
-        case(1.75, -0.7071067811865475, 0.7071067811865475);
-        case(1.875, -0.3826834323650904, 0.9238795325112865);
-        case(2.0, 0.0, 1.0);
-        case(2.125, 0.3826834323650899, 0.9238795325112867);
-        case(2.25, 0.7071067811865475, 0.7071067811865475);
-        case(2.375, 0.9238795325112865, 0.3826834323650904);
-        case(2.5, 1.0, 0.0);
-        case(2.625, 0.9238795325112867, -0.3826834323650899);
-        case(2.75, 0.7071067811865475, -0.7071067811865475);
-        case(2.875, 0.3826834323650905, -0.9238795325112864);
-        case(3.0, 0.0, -1.0);
-        case(3.125, -0.3826834323650898, -0.9238795325112867);
-        case(3.25, -0.7071067811865475, -0.7071067811865475);
-        case(3.375, -0.9238795325112864, -0.3826834323650905);
-        case(3.5, -1.0, 0.0);
-        case(3.625, -0.9238795325112867, 0.3826834323650898);
-        case(3.75, -0.7071067811865475, 0.7071067811865475);
-        case(3.875, -0.3826834323650906, 0.9238795325112864);
-        case(4.0, 0.0, 1.0);
+        test_reference_sin_cos_pi_test_cases!(case, f32);
+    }
+
+    #[test]
+    fn test_reference_sin_cos_pi_f64() {
+        fn same(a: f64, b: f64) -> bool {
+            if a.is_finite() && b.is_finite() {
+                a == b
+            } else {
+                a == b || (a.is_nan() && b.is_nan())
+            }
+        }
+        #[track_caller]
+        fn case(x: f64, expected_sin: f64, expected_cos: f64) {
+            let (ref_sin, ref_cos) = reference_sin_cos_pi_f64(x);
+            assert!(
+                same(ref_sin, expected_sin) && same(ref_cos, expected_cos),
+                "case failed: x={x}, expected_sin={expected_sin}, expected_cos={expected_cos}, ref_sin={ref_sin}, ref_cos={ref_cos}",
+                x=x,
+                expected_sin=expected_sin,
+                expected_cos=expected_cos,
+                ref_sin=ref_sin,
+                ref_cos=ref_cos,
+            );
+        }
+        test_reference_sin_cos_pi_test_cases!(case, f64);
     }
 
     #[test]