add copy_sign and genericify abs
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 13 May 2021 02:49:31 +0000 (19:49 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 13 May 2021 02:49:31 +0000 (19:49 -0700)
src/algorithms/base.rs
src/f16.rs
src/scalar.rs

index 56691a3e136a035077352447a9a9c094d518300a..0b6dcb60993100b966d73f308ceb9084826cbfd1 100644 (file)
@@ -1,15 +1,33 @@
-use crate::traits::{Context, Float};
+use crate::{
+    prim::{PrimFloat, PrimUInt},
+    traits::{Context, Float, Make},
+};
 
-pub fn abs_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
-    Ctx::VecF16::from_bits(x.to_bits() & ctx.make(0x7FFFu16))
+pub fn abs<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    x: VecF,
+) -> VecF {
+    VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
 }
 
-pub fn abs_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
-    Ctx::VecF32::from_bits(x.to_bits() & ctx.make(!(1u32 << 31)))
-}
-
-pub fn abs_f64<Ctx: Context>(ctx: Ctx, x: Ctx::VecF64) -> Ctx::VecF64 {
-    Ctx::VecF64::from_bits(x.to_bits() & ctx.make(!(1u64 << 63)))
+pub fn copy_sign<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    mag: VecF,
+    sign: VecF,
+) -> VecF {
+    let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
+    let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
+    VecF::from_bits(mag_bits | sign_bit)
 }
 
 #[cfg(test)]
@@ -29,7 +47,7 @@ mod tests {
         for bits in 0..=u16::MAX {
             let v = F16::from_bits(bits);
             let expected = v.abs();
-            let result = abs_f16(Scalar, Value(v)).0;
+            let result = abs(Scalar, Value(v)).0;
             assert_eq!(expected.to_bits(), result.to_bits());
         }
     }
@@ -39,7 +57,7 @@ mod tests {
         for bits in (0..=u32::MAX).step_by(10001) {
             let v = f32::from_bits(bits);
             let expected = v.abs();
-            let result = abs_f32(Scalar, Value(v)).0;
+            let result = abs(Scalar, Value(v)).0;
             assert_eq!(expected.to_bits(), result.to_bits());
         }
     }
@@ -49,8 +67,84 @@ mod tests {
         for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
             let v = f64::from_bits(bits);
             let expected = v.abs();
-            let result = abs_f64(Scalar, Value(v)).0;
+            let result = abs(Scalar, Value(v)).0;
+            assert_eq!(expected.to_bits(), result.to_bits());
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_copy_sign_f16() {
+        #[track_caller]
+        fn check(mag_bits: u16, sign_bits: u16) {
+            let mag = F16::from_bits(mag_bits);
+            let sign = F16::from_bits(sign_bits);
+            let expected = mag.copysign(sign);
+            let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
+            assert_eq!(expected.to_bits(), result.to_bits());
+        }
+        for mag_low_bits in 0..16 {
+            for mag_high_bits in 0..16 {
+                for sign_low_bits in 0..16 {
+                    for sign_high_bits in 0..16 {
+                        check(
+                            mag_low_bits | (mag_high_bits << (16 - 4)),
+                            sign_low_bits | (sign_high_bits << (16 - 4)),
+                        );
+                    }
+                }
+            }
+        }
+    }
+
+    #[test]
+    fn test_copy_sign_f32() {
+        #[track_caller]
+        fn check(mag_bits: u32, sign_bits: u32) {
+            let mag = f32::from_bits(mag_bits);
+            let sign = f32::from_bits(sign_bits);
+            let expected = mag.copysign(sign);
+            let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
+            assert_eq!(expected.to_bits(), result.to_bits());
+        }
+        for mag_low_bits in 0..16 {
+            for mag_high_bits in 0..16 {
+                for sign_low_bits in 0..16 {
+                    for sign_high_bits in 0..16 {
+                        check(
+                            mag_low_bits | (mag_high_bits << (32 - 4)),
+                            sign_low_bits | (sign_high_bits << (32 - 4)),
+                        );
+                    }
+                }
+            }
+        }
+    }
+
+    #[test]
+    fn test_copy_sign_f64() {
+        #[track_caller]
+        fn check(mag_bits: u64, sign_bits: u64) {
+            let mag = f64::from_bits(mag_bits);
+            let sign = f64::from_bits(sign_bits);
+            let expected = mag.copysign(sign);
+            let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
             assert_eq!(expected.to_bits(), result.to_bits());
         }
+        for mag_low_bits in 0..16 {
+            for mag_high_bits in 0..16 {
+                for sign_low_bits in 0..16 {
+                    for sign_high_bits in 0..16 {
+                        check(
+                            mag_low_bits | (mag_high_bits << (64 - 4)),
+                            sign_low_bits | (sign_high_bits << (64 - 4)),
+                        );
+                    }
+                }
+            }
+        }
     }
 }
index 5253fef282e136c966950c2d3c4f536bc120fa0c..280d00d4231f751192b12a5d715cd26734318fa6 100644 (file)
@@ -204,6 +204,12 @@ impl F16 {
     pub fn abs(self) -> Self {
         f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
     }
+    pub fn copysign(self, sign: Self) -> Self {
+        f16_impl!(
+            Self::from_bits((self.to_bits() & 0x7FFF) | (sign.to_bits() & 0x8000)),
+            [sign]
+        )
+    }
     pub fn trunc(self) -> Self {
         #[cfg(feature = "std")]
         return f32::from(self).trunc().to();
index 30aaa9e9cfc3eb981da354f0d9bb61aeafc98439..4e5009598aff43900ae158d35d4918c8e0d87140 100644 (file)
@@ -350,7 +350,7 @@ macro_rules! impl_float {
                 #[cfg(feature = "std")]
                 return Value(self.0.abs());
                 #[cfg(not(feature = "std"))]
-                todo!();
+                return crate::algorithms::base::abs(Scalar, self);
             }
             fn trunc(self) -> Self {
                 #[cfg(feature = "std")]