add count_leading_zeros, count_trailing_zeros, and count_ones implementations
[vector-math.git] / src / algorithms / base.rs
index 56691a3e136a035077352447a9a9c094d518300a..4ebd8493ee3d7dc12e0753e469db896253e6f14f 100644 (file)
-use crate::traits::{Context, Float};
+use crate::{
+    prim::{PrimFloat, PrimUInt},
+    traits::{Context, ConvertTo, Float, Make, Select, UInt},
+};
 
-pub fn abs_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
-    Ctx::VecF16::from_bits(x.to_bits() & ctx.make(0x7FFFu16))
+pub fn abs<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    x: VecF,
+) -> VecF {
+    VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
 }
 
-pub fn abs_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
-    Ctx::VecF32::from_bits(x.to_bits() & ctx.make(!(1u32 << 31)))
+pub fn copy_sign<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    mag: VecF,
+    sign: VecF,
+) -> VecF {
+    let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
+    let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
+    VecF::from_bits(mag_bits | sign_bit)
 }
 
-pub fn abs_f64<Ctx: Context>(ctx: Ctx, x: Ctx::VecF64) -> Ctx::VecF64 {
-    Ctx::VecF64::from_bits(x.to_bits() & ctx.make(!(1u64 << 63)))
+pub fn trunc<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
+    let out_of_range = big | small;
+    let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
+    let out_of_range_value = small.select(small_value, v);
+    let exponent_field = v.extract_exponent_field();
+    let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
+    let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
+    mask >>= right_shift_amount;
+    let in_range_value = VecF::from_bits(v.to_bits() & !mask);
+    out_of_range.select(out_of_range_value, in_range_value)
+}
+
+pub fn round_to_nearest_ties_to_even<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
+    let offset_value: VecF = v.abs() + offset;
+    let in_range_value = (offset_value - offset).copy_sign(v);
+    big.select(v, in_range_value)
+}
+
+pub fn floor<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
+    let offset_value: VecF = v.abs() + offset;
+    let rounded = (offset_value - offset).copy_sign(v);
+    let need_round_down = v.lt(rounded);
+    let in_range_value = need_round_down
+        .select(rounded - ctx.make(1.to()), rounded)
+        .copy_sign(v);
+    big.select(v, in_range_value)
+}
+
+pub fn ceil<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
+    let offset_value: VecF = v.abs() + offset;
+    let rounded = (offset_value - offset).copy_sign(v);
+    let need_round_up = v.gt(rounded);
+    let in_range_value = need_round_up
+        .select(rounded + ctx.make(1.to()), rounded)
+        .copy_sign(v);
+    big.select(v, in_range_value)
 }
 
 #[cfg(test)]
