X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Ftraits.rs;h=c02b60919cddc332266ce23b36318f2a45ec9e76;hb=ebb72b1eb96f918ebc34fb74221562c4cf65365b;hp=c1d8095f2ca78bcde9e13ee19132afd676b25448;hpb=bb3868c63a5e0fb3b7383bac22253f3387ab5dbb;p=vector-math.git diff --git a/src/traits.rs b/src/traits.rs index c1d8095..c02b609 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,423 +1,69 @@ +use crate::{ + f16::F16, + prim::{PrimFloat, PrimSInt, PrimUInt}, +}; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; -use crate::f16; - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_float_type { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [ - $({ - #[uint] - $uint_smaller:ident; - #[int] - $int_smaller:ident; - $( - #[float] - $float_smaller:ident; - )? - },)* - ], - { - #[uint] - $uint:ident; - #[int] - $int:ident; - #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)] - $float:ident; - }, - [ - $({ - #[uint] - $uint_larger:ident; - #[int] - $int_larger:ident; - $( - #[float] - $float_larger:ident; - )? - },)* - ] - ) => { - type $float: Float - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - + ConvertTo - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ Into)?)*; - }; - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - { - #[uint] - $uint:ident; - #[int] - $int:ident; - }, - [$($larger:tt,)*] - ) => {}; -} - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_uint_int_float_type { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [ - $({ - #[uint($($uint_smaller_traits:tt)*)] - $uint_smaller:ident; - #[int($($int_smaller_traits:tt)*)] - $int_smaller:ident; - $( - #[float($($float_smaller_traits:tt)*)] - $float_smaller:ident; - )? - },)* - ], - { - #[uint(prim = $uint_prim:ident $(, scalar = $uint_scalar:ident)?)] - $uint:ident; - #[int(prim = $int_prim:ident $(, scalar = $int_scalar:ident)?)] - $int:ident; - $( - #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)] - $float:ident; - )? - }, - [ - $({ - #[uint($($uint_larger_traits:tt)*)] - $uint_larger:ident; - #[int($($int_larger_traits:tt)*)] - $int_larger:ident; - $( - #[float($($float_larger_traits:tt)*)] - $float_larger:ident; - )? - },)* - ] - ) => { - type $uint: UInt - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - $(+ ConvertTo)? - $(+ Into)* - $(+ Into)* - $($(+ Into)?)*; - type $int: SInt - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - $(+ ConvertTo)? - $(+ ConvertTo)* - $(+ Into)* - $($(+ Into)?)*; - make_float_type! { - #[u32 = $u32] - #[bool = $bool] - [ - $({ - #[uint] - $uint_smaller; - #[int] - $int_smaller; - $( - #[float] - $float_smaller; - )? - },)* - ], - { - #[uint] - $uint; - #[int] - $int; - $( - #[float(prim = $float_prim $(, scalar = $float_scalar)?)] - $float; - )? - }, - [ - $({ - #[uint] - $uint_larger; - #[int] - $int_larger; - $( - #[float] - $float_larger; - )? - },)* - ] - } - }; -} - -macro_rules! make_uint_int_float_types { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - $current:tt, - [$first_larger:tt, $($larger:tt,)*] - ) => { - make_uint_int_float_type! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)*], - $current, - [$first_larger, $($larger,)*] - } - make_uint_int_float_types! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)* $current,], - $first_larger, - [$($larger,)*] - } - }; - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - $current:tt, - [] - ) => { - make_uint_int_float_type! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)*], - $current, - [] - } - }; -} - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_types { - ( - #[bool] - $(#[scalar = $ScalarBool:ident])? - type $Bool:ident; - - #[u8] - $(#[scalar = $ScalarU8:ident])? - type $U8:ident; - - #[u16] - $(#[scalar = $ScalarU16:ident])? - type $U16:ident; - - #[u32] - $(#[scalar = $ScalarU32:ident])? - type $U32:ident; - - #[u64] - $(#[scalar = $ScalarU64:ident])? - type $U64:ident; - - #[i8] - $(#[scalar = $ScalarI8:ident])? - type $I8:ident; - - #[i16] - $(#[scalar = $ScalarI16:ident])? - type $I16:ident; - - #[i32] - $(#[scalar = $ScalarI32:ident])? - type $I32:ident; - - #[i64] - $(#[scalar = $ScalarI64:ident])? - type $I64:ident; - - #[f16] - $(#[scalar = $ScalarF16:ident])? - type $F16:ident; - - #[f32] - $(#[scalar = $ScalarF32:ident])? - type $F32:ident; - - #[f64] - $(#[scalar = $ScalarF64:ident])? - type $F64:ident; - ) => { - type $Bool: Bool - $(+ From)? - + Make - + Select; - make_uint_int_float_types! { - #[u32 = $U32] - #[bool = $Bool] - [], - { - #[uint(prim = u8 $(, scalar = $ScalarU8)?)] - $U8; - #[int(prim = i8 $(, scalar = $ScalarI8)?)] - $I8; - }, - [ - { - #[uint(prim = u16 $(, scalar = $ScalarU16)?)] - $U16; - #[int(prim = i16 $(, scalar = $ScalarI16)?)] - $I16; - #[float(prim = f16 $(, scalar = $ScalarF16)?)] - $F16; - }, - { - #[uint(prim = u32 $(, scalar = $ScalarU32)?)] - $U32; - #[int(prim = i32 $(, scalar = $ScalarI32)?)] - $I32; - #[float(prim = f32 $(, scalar = $ScalarF32)?)] - $F32; - }, - { - #[uint(prim = u64 $(, scalar = $ScalarU64)?)] - $U64; - #[int(prim = i64 $(, scalar = $ScalarI64)?)] - $I64; - #[float(prim = f64 $(, scalar = $ScalarF64)?)] - $F64; - }, - ] - } - }; -} - /// reference used to build IR for Kazan; an empty type for `core::simd` pub trait Context: Copy { - make_types! { - #[bool] - type Bool; - - #[u8] - type U8; - - #[u16] - type U16; - - #[u32] - type U32; - - #[u64] - type U64; - - #[i8] - type I8; - - #[i16] - type I16; - - #[i32] - type I32; - - #[i64] - type I64; - - #[f16] - type F16; - - #[f32] - type F32; - - #[f64] - type F64; + vector_math_proc_macro::make_context_types!(); + fn make>(self, v: T::Prim) -> T { + T::make(self, v) } - make_types! { - #[bool] - #[scalar = Bool] - type VecBool; - - #[u8] - #[scalar = U8] - type VecU8; - - #[u16] - #[scalar = U16] - type VecU16; - - #[u32] - #[scalar = U32] - type VecU32; - - #[u64] - #[scalar = U64] - type VecU64; - - #[i8] - #[scalar = I8] - type VecI8; - - #[i16] - #[scalar = I16] - type VecI16; - - #[i32] - #[scalar = I32] - type VecI32; - - #[i64] - #[scalar = I64] - type VecI64; - - #[f16] - #[scalar = F16] - type VecF16; +} - #[f32] - #[scalar = F32] - type VecF32; +pub trait Make: Copy { + type Prim: Copy; + type Context: Context; + fn ctx(self) -> Self::Context; + fn make(ctx: Self::Context, v: Self::Prim) -> Self; +} - #[f64] - #[scalar = F64] - type VecF64; - } - fn make>(self, v: T::Prim) -> T { - T::make(self, v) - } +pub trait ConvertFrom: Sized { + fn cvt_from(v: T) -> Self; } -pub trait Make: Sized { - type Prim; - fn make(ctx: Context, v: Self::Prim) -> Self; +impl ConvertFrom for T { + fn cvt_from(v: T) -> Self { + v + } } pub trait ConvertTo { fn to(self) -> T; } -impl> ConvertTo for U { +impl> ConvertTo for F { fn to(self) -> T { - self.into() + T::cvt_from(self) } } +macro_rules! impl_convert_from_using_as { + ($first:ident $(, $ty:ident)*) => { + $( + impl ConvertFrom<$first> for $ty { + fn cvt_from(v: $first) -> Self { + v as _ + } + } + impl ConvertFrom<$ty> for $first { + fn cvt_from(v: $ty) -> Self { + v as _ + } + } + )* + impl_convert_from_using_as![$($ty),*]; + }; + () => { + }; +} + +impl_convert_from_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; + pub trait Number: Compare + Add @@ -433,6 +79,21 @@ pub trait Number: { } +impl Number for T where + T: Compare + + Add + + Sub + + Mul + + Div + + Rem + + AddAssign + + SubAssign + + MulAssign + + DivAssign + + RemAssign +{ +} + pub trait BitOps: Copy + BitAnd @@ -445,43 +106,167 @@ pub trait BitOps: { } -pub trait Int: - Number - + BitOps - + Shl - + Shr - + ShlAssign - + ShrAssign +impl BitOps for T where + T: Copy + + BitAnd + + BitOr + + BitXor + + Not + + BitAndAssign + + BitOrAssign + + BitXorAssign { } -pub trait UInt: Int {} +pub trait Int: + Number + BitOps + Shl + Shr + ShlAssign + ShrAssign +{ + fn leading_zeros(self) -> Self; + fn leading_ones(self) -> Self { + self.not().leading_zeros() + } + fn trailing_zeros(self) -> Self; + fn trailing_ones(self) -> Self { + self.not().trailing_zeros() + } + fn count_zeros(self) -> Self { + self.not().count_ones() + } + fn count_ones(self) -> Self; +} + +pub trait UInt: Int + Make + ConvertFrom { + type PrimUInt: PrimUInt::PrimSInt>; + type SignedType: SInt + + ConvertFrom + + Make + + Compare; +} -pub trait SInt: Int + Neg {} +pub trait SInt: + Int + Neg + Make + ConvertFrom +{ + type PrimSInt: PrimSInt::PrimUInt>; + type UnsignedType: UInt + + ConvertFrom + + Make + + Compare; +} -pub trait Float: Number + Neg { - type BitsType: UInt; +pub trait Float: + Number + + Neg + + Make + + ConvertFrom + + ConvertFrom +{ + type PrimFloat: PrimFloat; + type BitsType: UInt::BitsType, SignedType = Self::SignedBitsType> + + Make::BitsType> + + Compare + + ConvertFrom; + type SignedBitsType: SInt< + PrimSInt = ::SignedBitsType, + UnsignedType = Self::BitsType, + > + Make::SignedBitsType> + + Compare + + ConvertFrom; fn abs(self) -> Self; + fn copy_sign(self, sign: Self) -> Self { + crate::algorithms::base::copy_sign(self.ctx(), self, sign) + } fn trunc(self) -> Self; fn ceil(self) -> Self; fn floor(self) -> Self; + /// round to nearest integer, unspecified which way half-way cases are rounded fn round(self) -> Self; + /// returns `self * a + b` but only rounding once + #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self; - fn is_nan(self) -> Self::Bool; - fn is_infinity(self) -> Self::Bool; + /// returns `self * a + b` either using `fma` or `self * a + b` + fn mul_add_fast(self, a: Self, b: Self) -> Self { + #[cfg(feature = "fma")] + return self.fma(a, b); + #[cfg(not(feature = "fma"))] + return self * a + b; + } + fn is_nan(self) -> Self::Bool { + self.ne(self) + } + fn is_infinite(self) -> Self::Bool { + self.abs().eq(Self::infinity(self.ctx())) + } + fn infinity(ctx: Self::Context) -> Self { + Self::from_bits(ctx.make(Self::PrimFloat::INFINITY_BITS)) + } + fn nan(ctx: Self::Context) -> Self { + Self::from_bits(ctx.make(Self::PrimFloat::NAN_BITS)) + } fn is_finite(self) -> Self::Bool; + fn is_zero_or_subnormal(self) -> Self::Bool { + self.extract_exponent_field() + .eq(self.ctx().make(Self::PrimFloat::ZERO_SUBNORMAL_EXPONENT)) + } fn from_bits(v: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; + fn extract_exponent_field(self) -> Self::BitsType { + let mask = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_MASK); + let shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); + (self.to_bits() & mask) >> shift + } + fn extract_exponent_unbiased(self) -> Self::SignedBitsType { + Self::sub_exponent_bias(self.extract_exponent_field()) + } + fn extract_mantissa_field(self) -> Self::BitsType { + let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK); + self.to_bits() & mask + } + fn is_sign_negative(self) -> Self::Bool { + let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK); + self.ctx() + .make::(0.to()) + .ne(self.to_bits() & mask) + } + fn is_sign_positive(self) -> Self::Bool { + let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK); + self.ctx() + .make::(0.to()) + .eq(self.to_bits() & mask) + } + fn extract_sign_field(self) -> Self::BitsType { + let shift = self.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT); + self.to_bits() >> shift + } + fn from_fields( + sign_field: Self::BitsType, + exponent_field: Self::BitsType, + mantissa_field: Self::BitsType, + ) -> Self { + let sign_shift = sign_field.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT); + let exponent_shift = sign_field.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); + Self::from_bits( + (sign_field << sign_shift) | (exponent_field << exponent_shift) | mantissa_field, + ) + } + fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType { + Self::SignedBitsType::cvt_from(exponent_field) + - exponent_field + .ctx() + .make(Self::PrimFloat::EXPONENT_BIAS_SIGNED) + } + fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType { + (exponent + exponent.ctx().make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)).to() + } } -pub trait Bool: BitOps {} +pub trait Bool: Make + BitOps + Select {} -pub trait Select: Bool { +pub trait Select { fn select(self, true_v: T, false_v: T) -> T; } -pub trait Compare: Copy { - type Bool: Bool + Select; +pub trait Compare: Make { + type Bool: Bool + Select + Make; fn eq(self, rhs: Self) -> Self::Bool; fn ne(self, rhs: Self) -> Self::Bool; fn lt(self, rhs: Self) -> Self::Bool;