out_of_range.select(out_of_range_value, in_range_value)
}
+pub fn round_to_nearest_ties_to_even<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ v: VecF,
+) -> VecF {
+ let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+ let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+ let small = v.abs().le(ctx.make(PrimF::cvt_from(0.5)));
+ let out_of_range = big | small;
+ let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
+ let out_of_range_value = small.select(small_value, v);
+ let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
+ let offset_value: VecF = v.abs() + offset;
+ let in_range_value = (offset_value - offset).copy_sign(v);
+ out_of_range.select(out_of_range_value, in_range_value)
+}
+
#[cfg(test)]
mod tests {
use super::*;
use crate::{
f16::F16,
+ prim::PrimSInt,
scalar::{Scalar, Value},
+ traits::ConvertFrom,
};
#[test]
);
}
}
+
+ fn reference_round_to_nearest_ties_to_even<
+ F: PrimFloat<BitsType = U, SignedBitsType = S>,
+ U: PrimUInt,
+ S: PrimSInt + ConvertFrom<F>,
+ >(
+ v: F,
+ ) -> F {
+ if v.abs() < F::cvt_from(S::MAX) {
+ let int_value: S = v.to();
+ let int_value_f: F = int_value.to();
+ let remainder: F = v - int_value_f;
+ if remainder.abs() < 0.5.to()
+ || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to())
+ {
+ int_value_f.copy_sign(v)
+ } else if remainder < 0.0.to() {
+ int_value_f - 1.0.to()
+ } else {
+ int_value_f + 1.0.to()
+ }
+ } else {
+ v
+ }
+ }
+
+ #[test]
+ fn test_reference_round_to_nearest_ties_to_even() {
+ #[track_caller]
+ fn case(v: f32, expected: f32) {
+ let result = reference_round_to_nearest_ties_to_even(v);
+ let same = if expected.is_nan() {
+ result.is_nan()
+ } else {
+ expected.to_bits() == result.to_bits()
+ };
+ assert!(
+ same,
+ "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+ v=v,
+ v_bits=v.to_bits(),
+ expected=expected,
+ expected_bits=expected.to_bits(),
+ result=result,
+ result_bits=result.to_bits(),
+ );
+ }
+ case(0.0, 0.0);
+ case(-0.0, -0.0);
+ case(0.499, 0.0);
+ case(-0.499, -0.0);
+ case(0.5, 0.0);
+ case(-0.5, -0.0);
+ case(0.501, 1.0);
+ case(-0.501, -1.0);
+ case(1.0, 1.0);
+ case(-1.0, -1.0);
+ case(1.499, 1.0);
+ case(-1.499, -1.0);
+ case(1.5, 2.0);
+ case(-1.5, -2.0);
+ case(1.501, 2.0);
+ case(-1.501, -2.0);
+ case(2.0, 2.0);
+ case(-2.0, -2.0);
+ case(2.499, 2.0);
+ case(-2.499, -2.0);
+ case(2.5, 2.0);
+ case(-2.5, -2.0);
+ case(2.501, 3.0);
+ case(-2.501, -3.0);
+ case(f32::INFINITY, f32::INFINITY);
+ case(-f32::INFINITY, -f32::INFINITY);
+ case(f32::NAN, f32::NAN);
+ case(1e30, 1e30);
+ case(-1e30, -1e30);
+ let i32_max = i32::MAX as f32;
+ let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1);
+ let i32_max_next = f32::from_bits(i32_max.to_bits() + 1);
+ case(i32_max, i32_max);
+ case(-i32_max, -i32_max);
+ case(i32_max_prev, i32_max_prev);
+ case(-i32_max_prev, -i32_max_prev);
+ case(i32_max_next, i32_max_next);
+ case(-i32_max_next, -i32_max_next);
+ }
+
+ #[test]
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_round_to_nearest_ties_to_even_f16() {
+ for bits in 0..=u16::MAX {
+ let v = F16::from_bits(bits);
+ let expected = reference_round_to_nearest_ties_to_even(v);
+ let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
+ assert!(
+ same(expected, result),
+ "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+ v=v,
+ v_bits=v.to_bits(),
+ expected=expected,
+ expected_bits=expected.to_bits(),
+ result=result,
+ result_bits=result.to_bits(),
+ );
+ }
+ }
+
+ #[test]
+ fn test_round_to_nearest_ties_to_even_f32() {
+ for bits in (0..=u32::MAX).step_by(0x10000) {
+ let v = f32::from_bits(bits);
+ let expected = reference_round_to_nearest_ties_to_even(v);
+ let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
+ assert!(
+ same(expected, result),
+ "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+ v=v,
+ v_bits=v.to_bits(),
+ expected=expected,
+ expected_bits=expected.to_bits(),
+ result=result,
+ result_bits=result.to_bits(),
+ );
+ }
+ }
+
+ #[test]
+ fn test_round_to_nearest_ties_to_even_f64() {
+ for bits in (0..=u64::MAX).step_by(1 << 48) {
+ let v = f64::from_bits(bits);
+ let expected = reference_round_to_nearest_ties_to_even(v);
+ let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
+ assert!(
+ same(expected, result),
+ "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+ v=v,
+ v_bits=v.to_bits(),
+ expected=expected,
+ expected_bits=expected.to_bits(),
+ result=result,
+ result_bits=result.to_bits(),
+ );
+ }
+ }
}
+ ops::ShlAssign
+ ops::ShrAssign
{
+ const ZERO: Self;
+ const ONE: Self;
+ const MIN: Self;
+ const MAX: Self;
}
pub trait PrimUInt: PrimInt + ConvertFrom<Self::SignedType> {
($uint:ident, $sint:ident) => {
impl PrimBase for $uint {}
impl PrimBase for $sint {}
- impl PrimInt for $uint {}
- impl PrimInt for $sint {}
+ impl PrimInt for $uint {
+ const ZERO: Self = 0;
+ const ONE: Self = 1;
+ const MIN: Self = 0;
+ const MAX: Self = !0;
+ }
+ impl PrimInt for $sint {
+ const ZERO: Self = 0;
+ const ONE: Self = 1;
+ const MIN: Self = $sint::MIN;
+ const MAX: Self = $sint::MAX;
+ }
impl PrimUInt for $uint {
type SignedType = $sint;
}
}
fn is_finite(self) -> bool;
fn trunc(self) -> Self;
+ fn copy_sign(self, sign: Self) -> Self;
}
macro_rules! impl_float {
#[cfg(not(feature = "std"))]
return crate::algorithms::base::trunc(Scalar, Value(self)).0;
}
+ fn copy_sign(self, sign: Self) -> Self {
+ #[cfg(feature = "std")]
+ return $float::copysign(self);
+ #[cfg(not(feature = "std"))]
+ return crate::algorithms::base::copy_sign(Scalar, Value(self), Value(sign)).0;
+ }
}
};
}