From aa592610b0802e56c6b5ea2ea70c3d742f36e797 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 13 May 2021 18:09:43 -0700 Subject: [PATCH] add round_to_nearest_ties_to_even --- src/algorithms/base.rs | 171 +++++++++++++++++++++++++++++++++++++++++ src/prim.rs | 25 +++++- 2 files changed, 194 insertions(+), 2 deletions(-) diff --git a/src/algorithms/base.rs b/src/algorithms/base.rs index d387340..28a641d 100644 --- a/src/algorithms/base.rs +++ b/src/algorithms/base.rs @@ -54,12 +54,36 @@ pub fn trunc< out_of_range.select(out_of_range_value, in_range_value) } +pub fn round_to_nearest_ties_to_even< + Ctx: Context, + VecF: Float + Make, + VecU: UInt + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, +>( + ctx: Ctx, + v: VecF, +) -> VecF { + let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to()); + let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big + let small = v.abs().le(ctx.make(PrimF::cvt_from(0.5))); + let out_of_range = big | small; + let small_value = ctx.make::(0.to()).copy_sign(v); + let out_of_range_value = small.select(small_value, v); + let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to()); + let offset_value: VecF = v.abs() + offset; + let in_range_value = (offset_value - offset).copy_sign(v); + out_of_range.select(out_of_range_value, in_range_value) +} + #[cfg(test)] mod tests { use super::*; use crate::{ f16::F16, + prim::PrimSInt, scalar::{Scalar, Value}, + traits::ConvertFrom, }; #[test] @@ -240,4 +264,151 @@ mod tests { ); } } + + fn reference_round_to_nearest_ties_to_even< + F: PrimFloat, + U: PrimUInt, + S: PrimSInt + ConvertFrom, + >( + v: F, + ) -> F { + if v.abs() < F::cvt_from(S::MAX) { + let int_value: S = v.to(); + let int_value_f: F = int_value.to(); + let remainder: F = v - int_value_f; + if remainder.abs() < 0.5.to() + || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to()) + { + int_value_f.copy_sign(v) + } else if remainder < 0.0.to() { + int_value_f - 1.0.to() + } else { + int_value_f + 1.0.to() + } + } else { + v + } + } + + #[test] + fn test_reference_round_to_nearest_ties_to_even() { + #[track_caller] + fn case(v: f32, expected: f32) { + let result = reference_round_to_nearest_ties_to_even(v); + let same = if expected.is_nan() { + result.is_nan() + } else { + expected.to_bits() == result.to_bits() + }; + assert!( + same, + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + case(0.0, 0.0); + case(-0.0, -0.0); + case(0.499, 0.0); + case(-0.499, -0.0); + case(0.5, 0.0); + case(-0.5, -0.0); + case(0.501, 1.0); + case(-0.501, -1.0); + case(1.0, 1.0); + case(-1.0, -1.0); + case(1.499, 1.0); + case(-1.499, -1.0); + case(1.5, 2.0); + case(-1.5, -2.0); + case(1.501, 2.0); + case(-1.501, -2.0); + case(2.0, 2.0); + case(-2.0, -2.0); + case(2.499, 2.0); + case(-2.499, -2.0); + case(2.5, 2.0); + case(-2.5, -2.0); + case(2.501, 3.0); + case(-2.501, -3.0); + case(f32::INFINITY, f32::INFINITY); + case(-f32::INFINITY, -f32::INFINITY); + case(f32::NAN, f32::NAN); + case(1e30, 1e30); + case(-1e30, -1e30); + let i32_max = i32::MAX as f32; + let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1); + let i32_max_next = f32::from_bits(i32_max.to_bits() + 1); + case(i32_max, i32_max); + case(-i32_max, -i32_max); + case(i32_max_prev, i32_max_prev); + case(-i32_max_prev, -i32_max_prev); + case(i32_max_next, i32_max_next); + case(-i32_max_next, -i32_max_next); + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_round_to_nearest_ties_to_even_f16() { + for bits in 0..=u16::MAX { + let v = F16::from_bits(bits); + let expected = reference_round_to_nearest_ties_to_even(v); + let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_round_to_nearest_ties_to_even_f32() { + for bits in (0..=u32::MAX).step_by(0x10000) { + let v = f32::from_bits(bits); + let expected = reference_round_to_nearest_ties_to_even(v); + let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_round_to_nearest_ties_to_even_f64() { + for bits in (0..=u64::MAX).step_by(1 << 48) { + let v = f64::from_bits(bits); + let expected = reference_round_to_nearest_ties_to_even(v); + let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } } diff --git a/src/prim.rs b/src/prim.rs index 7f9fa30..41eca19 100644 --- a/src/prim.rs +++ b/src/prim.rs @@ -86,6 +86,10 @@ pub trait PrimInt: + ops::ShlAssign + ops::ShrAssign { + const ZERO: Self; + const ONE: Self; + const MIN: Self; + const MAX: Self; } pub trait PrimUInt: PrimInt + ConvertFrom { @@ -100,8 +104,18 @@ macro_rules! impl_int { ($uint:ident, $sint:ident) => { impl PrimBase for $uint {} impl PrimBase for $sint {} - impl PrimInt for $uint {} - impl PrimInt for $sint {} + impl PrimInt for $uint { + const ZERO: Self = 0; + const ONE: Self = 1; + const MIN: Self = 0; + const MAX: Self = !0; + } + impl PrimInt for $sint { + const ZERO: Self = 0; + const ONE: Self = 1; + const MIN: Self = $sint::MIN; + const MAX: Self = $sint::MAX; + } impl PrimUInt for $uint { type SignedType = $sint; } @@ -146,6 +160,7 @@ pub trait PrimFloat: } fn is_finite(self) -> bool; fn trunc(self) -> Self; + fn copy_sign(self, sign: Self) -> Self; } macro_rules! impl_float { @@ -207,6 +222,12 @@ macro_rules! impl_float { #[cfg(not(feature = "std"))] return crate::algorithms::base::trunc(Scalar, Value(self)).0; } + fn copy_sign(self, sign: Self) -> Self { + #[cfg(feature = "std")] + return $float::copysign(self); + #[cfg(not(feature = "std"))] + return crate::algorithms::base::copy_sign(Scalar, Value(self), Value(sign)).0; + } } }; } -- 2.30.2