use core::ops::{ Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, }; use crate::traits::{ConvertTo, Float}; #[cfg(feature = "f16")] use half::f16 as F16Impl; #[cfg(not(feature = "f16"))] type F16Impl = u16; #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)] #[repr(transparent)] pub struct F16(F16Impl); #[cfg(not(feature = "f16"))] #[track_caller] pub(crate) fn panic_f16_feature_disabled() -> ! { panic!("f16 feature is not enabled") } #[cfg(feature = "f16")] macro_rules! f16_impl { ($v:expr, [$($vars:ident),*]) => { $v }; } #[cfg(not(feature = "f16"))] macro_rules! f16_impl { ($v:expr, [$($vars:ident),*]) => { { $(let _ = $vars;)* panic_f16_feature_disabled() } }; } impl Default for F16 { fn default() -> Self { f16_impl!(F16(F16Impl::default()), []) } } impl From for F16 { fn from(v: F16Impl) -> Self { F16(v) } } impl From for F16Impl { fn from(v: F16) -> Self { v.0 } } macro_rules! impl_f16_from { ($($ty:ident,)*) => { $( impl From<$ty> for F16 { fn from(v: $ty) -> Self { f16_impl!(F16(F16Impl::from(v)), [v]) } } impl ConvertTo for $ty { fn to(self) -> F16 { self.into() } } )* }; } macro_rules! impl_from_f16 { ($($ty:ident,)*) => { $( impl From for $ty { fn from(v: F16) -> Self { f16_impl!(v.0.into(), [v]) } } impl ConvertTo<$ty> for F16 { fn to(self) -> $ty { self.into() } } )* }; } impl_f16_from![i8, u8,]; impl_from_f16![f32, f64,]; macro_rules! impl_int_to_f16 { ($($int:ident),*) => { $( impl ConvertTo for $int { fn to(self) -> F16 { // f32 has enough mantissa bits such that f16 overflows to // infinity before f32 stops being able to properly // represent integer values, making the below conversion correct. (self as f32).to() } } )* }; } macro_rules! impl_f16_to_int { ($($int:ident),*) => { $( impl ConvertTo<$int> for F16 { fn to(self) -> $int { f32::from(self) as $int } } )* }; } impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128]; impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128]; impl ConvertTo for f32 { fn to(self) -> F16 { f16_impl!(F16(F16Impl::from_f32(self)), []) } } impl ConvertTo for f64 { fn to(self) -> F16 { f16_impl!(F16(F16Impl::from_f64(self)), []) } } impl Neg for F16 { type Output = Self; fn neg(self) -> Self::Output { f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), []) } } macro_rules! impl_bin_op_using_f32 { ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => { $( impl $op for F16 { type Output = Self; fn $op_fn(self, rhs: Self) -> Self::Output { f32::from(self).$op_fn(f32::from(rhs)).to() } } impl $op_assign for F16 { fn $op_assign_fn(&mut self, rhs: Self) { *self = (*self).$op_fn(rhs); } } )* }; } impl_bin_op_using_f32! { Add, add, AddAssign, add_assign; Sub, sub, SubAssign, sub_assign; Mul, mul, MulAssign, mul_assign; Div, div, DivAssign, div_assign; Rem, rem, RemAssign, rem_assign; } impl Float for F16 { type FloatEncoding = F16; type BitsType = u16; type SignedBitsType = i16; fn abs(self) -> Self { f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), []) } fn trunc(self) -> Self { f32::from(self).trunc().to() } fn ceil(self) -> Self { f32::from(self).ceil().to() } fn floor(self) -> Self { f32::from(self).floor().to() } fn round(self) -> Self { f32::from(self).round().to() } #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self { (f64::from(self) * f64::from(a) + f64::from(b)).to() } fn is_nan(self) -> Self::Bool { f16_impl!(self.0.is_nan(), []) } fn is_infinite(self) -> Self::Bool { f16_impl!(self.0.is_infinite(), []) } fn is_finite(self) -> Self::Bool { f16_impl!(self.0.is_finite(), []) } fn from_bits(v: Self::BitsType) -> Self { #[cfg(feature = "f16")] return F16(F16Impl::from_bits(v)); #[cfg(not(feature = "f16"))] return F16(v); } fn to_bits(self) -> Self::BitsType { #[cfg(feature = "f16")] return self.0.to_bits(); #[cfg(not(feature = "f16"))] return self.0; } } #[cfg(test)] mod tests { use super::*; use core::cmp::Ordering; #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_abs() { assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0); assert_eq!(F16::from_bits(0).abs().to_bits(), 0); assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC); assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00); assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00); } #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_neg() { assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0); assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000); assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC); assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00); assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00); } #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_int_to_f16() { assert_eq!(F16::to_bits(0u32.to()), 0); for v in 1..0x20000u32 { let leading_zeros = u32::leading_zeros(v); let shifted_v = v << leading_zeros; // round to nearest, ties to even let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) { Ordering::Less => false, Ordering::Equal => (shifted_v & 0x200000) != 0, Ordering::Greater => true, }; let (rounded, carry) = (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0)); let mantissa; if carry { mantissa = (rounded >> 22) as u16 + 0x400; } else { mantissa = (rounded >> 21) as u16; } assert_eq!((mantissa & !0x3FF), 0x400); let exponent = 31 - leading_zeros as u16 + 15 + carry as u16; let expected = if exponent < 0x1F { (mantissa & 0x3FF) + (exponent << 10) } else { 0x7C00 }; let actual = F16::to_bits(v.to()); assert_eq!( actual, expected, "actual = {:#X}, expected = {:#X}, v = {:#X}", actual, expected, v ); } } }