refactor to easily allow algorithms generic over f16/32/64
[vector-math.git] / src / algorithms / trig_pi.rs
index 38104a655ee23870c5c04cfc7c1dc259d6cd32f9..5b07c2a3994593cc35f61badd9d866c4f8fc3af3 100644 (file)
@@ -1,7 +1,7 @@
 use crate::{
     f16::F16,
-    ieee754::FloatEncoding,
-    traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Select},
+    prim::PrimFloat,prim::PrimSInt,prim::PrimUInt,
+    traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Make, Select},
 };
 
 mod consts {
@@ -87,39 +87,58 @@ pub fn cos_pi_kernel_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16
 
 /// computes `(sin(pi * x), cos(pi * x))`
 /// not guaranteed to give correct sign for zero results
-/// has an error of up to 2ULP
-pub fn sin_cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) {
-    let two_f16: Ctx::VecF16 = ctx.make(2.0.to());
-    let one_half: Ctx::VecF16 = ctx.make(0.5.to());
-    let max_contiguous_integer: Ctx::VecF16 =
-        ctx.make((1u16 << (F16::MANTISSA_FIELD_WIDTH + 1)).to());
+/// inherits error from `sin_pi_kernel` and `cos_pi_kernel`
+pub fn sin_cos_pi_impl<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+    SinPiKernel: FnOnce(Ctx, VecF) -> VecF,
+    CosPiKernel: FnOnce(Ctx, VecF) -> VecF,
+>(
+    ctx: Ctx,
+    x: VecF,
+    sin_pi_kernel: SinPiKernel,
+    cos_pi_kernel: CosPiKernel,
+) -> (VecF, VecF) {
+    let two_f: VecF = ctx.make(2.0.to());
+    let one_half: VecF = ctx.make(0.5.to());
+    let max_contiguous_integer: VecF =
+        ctx.make((PrimU::cvt_from(1) << (PrimF::MANTISSA_FIELD_WIDTH + 1.to())).to());
     // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer
     let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range
     let is_finite = x.is_finite();
-    let nan: Ctx::VecF16 = ctx.make(f32::NAN.to());
-    let zero_f16: Ctx::VecF16 = ctx.make(0.to());
-    let one_f16: Ctx::VecF16 = ctx.make(1.to());
-    let zero_i16: Ctx::VecI16 = ctx.make(0.to());
-    let one_i16: Ctx::VecI16 = ctx.make(1.to());
-    let two_i16: Ctx::VecI16 = ctx.make(2.to());
-    let out_of_range_sin = is_finite.select(zero_f16, nan);
-    let out_of_range_cos = is_finite.select(one_f16, nan);
-    let xi = (x * two_f16).round();
+    let nan: VecF = ctx.make(f32::NAN.to());
+    let zero_f: VecF = ctx.make(0.to());
+    let one_f: VecF = ctx.make(1.to());
+    let zero_i: VecF::SignedBitsType = ctx.make(0.to());
+    let one_i: VecF::SignedBitsType = ctx.make(1.to());
+    let two_i: VecF::SignedBitsType = ctx.make(2.to());
+    let out_of_range_sin = is_finite.select(zero_f, nan);
+    let out_of_range_cos = is_finite.select(one_f, nan);
+    let xi = (x * two_f).round();
     let xk = x - xi * one_half;
-    let sk = sin_pi_kernel_f16(ctx, xk);
-    let ck = cos_pi_kernel_f16(ctx, xk);
-    let xi = Ctx::VecI16::cvt_from(xi);
-    let bit_0_clear = (xi & one_i16).eq(zero_i16);
+    let sk = sin_pi_kernel(ctx, xk);
+    let ck = cos_pi_kernel(ctx, xk);
+    let xi = VecF::SignedBitsType::cvt_from(xi);
+    let bit_0_clear = (xi & one_i).eq(zero_i);
     let st = bit_0_clear.select(sk, ck);
     let ct = bit_0_clear.select(ck, sk);
-    let s = (xi & two_i16).eq(zero_i16).select(st, -st);
-    let c = ((xi + one_i16) & two_i16).eq(zero_i16).select(ct, -ct);
+    let s = (xi & two_i).eq(zero_i).select(st, -st);
+    let c = ((xi + one_i) & two_i).eq(zero_i).select(ct, -ct);
     (
         in_range.select(s, out_of_range_sin),
         in_range.select(c, out_of_range_cos),
     )
 }
 
+/// computes `(sin(pi * x), cos(pi * x))`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn sin_cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) {
+    sin_cos_pi_impl(ctx, x, sin_pi_kernel_f16, cos_pi_kernel_f16)
+}
+
 /// computes `sin(pi * x)`
 /// not guaranteed to give correct sign for zero results
 /// has an error of up to 2ULP