From 77cfb7eaa2407576c3a49938c58e26b12ed9699d Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Sun, 2 May 2021 19:38:29 -0700 Subject: [PATCH] impl traits for scalar types --- Cargo.toml | 4 +- src/f16.rs | 287 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 14 ++- src/scalar.rs | 62 +++++++++++ src/traits.rs | 187 ++++++++++++++++++++++++++++++-- 5 files changed, 538 insertions(+), 16 deletions(-) create mode 100644 src/f16.rs create mode 100644 src/scalar.rs diff --git a/Cargo.toml b/Cargo.toml index 02c3dcd..9e011cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,5 +9,7 @@ license = "MIT OR Apache-2.0" half = { version = "1.7.1", optional = true } [features] -default = ["f16"] +default = ["f16", "fma"] f16 = ["half"] +fma = ["std"] +std = [] diff --git a/src/f16.rs b/src/f16.rs new file mode 100644 index 0000000..04e8fbd --- /dev/null +++ b/src/f16.rs @@ -0,0 +1,287 @@ +use core::ops::{ + Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, +}; + +use crate::traits::{ConvertTo, Float}; + +#[cfg(feature = "f16")] +use half::f16 as F16Impl; + +#[cfg(not(feature = "f16"))] +#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)] +enum F16Impl {} + +#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)] +#[cfg_attr(feature = "f16", repr(transparent))] +pub struct F16(F16Impl); + +#[cfg(feature = "f16")] +macro_rules! f16_impl { + ($v:expr, [$($vars:ident),*]) => { + $v + }; +} + +#[cfg(not(feature = "f16"))] +macro_rules! f16_impl { + ($v:expr, [$($vars:ident),*]) => { + { + $(let _ = $vars;)* + panic!("f16 feature is not enabled") + } + }; +} + +impl From for F16 { + fn from(v: F16Impl) -> Self { + F16(v) + } +} + +impl From for F16Impl { + fn from(v: F16) -> Self { + v.0 + } +} + +macro_rules! impl_f16_from { + ($($ty:ident,)*) => { + $( + impl From<$ty> for F16 { + fn from(v: $ty) -> Self { + f16_impl!(F16(F16Impl::from(v)), [v]) + } + } + + impl ConvertTo for $ty { + fn to(self) -> F16 { + self.into() + } + } + )* + }; +} + +macro_rules! impl_from_f16 { + ($($ty:ident,)*) => { + $( + impl From for $ty { + fn from(v: F16) -> Self { + #[cfg(feature = "f16")] + return v.0.into(); + #[cfg(not(feature = "f16"))] + match v.0 {} + } + } + + impl ConvertTo<$ty> for F16 { + fn to(self) -> $ty { + self.into() + } + } + )* + }; +} + +impl_f16_from![i8, u8,]; + +impl_from_f16![f32, f64,]; + +macro_rules! impl_int_to_f16 { + ($($int:ident),*) => { + $( + impl ConvertTo for $int { + fn to(self) -> F16 { + // f32 has enough mantissa bits such that f16 overflows to + // infinity before f32 can't properly represent integer + // values, making the below conversion correct. + (self as f32).to() + } + } + )* + }; +} + +macro_rules! impl_f16_to_int { + ($($int:ident),*) => { + $( + impl ConvertTo<$int> for F16 { + fn to(self) -> $int { + f32::from(self) as $int + } + } + )* + }; +} + +impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128]; +impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128]; + +impl ConvertTo for f32 { + fn to(self) -> F16 { + f16_impl!(F16(F16Impl::from_f32(self)), []) + } +} + +impl ConvertTo for f64 { + fn to(self) -> F16 { + f16_impl!(F16(F16Impl::from_f64(self)), []) + } +} + +impl Neg for F16 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::from_bits(self.to_bits() ^ 0x8000) + } +} + +macro_rules! impl_bin_op_using_f32 { + ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => { + $( + impl $op for F16 { + type Output = Self; + + fn $op_fn(self, rhs: Self) -> Self::Output { + f32::from(self).$op_fn(f32::from(rhs)).to() + } + } + + impl $op_assign for F16 { + fn $op_assign_fn(&mut self, rhs: Self) { + *self = (*self).$op_fn(rhs); + } + } + )* + }; +} + +impl_bin_op_using_f32! { + Add, add, AddAssign, add_assign; + Sub, sub, SubAssign, sub_assign; + Mul, mul, MulAssign, mul_assign; + Div, div, DivAssign, div_assign; + Rem, rem, RemAssign, rem_assign; +} + +impl Float for F16 { + type BitsType = u16; + + fn abs(self) -> Self { + Self::from_bits(self.to_bits() & 0x7FFF) + } + + fn trunc(self) -> Self { + f32::from(self).trunc().to() + } + + fn ceil(self) -> Self { + f32::from(self).ceil().to() + } + + fn floor(self) -> Self { + f32::from(self).floor().to() + } + + fn round(self) -> Self { + f32::from(self).round().to() + } + + #[cfg(feature = "fma")] + fn fma(self, a: Self, b: Self) -> Self { + (f64::from(self) * f64::from(a) + f64::from(b)).to() + } + + fn is_nan(self) -> Self::Bool { + f16_impl!(self.0.is_nan(), []) + } + + fn is_infinite(self) -> Self::Bool { + f16_impl!(self.0.is_infinite(), []) + } + + fn is_finite(self) -> Self::Bool { + f16_impl!(self.0.is_finite(), []) + } + + fn from_bits(v: Self::BitsType) -> Self { + f16_impl!(F16(F16Impl::from_bits(v)), [v]) + } + + fn to_bits(self) -> Self::BitsType { + f16_impl!(self.0.to_bits(), []) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use core::cmp::Ordering; + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_abs() { + assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0); + assert_eq!(F16::from_bits(0).abs().to_bits(), 0); + assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC); + assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00); + assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00); + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_neg() { + assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0); + assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000); + assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC); + assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00); + assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00); + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_int_to_f16() { + assert_eq!(F16::to_bits(0u32.to()), 0); + for v in 1..0x20000u32 { + let leading_zeros = u32::leading_zeros(v); + let shifted_v = v << leading_zeros; + // round to nearest, ties to even + let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) { + Ordering::Less => false, + Ordering::Equal => (shifted_v & 0x200000) != 0, + Ordering::Greater => true, + }; + let (rounded, carry) = + (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0)); + let mantissa; + if carry { + mantissa = (rounded >> 22) as u16 + 0x400; + } else { + mantissa = (rounded >> 21) as u16; + } + assert_eq!((mantissa & !0x3FF), 0x400); + let exponent = 31 - leading_zeros as u16 + 15 + carry as u16; + let expected = if exponent < 0x1F { + (mantissa & 0x3FF) + (exponent << 10) + } else { + 0x7C00 + }; + let actual = F16::to_bits(v.to()); + assert_eq!( + actual, expected, + "actual = {:#X}, expected = {:#X}, v = {:#X}", + actual, expected, v + ); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 2f64c51..551ef7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,9 @@ #![no_std] +#![deny(unconditional_recursion)] -pub mod traits; - -#[cfg(feature = "f16")] -pub use half::f16; +#[cfg(any(feature = "std", test))] +extern crate std; -#[cfg(not(feature = "f16"))] -#[allow(non_camel_case_types)] -#[derive(Clone, Copy, PartialEq, PartialOrd, Debug, Hash)] -pub enum f16 {} +pub mod f16; +pub mod scalar; +pub mod traits; diff --git a/src/scalar.rs b/src/scalar.rs new file mode 100644 index 0000000..c6794e2 --- /dev/null +++ b/src/scalar.rs @@ -0,0 +1,62 @@ +use crate::traits::{Context, Make}; + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)] +pub struct Scalar; + +impl Context for Scalar { + type Bool = bool; + + type U8 = u8; + + type I8 = i8; + + type U16 = u16; + + type I16 = i16; + + type F16 = crate::f16::F16; + + type U32 = u32; + + type I32 = i32; + + type F32 = f32; + + type U64 = u64; + + type I64 = i64; + + type F64 = f64; + + type VecBool = bool; + + type VecU8 = u8; + + type VecI8 = i8; + + type VecU16 = u16; + + type VecI16 = i16; + + type VecF16 = crate::f16::F16; + + type VecU32 = u32; + + type VecI32 = i32; + + type VecF32 = f32; + + type VecU64 = u64; + + type VecI64 = i64; + + type VecF64 = f64; +} + +impl Make for T { + type Prim = T; + + fn make(_ctx: Scalar, v: Self::Prim) -> Self { + v + } +} diff --git a/src/traits.rs b/src/traits.rs index c1d8095..fa46654 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -3,7 +3,7 @@ use core::ops::{ Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; -use crate::f16; +use crate::f16::F16; #[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 macro_rules! make_float_type { @@ -286,7 +286,7 @@ macro_rules! make_types { $U16; #[int(prim = i16 $(, scalar = $ScalarI16)?)] $I16; - #[float(prim = f16 $(, scalar = $ScalarF16)?)] + #[float(prim = F16 $(, scalar = $ScalarF16)?)] $F16; }, { @@ -412,10 +412,32 @@ pub trait ConvertTo { fn to(self) -> T; } -impl> ConvertTo for U { - fn to(self) -> T { - self.into() - } +macro_rules! impl_convert_to_using_as { + ($($src:ident -> [$($dest:ident),*];)*) => { + $($( + impl ConvertTo<$dest> for $src { + fn to(self) -> $dest { + self as $dest + } + } + )*)* + }; + ([$($src:ident),*] -> $dest:tt;) => { + impl_convert_to_using_as! { + $( + $src -> $dest; + )* + } + }; + ([$($src:ident),*];) => { + impl_convert_to_using_as! { + [$($src),*] -> [$($src),*]; + } + }; +} + +impl_convert_to_using_as! { + [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; } pub trait Number: @@ -433,6 +455,21 @@ pub trait Number: { } +impl Number for T where + T: Compare + + Add + + Sub + + Mul + + Div + + Rem + + AddAssign + + SubAssign + + MulAssign + + DivAssign + + RemAssign +{ +} + pub trait BitOps: Copy + BitAnd @@ -445,6 +482,18 @@ pub trait BitOps: { } +impl BitOps for T where + T: Copy + + BitAnd + + BitOr + + BitXor + + Not + + BitAndAssign + + BitOrAssign + + BitXorAssign +{ +} + pub trait Int: Number + BitOps @@ -459,6 +508,28 @@ pub trait UInt: Int {} pub trait SInt: Int + Neg {} +macro_rules! impl_uint { + ($($ty:ident),*) => { + $( + impl Int for $ty {} + impl UInt for $ty {} + )* + }; +} + +impl_uint![u8, u16, u32, u64]; + +macro_rules! impl_int { + ($($ty:ident),*) => { + $( + impl Int for $ty {} + impl SInt for $ty {} + )* + }; +} + +impl_int![i8, i16, i32, i64]; + pub trait Float: Number + Neg { type BitsType: UInt; fn abs(self) -> Self; @@ -466,20 +537,92 @@ pub trait Float: Number + Neg { fn ceil(self) -> Self; fn floor(self) -> Self; fn round(self) -> Self; + #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self; fn is_nan(self) -> Self::Bool; - fn is_infinity(self) -> Self::Bool; + fn is_infinite(self) -> Self::Bool; fn is_finite(self) -> Self::Bool; fn from_bits(v: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; } +macro_rules! impl_float { + ($ty:ty, $bits:ty) => { + impl Float for $ty { + type BitsType = $bits; + fn abs(self) -> Self { + #[cfg(feature = "std")] + return self.abs(); + #[cfg(not(feature = "std"))] + todo!(); + } + fn trunc(self) -> Self { + #[cfg(feature = "std")] + return self.trunc(); + #[cfg(not(feature = "std"))] + todo!(); + } + fn ceil(self) -> Self { + #[cfg(feature = "std")] + return self.ceil(); + #[cfg(not(feature = "std"))] + todo!(); + } + fn floor(self) -> Self { + #[cfg(feature = "std")] + return self.floor(); + #[cfg(not(feature = "std"))] + todo!(); + } + fn round(self) -> Self { + #[cfg(feature = "std")] + return self.round(); + #[cfg(not(feature = "std"))] + todo!(); + } + #[cfg(feature = "fma")] + fn fma(self, a: Self, b: Self) -> Self { + self.mul_add(a, b) + } + fn is_nan(self) -> Self::Bool { + self.is_nan() + } + fn is_infinite(self) -> Self::Bool { + self.is_infinite() + } + fn is_finite(self) -> Self::Bool { + self.is_finite() + } + fn from_bits(v: Self::BitsType) -> Self { + <$ty>::from_bits(v) + } + fn to_bits(self) -> Self::BitsType { + self.to_bits() + } + } + }; +} + +impl_float!(f32, u32); +impl_float!(f64, u64); + pub trait Bool: BitOps {} +impl Bool for bool {} + pub trait Select: Bool { fn select(self, true_v: T, false_v: T) -> T; } +impl Select for bool { + fn select(self, true_v: T, false_v: T) -> T { + if self { + true_v + } else { + false_v + } + } +} pub trait Compare: Copy { type Bool: Bool + Select; fn eq(self, rhs: Self) -> Self::Bool; @@ -489,3 +632,33 @@ pub trait Compare: Copy { fn le(self, rhs: Self) -> Self::Bool; fn ge(self, rhs: Self) -> Self::Bool; } + +macro_rules! impl_compare_using_partial_cmp { + ($($ty:ty),*) => { + $( + impl Compare for $ty { + type Bool = bool; + fn eq(self, rhs: Self) -> Self::Bool { + self == rhs + } + fn ne(self, rhs: Self) -> Self::Bool { + self != rhs + } + fn lt(self, rhs: Self) -> Self::Bool { + self < rhs + } + fn gt(self, rhs: Self) -> Self::Bool { + self > rhs + } + fn le(self, rhs: Self) -> Self::Bool { + self <= rhs + } + fn ge(self, rhs: Self) -> Self::Bool { + self >= rhs + } + } + )* + }; +} + +impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64]; -- 2.30.2