use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar}; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; /// reference used to build IR for Kazan; an empty type for `core::simd` pub trait Context: Copy { vector_math_proc_macro::make_context_types!(); fn make>(self, v: T::Prim) -> T { T::make(self, v) } } pub trait Make: Copy { type Prim: Copy; type Context: Context; fn ctx(self) -> Self::Context; fn make(ctx: Self::Context, v: Self::Prim) -> Self; } pub trait ConvertTo { fn to(self) -> T; } impl ConvertTo for T { fn to(self) -> T { self } } macro_rules! impl_convert_to_using_as { ($first:ident $(, $ty:ident)*) => { $( impl ConvertTo<$first> for $ty { fn to(self) -> $first { self as $first } } impl ConvertTo<$ty> for $first { fn to(self) -> $ty { self as $ty } } )* impl_convert_to_using_as![$($ty),*]; }; () => { }; } impl_convert_to_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; pub trait Number: Compare + Add + Sub + Mul + Div + Rem + AddAssign + SubAssign + MulAssign + DivAssign + RemAssign { } impl Number for T where T: Compare + Add + Sub + Mul + Div + Rem + AddAssign + SubAssign + MulAssign + DivAssign + RemAssign { } pub trait BitOps: Copy + BitAnd + BitOr + BitXor + Not + BitAndAssign + BitOrAssign + BitXorAssign { } impl BitOps for T where T: Copy + BitAnd + BitOr + BitXor + Not + BitAndAssign + BitOrAssign + BitXorAssign { } pub trait Int: Number + BitOps + Shl + Shr + ShlAssign + ShrAssign { fn leading_zeros(self) -> Self; fn leading_ones(self) -> Self { self.not().leading_zeros() } fn trailing_zeros(self) -> Self; fn trailing_ones(self) -> Self { self.not().trailing_zeros() } fn count_zeros(self) -> Self { self.not().count_ones() } fn count_ones(self) -> Self; } pub trait UInt: Int {} pub trait SInt: Int + Neg {} macro_rules! impl_int { ($ty:ident) => { impl Int for $ty { fn leading_zeros(self) -> Self { self.leading_zeros() as Self } fn leading_ones(self) -> Self { self.leading_ones() as Self } fn trailing_zeros(self) -> Self { self.trailing_zeros() as Self } fn trailing_ones(self) -> Self { self.trailing_ones() as Self } fn count_zeros(self) -> Self { self.count_zeros() as Self } fn count_ones(self) -> Self { self.count_ones() as Self } } }; } macro_rules! impl_uint { ($($ty:ident),*) => { $( impl_int!($ty); impl UInt for $ty {} )* }; } impl_uint![u8, u16, u32, u64]; macro_rules! impl_sint { ($($ty:ident),*) => { $( impl_int!($ty); impl SInt for $ty {} )* }; } impl_sint![i8, i16, i32, i64]; pub trait Float: Number + Neg { type FloatEncoding: FloatEncoding + Make::Prim>; type BitsType: UInt + Make::BitsType> + ConvertTo + Compare; type SignedBitsType: SInt + Make::SignedBitsType> + ConvertTo + Compare; fn abs(self) -> Self; fn trunc(self) -> Self; fn ceil(self) -> Self; fn floor(self) -> Self; /// round to nearest integer, unspecified which way half-way cases are rounded fn round(self) -> Self; /// returns `self * a + b` but only rounding once #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self; /// returns `self * a + b` either using `fma` or `self * a + b` fn mul_add_fast(self, a: Self, b: Self) -> Self { #[cfg(feature = "fma")] return self.fma(a, b); #[cfg(not(feature = "fma"))] return self * a + b; } fn is_nan(self) -> Self::Bool { self.ne(self) } fn is_infinite(self) -> Self::Bool { self.abs().eq(Self::infinity(self.ctx())) } fn infinity(ctx: Self::Context) -> Self { Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS)) } fn nan(ctx: Self::Context) -> Self { Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS)) } fn is_finite(self) -> Self::Bool; fn is_zero_or_subnormal(self) -> Self::Bool { self.extract_exponent_field().eq(self .ctx() .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT)) } fn from_bits(v: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; fn extract_exponent_field(self) -> Self::BitsType { let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK); let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT); (self.to_bits() & mask) >> shift } fn extract_exponent_unbiased(self) -> Self::SignedBitsType { Self::sub_exponent_bias(self.extract_exponent_field()) } fn extract_mantissa_field(self) -> Self::BitsType { let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK); self.to_bits() & mask } fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType { exponent_field.to() - exponent_field .ctx() .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED) } fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType { (exponent + exponent .ctx() .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)) .to() } } macro_rules! impl_float { ($ty:ty, $bits:ty, $signed_bits:ty) => { impl Float for $ty { type FloatEncoding = $ty; type BitsType = $bits; type SignedBitsType = $signed_bits; fn abs(self) -> Self { #[cfg(feature = "std")] return self.abs(); #[cfg(not(feature = "std"))] todo!(); } fn trunc(self) -> Self { #[cfg(feature = "std")] return self.trunc(); #[cfg(not(feature = "std"))] todo!(); } fn ceil(self) -> Self { #[cfg(feature = "std")] return self.ceil(); #[cfg(not(feature = "std"))] todo!(); } fn floor(self) -> Self { #[cfg(feature = "std")] return self.floor(); #[cfg(not(feature = "std"))] todo!(); } fn round(self) -> Self { #[cfg(feature = "std")] return self.round(); #[cfg(not(feature = "std"))] todo!(); } #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self { self.mul_add(a, b) } fn is_nan(self) -> Self::Bool { self.is_nan() } fn is_infinite(self) -> Self::Bool { self.is_infinite() } fn is_finite(self) -> Self::Bool { self.is_finite() } fn from_bits(v: Self::BitsType) -> Self { <$ty>::from_bits(v) } fn to_bits(self) -> Self::BitsType { self.to_bits() } } }; } impl_float!(f32, u32, i32); impl_float!(f64, u64, i64); pub trait Bool: Make + BitOps {} impl Bool for bool {} pub trait Select: Bool { fn select(self, true_v: T, false_v: T) -> T; } impl Select for bool { fn select(self, true_v: T, false_v: T) -> T { if self { true_v } else { false_v } } } pub trait Compare: Make { type Bool: Bool + Select; fn eq(self, rhs: Self) -> Self::Bool; fn ne(self, rhs: Self) -> Self::Bool; fn lt(self, rhs: Self) -> Self::Bool; fn gt(self, rhs: Self) -> Self::Bool; fn le(self, rhs: Self) -> Self::Bool; fn ge(self, rhs: Self) -> Self::Bool; } macro_rules! impl_compare_using_partial_cmp { ($($ty:ty),*) => { $( impl Compare for $ty { type Bool = bool; fn eq(self, rhs: Self) -> Self::Bool { self == rhs } fn ne(self, rhs: Self) -> Self::Bool { self != rhs } fn lt(self, rhs: Self) -> Self::Bool { self < rhs } fn gt(self, rhs: Self) -> Self::Bool { self > rhs } fn le(self, rhs: Self) -> Self::Bool { self <= rhs } fn ge(self, rhs: Self) -> Self::Bool { self >= rhs } } )* }; } impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];