From 94017d5d69a43da5ce46f5fdd24bab86fc43fe01 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 12 May 2021 19:49:31 -0700 Subject: [PATCH] add copy_sign and genericify abs --- src/algorithms/base.rs | 118 ++++++++++++++++++++++++++++++++++++----- src/f16.rs | 6 +++ src/scalar.rs | 2 +- 3 files changed, 113 insertions(+), 13 deletions(-) diff --git a/src/algorithms/base.rs b/src/algorithms/base.rs index 56691a3..0b6dcb6 100644 --- a/src/algorithms/base.rs +++ b/src/algorithms/base.rs @@ -1,15 +1,33 @@ -use crate::traits::{Context, Float}; +use crate::{ + prim::{PrimFloat, PrimUInt}, + traits::{Context, Float, Make}, +}; -pub fn abs_f16(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 { - Ctx::VecF16::from_bits(x.to_bits() & ctx.make(0x7FFFu16)) +pub fn abs< + Ctx: Context, + VecF: Float + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, +>( + ctx: Ctx, + x: VecF, +) -> VecF { + VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK)) } -pub fn abs_f32(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 { - Ctx::VecF32::from_bits(x.to_bits() & ctx.make(!(1u32 << 31))) -} - -pub fn abs_f64(ctx: Ctx, x: Ctx::VecF64) -> Ctx::VecF64 { - Ctx::VecF64::from_bits(x.to_bits() & ctx.make(!(1u64 << 63))) +pub fn copy_sign< + Ctx: Context, + VecF: Float + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, +>( + ctx: Ctx, + mag: VecF, + sign: VecF, +) -> VecF { + let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK); + let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK); + VecF::from_bits(mag_bits | sign_bit) } #[cfg(test)] @@ -29,7 +47,7 @@ mod tests { for bits in 0..=u16::MAX { let v = F16::from_bits(bits); let expected = v.abs(); - let result = abs_f16(Scalar, Value(v)).0; + let result = abs(Scalar, Value(v)).0; assert_eq!(expected.to_bits(), result.to_bits()); } } @@ -39,7 +57,7 @@ mod tests { for bits in (0..=u32::MAX).step_by(10001) { let v = f32::from_bits(bits); let expected = v.abs(); - let result = abs_f32(Scalar, Value(v)).0; + let result = abs(Scalar, Value(v)).0; assert_eq!(expected.to_bits(), result.to_bits()); } } @@ -49,8 +67,84 @@ mod tests { for bits in (0..=u64::MAX).step_by(100_000_000_000_001) { let v = f64::from_bits(bits); let expected = v.abs(); - let result = abs_f64(Scalar, Value(v)).0; + let result = abs(Scalar, Value(v)).0; + assert_eq!(expected.to_bits(), result.to_bits()); + } + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_copy_sign_f16() { + #[track_caller] + fn check(mag_bits: u16, sign_bits: u16) { + let mag = F16::from_bits(mag_bits); + let sign = F16::from_bits(sign_bits); + let expected = mag.copysign(sign); + let result = copy_sign(Scalar, Value(mag), Value(sign)).0; + assert_eq!(expected.to_bits(), result.to_bits()); + } + for mag_low_bits in 0..16 { + for mag_high_bits in 0..16 { + for sign_low_bits in 0..16 { + for sign_high_bits in 0..16 { + check( + mag_low_bits | (mag_high_bits << (16 - 4)), + sign_low_bits | (sign_high_bits << (16 - 4)), + ); + } + } + } + } + } + + #[test] + fn test_copy_sign_f32() { + #[track_caller] + fn check(mag_bits: u32, sign_bits: u32) { + let mag = f32::from_bits(mag_bits); + let sign = f32::from_bits(sign_bits); + let expected = mag.copysign(sign); + let result = copy_sign(Scalar, Value(mag), Value(sign)).0; + assert_eq!(expected.to_bits(), result.to_bits()); + } + for mag_low_bits in 0..16 { + for mag_high_bits in 0..16 { + for sign_low_bits in 0..16 { + for sign_high_bits in 0..16 { + check( + mag_low_bits | (mag_high_bits << (32 - 4)), + sign_low_bits | (sign_high_bits << (32 - 4)), + ); + } + } + } + } + } + + #[test] + fn test_copy_sign_f64() { + #[track_caller] + fn check(mag_bits: u64, sign_bits: u64) { + let mag = f64::from_bits(mag_bits); + let sign = f64::from_bits(sign_bits); + let expected = mag.copysign(sign); + let result = copy_sign(Scalar, Value(mag), Value(sign)).0; assert_eq!(expected.to_bits(), result.to_bits()); } + for mag_low_bits in 0..16 { + for mag_high_bits in 0..16 { + for sign_low_bits in 0..16 { + for sign_high_bits in 0..16 { + check( + mag_low_bits | (mag_high_bits << (64 - 4)), + sign_low_bits | (sign_high_bits << (64 - 4)), + ); + } + } + } + } } } diff --git a/src/f16.rs b/src/f16.rs index 5253fef..280d00d 100644 --- a/src/f16.rs +++ b/src/f16.rs @@ -204,6 +204,12 @@ impl F16 { pub fn abs(self) -> Self { f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), []) } + pub fn copysign(self, sign: Self) -> Self { + f16_impl!( + Self::from_bits((self.to_bits() & 0x7FFF) | (sign.to_bits() & 0x8000)), + [sign] + ) + } pub fn trunc(self) -> Self { #[cfg(feature = "std")] return f32::from(self).trunc().to(); diff --git a/src/scalar.rs b/src/scalar.rs index 30aaa9e..4e50095 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -350,7 +350,7 @@ macro_rules! impl_float { #[cfg(feature = "std")] return Value(self.0.abs()); #[cfg(not(feature = "std"))] - todo!(); + return crate::algorithms::base::abs(Scalar, self); } fn trunc(self) -> Self { #[cfg(feature = "std")] -- 2.30.2