use crate::{ prim::{PrimFloat, PrimUInt}, traits::{Context, ConvertTo, Float, Make, Select, UInt}, }; pub fn abs< Ctx: Context, VecF: Float + Make, PrimF: PrimFloat, PrimU: PrimUInt, >( ctx: Ctx, x: VecF, ) -> VecF { VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK)) } pub fn copy_sign< Ctx: Context, VecF: Float + Make, PrimF: PrimFloat, PrimU: PrimUInt, >( ctx: Ctx, mag: VecF, sign: VecF, ) -> VecF { let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK); let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK); VecF::from_bits(mag_bits | sign_bit) } pub fn trunc< Ctx: Context, VecF: Float + Make, VecU: UInt + Make, PrimF: PrimFloat, PrimU: PrimUInt, >( ctx: Ctx, v: VecF, ) -> VecF { let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to()); let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big let small = v.abs().lt(ctx.make(PrimF::cvt_from(1))); let out_of_range = big | small; let small_value = ctx.make::(0.to()).copy_sign(v); let out_of_range_value = small.select(small_value, v); let exponent_field = v.extract_exponent_field(); let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED); let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK); mask >>= right_shift_amount; let in_range_value = VecF::from_bits(v.to_bits() & !mask); out_of_range.select(out_of_range_value, in_range_value) } pub fn round_to_nearest_ties_to_even< Ctx: Context, VecF: Float + Make, VecU: UInt + Make, PrimF: PrimFloat, PrimU: PrimUInt, >( ctx: Ctx, v: VecF, ) -> VecF { let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to()); let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big let small = v.abs().le(ctx.make(PrimF::cvt_from(0.5))); let out_of_range = big | small; let small_value = ctx.make::(0.to()).copy_sign(v); let out_of_range_value = small.select(small_value, v); let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to()); let offset_value: VecF = v.abs() + offset; let in_range_value = (offset_value - offset).copy_sign(v); out_of_range.select(out_of_range_value, in_range_value) } #[cfg(test)] mod tests { use super::*; use crate::{ f16::F16, prim::PrimSInt, scalar::{Scalar, Value}, traits::ConvertFrom, }; #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_abs_f16() { for bits in 0..=u16::MAX { let v = F16::from_bits(bits); let expected = v.abs(); let result = abs(Scalar, Value(v)).0; assert_eq!(expected.to_bits(), result.to_bits()); } } #[test] fn test_abs_f32() { for bits in (0..=u32::MAX).step_by(10001) { let v = f32::from_bits(bits); let expected = v.abs(); let result = abs(Scalar, Value(v)).0; assert_eq!(expected.to_bits(), result.to_bits()); } } #[test] fn test_abs_f64() { for bits in (0..=u64::MAX).step_by(100_000_000_000_001) { let v = f64::from_bits(bits); let expected = v.abs(); let result = abs(Scalar, Value(v)).0; assert_eq!(expected.to_bits(), result.to_bits()); } } #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_copy_sign_f16() { #[track_caller] fn check(mag_bits: u16, sign_bits: u16) { let mag = F16::from_bits(mag_bits); let sign = F16::from_bits(sign_bits); let expected = mag.copysign(sign); let result = copy_sign(Scalar, Value(mag), Value(sign)).0; assert_eq!(expected.to_bits(), result.to_bits()); } for mag_low_bits in 0..16 { for mag_high_bits in 0..16 { for sign_low_bits in 0..16 { for sign_high_bits in 0..16 { check( mag_low_bits | (mag_high_bits << (16 - 4)), sign_low_bits | (sign_high_bits << (16 - 4)), ); } } } } } #[test] fn test_copy_sign_f32() { #[track_caller] fn check(mag_bits: u32, sign_bits: u32) { let mag = f32::from_bits(mag_bits); let sign = f32::from_bits(sign_bits); let expected = mag.copysign(sign); let result = copy_sign(Scalar, Value(mag), Value(sign)).0; assert_eq!(expected.to_bits(), result.to_bits()); } for mag_low_bits in 0..16 { for mag_high_bits in 0..16 { for sign_low_bits in 0..16 { for sign_high_bits in 0..16 { check( mag_low_bits | (mag_high_bits << (32 - 4)), sign_low_bits | (sign_high_bits << (32 - 4)), ); } } } } } #[test] fn test_copy_sign_f64() { #[track_caller] fn check(mag_bits: u64, sign_bits: u64) { let mag = f64::from_bits(mag_bits); let sign = f64::from_bits(sign_bits); let expected = mag.copysign(sign); let result = copy_sign(Scalar, Value(mag), Value(sign)).0; assert_eq!(expected.to_bits(), result.to_bits()); } for mag_low_bits in 0..16 { for mag_high_bits in 0..16 { for sign_low_bits in 0..16 { for sign_high_bits in 0..16 { check( mag_low_bits | (mag_high_bits << (64 - 4)), sign_low_bits | (sign_high_bits << (64 - 4)), ); } } } } } fn same(a: F, b: F) -> bool { if a.is_finite() && b.is_finite() { a == b } else { a == b || (a.is_nan() && b.is_nan()) } } #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_trunc_f16() { for bits in 0..=u16::MAX { let v = F16::from_bits(bits); let expected = v.trunc(); let result = trunc(Scalar, Value(v)).0; assert!( same(expected, result), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), expected=expected, expected_bits=expected.to_bits(), result=result, result_bits=result.to_bits(), ); } } #[test] fn test_trunc_f32() { for bits in (0..=u32::MAX).step_by(0x10000) { let v = f32::from_bits(bits); let expected = v.trunc(); let result = trunc(Scalar, Value(v)).0; assert!( same(expected, result), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), expected=expected, expected_bits=expected.to_bits(), result=result, result_bits=result.to_bits(), ); } } #[test] fn test_trunc_f64() { for bits in (0..=u64::MAX).step_by(1 << 48) { let v = f64::from_bits(bits); let expected = v.trunc(); let result = trunc(Scalar, Value(v)).0; assert!( same(expected, result), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), expected=expected, expected_bits=expected.to_bits(), result=result, result_bits=result.to_bits(), ); } } fn reference_round_to_nearest_ties_to_even< F: PrimFloat, U: PrimUInt, S: PrimSInt + ConvertFrom, >( v: F, ) -> F { if v.abs() < F::cvt_from(S::MAX) { let int_value: S = v.to(); let int_value_f: F = int_value.to(); let remainder: F = v - int_value_f; if remainder.abs() < 0.5.to() || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to()) { int_value_f.copy_sign(v) } else if remainder < 0.0.to() { int_value_f - 1.0.to() } else { int_value_f + 1.0.to() } } else { v } } #[test] fn test_reference_round_to_nearest_ties_to_even() { #[track_caller] fn case(v: f32, expected: f32) { let result = reference_round_to_nearest_ties_to_even(v); let same = if expected.is_nan() { result.is_nan() } else { expected.to_bits() == result.to_bits() }; assert!( same, "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), expected=expected, expected_bits=expected.to_bits(), result=result, result_bits=result.to_bits(), ); } case(0.0, 0.0); case(-0.0, -0.0); case(0.499, 0.0); case(-0.499, -0.0); case(0.5, 0.0); case(-0.5, -0.0); case(0.501, 1.0); case(-0.501, -1.0); case(1.0, 1.0); case(-1.0, -1.0); case(1.499, 1.0); case(-1.499, -1.0); case(1.5, 2.0); case(-1.5, -2.0); case(1.501, 2.0); case(-1.501, -2.0); case(2.0, 2.0); case(-2.0, -2.0); case(2.499, 2.0); case(-2.499, -2.0); case(2.5, 2.0); case(-2.5, -2.0); case(2.501, 3.0); case(-2.501, -3.0); case(f32::INFINITY, f32::INFINITY); case(-f32::INFINITY, -f32::INFINITY); case(f32::NAN, f32::NAN); case(1e30, 1e30); case(-1e30, -1e30); let i32_max = i32::MAX as f32; let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1); let i32_max_next = f32::from_bits(i32_max.to_bits() + 1); case(i32_max, i32_max); case(-i32_max, -i32_max); case(i32_max_prev, i32_max_prev); case(-i32_max_prev, -i32_max_prev); case(i32_max_next, i32_max_next); case(-i32_max_next, -i32_max_next); } #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_round_to_nearest_ties_to_even_f16() { for bits in 0..=u16::MAX { let v = F16::from_bits(bits); let expected = reference_round_to_nearest_ties_to_even(v); let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; assert!( same(expected, result), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), expected=expected, expected_bits=expected.to_bits(), result=result, result_bits=result.to_bits(), ); } } #[test] fn test_round_to_nearest_ties_to_even_f32() { for bits in (0..=u32::MAX).step_by(0x10000) { let v = f32::from_bits(bits); let expected = reference_round_to_nearest_ties_to_even(v); let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; assert!( same(expected, result), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), expected=expected, expected_bits=expected.to_bits(), result=result, result_bits=result.to_bits(), ); } } #[test] fn test_round_to_nearest_ties_to_even_f64() { for bits in (0..=u64::MAX).step_by(1 << 48) { let v = f64::from_bits(bits); let expected = reference_round_to_nearest_ties_to_even(v); let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; assert!( same(expected, result), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), expected=expected, expected_bits=expected.to_bits(), result=result, result_bits=result.to_bits(), ); } } }