use crate::{ f16::F16, prim::{PrimFloat, PrimSInt, PrimUInt}, }; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; /// reference used to build IR for Kazan; an empty type for `core::simd` pub trait Context: Copy { vector_math_proc_macro::make_context_types!(); fn make>(self, v: T::Prim) -> T { T::make(self, v) } } pub trait Make: Copy { type Prim: Copy; type Context: Context; fn ctx(self) -> Self::Context; fn make(ctx: Self::Context, v: Self::Prim) -> Self; } pub trait ConvertFrom: Sized { fn cvt_from(v: T) -> Self; } impl ConvertFrom for T { fn cvt_from(v: T) -> Self { v } } pub trait ConvertTo { fn to(self) -> T; } impl> ConvertTo for F { fn to(self) -> T { T::cvt_from(self) } } macro_rules! impl_convert_from_using_as { ($first:ident $(, $ty:ident)*) => { $( impl ConvertFrom<$first> for $ty { fn cvt_from(v: $first) -> Self { v as _ } } impl ConvertFrom<$ty> for $first { fn cvt_from(v: $ty) -> Self { v as _ } } )* impl_convert_from_using_as![$($ty),*]; }; () => { }; } impl_convert_from_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; pub trait Number: Compare + Add + Sub + Mul + Div + Rem + AddAssign + SubAssign + MulAssign + DivAssign + RemAssign { } impl Number for T where T: Compare + Add + Sub + Mul + Div + Rem + AddAssign + SubAssign + MulAssign + DivAssign + RemAssign { } pub trait BitOps: Copy + BitAnd + BitOr + BitXor + Not + BitAndAssign + BitOrAssign + BitXorAssign { } impl BitOps for T where T: Copy + BitAnd + BitOr + BitXor + Not + BitAndAssign + BitOrAssign + BitXorAssign { } pub trait Int: Number + BitOps + Shl + Shr + ShlAssign + ShrAssign { fn leading_zeros(self) -> Self; fn leading_ones(self) -> Self { self.not().leading_zeros() } fn trailing_zeros(self) -> Self; fn trailing_ones(self) -> Self { self.not().trailing_zeros() } fn count_zeros(self) -> Self { self.not().count_ones() } fn count_ones(self) -> Self; } pub trait UInt: Int + Make + ConvertFrom { type PrimUInt: PrimUInt::PrimSInt>; type SignedType: SInt + ConvertFrom + Make + Compare; } pub trait SInt: Int + Neg + Make + ConvertFrom { type PrimSInt: PrimSInt::PrimUInt>; type UnsignedType: UInt + ConvertFrom + Make + Compare; } pub trait Float: Number + Neg + Make + ConvertFrom + ConvertFrom { type PrimFloat: PrimFloat; type BitsType: UInt::BitsType, SignedType = Self::SignedBitsType> + Make::BitsType> + Compare + ConvertFrom; type SignedBitsType: SInt< PrimSInt = ::SignedBitsType, UnsignedType = Self::BitsType, > + Make::SignedBitsType> + Compare + ConvertFrom; fn abs(self) -> Self; fn trunc(self) -> Self; fn ceil(self) -> Self; fn floor(self) -> Self; /// round to nearest integer, unspecified which way half-way cases are rounded fn round(self) -> Self; /// returns `self * a + b` but only rounding once #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self; /// returns `self * a + b` either using `fma` or `self * a + b` fn mul_add_fast(self, a: Self, b: Self) -> Self { #[cfg(feature = "fma")] return self.fma(a, b); #[cfg(not(feature = "fma"))] return self * a + b; } fn is_nan(self) -> Self::Bool { self.ne(self) } fn is_infinite(self) -> Self::Bool { self.abs().eq(Self::infinity(self.ctx())) } fn infinity(ctx: Self::Context) -> Self { Self::from_bits(ctx.make(Self::PrimFloat::INFINITY_BITS)) } fn nan(ctx: Self::Context) -> Self { Self::from_bits(ctx.make(Self::PrimFloat::NAN_BITS)) } fn is_finite(self) -> Self::Bool; fn is_zero_or_subnormal(self) -> Self::Bool { self.extract_exponent_field() .eq(self.ctx().make(Self::PrimFloat::ZERO_SUBNORMAL_EXPONENT)) } fn from_bits(v: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; fn extract_exponent_field(self) -> Self::BitsType { let mask = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_MASK); let shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); (self.to_bits() & mask) >> shift } fn extract_exponent_unbiased(self) -> Self::SignedBitsType { Self::sub_exponent_bias(self.extract_exponent_field()) } fn extract_mantissa_field(self) -> Self::BitsType { let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK); self.to_bits() & mask } fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType { Self::SignedBitsType::cvt_from(exponent_field) - exponent_field .ctx() .make(Self::PrimFloat::EXPONENT_BIAS_SIGNED) } fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType { (exponent + exponent.ctx().make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)).to() } } pub trait Bool: Make + BitOps {} pub trait Select: Bool { fn select(self, true_v: T, false_v: T) -> T; } pub trait Compare: Make { type Bool: Bool + Select; fn eq(self, rhs: Self) -> Self::Bool; fn ne(self, rhs: Self) -> Self::Bool; fn lt(self, rhs: Self) -> Self::Bool; fn gt(self, rhs: Self) -> Self::Bool; fn le(self, rhs: Self) -> Self::Bool; fn ge(self, rhs: Self) -> Self::Bool; }