--- /dev/null
+use core::ops::{
+ Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
+};
+
+use crate::traits::{ConvertTo, Float};
+
+#[cfg(feature = "f16")]
+use half::f16 as F16Impl;
+
+#[cfg(not(feature = "f16"))]
+#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
+enum F16Impl {}
+
+#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
+#[cfg_attr(feature = "f16", repr(transparent))]
+pub struct F16(F16Impl);
+
+#[cfg(feature = "f16")]
+macro_rules! f16_impl {
+ ($v:expr, [$($vars:ident),*]) => {
+ $v
+ };
+}
+
+#[cfg(not(feature = "f16"))]
+macro_rules! f16_impl {
+ ($v:expr, [$($vars:ident),*]) => {
+ {
+ $(let _ = $vars;)*
+ panic!("f16 feature is not enabled")
+ }
+ };
+}
+
+impl From<F16Impl> for F16 {
+ fn from(v: F16Impl) -> Self {
+ F16(v)
+ }
+}
+
+impl From<F16> for F16Impl {
+ fn from(v: F16) -> Self {
+ v.0
+ }
+}
+
+macro_rules! impl_f16_from {
+ ($($ty:ident,)*) => {
+ $(
+ impl From<$ty> for F16 {
+ fn from(v: $ty) -> Self {
+ f16_impl!(F16(F16Impl::from(v)), [v])
+ }
+ }
+
+ impl ConvertTo<F16> for $ty {
+ fn to(self) -> F16 {
+ self.into()
+ }
+ }
+ )*
+ };
+}
+
+macro_rules! impl_from_f16 {
+ ($($ty:ident,)*) => {
+ $(
+ impl From<F16> for $ty {
+ fn from(v: F16) -> Self {
+ #[cfg(feature = "f16")]
+ return v.0.into();
+ #[cfg(not(feature = "f16"))]
+ match v.0 {}
+ }
+ }
+
+ impl ConvertTo<$ty> for F16 {
+ fn to(self) -> $ty {
+ self.into()
+ }
+ }
+ )*
+ };
+}
+
+impl_f16_from![i8, u8,];
+
+impl_from_f16![f32, f64,];
+
+macro_rules! impl_int_to_f16 {
+ ($($int:ident),*) => {
+ $(
+ impl ConvertTo<F16> for $int {
+ fn to(self) -> F16 {
+ // f32 has enough mantissa bits such that f16 overflows to
+ // infinity before f32 can't properly represent integer
+ // values, making the below conversion correct.
+ (self as f32).to()
+ }
+ }
+ )*
+ };
+}
+
+macro_rules! impl_f16_to_int {
+ ($($int:ident),*) => {
+ $(
+ impl ConvertTo<$int> for F16 {
+ fn to(self) -> $int {
+ f32::from(self) as $int
+ }
+ }
+ )*
+ };
+}
+
+impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
+impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
+
+impl ConvertTo<F16> for f32 {
+ fn to(self) -> F16 {
+ f16_impl!(F16(F16Impl::from_f32(self)), [])
+ }
+}
+
+impl ConvertTo<F16> for f64 {
+ fn to(self) -> F16 {
+ f16_impl!(F16(F16Impl::from_f64(self)), [])
+ }
+}
+
+impl Neg for F16 {
+ type Output = Self;
+
+ fn neg(self) -> Self::Output {
+ Self::from_bits(self.to_bits() ^ 0x8000)
+ }
+}
+
+macro_rules! impl_bin_op_using_f32 {
+ ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
+ $(
+ impl $op for F16 {
+ type Output = Self;
+
+ fn $op_fn(self, rhs: Self) -> Self::Output {
+ f32::from(self).$op_fn(f32::from(rhs)).to()
+ }
+ }
+
+ impl $op_assign for F16 {
+ fn $op_assign_fn(&mut self, rhs: Self) {
+ *self = (*self).$op_fn(rhs);
+ }
+ }
+ )*
+ };
+}
+
+impl_bin_op_using_f32! {
+ Add, add, AddAssign, add_assign;
+ Sub, sub, SubAssign, sub_assign;
+ Mul, mul, MulAssign, mul_assign;
+ Div, div, DivAssign, div_assign;
+ Rem, rem, RemAssign, rem_assign;
+}
+
+impl Float<u32> for F16 {
+ type BitsType = u16;
+
+ fn abs(self) -> Self {
+ Self::from_bits(self.to_bits() & 0x7FFF)
+ }
+
+ fn trunc(self) -> Self {
+ f32::from(self).trunc().to()
+ }
+
+ fn ceil(self) -> Self {
+ f32::from(self).ceil().to()
+ }
+
+ fn floor(self) -> Self {
+ f32::from(self).floor().to()
+ }
+
+ fn round(self) -> Self {
+ f32::from(self).round().to()
+ }
+
+ #[cfg(feature = "fma")]
+ fn fma(self, a: Self, b: Self) -> Self {
+ (f64::from(self) * f64::from(a) + f64::from(b)).to()
+ }
+
+ fn is_nan(self) -> Self::Bool {
+ f16_impl!(self.0.is_nan(), [])
+ }
+
+ fn is_infinite(self) -> Self::Bool {
+ f16_impl!(self.0.is_infinite(), [])
+ }
+
+ fn is_finite(self) -> Self::Bool {
+ f16_impl!(self.0.is_finite(), [])
+ }
+
+ fn from_bits(v: Self::BitsType) -> Self {
+ f16_impl!(F16(F16Impl::from_bits(v)), [v])
+ }
+
+ fn to_bits(self) -> Self::BitsType {
+ f16_impl!(self.0.to_bits(), [])
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use core::cmp::Ordering;
+
+ #[test]
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_abs() {
+ assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
+ assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
+ assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
+ assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
+ assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
+ }
+
+ #[test]
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_neg() {
+ assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
+ assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
+ assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
+ assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
+ assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
+ }
+
+ #[test]
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_int_to_f16() {
+ assert_eq!(F16::to_bits(0u32.to()), 0);
+ for v in 1..0x20000u32 {
+ let leading_zeros = u32::leading_zeros(v);
+ let shifted_v = v << leading_zeros;
+ // round to nearest, ties to even
+ let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
+ Ordering::Less => false,
+ Ordering::Equal => (shifted_v & 0x200000) != 0,
+ Ordering::Greater => true,
+ };
+ let (rounded, carry) =
+ (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
+ let mantissa;
+ if carry {
+ mantissa = (rounded >> 22) as u16 + 0x400;
+ } else {
+ mantissa = (rounded >> 21) as u16;
+ }
+ assert_eq!((mantissa & !0x3FF), 0x400);
+ let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
+ let expected = if exponent < 0x1F {
+ (mantissa & 0x3FF) + (exponent << 10)
+ } else {
+ 0x7C00
+ };
+ let actual = F16::to_bits(v.to());
+ assert_eq!(
+ actual, expected,
+ "actual = {:#X}, expected = {:#X}, v = {:#X}",
+ actual, expected, v
+ );
+ }
+ }
+}
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
-use crate::f16;
+use crate::f16::F16;
#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
macro_rules! make_float_type {
$U16;
#[int(prim = i16 $(, scalar = $ScalarI16)?)]
$I16;
- #[float(prim = f16 $(, scalar = $ScalarF16)?)]
+ #[float(prim = F16 $(, scalar = $ScalarF16)?)]
$F16;
},
{
fn to(self) -> T;
}
-impl<T, U: Into<T>> ConvertTo<T> for U {
- fn to(self) -> T {
- self.into()
- }
+macro_rules! impl_convert_to_using_as {
+ ($($src:ident -> [$($dest:ident),*];)*) => {
+ $($(
+ impl ConvertTo<$dest> for $src {
+ fn to(self) -> $dest {
+ self as $dest
+ }
+ }
+ )*)*
+ };
+ ([$($src:ident),*] -> $dest:tt;) => {
+ impl_convert_to_using_as! {
+ $(
+ $src -> $dest;
+ )*
+ }
+ };
+ ([$($src:ident),*];) => {
+ impl_convert_to_using_as! {
+ [$($src),*] -> [$($src),*];
+ }
+ };
+}
+
+impl_convert_to_using_as! {
+ [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
}
pub trait Number:
{
}
+impl<T> Number for T where
+ T: Compare
+ + Add<Output = Self>
+ + Sub<Output = Self>
+ + Mul<Output = Self>
+ + Div<Output = Self>
+ + Rem<Output = Self>
+ + AddAssign
+ + SubAssign
+ + MulAssign
+ + DivAssign
+ + RemAssign
+{
+}
+
pub trait BitOps:
Copy
+ BitAnd<Output = Self>
{
}
+impl<T> BitOps for T where
+ T: Copy
+ + BitAnd<Output = Self>
+ + BitOr<Output = Self>
+ + BitXor<Output = Self>
+ + Not<Output = Self>
+ + BitAndAssign
+ + BitOrAssign
+ + BitXorAssign
+{
+}
+
pub trait Int<ShiftRhs>:
Number
+ BitOps
pub trait SInt<ShiftRhs>: Int<ShiftRhs> + Neg<Output = Self> {}
+macro_rules! impl_uint {
+ ($($ty:ident),*) => {
+ $(
+ impl Int<u32> for $ty {}
+ impl UInt<u32> for $ty {}
+ )*
+ };
+}
+
+impl_uint![u8, u16, u32, u64];
+
+macro_rules! impl_int {
+ ($($ty:ident),*) => {
+ $(
+ impl Int<u32> for $ty {}
+ impl SInt<u32> for $ty {}
+ )*
+ };
+}
+
+impl_int![i8, i16, i32, i64];
+
pub trait Float<BitsShiftRhs>: Number + Neg<Output = Self> {
type BitsType: UInt<BitsShiftRhs>;
fn abs(self) -> Self;
fn ceil(self) -> Self;
fn floor(self) -> Self;
fn round(self) -> Self;
+ #[cfg(feature = "fma")]
fn fma(self, a: Self, b: Self) -> Self;
fn is_nan(self) -> Self::Bool;
- fn is_infinity(self) -> Self::Bool;
+ fn is_infinite(self) -> Self::Bool;
fn is_finite(self) -> Self::Bool;
fn from_bits(v: Self::BitsType) -> Self;
fn to_bits(self) -> Self::BitsType;
}
+macro_rules! impl_float {
+ ($ty:ty, $bits:ty) => {
+ impl Float<u32> for $ty {
+ type BitsType = $bits;
+ fn abs(self) -> Self {
+ #[cfg(feature = "std")]
+ return self.abs();
+ #[cfg(not(feature = "std"))]
+ todo!();
+ }
+ fn trunc(self) -> Self {
+ #[cfg(feature = "std")]
+ return self.trunc();
+ #[cfg(not(feature = "std"))]
+ todo!();
+ }
+ fn ceil(self) -> Self {
+ #[cfg(feature = "std")]
+ return self.ceil();
+ #[cfg(not(feature = "std"))]
+ todo!();
+ }
+ fn floor(self) -> Self {
+ #[cfg(feature = "std")]
+ return self.floor();
+ #[cfg(not(feature = "std"))]
+ todo!();
+ }
+ fn round(self) -> Self {
+ #[cfg(feature = "std")]
+ return self.round();
+ #[cfg(not(feature = "std"))]
+ todo!();
+ }
+ #[cfg(feature = "fma")]
+ fn fma(self, a: Self, b: Self) -> Self {
+ self.mul_add(a, b)
+ }
+ fn is_nan(self) -> Self::Bool {
+ self.is_nan()
+ }
+ fn is_infinite(self) -> Self::Bool {
+ self.is_infinite()
+ }
+ fn is_finite(self) -> Self::Bool {
+ self.is_finite()
+ }
+ fn from_bits(v: Self::BitsType) -> Self {
+ <$ty>::from_bits(v)
+ }
+ fn to_bits(self) -> Self::BitsType {
+ self.to_bits()
+ }
+ }
+ };
+}
+
+impl_float!(f32, u32);
+impl_float!(f64, u64);
+
pub trait Bool: BitOps {}
+impl Bool for bool {}
+
pub trait Select<T>: Bool {
fn select(self, true_v: T, false_v: T) -> T;
}
+impl<T> Select<T> for bool {
+ fn select(self, true_v: T, false_v: T) -> T {
+ if self {
+ true_v
+ } else {
+ false_v
+ }
+ }
+}
pub trait Compare: Copy {
type Bool: Bool + Select<Self>;
fn eq(self, rhs: Self) -> Self::Bool;
fn le(self, rhs: Self) -> Self::Bool;
fn ge(self, rhs: Self) -> Self::Bool;
}
+
+macro_rules! impl_compare_using_partial_cmp {
+ ($($ty:ty),*) => {
+ $(
+ impl Compare for $ty {
+ type Bool = bool;
+ fn eq(self, rhs: Self) -> Self::Bool {
+ self == rhs
+ }
+ fn ne(self, rhs: Self) -> Self::Bool {
+ self != rhs
+ }
+ fn lt(self, rhs: Self) -> Self::Bool {
+ self < rhs
+ }
+ fn gt(self, rhs: Self) -> Self::Bool {
+ self > rhs
+ }
+ fn le(self, rhs: Self) -> Self::Bool {
+ self <= rhs
+ }
+ fn ge(self, rhs: Self) -> Self::Bool {
+ self >= rhs
+ }
+ }
+ )*
+ };
+}
+
+impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];