use core::ops::{ Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, }; use crate::{ scalar::Value, traits::{ConvertFrom, ConvertTo, Float}, }; #[cfg(feature = "f16")] use half::f16 as F16Impl; #[cfg(not(feature = "f16"))] type F16Impl = u16; #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)] #[repr(transparent)] pub struct F16(F16Impl); #[cfg(not(feature = "f16"))] #[track_caller] pub(crate) fn panic_f16_feature_disabled() -> ! { panic!("f16 feature is not enabled") } #[cfg(feature = "f16")] macro_rules! f16_impl { ($v:expr, [$($vars:ident),*]) => { $v }; } #[cfg(not(feature = "f16"))] macro_rules! f16_impl { ($v:expr, [$($vars:ident),*]) => { { $(let _ = $vars;)* panic_f16_feature_disabled() } }; } impl Default for F16 { fn default() -> Self { f16_impl!(F16(F16Impl::default()), []) } } impl From for F16 { fn from(v: F16Impl) -> Self { F16(v) } } impl From for F16Impl { fn from(v: F16) -> Self { v.0 } } macro_rules! impl_f16_from { ($($ty:ident,)*) => { $( impl From<$ty> for F16 { fn from(v: $ty) -> Self { f16_impl!(F16(F16Impl::from(v)), [v]) } } impl ConvertFrom<$ty> for F16 { fn cvt_from(v: $ty) -> F16 { v.into() } } )* }; } macro_rules! impl_from_f16 { ($($ty:ident,)*) => { $( impl From for $ty { fn from(v: F16) -> Self { f16_impl!(v.0.into(), [v]) } } impl ConvertFrom for $ty { fn cvt_from(v: F16) -> Self { v.into() } } )* }; } impl_f16_from![i8, u8,]; impl_from_f16![f32, f64,]; macro_rules! impl_int_to_f16 { ($($int:ident),*) => { $( impl ConvertFrom<$int> for F16 { fn cvt_from(v: $int) -> Self { // f32 has enough mantissa bits such that f16 overflows to // infinity before f32 stops being able to properly // represent integer values, making the below conversion correct. F16::cvt_from(v as f32) } } )* }; } macro_rules! impl_f16_to_int { ($($int:ident),*) => { $( impl ConvertFrom for $int { fn cvt_from(v: F16) -> Self { f32::from(v) as $int } } )* }; } impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128]; impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128]; impl ConvertFrom for F16 { fn cvt_from(v: f32) -> Self { f16_impl!(F16(F16Impl::from_f32(v)), [v]) } } impl ConvertFrom for F16 { fn cvt_from(v: f64) -> Self { f16_impl!(F16(F16Impl::from_f64(v)), [v]) } } impl Neg for F16 { type Output = Self; fn neg(self) -> Self::Output { f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), []) } } macro_rules! impl_bin_op_using_f32 { ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => { $( impl $op for F16 { type Output = Self; fn $op_fn(self, rhs: Self) -> Self::Output { f32::from(self).$op_fn(f32::from(rhs)).to() } } impl $op_assign for F16 { fn $op_assign_fn(&mut self, rhs: Self) { *self = (*self).$op_fn(rhs); } } )* }; } impl_bin_op_using_f32! { Add, add, AddAssign, add_assign; Sub, sub, SubAssign, sub_assign; Mul, mul, MulAssign, mul_assign; Div, div, DivAssign, div_assign; Rem, rem, RemAssign, rem_assign; } impl F16 { pub fn from_bits(v: u16) -> Self { #[cfg(feature = "f16")] return F16(F16Impl::from_bits(v)); #[cfg(not(feature = "f16"))] return F16(v); } pub fn to_bits(self) -> u16 { #[cfg(feature = "f16")] return self.0.to_bits(); #[cfg(not(feature = "f16"))] return self.0; } pub fn abs(self) -> Self { f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), []) } pub fn trunc(self) -> Self { f32::from(self).trunc().to() } pub fn ceil(self) -> Self { f32::from(self).ceil().to() } pub fn floor(self) -> Self { f32::from(self).floor().to() } pub fn round(self) -> Self { f32::from(self).round().to() } #[cfg(feature = "fma")] pub fn fma(self, a: Self, b: Self) -> Self { (f64::from(self) * f64::from(a) + f64::from(b)).to() } pub fn is_nan(self) -> bool { f16_impl!(self.0.is_nan(), []) } pub fn is_infinite(self) -> bool { f16_impl!(self.0.is_infinite(), []) } pub fn is_finite(self) -> bool { f16_impl!(self.0.is_finite(), []) } } impl Float for Value { type FloatEncoding = F16; type BitsType = Value; type SignedBitsType = Value; fn abs(self) -> Self { Value(self.0.abs()) } fn trunc(self) -> Self { Value(self.0.trunc()) } fn ceil(self) -> Self { Value(self.0.ceil()) } fn floor(self) -> Self { Value(self.0.floor()) } fn round(self) -> Self { Value(self.0.round()) } #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self { Value(self.0.fma(a.0, b.0)) } fn is_nan(self) -> Self::Bool { Value(self.0.is_nan()) } fn is_infinite(self) -> Self::Bool { Value(self.0.is_infinite()) } fn is_finite(self) -> Self::Bool { Value(self.0.is_finite()) } fn from_bits(v: Self::BitsType) -> Self { Value(F16::from_bits(v.0)) } fn to_bits(self) -> Self::BitsType { Value(self.0.to_bits()) } } #[cfg(test)] mod tests { use super::*; use core::cmp::Ordering; #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_abs() { assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0); assert_eq!(F16::from_bits(0).abs().to_bits(), 0); assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC); assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00); assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00); } #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_neg() { assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0); assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000); assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC); assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00); assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00); } #[test] #[cfg_attr( not(feature = "f16"), should_panic(expected = "f16 feature is not enabled") )] fn test_int_to_f16() { assert_eq!(F16::to_bits(0u32.to()), 0); for v in 1..0x20000u32 { let leading_zeros = u32::leading_zeros(v); let shifted_v = v << leading_zeros; // round to nearest, ties to even let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) { Ordering::Less => false, Ordering::Equal => (shifted_v & 0x200000) != 0, Ordering::Greater => true, }; let (rounded, carry) = (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0)); let mantissa; if carry { mantissa = (rounded >> 22) as u16 + 0x400; } else { mantissa = (rounded >> 21) as u16; } assert_eq!((mantissa & !0x3FF), 0x400); let exponent = 31 - leading_zeros as u16 + 15 + carry as u16; let expected = if exponent < 0x1F { (mantissa & 0x3FF) + (exponent << 10) } else { 0x7C00 }; let actual = F16::to_bits(v.to()); assert_eq!( actual, expected, "actual = {:#X}, expected = {:#X}, v = {:#X}", actual, expected, v ); } } }