X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Ftraits.rs;h=c02b60919cddc332266ce23b36318f2a45ec9e76;hb=8a170330691c442c16cf6a7c6606fc19493e9e81;hp=942c67f827734139aa610d17c9ef9400e729fdf8;hpb=af91717c9fff32fdb389c9a21bcf602b71ec7ccb;p=vector-math.git diff --git a/src/traits.rs b/src/traits.rs index 942c67f..c02b609 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,414 +1,15 @@ +use crate::{ + f16::F16, + prim::{PrimFloat, PrimSInt, PrimUInt}, +}; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; -use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar}; - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_float_type { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [ - $({ - #[uint] - $uint_smaller:ident; - #[int] - $int_smaller:ident; - $( - #[float] - $float_smaller:ident; - )? - },)* - ], - { - #[uint] - $uint:ident; - #[int] - $int:ident; - #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)] - $float:ident; - }, - [ - $({ - #[uint] - $uint_larger:ident; - #[int] - $int_larger:ident; - $( - #[float] - $float_larger:ident; - )? - },)* - ] - ) => { - type $float: Float - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - + ConvertTo - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ Into + ConvertTo)?)*; - }; - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - { - #[uint] - $uint:ident; - #[int] - $int:ident; - }, - [$($larger:tt,)*] - ) => {}; -} - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_uint_int_float_type { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [ - $({ - #[uint($($uint_smaller_traits:tt)*)] - $uint_smaller:ident; - #[int($($int_smaller_traits:tt)*)] - $int_smaller:ident; - $( - #[float($($float_smaller_traits:tt)*)] - $float_smaller:ident; - )? - },)* - ], - { - #[uint(prim = $uint_prim:ident $(, scalar = $uint_scalar:ident)?)] - $uint:ident; - #[int(prim = $int_prim:ident $(, scalar = $int_scalar:ident)?)] - $int:ident; - $( - #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)] - $float:ident; - )? - }, - [ - $({ - #[uint($($uint_larger_traits:tt)*)] - $uint_larger:ident; - #[int($($int_larger_traits:tt)*)] - $int_larger:ident; - $( - #[float($($float_larger_traits:tt)*)] - $float_larger:ident; - )? - },)* - ] - ) => { - type $uint: UInt - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - $(+ ConvertTo)? - $(+ Into + ConvertTo)* - $(+ Into + ConvertTo)* - $($(+ Into + ConvertTo)?)*; - type $int: SInt - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - $(+ ConvertTo)? - $(+ ConvertTo)* - $(+ Into + ConvertTo)* - $($(+ Into + ConvertTo)?)*; - make_float_type! { - #[u32 = $u32] - #[bool = $bool] - [ - $({ - #[uint] - $uint_smaller; - #[int] - $int_smaller; - $( - #[float] - $float_smaller; - )? - },)* - ], - { - #[uint] - $uint; - #[int] - $int; - $( - #[float(prim = $float_prim $(, scalar = $float_scalar)?)] - $float; - )? - }, - [ - $({ - #[uint] - $uint_larger; - #[int] - $int_larger; - $( - #[float] - $float_larger; - )? - },)* - ] - } - }; -} - -macro_rules! make_uint_int_float_types { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - $current:tt, - [$first_larger:tt, $($larger:tt,)*] - ) => { - make_uint_int_float_type! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)*], - $current, - [$first_larger, $($larger,)*] - } - make_uint_int_float_types! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)* $current,], - $first_larger, - [$($larger,)*] - } - }; - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - $current:tt, - [] - ) => { - make_uint_int_float_type! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)*], - $current, - [] - } - }; -} - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_types { - ( - #[bool] - $(#[scalar = $ScalarBool:ident])? - type $Bool:ident; - - #[u8] - $(#[scalar = $ScalarU8:ident])? - type $U8:ident; - - #[u16] - $(#[scalar = $ScalarU16:ident])? - type $U16:ident; - - #[u32] - $(#[scalar = $ScalarU32:ident])? - type $U32:ident; - - #[u64] - $(#[scalar = $ScalarU64:ident])? - type $U64:ident; - - #[i8] - $(#[scalar = $ScalarI8:ident])? - type $I8:ident; - - #[i16] - $(#[scalar = $ScalarI16:ident])? - type $I16:ident; - - #[i32] - $(#[scalar = $ScalarI32:ident])? - type $I32:ident; - - #[i64] - $(#[scalar = $ScalarI64:ident])? - type $I64:ident; - - #[f16] - $(#[scalar = $ScalarF16:ident])? - type $F16:ident; - - #[f32] - $(#[scalar = $ScalarF32:ident])? - type $F32:ident; - - #[f64] - $(#[scalar = $ScalarF64:ident])? - type $F64:ident; - ) => { - type $Bool: Bool - $(+ From)? - + Make - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select; - make_uint_int_float_types! { - #[u32 = $U32] - #[bool = $Bool] - [], - { - #[uint(prim = u8 $(, scalar = $ScalarU8)?)] - $U8; - #[int(prim = i8 $(, scalar = $ScalarI8)?)] - $I8; - }, - [ - { - #[uint(prim = u16 $(, scalar = $ScalarU16)?)] - $U16; - #[int(prim = i16 $(, scalar = $ScalarI16)?)] - $I16; - #[float(prim = F16 $(, scalar = $ScalarF16)?)] - $F16; - }, - { - #[uint(prim = u32 $(, scalar = $ScalarU32)?)] - $U32; - #[int(prim = i32 $(, scalar = $ScalarI32)?)] - $I32; - #[float(prim = f32 $(, scalar = $ScalarF32)?)] - $F32; - }, - { - #[uint(prim = u64 $(, scalar = $ScalarU64)?)] - $U64; - #[int(prim = i64 $(, scalar = $ScalarI64)?)] - $I64; - #[float(prim = f64 $(, scalar = $ScalarF64)?)] - $F64; - }, - ] - } - }; -} - /// reference used to build IR for Kazan; an empty type for `core::simd` pub trait Context: Copy { - make_types! { - #[bool] - type Bool; - - #[u8] - type U8; - - #[u16] - type U16; - - #[u32] - type U32; - - #[u64] - type U64; - - #[i8] - type I8; - - #[i16] - type I16; - - #[i32] - type I32; - - #[i64] - type I64; - - #[f16] - type F16; - - #[f32] - type F32; - - #[f64] - type F64; - } - make_types! { - #[bool] - #[scalar = Bool] - type VecBool; - - #[u8] - #[scalar = U8] - type VecU8; - - #[u16] - #[scalar = U16] - type VecU16; - - #[u32] - #[scalar = U32] - type VecU32; - - #[u64] - #[scalar = U64] - type VecU64; - - #[i8] - #[scalar = I8] - type VecI8; - - #[i16] - #[scalar = I16] - type VecI16; - - #[i32] - #[scalar = I32] - type VecI32; - - #[i64] - #[scalar = I64] - type VecI64; - - #[f16] - #[scalar = F16] - type VecF16; - - #[f32] - #[scalar = F32] - type VecF32; - - #[f64] - #[scalar = F64] - type VecF64; - } + vector_math_proc_macro::make_context_types!(); fn make>(self, v: T::Prim) -> T { T::make(self, v) } @@ -421,37 +22,47 @@ pub trait Make: Copy { fn make(ctx: Self::Context, v: Self::Prim) -> Self; } +pub trait ConvertFrom: Sized { + fn cvt_from(v: T) -> Self; +} + +impl ConvertFrom for T { + fn cvt_from(v: T) -> Self { + v + } +} + pub trait ConvertTo { fn to(self) -> T; } -macro_rules! impl_convert_to_using_as { - ($($src:ident -> [$($dest:ident),*];)*) => { - $($( - impl ConvertTo<$dest> for $src { - fn to(self) -> $dest { - self as $dest +impl> ConvertTo for F { + fn to(self) -> T { + T::cvt_from(self) + } +} + +macro_rules! impl_convert_from_using_as { + ($first:ident $(, $ty:ident)*) => { + $( + impl ConvertFrom<$first> for $ty { + fn cvt_from(v: $first) -> Self { + v as _ } } - )*)* - }; - ([$($src:ident),*] -> $dest:tt;) => { - impl_convert_to_using_as! { - $( - $src -> $dest; - )* - } + impl ConvertFrom<$ty> for $first { + fn cvt_from(v: $ty) -> Self { + v as _ + } + } + )* + impl_convert_from_using_as![$($ty),*]; }; - ([$($src:ident),*];) => { - impl_convert_to_using_as! { - [$($src),*] -> [$($src),*]; - } + () => { }; } -impl_convert_to_using_as! { - [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; -} +impl_convert_from_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; pub trait Number: Compare @@ -507,13 +118,8 @@ impl BitOps for T where { } -pub trait Int: - Number - + BitOps - + Shl - + Shr - + ShlAssign - + ShrAssign +pub trait Int: + Number + BitOps + Shl + Shr + ShlAssign + ShrAssign { fn leading_zeros(self) -> Self; fn leading_ones(self) -> Self { @@ -529,76 +135,61 @@ pub trait Int: fn count_ones(self) -> Self; } -pub trait UInt: Int {} - -pub trait SInt: Int + Neg {} - -macro_rules! impl_int { - ($ty:ident) => { - impl Int for $ty { - fn leading_zeros(self) -> Self { - self.leading_zeros() as Self - } - fn leading_ones(self) -> Self { - self.leading_ones() as Self - } - fn trailing_zeros(self) -> Self { - self.trailing_zeros() as Self - } - fn trailing_ones(self) -> Self { - self.trailing_ones() as Self - } - fn count_zeros(self) -> Self { - self.count_zeros() as Self - } - fn count_ones(self) -> Self { - self.count_ones() as Self - } - } - }; -} - -macro_rules! impl_uint { - ($($ty:ident),*) => { - $( - impl_int!($ty); - impl UInt for $ty {} - )* - }; +pub trait UInt: Int + Make + ConvertFrom { + type PrimUInt: PrimUInt::PrimSInt>; + type SignedType: SInt + + ConvertFrom + + Make + + Compare; } -impl_uint![u8, u16, u32, u64]; - -macro_rules! impl_sint { - ($($ty:ident),*) => { - $( - impl_int!($ty); - impl SInt for $ty {} - )* - }; +pub trait SInt: + Int + Neg + Make + ConvertFrom +{ + type PrimSInt: PrimSInt::PrimUInt>; + type UnsignedType: UInt + + ConvertFrom + + Make + + Compare; } -impl_sint![i8, i16, i32, i64]; - -pub trait Float>: - Number + Neg +pub trait Float: + Number + + Neg + + Make + + ConvertFrom + + ConvertFrom { - type FloatEncoding: FloatEncoding + Make::Prim>; - type BitsType: UInt - + Make>::BitsType> - + ConvertTo - + Compare; - type SignedBitsType: SInt - + Make>::SignedBitsType> - + ConvertTo - + Compare; + type PrimFloat: PrimFloat; + type BitsType: UInt::BitsType, SignedType = Self::SignedBitsType> + + Make::BitsType> + + Compare + + ConvertFrom; + type SignedBitsType: SInt< + PrimSInt = ::SignedBitsType, + UnsignedType = Self::BitsType, + > + Make::SignedBitsType> + + Compare + + ConvertFrom; fn abs(self) -> Self; + fn copy_sign(self, sign: Self) -> Self { + crate::algorithms::base::copy_sign(self.ctx(), self, sign) + } fn trunc(self) -> Self; fn ceil(self) -> Self; fn floor(self) -> Self; + /// round to nearest integer, unspecified which way half-way cases are rounded fn round(self) -> Self; + /// returns `self * a + b` but only rounding once #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self; + /// returns `self * a + b` either using `fma` or `self * a + b` + fn mul_add_fast(self, a: Self, b: Self) -> Self { + #[cfg(feature = "fma")] + return self.fma(a, b); + #[cfg(not(feature = "fma"))] + return self * a + b; + } fn is_nan(self) -> Self::Bool { self.ne(self) } @@ -606,127 +197,76 @@ pub trait Float>: self.abs().eq(Self::infinity(self.ctx())) } fn infinity(ctx: Self::Context) -> Self { - Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS)) + Self::from_bits(ctx.make(Self::PrimFloat::INFINITY_BITS)) } fn nan(ctx: Self::Context) -> Self { - Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS)) + Self::from_bits(ctx.make(Self::PrimFloat::NAN_BITS)) } fn is_finite(self) -> Self::Bool; fn is_zero_or_subnormal(self) -> Self::Bool { - self.extract_exponent_field().eq(self - .ctx() - .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT)) + self.extract_exponent_field() + .eq(self.ctx().make(Self::PrimFloat::ZERO_SUBNORMAL_EXPONENT)) } fn from_bits(v: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; fn extract_exponent_field(self) -> Self::BitsType { - let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK); - let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT); + let mask = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_MASK); + let shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); (self.to_bits() & mask) >> shift } fn extract_exponent_unbiased(self) -> Self::SignedBitsType { Self::sub_exponent_bias(self.extract_exponent_field()) } fn extract_mantissa_field(self) -> Self::BitsType { - let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK); + let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK); self.to_bits() & mask } + fn is_sign_negative(self) -> Self::Bool { + let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK); + self.ctx() + .make::(0.to()) + .ne(self.to_bits() & mask) + } + fn is_sign_positive(self) -> Self::Bool { + let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK); + self.ctx() + .make::(0.to()) + .eq(self.to_bits() & mask) + } + fn extract_sign_field(self) -> Self::BitsType { + let shift = self.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT); + self.to_bits() >> shift + } + fn from_fields( + sign_field: Self::BitsType, + exponent_field: Self::BitsType, + mantissa_field: Self::BitsType, + ) -> Self { + let sign_shift = sign_field.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT); + let exponent_shift = sign_field.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); + Self::from_bits( + (sign_field << sign_shift) | (exponent_field << exponent_shift) | mantissa_field, + ) + } fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType { - exponent_field.to() + Self::SignedBitsType::cvt_from(exponent_field) - exponent_field .ctx() - .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED) + .make(Self::PrimFloat::EXPONENT_BIAS_SIGNED) } fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType { - (exponent - + exponent - .ctx() - .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)) - .to() + (exponent + exponent.ctx().make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)).to() } } -macro_rules! impl_float { - ($ty:ty, $bits:ty, $signed_bits:ty) => { - impl Float for $ty { - type FloatEncoding = $ty; - type BitsType = $bits; - type SignedBitsType = $signed_bits; - fn abs(self) -> Self { - #[cfg(feature = "std")] - return self.abs(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn trunc(self) -> Self { - #[cfg(feature = "std")] - return self.trunc(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn ceil(self) -> Self { - #[cfg(feature = "std")] - return self.ceil(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn floor(self) -> Self { - #[cfg(feature = "std")] - return self.floor(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn round(self) -> Self { - #[cfg(feature = "std")] - return self.round(); - #[cfg(not(feature = "std"))] - todo!(); - } - #[cfg(feature = "fma")] - fn fma(self, a: Self, b: Self) -> Self { - self.mul_add(a, b) - } - fn is_nan(self) -> Self::Bool { - self.is_nan() - } - fn is_infinite(self) -> Self::Bool { - self.is_infinite() - } - fn is_finite(self) -> Self::Bool { - self.is_finite() - } - fn from_bits(v: Self::BitsType) -> Self { - <$ty>::from_bits(v) - } - fn to_bits(self) -> Self::BitsType { - self.to_bits() - } - } - }; -} - -impl_float!(f32, u32, i32); -impl_float!(f64, u64, i64); - -pub trait Bool: Make + BitOps {} +pub trait Bool: Make + BitOps + Select {} -impl Bool for bool {} - -pub trait Select: Bool { +pub trait Select { fn select(self, true_v: T, false_v: T) -> T; } -impl Select for bool { - fn select(self, true_v: T, false_v: T) -> T { - if self { - true_v - } else { - false_v - } - } -} pub trait Compare: Make { - type Bool: Bool + Select; + type Bool: Bool + Select + Make; fn eq(self, rhs: Self) -> Self::Bool; fn ne(self, rhs: Self) -> Self::Bool; fn lt(self, rhs: Self) -> Self::Bool; @@ -734,33 +274,3 @@ pub trait Compare: Make { fn le(self, rhs: Self) -> Self::Bool; fn ge(self, rhs: Self) -> Self::Bool; } - -macro_rules! impl_compare_using_partial_cmp { - ($($ty:ty),*) => { - $( - impl Compare for $ty { - type Bool = bool; - fn eq(self, rhs: Self) -> Self::Bool { - self == rhs - } - fn ne(self, rhs: Self) -> Self::Bool { - self != rhs - } - fn lt(self, rhs: Self) -> Self::Bool { - self < rhs - } - fn gt(self, rhs: Self) -> Self::Bool { - self > rhs - } - fn le(self, rhs: Self) -> Self::Bool { - self <= rhs - } - fn ge(self, rhs: Self) -> Self::Bool { - self >= rhs - } - } - )* - }; -} - -impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];