impl traits for scalar types
[vector-math.git] / src / f16.rs
diff --git a/src/f16.rs b/src/f16.rs
new file mode 100644 (file)
index 0000000..04e8fbd
--- /dev/null
@@ -0,0 +1,287 @@
+use core::ops::{
+    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
+};
+
+use crate::traits::{ConvertTo, Float};
+
+#[cfg(feature = "f16")]
+use half::f16 as F16Impl;
+
+#[cfg(not(feature = "f16"))]
+#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
+enum F16Impl {}
+
+#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
+#[cfg_attr(feature = "f16", repr(transparent))]
+pub struct F16(F16Impl);
+
+#[cfg(feature = "f16")]
+macro_rules! f16_impl {
+    ($v:expr, [$($vars:ident),*]) => {
+        $v
+    };
+}
+
+#[cfg(not(feature = "f16"))]
+macro_rules! f16_impl {
+    ($v:expr, [$($vars:ident),*]) => {
+        {
+            $(let _ = $vars;)*
+            panic!("f16 feature is not enabled")
+        }
+    };
+}
+
+impl From<F16Impl> for F16 {
+    fn from(v: F16Impl) -> Self {
+        F16(v)
+    }
+}
+
+impl From<F16> for F16Impl {
+    fn from(v: F16) -> Self {
+        v.0
+    }
+}
+
+macro_rules! impl_f16_from {
+    ($($ty:ident,)*) => {
+        $(
+            impl From<$ty> for F16 {
+                fn from(v: $ty) -> Self {
+                    f16_impl!(F16(F16Impl::from(v)), [v])
+                }
+            }
+
+            impl ConvertTo<F16> for $ty {
+                fn to(self) -> F16 {
+                    self.into()
+                }
+            }
+        )*
+    };
+}
+
+macro_rules! impl_from_f16 {
+    ($($ty:ident,)*) => {
+        $(
+            impl From<F16> for $ty {
+                fn from(v: F16) -> Self {
+                    #[cfg(feature = "f16")]
+                    return v.0.into();
+                    #[cfg(not(feature = "f16"))]
+                    match v.0 {}
+                }
+            }
+
+            impl ConvertTo<$ty> for F16 {
+                fn to(self) -> $ty {
+                    self.into()
+                }
+            }
+        )*
+    };
+}
+
+impl_f16_from![i8, u8,];
+
+impl_from_f16![f32, f64,];
+
+macro_rules! impl_int_to_f16 {
+    ($($int:ident),*) => {
+        $(
+            impl ConvertTo<F16> for $int {
+                fn to(self) -> F16 {
+                    // f32 has enough mantissa bits such that f16 overflows to
+                    // infinity before f32 can't properly represent integer
+                    // values, making the below conversion correct.
+                    (self as f32).to()
+                }
+            }
+        )*
+    };
+}
+
+macro_rules! impl_f16_to_int {
+    ($($int:ident),*) => {
+        $(
+            impl ConvertTo<$int> for F16 {
+                fn to(self) -> $int {
+                    f32::from(self) as $int
+                }
+            }
+        )*
+    };
+}
+
+impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
+impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
+
+impl ConvertTo<F16> for f32 {
+    fn to(self) -> F16 {
+        f16_impl!(F16(F16Impl::from_f32(self)), [])
+    }
+}
+
+impl ConvertTo<F16> for f64 {
+    fn to(self) -> F16 {
+        f16_impl!(F16(F16Impl::from_f64(self)), [])
+    }
+}
+
+impl Neg for F16 {
+    type Output = Self;
+
+    fn neg(self) -> Self::Output {
+        Self::from_bits(self.to_bits() ^ 0x8000)
+    }
+}
+
+macro_rules! impl_bin_op_using_f32 {
+    ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
+        $(
+            impl $op for F16 {
+                type Output = Self;
+
+                fn $op_fn(self, rhs: Self) -> Self::Output {
+                    f32::from(self).$op_fn(f32::from(rhs)).to()
+                }
+            }
+
+            impl $op_assign for F16 {
+                fn $op_assign_fn(&mut self, rhs: Self) {
+                    *self = (*self).$op_fn(rhs);
+                }
+            }
+        )*
+    };
+}
+
+impl_bin_op_using_f32! {
+    Add, add, AddAssign, add_assign;
+    Sub, sub, SubAssign, sub_assign;
+    Mul, mul, MulAssign, mul_assign;
+    Div, div, DivAssign, div_assign;
+    Rem, rem, RemAssign, rem_assign;
+}
+
+impl Float<u32> for F16 {
+    type BitsType = u16;
+
+    fn abs(self) -> Self {
+        Self::from_bits(self.to_bits() & 0x7FFF)
+    }
+
+    fn trunc(self) -> Self {
+        f32::from(self).trunc().to()
+    }
+
+    fn ceil(self) -> Self {
+        f32::from(self).ceil().to()
+    }
+
+    fn floor(self) -> Self {
+        f32::from(self).floor().to()
+    }
+
+    fn round(self) -> Self {
+        f32::from(self).round().to()
+    }
+
+    #[cfg(feature = "fma")]
+    fn fma(self, a: Self, b: Self) -> Self {
+        (f64::from(self) * f64::from(a) + f64::from(b)).to()
+    }
+
+    fn is_nan(self) -> Self::Bool {
+        f16_impl!(self.0.is_nan(), [])
+    }
+
+    fn is_infinite(self) -> Self::Bool {
+        f16_impl!(self.0.is_infinite(), [])
+    }
+
+    fn is_finite(self) -> Self::Bool {
+        f16_impl!(self.0.is_finite(), [])
+    }
+
+    fn from_bits(v: Self::BitsType) -> Self {
+        f16_impl!(F16(F16Impl::from_bits(v)), [v])
+    }
+
+    fn to_bits(self) -> Self::BitsType {
+        f16_impl!(self.0.to_bits(), [])
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use core::cmp::Ordering;
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_abs() {
+        assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
+        assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
+        assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
+        assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
+        assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_neg() {
+        assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
+        assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
+        assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
+        assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
+        assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_int_to_f16() {
+        assert_eq!(F16::to_bits(0u32.to()), 0);
+        for v in 1..0x20000u32 {
+            let leading_zeros = u32::leading_zeros(v);
+            let shifted_v = v << leading_zeros;
+            // round to nearest, ties to even
+            let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
+                Ordering::Less => false,
+                Ordering::Equal => (shifted_v & 0x200000) != 0,
+                Ordering::Greater => true,
+            };
+            let (rounded, carry) =
+                (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
+            let mantissa;
+            if carry {
+                mantissa = (rounded >> 22) as u16 + 0x400;
+            } else {
+                mantissa = (rounded >> 21) as u16;
+            }
+            assert_eq!((mantissa & !0x3FF), 0x400);
+            let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
+            let expected = if exponent < 0x1F {
+                (mantissa & 0x3FF) + (exponent << 10)
+            } else {
+                0x7C00
+            };
+            let actual = F16::to_bits(v.to());
+            assert_eq!(
+                actual, expected,
+                "actual = {:#X}, expected = {:#X}, v = {:#X}",
+                actual, expected, v
+            );
+        }
+    }
+}