add round_to_nearest_ties_to_even
[vector-math.git] / src / traits.rs
1 use crate::{
2 f16::F16,
3 prim::{PrimFloat, PrimSInt, PrimUInt},
4 };
5 use core::ops::{
6 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
7 Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
8 };
9
10 /// reference used to build IR for Kazan; an empty type for `core::simd`
11 pub trait Context: Copy {
12 vector_math_proc_macro::make_context_types!();
13 fn make<T: Make<Context = Self>>(self, v: T::Prim) -> T {
14 T::make(self, v)
15 }
16 }
17
18 pub trait Make: Copy {
19 type Prim: Copy;
20 type Context: Context;
21 fn ctx(self) -> Self::Context;
22 fn make(ctx: Self::Context, v: Self::Prim) -> Self;
23 }
24
25 pub trait ConvertFrom<T>: Sized {
26 fn cvt_from(v: T) -> Self;
27 }
28
29 impl<T> ConvertFrom<T> for T {
30 fn cvt_from(v: T) -> Self {
31 v
32 }
33 }
34
35 pub trait ConvertTo<T> {
36 fn to(self) -> T;
37 }
38
39 impl<F, T: ConvertFrom<F>> ConvertTo<T> for F {
40 fn to(self) -> T {
41 T::cvt_from(self)
42 }
43 }
44
45 macro_rules! impl_convert_from_using_as {
46 ($first:ident $(, $ty:ident)*) => {
47 $(
48 impl ConvertFrom<$first> for $ty {
49 fn cvt_from(v: $first) -> Self {
50 v as _
51 }
52 }
53 impl ConvertFrom<$ty> for $first {
54 fn cvt_from(v: $ty) -> Self {
55 v as _
56 }
57 }
58 )*
59 impl_convert_from_using_as![$($ty),*];
60 };
61 () => {
62 };
63 }
64
65 impl_convert_from_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
66
67 pub trait Number:
68 Compare
69 + Add<Output = Self>
70 + Sub<Output = Self>
71 + Mul<Output = Self>
72 + Div<Output = Self>
73 + Rem<Output = Self>
74 + AddAssign
75 + SubAssign
76 + MulAssign
77 + DivAssign
78 + RemAssign
79 {
80 }
81
82 impl<T> Number for T where
83 T: Compare
84 + Add<Output = Self>
85 + Sub<Output = Self>
86 + Mul<Output = Self>
87 + Div<Output = Self>
88 + Rem<Output = Self>
89 + AddAssign
90 + SubAssign
91 + MulAssign
92 + DivAssign
93 + RemAssign
94 {
95 }
96
97 pub trait BitOps:
98 Copy
99 + BitAnd<Output = Self>
100 + BitOr<Output = Self>
101 + BitXor<Output = Self>
102 + Not<Output = Self>
103 + BitAndAssign
104 + BitOrAssign
105 + BitXorAssign
106 {
107 }
108
109 impl<T> BitOps for T where
110 T: Copy
111 + BitAnd<Output = Self>
112 + BitOr<Output = Self>
113 + BitXor<Output = Self>
114 + Not<Output = Self>
115 + BitAndAssign
116 + BitOrAssign
117 + BitXorAssign
118 {
119 }
120
121 pub trait Int:
122 Number + BitOps + Shl<Output = Self> + Shr<Output = Self> + ShlAssign + ShrAssign
123 {
124 fn leading_zeros(self) -> Self;
125 fn leading_ones(self) -> Self {
126 self.not().leading_zeros()
127 }
128 fn trailing_zeros(self) -> Self;
129 fn trailing_ones(self) -> Self {
130 self.not().trailing_zeros()
131 }
132 fn count_zeros(self) -> Self {
133 self.not().count_ones()
134 }
135 fn count_ones(self) -> Self;
136 }
137
138 pub trait UInt: Int + Make<Prim = Self::PrimUInt> + ConvertFrom<Self::SignedType> {
139 type PrimUInt: PrimUInt<SignedType = <Self::SignedType as SInt>::PrimSInt>;
140 type SignedType: SInt
141 + ConvertFrom<Self>
142 + Make<Context = Self::Context>
143 + Compare<Bool = Self::Bool>;
144 }
145
146 pub trait SInt:
147 Int + Neg<Output = Self> + Make<Prim = Self::PrimSInt> + ConvertFrom<Self::UnsignedType>
148 {
149 type PrimSInt: PrimSInt<UnsignedType = <Self::UnsignedType as UInt>::PrimUInt>;
150 type UnsignedType: UInt
151 + ConvertFrom<Self>
152 + Make<Context = Self::Context>
153 + Compare<Bool = Self::Bool>;
154 }
155
156 pub trait Float:
157 Number
158 + Neg<Output = Self>
159 + Make<Prim = Self::PrimFloat>
160 + ConvertFrom<Self::SignedBitsType>
161 + ConvertFrom<Self::BitsType>
162 {
163 type PrimFloat: PrimFloat;
164 type BitsType: UInt<PrimUInt = <Self::PrimFloat as PrimFloat>::BitsType, SignedType = Self::SignedBitsType>
165 + Make<Context = Self::Context, Prim = <Self::PrimFloat as PrimFloat>::BitsType>
166 + Compare<Bool = Self::Bool>
167 + ConvertFrom<Self>;
168 type SignedBitsType: SInt<
169 PrimSInt = <Self::PrimFloat as PrimFloat>::SignedBitsType,
170 UnsignedType = Self::BitsType,
171 > + Make<Context = Self::Context, Prim = <Self::PrimFloat as PrimFloat>::SignedBitsType>
172 + Compare<Bool = Self::Bool>
173 + ConvertFrom<Self>;
174 fn abs(self) -> Self;
175 fn copy_sign(self, sign: Self) -> Self {
176 crate::algorithms::base::copy_sign(self.ctx(), self, sign)
177 }
178 fn trunc(self) -> Self;
179 fn ceil(self) -> Self;
180 fn floor(self) -> Self;
181 /// round to nearest integer, unspecified which way half-way cases are rounded
182 fn round(self) -> Self;
183 /// returns `self * a + b` but only rounding once
184 #[cfg(feature = "fma")]
185 fn fma(self, a: Self, b: Self) -> Self;
186 /// returns `self * a + b` either using `fma` or `self * a + b`
187 fn mul_add_fast(self, a: Self, b: Self) -> Self {
188 #[cfg(feature = "fma")]
189 return self.fma(a, b);
190 #[cfg(not(feature = "fma"))]
191 return self * a + b;
192 }
193 fn is_nan(self) -> Self::Bool {
194 self.ne(self)
195 }
196 fn is_infinite(self) -> Self::Bool {
197 self.abs().eq(Self::infinity(self.ctx()))
198 }
199 fn infinity(ctx: Self::Context) -> Self {
200 Self::from_bits(ctx.make(Self::PrimFloat::INFINITY_BITS))
201 }
202 fn nan(ctx: Self::Context) -> Self {
203 Self::from_bits(ctx.make(Self::PrimFloat::NAN_BITS))
204 }
205 fn is_finite(self) -> Self::Bool;
206 fn is_zero_or_subnormal(self) -> Self::Bool {
207 self.extract_exponent_field()
208 .eq(self.ctx().make(Self::PrimFloat::ZERO_SUBNORMAL_EXPONENT))
209 }
210 fn from_bits(v: Self::BitsType) -> Self;
211 fn to_bits(self) -> Self::BitsType;
212 fn extract_exponent_field(self) -> Self::BitsType {
213 let mask = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_MASK);
214 let shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT);
215 (self.to_bits() & mask) >> shift
216 }
217 fn extract_exponent_unbiased(self) -> Self::SignedBitsType {
218 Self::sub_exponent_bias(self.extract_exponent_field())
219 }
220 fn extract_mantissa_field(self) -> Self::BitsType {
221 let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK);
222 self.to_bits() & mask
223 }
224 fn is_sign_negative(self) -> Self::Bool {
225 let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK);
226 self.ctx()
227 .make::<Self::BitsType>(0.to())
228 .ne(self.to_bits() & mask)
229 }
230 fn is_sign_positive(self) -> Self::Bool {
231 let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK);
232 self.ctx()
233 .make::<Self::BitsType>(0.to())
234 .eq(self.to_bits() & mask)
235 }
236 fn extract_sign_field(self) -> Self::BitsType {
237 let shift = self.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT);
238 self.to_bits() >> shift
239 }
240 fn from_fields(
241 sign_field: Self::BitsType,
242 exponent_field: Self::BitsType,
243 mantissa_field: Self::BitsType,
244 ) -> Self {
245 let sign_shift = sign_field.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT);
246 let exponent_shift = sign_field.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT);
247 Self::from_bits(
248 (sign_field << sign_shift) | (exponent_field << exponent_shift) | mantissa_field,
249 )
250 }
251 fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType {
252 Self::SignedBitsType::cvt_from(exponent_field)
253 - exponent_field
254 .ctx()
255 .make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)
256 }
257 fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType {
258 (exponent + exponent.ctx().make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)).to()
259 }
260 }
261
262 pub trait Bool: Make<Prim = bool> + BitOps + Select<Self> {}
263
264 pub trait Select<T> {
265 fn select(self, true_v: T, false_v: T) -> T;
266 }
267
268 pub trait Compare: Make {
269 type Bool: Bool + Select<Self> + Make<Context = Self::Context>;
270 fn eq(self, rhs: Self) -> Self::Bool;
271 fn ne(self, rhs: Self) -> Self::Bool;
272 fn lt(self, rhs: Self) -> Self::Bool;
273 fn gt(self, rhs: Self) -> Self::Bool;
274 fn le(self, rhs: Self) -> Self::Bool;
275 fn ge(self, rhs: Self) -> Self::Bool;
276 }