@@ -17,7 +121,9 @@ mod tests {
     use super::*;
     use crate::{
         f16::F16,
+        prim::PrimSInt,
         scalar::{Scalar, Value},
+        traits::ConvertFrom,
     };
 
     #[test]
@@ -29,7 +135,7 @@ mod tests {
         for bits in 0..=u16::MAX {
             let v = F16::from_bits(bits);
             let expected = v.abs();
-            let result = abs_f16(Scalar, Value(v)).0;
+            let result = abs(Scalar, Value(v)).0;
             assert_eq!(expected.to_bits(), result.to_bits());
         }
     }
@@ -39,7 +145,7 @@ mod tests {
         for bits in (0..=u32::MAX).step_by(10001) {
             let v = f32::from_bits(bits);
             let expected = v.abs();
-            let result = abs_f32(Scalar, Value(v)).0;
+            let result = abs(Scalar, Value(v)).0;
             assert_eq!(expected.to_bits(), result.to_bits());
         }
     }
@@ -49,8 +155,417 @@ mod tests {
         for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
             let v = f64::from_bits(bits);
             let expected = v.abs();
-            let result = abs_f64(Scalar, Value(v)).0;
+            let result = abs(Scalar, Value(v)).0;
+            assert_eq!(expected.to_bits(), result.to_bits());
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_copy_sign_f16() {
+        #[track_caller]
+        fn check(mag_bits: u16, sign_bits: u16) {
+            let mag = F16::from_bits(mag_bits);
+            let sign = F16::from_bits(sign_bits);
+            let expected = mag.copysign(sign);
+            let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
             assert_eq!(expected.to_bits(), result.to_bits());
         }
+        for mag_low_bits in 0..16 {
+            for mag_high_bits in 0..16 {
+                for sign_low_bits in 0..16 {
+                    for sign_high_bits in 0..16 {
+                        check(
+                            mag_low_bits | (mag_high_bits << (16 - 4)),
+                            sign_low_bits | (sign_high_bits << (16 - 4)),
+                        );
+                    }
+                }
+            }
+        }
+    }
+
+    #[test]
+    fn test_copy_sign_f32() {
+        #[track_caller]
+        fn check(mag_bits: u32, sign_bits: u32) {
+            let mag = f32::from_bits(mag_bits);
+            let sign = f32::from_bits(sign_bits);
+            let expected = mag.copysign(sign);
+            let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
+            assert_eq!(expected.to_bits(), result.to_bits());
+        }
+        for mag_low_bits in 0..16 {
+            for mag_high_bits in 0..16 {
+                for sign_low_bits in 0..16 {
+                    for sign_high_bits in 0..16 {
+                        check(
+                            mag_low_bits | (mag_high_bits << (32 - 4)),
+                            sign_low_bits | (sign_high_bits << (32 - 4)),
+                        );
+                    }
+                }
+            }
+        }
+    }
+
+    #[test]
+    fn test_copy_sign_f64() {
+        #[track_caller]
+        fn check(mag_bits: u64, sign_bits: u64) {
+            let mag = f64::from_bits(mag_bits);
+            let sign = f64::from_bits(sign_bits);
+            let expected = mag.copysign(sign);
+            let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
+            assert_eq!(expected.to_bits(), result.to_bits());
+        }
+        for mag_low_bits in 0..16 {
+            for mag_high_bits in 0..16 {
+                for sign_low_bits in 0..16 {
+                    for sign_high_bits in 0..16 {
+                        check(
+                            mag_low_bits | (mag_high_bits << (64 - 4)),
+                            sign_low_bits | (sign_high_bits << (64 - 4)),
+                        );
+                    }
+                }
+            }
+        }
+    }
+
+    fn same<F: PrimFloat>(a: F, b: F) -> bool {
+        if a.is_finite() && b.is_finite() {
+            a.to_bits() == b.to_bits()
+        } else {
+            a == b || (a.is_nan() && b.is_nan())
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_trunc_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_trunc_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_trunc_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    fn reference_round_to_nearest_ties_to_even<
+        F: PrimFloat<BitsType = U, SignedBitsType = S>,
+        U: PrimUInt,
+        S: PrimSInt + ConvertFrom<F>,
+    >(
+        v: F,
+    ) -> F {
+        if v.abs() < F::cvt_from(S::MAX) {
+            let int_value: S = v.to();
+            let int_value_f: F = int_value.to();
+            let remainder: F = v - int_value_f;
+            if remainder.abs() < 0.5.to()
+                || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to())
+            {
+                int_value_f.copy_sign(v)
+            } else if remainder < 0.0.to() {
+                int_value_f - 1.0.to()
+            } else {
+                int_value_f + 1.0.to()
+            }
+        } else {
+            v
+        }
+    }
+
+    #[test]
+    fn test_reference_round_to_nearest_ties_to_even() {
+        #[track_caller]
+        fn case(v: f32, expected: f32) {
+            let result = reference_round_to_nearest_ties_to_even(v);
+            assert!(
+                same(result, expected),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+        case(0.0, 0.0);
+        case(-0.0, -0.0);
+        case(0.499, 0.0);
+        case(-0.499, -0.0);
+        case(0.5, 0.0);
+        case(-0.5, -0.0);
+        case(0.501, 1.0);
+        case(-0.501, -1.0);
+        case(1.0, 1.0);
+        case(-1.0, -1.0);
+        case(1.499, 1.0);
+        case(-1.499, -1.0);
+        case(1.5, 2.0);
+        case(-1.5, -2.0);
+        case(1.501, 2.0);
+        case(-1.501, -2.0);
+        case(2.0, 2.0);
+        case(-2.0, -2.0);
+        case(2.499, 2.0);
+        case(-2.499, -2.0);
+        case(2.5, 2.0);
+        case(-2.5, -2.0);
+        case(2.501, 3.0);
+        case(-2.501, -3.0);
+        case(f32::INFINITY, f32::INFINITY);
+        case(-f32::INFINITY, -f32::INFINITY);
+        case(f32::NAN, f32::NAN);
+        case(1e30, 1e30);
+        case(-1e30, -1e30);
+        let i32_max = i32::MAX as f32;
+        let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1);
+        let i32_max_next = f32::from_bits(i32_max.to_bits() + 1);
+        case(i32_max, i32_max);
+        case(-i32_max, -i32_max);
+        case(i32_max_prev, i32_max_prev);
+        case(-i32_max_prev, -i32_max_prev);
+        case(i32_max_next, i32_max_next);
+        case(-i32_max_next, -i32_max_next);
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_round_to_nearest_ties_to_even_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = reference_round_to_nearest_ties_to_even(v);
+            let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
+            assert!(
+                same(result, expected),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_round_to_nearest_ties_to_even_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = reference_round_to_nearest_ties_to_even(v);
+            let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
+            assert!(
+                same(result, expected),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_round_to_nearest_ties_to_even_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = reference_round_to_nearest_ties_to_even(v);
+            let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
+            assert!(
+                same(result, expected),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_floor_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = v.floor();
+            let result = floor(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_floor_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = v.floor();
+            let result = floor(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_floor_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = v.floor();
+            let result = floor(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_ceil_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = v.ceil();
+            let result = ceil(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_ceil_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = v.ceil();
+            let result = ceil(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_ceil_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = v.ceil();
+            let result = ceil(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
     }
 }