impl traits for scalar types
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 3 May 2021 02:38:29 +0000 (19:38 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 3 May 2021 02:38:29 +0000 (19:38 -0700)
Cargo.toml
src/f16.rs [new file with mode: 0644]
src/lib.rs
src/scalar.rs [new file with mode: 0644]
src/traits.rs

index 02c3dcd85f54fd19301f1a3af660f249b39c3f3b..9e011cba91d62121a71cdc3a2752f0e912edc2b8 100644 (file)
@@ -9,5 +9,7 @@ license = "MIT OR Apache-2.0"
 half = { version = "1.7.1", optional = true }
 
 [features]
-default = ["f16"]
+default = ["f16", "fma"]
 f16 = ["half"]
+fma = ["std"]
+std = []
diff --git a/src/f16.rs b/src/f16.rs
new file mode 100644 (file)
index 0000000..04e8fbd
--- /dev/null
@@ -0,0 +1,287 @@
+use core::ops::{
+    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
+};
+
+use crate::traits::{ConvertTo, Float};
+
+#[cfg(feature = "f16")]
+use half::f16 as F16Impl;
+
+#[cfg(not(feature = "f16"))]
+#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
+enum F16Impl {}
+
+#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
+#[cfg_attr(feature = "f16", repr(transparent))]
+pub struct F16(F16Impl);
+
+#[cfg(feature = "f16")]
+macro_rules! f16_impl {
+    ($v:expr, [$($vars:ident),*]) => {
+        $v
+    };
+}
+
+#[cfg(not(feature = "f16"))]
+macro_rules! f16_impl {
+    ($v:expr, [$($vars:ident),*]) => {
+        {
+            $(let _ = $vars;)*
+            panic!("f16 feature is not enabled")
+        }
+    };
+}
+
+impl From<F16Impl> for F16 {
+    fn from(v: F16Impl) -> Self {
+        F16(v)
+    }
+}
+
+impl From<F16> for F16Impl {
+    fn from(v: F16) -> Self {
+        v.0
+    }
+}
+
+macro_rules! impl_f16_from {
+    ($($ty:ident,)*) => {
+        $(
+            impl From<$ty> for F16 {
+                fn from(v: $ty) -> Self {
+                    f16_impl!(F16(F16Impl::from(v)), [v])
+                }
+            }
+
+            impl ConvertTo<F16> for $ty {
+                fn to(self) -> F16 {
+                    self.into()
+                }
+            }
+        )*
+    };
+}
+
+macro_rules! impl_from_f16 {
+    ($($ty:ident,)*) => {
+        $(
+            impl From<F16> for $ty {
+                fn from(v: F16) -> Self {
+                    #[cfg(feature = "f16")]
+                    return v.0.into();
+                    #[cfg(not(feature = "f16"))]
+                    match v.0 {}
+                }
+            }
+
+            impl ConvertTo<$ty> for F16 {
+                fn to(self) -> $ty {
+                    self.into()
+                }
+            }
+        )*
+    };
+}
+
+impl_f16_from![i8, u8,];
+
+impl_from_f16![f32, f64,];
+
+macro_rules! impl_int_to_f16 {
+    ($($int:ident),*) => {
+        $(
+            impl ConvertTo<F16> for $int {
+                fn to(self) -> F16 {
+                    // f32 has enough mantissa bits such that f16 overflows to
+                    // infinity before f32 can't properly represent integer
+                    // values, making the below conversion correct.
+                    (self as f32).to()
+                }
+            }
+        )*
+    };
+}
+
+macro_rules! impl_f16_to_int {
+    ($($int:ident),*) => {
+        $(
+            impl ConvertTo<$int> for F16 {
+                fn to(self) -> $int {
+                    f32::from(self) as $int
+                }
+            }
+        )*
+    };
+}
+
+impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
+impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
+
+impl ConvertTo<F16> for f32 {
+    fn to(self) -> F16 {
+        f16_impl!(F16(F16Impl::from_f32(self)), [])
+    }
+}
+
+impl ConvertTo<F16> for f64 {
+    fn to(self) -> F16 {
+        f16_impl!(F16(F16Impl::from_f64(self)), [])
+    }
+}
+
+impl Neg for F16 {
+    type Output = Self;
+
+    fn neg(self) -> Self::Output {
+        Self::from_bits(self.to_bits() ^ 0x8000)
+    }
+}
+
+macro_rules! impl_bin_op_using_f32 {
+    ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
+        $(
+            impl $op for F16 {
+                type Output = Self;
+
+                fn $op_fn(self, rhs: Self) -> Self::Output {
+                    f32::from(self).$op_fn(f32::from(rhs)).to()
+                }
+            }
+
+            impl $op_assign for F16 {
+                fn $op_assign_fn(&mut self, rhs: Self) {
+                    *self = (*self).$op_fn(rhs);
+                }
+            }
+        )*
+    };
+}
+
+impl_bin_op_using_f32! {
+    Add, add, AddAssign, add_assign;
+    Sub, sub, SubAssign, sub_assign;
+    Mul, mul, MulAssign, mul_assign;
+    Div, div, DivAssign, div_assign;
+    Rem, rem, RemAssign, rem_assign;
+}
+
+impl Float<u32> for F16 {
+    type BitsType = u16;
+
+    fn abs(self) -> Self {
+        Self::from_bits(self.to_bits() & 0x7FFF)
+    }
+
+    fn trunc(self) -> Self {
+        f32::from(self).trunc().to()
+    }
+
+    fn ceil(self) -> Self {
+        f32::from(self).ceil().to()
+    }
+
+    fn floor(self) -> Self {
+        f32::from(self).floor().to()
+    }
+
+    fn round(self) -> Self {
+        f32::from(self).round().to()
+    }
+
+    #[cfg(feature = "fma")]
+    fn fma(self, a: Self, b: Self) -> Self {
+        (f64::from(self) * f64::from(a) + f64::from(b)).to()
+    }
+
+    fn is_nan(self) -> Self::Bool {
+        f16_impl!(self.0.is_nan(), [])
+    }
+
+    fn is_infinite(self) -> Self::Bool {
+        f16_impl!(self.0.is_infinite(), [])
+    }
+
+    fn is_finite(self) -> Self::Bool {
+        f16_impl!(self.0.is_finite(), [])
+    }
+
+    fn from_bits(v: Self::BitsType) -> Self {
+        f16_impl!(F16(F16Impl::from_bits(v)), [v])
+    }
+
+    fn to_bits(self) -> Self::BitsType {
+        f16_impl!(self.0.to_bits(), [])
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use core::cmp::Ordering;
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_abs() {
+        assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
+        assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
+        assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
+        assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
+        assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_neg() {
+        assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
+        assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
+        assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
+        assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
+        assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_int_to_f16() {
+        assert_eq!(F16::to_bits(0u32.to()), 0);
+        for v in 1..0x20000u32 {
+            let leading_zeros = u32::leading_zeros(v);
+            let shifted_v = v << leading_zeros;
+            // round to nearest, ties to even
+            let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
+                Ordering::Less => false,
+                Ordering::Equal => (shifted_v & 0x200000) != 0,
+                Ordering::Greater => true,
+            };
+            let (rounded, carry) =
+                (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
+            let mantissa;
+            if carry {
+                mantissa = (rounded >> 22) as u16 + 0x400;
+            } else {
+                mantissa = (rounded >> 21) as u16;
+            }
+            assert_eq!((mantissa & !0x3FF), 0x400);
+            let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
+            let expected = if exponent < 0x1F {
+                (mantissa & 0x3FF) + (exponent << 10)
+            } else {
+                0x7C00
+            };
+            let actual = F16::to_bits(v.to());
+            assert_eq!(
+                actual, expected,
+                "actual = {:#X}, expected = {:#X}, v = {:#X}",
+                actual, expected, v
+            );
+        }
+    }
+}
index 2f64c51b340ce1152fd70dc73cfe039da5543d22..551ef7d5d073e58082486b7e0b3da1a7eaa9ab38 100644 (file)
@@ -1,11 +1,9 @@
 #![no_std]
+#![deny(unconditional_recursion)]
 
-pub mod traits;
-
-#[cfg(feature = "f16")]
-pub use half::f16;
+#[cfg(any(feature = "std", test))]
+extern crate std;
 
-#[cfg(not(feature = "f16"))]
-#[allow(non_camel_case_types)]
-#[derive(Clone, Copy, PartialEq, PartialOrd, Debug, Hash)]
-pub enum f16 {}
+pub mod f16;
+pub mod scalar;
+pub mod traits;
diff --git a/src/scalar.rs b/src/scalar.rs
new file mode 100644 (file)
index 0000000..c6794e2
--- /dev/null
@@ -0,0 +1,62 @@
+use crate::traits::{Context, Make};
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)]
+pub struct Scalar;
+
+impl Context for Scalar {
+    type Bool = bool;
+
+    type U8 = u8;
+
+    type I8 = i8;
+
+    type U16 = u16;
+
+    type I16 = i16;
+
+    type F16 = crate::f16::F16;
+
+    type U32 = u32;
+
+    type I32 = i32;
+
+    type F32 = f32;
+
+    type U64 = u64;
+
+    type I64 = i64;
+
+    type F64 = f64;
+
+    type VecBool = bool;
+
+    type VecU8 = u8;
+
+    type VecI8 = i8;
+
+    type VecU16 = u16;
+
+    type VecI16 = i16;
+
+    type VecF16 = crate::f16::F16;
+
+    type VecU32 = u32;
+
+    type VecI32 = i32;
+
+    type VecF32 = f32;
+
+    type VecU64 = u64;
+
+    type VecI64 = i64;
+
+    type VecF64 = f64;
+}
+
+impl<T> Make<Scalar> for T {
+    type Prim = T;
+
+    fn make(_ctx: Scalar, v: Self::Prim) -> Self {
+        v
+    }
+}
index c1d8095f2ca78bcde9e13ee19132afd676b25448..fa466542470170c81004ac958f736ee3b7da1647 100644 (file)
@@ -3,7 +3,7 @@ use core::ops::{
     Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
 };
 
-use crate::f16;
+use crate::f16::F16;
 
 #[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
 macro_rules! make_float_type {
@@ -286,7 +286,7 @@ macro_rules! make_types {
                     $U16;
                     #[int(prim = i16 $(, scalar = $ScalarI16)?)]
                     $I16;
-                    #[float(prim = f16 $(, scalar = $ScalarF16)?)]
+                    #[float(prim = F16 $(, scalar = $ScalarF16)?)]
                     $F16;
                 },
                 {
@@ -412,10 +412,32 @@ pub trait ConvertTo<T> {
     fn to(self) -> T;
 }
 
-impl<T, U: Into<T>> ConvertTo<T> for U {
-    fn to(self) -> T {
-        self.into()
-    }
+macro_rules! impl_convert_to_using_as {
+    ($($src:ident -> [$($dest:ident),*];)*) => {
+        $($(
+            impl ConvertTo<$dest> for $src {
+                fn to(self) -> $dest {
+                    self as $dest
+                }
+            }
+        )*)*
+    };
+    ([$($src:ident),*] -> $dest:tt;) => {
+        impl_convert_to_using_as! {
+            $(
+                $src -> $dest;
+            )*
+        }
+    };
+    ([$($src:ident),*];) => {
+        impl_convert_to_using_as! {
+            [$($src),*] -> [$($src),*];
+        }
+    };
+}
+
+impl_convert_to_using_as! {
+    [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
 }
 
 pub trait Number:
@@ -433,6 +455,21 @@ pub trait Number:
 {
 }
 
+impl<T> Number for T where
+    T: Compare
+        + Add<Output = Self>
+        + Sub<Output = Self>
+        + Mul<Output = Self>
+        + Div<Output = Self>
+        + Rem<Output = Self>
+        + AddAssign
+        + SubAssign
+        + MulAssign
+        + DivAssign
+        + RemAssign
+{
+}
+
 pub trait BitOps:
     Copy
     + BitAnd<Output = Self>
@@ -445,6 +482,18 @@ pub trait BitOps:
 {
 }
 
+impl<T> BitOps for T where
+    T: Copy
+        + BitAnd<Output = Self>
+        + BitOr<Output = Self>
+        + BitXor<Output = Self>
+        + Not<Output = Self>
+        + BitAndAssign
+        + BitOrAssign
+        + BitXorAssign
+{
+}
+
 pub trait Int<ShiftRhs>:
     Number
     + BitOps
@@ -459,6 +508,28 @@ pub trait UInt<ShiftRhs>: Int<ShiftRhs> {}
 
 pub trait SInt<ShiftRhs>: Int<ShiftRhs> + Neg<Output = Self> {}
 
+macro_rules! impl_uint {
+    ($($ty:ident),*) => {
+        $(
+            impl Int<u32> for $ty {}
+            impl UInt<u32> for $ty {}
+        )*
+    };
+}
+
+impl_uint![u8, u16, u32, u64];
+
+macro_rules! impl_int {
+    ($($ty:ident),*) => {
+        $(
+            impl Int<u32> for $ty {}
+            impl SInt<u32> for $ty {}
+        )*
+    };
+}
+
+impl_int![i8, i16, i32, i64];
+
 pub trait Float<BitsShiftRhs>: Number + Neg<Output = Self> {
     type BitsType: UInt<BitsShiftRhs>;
     fn abs(self) -> Self;
@@ -466,20 +537,92 @@ pub trait Float<BitsShiftRhs>: Number + Neg<Output = Self> {
     fn ceil(self) -> Self;
     fn floor(self) -> Self;
     fn round(self) -> Self;
+    #[cfg(feature = "fma")]
     fn fma(self, a: Self, b: Self) -> Self;
     fn is_nan(self) -> Self::Bool;
-    fn is_infinity(self) -> Self::Bool;
+    fn is_infinite(self) -> Self::Bool;
     fn is_finite(self) -> Self::Bool;
     fn from_bits(v: Self::BitsType) -> Self;
     fn to_bits(self) -> Self::BitsType;
 }
 
+macro_rules! impl_float {
+    ($ty:ty, $bits:ty) => {
+        impl Float<u32> for $ty {
+            type BitsType = $bits;
+            fn abs(self) -> Self {
+                #[cfg(feature = "std")]
+                return self.abs();
+                #[cfg(not(feature = "std"))]
+                todo!();
+            }
+            fn trunc(self) -> Self {
+                #[cfg(feature = "std")]
+                return self.trunc();
+                #[cfg(not(feature = "std"))]
+                todo!();
+            }
+            fn ceil(self) -> Self {
+                #[cfg(feature = "std")]
+                return self.ceil();
+                #[cfg(not(feature = "std"))]
+                todo!();
+            }
+            fn floor(self) -> Self {
+                #[cfg(feature = "std")]
+                return self.floor();
+                #[cfg(not(feature = "std"))]
+                todo!();
+            }
+            fn round(self) -> Self {
+                #[cfg(feature = "std")]
+                return self.round();
+                #[cfg(not(feature = "std"))]
+                todo!();
+            }
+            #[cfg(feature = "fma")]
+            fn fma(self, a: Self, b: Self) -> Self {
+                self.mul_add(a, b)
+            }
+            fn is_nan(self) -> Self::Bool {
+                self.is_nan()
+            }
+            fn is_infinite(self) -> Self::Bool {
+                self.is_infinite()
+            }
+            fn is_finite(self) -> Self::Bool {
+                self.is_finite()
+            }
+            fn from_bits(v: Self::BitsType) -> Self {
+                <$ty>::from_bits(v)
+            }
+            fn to_bits(self) -> Self::BitsType {
+                self.to_bits()
+            }
+        }
+    };
+}
+
+impl_float!(f32, u32);
+impl_float!(f64, u64);
+
 pub trait Bool: BitOps {}
 
+impl Bool for bool {}
+
 pub trait Select<T>: Bool {
     fn select(self, true_v: T, false_v: T) -> T;
 }
 
+impl<T> Select<T> for bool {
+    fn select(self, true_v: T, false_v: T) -> T {
+        if self {
+            true_v
+        } else {
+            false_v
+        }
+    }
+}
 pub trait Compare: Copy {
     type Bool: Bool + Select<Self>;
     fn eq(self, rhs: Self) -> Self::Bool;
@@ -489,3 +632,33 @@ pub trait Compare: Copy {
     fn le(self, rhs: Self) -> Self::Bool;
     fn ge(self, rhs: Self) -> Self::Bool;
 }
+
+macro_rules! impl_compare_using_partial_cmp {
+    ($($ty:ty),*) => {
+        $(
+            impl Compare for $ty {
+                type Bool = bool;
+                fn eq(self, rhs: Self) -> Self::Bool {
+                    self == rhs
+                }
+                fn ne(self, rhs: Self) -> Self::Bool {
+                    self != rhs
+                }
+                fn lt(self, rhs: Self) -> Self::Bool {
+                    self < rhs
+                }
+                fn gt(self, rhs: Self) -> Self::Bool {
+                    self > rhs
+                }
+                fn le(self, rhs: Self) -> Self::Bool {
+                    self <= rhs
+                }
+                fn ge(self, rhs: Self) -> Self::Bool {
+                    self >= rhs
+                }
+            }
+        )*
+    };
+}
+
+impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];