6555016d98793e1c70b8715fe85e5f3f3a5878cf
[vector-math.git] / src / traits.rs
1 use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar};
2 use core::ops::{
3 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
4 Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
5 };
6
7 /// reference used to build IR for Kazan; an empty type for `core::simd`
8 pub trait Context: Copy {
9 vector_math_proc_macro::make_context_types!();
10 fn make<T: Make<Context = Self>>(self, v: T::Prim) -> T {
11 T::make(self, v)
12 }
13 }
14
15 pub trait Make: Copy {
16 type Prim: Copy;
17 type Context: Context;
18 fn ctx(self) -> Self::Context;
19 fn make(ctx: Self::Context, v: Self::Prim) -> Self;
20 }
21
22 pub trait ConvertTo<T> {
23 fn to(self) -> T;
24 }
25
26 impl<T> ConvertTo<T> for T {
27 fn to(self) -> T {
28 self
29 }
30 }
31
32 macro_rules! impl_convert_to_using_as {
33 ($first:ident $(, $ty:ident)*) => {
34 $(
35 impl ConvertTo<$first> for $ty {
36 fn to(self) -> $first {
37 self as $first
38 }
39 }
40 impl ConvertTo<$ty> for $first {
41 fn to(self) -> $ty {
42 self as $ty
43 }
44 }
45 )*
46 impl_convert_to_using_as![$($ty),*];
47 };
48 () => {
49 };
50 }
51
52 impl_convert_to_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
53
54 pub trait Number:
55 Compare
56 + Add<Output = Self>
57 + Sub<Output = Self>
58 + Mul<Output = Self>
59 + Div<Output = Self>
60 + Rem<Output = Self>
61 + AddAssign
62 + SubAssign
63 + MulAssign
64 + DivAssign
65 + RemAssign
66 {
67 }
68
69 impl<T> Number for T where
70 T: Compare
71 + Add<Output = Self>
72 + Sub<Output = Self>
73 + Mul<Output = Self>
74 + Div<Output = Self>
75 + Rem<Output = Self>
76 + AddAssign
77 + SubAssign
78 + MulAssign
79 + DivAssign
80 + RemAssign
81 {
82 }
83
84 pub trait BitOps:
85 Copy
86 + BitAnd<Output = Self>
87 + BitOr<Output = Self>
88 + BitXor<Output = Self>
89 + Not<Output = Self>
90 + BitAndAssign
91 + BitOrAssign
92 + BitXorAssign
93 {
94 }
95
96 impl<T> BitOps for T where
97 T: Copy
98 + BitAnd<Output = Self>
99 + BitOr<Output = Self>
100 + BitXor<Output = Self>
101 + Not<Output = Self>
102 + BitAndAssign
103 + BitOrAssign
104 + BitXorAssign
105 {
106 }
107
108 pub trait Int:
109 Number + BitOps + Shl<Output = Self> + Shr<Output = Self> + ShlAssign + ShrAssign
110 {
111 fn leading_zeros(self) -> Self;
112 fn leading_ones(self) -> Self {
113 self.not().leading_zeros()
114 }
115 fn trailing_zeros(self) -> Self;
116 fn trailing_ones(self) -> Self {
117 self.not().trailing_zeros()
118 }
119 fn count_zeros(self) -> Self {
120 self.not().count_ones()
121 }
122 fn count_ones(self) -> Self;
123 }
124
125 pub trait UInt: Int {}
126
127 pub trait SInt: Int + Neg<Output = Self> {}
128
129 macro_rules! impl_int {
130 ($ty:ident) => {
131 impl Int for $ty {
132 fn leading_zeros(self) -> Self {
133 self.leading_zeros() as Self
134 }
135 fn leading_ones(self) -> Self {
136 self.leading_ones() as Self
137 }
138 fn trailing_zeros(self) -> Self {
139 self.trailing_zeros() as Self
140 }
141 fn trailing_ones(self) -> Self {
142 self.trailing_ones() as Self
143 }
144 fn count_zeros(self) -> Self {
145 self.count_zeros() as Self
146 }
147 fn count_ones(self) -> Self {
148 self.count_ones() as Self
149 }
150 }
151 };
152 }
153
154 macro_rules! impl_uint {
155 ($($ty:ident),*) => {
156 $(
157 impl_int!($ty);
158 impl UInt for $ty {}
159 )*
160 };
161 }
162
163 impl_uint![u8, u16, u32, u64];
164
165 macro_rules! impl_sint {
166 ($($ty:ident),*) => {
167 $(
168 impl_int!($ty);
169 impl SInt for $ty {}
170 )*
171 };
172 }
173
174 impl_sint![i8, i16, i32, i64];
175
176 pub trait Float: Number + Neg<Output = Self> {
177 type FloatEncoding: FloatEncoding + Make<Context = Scalar, Prim = <Self as Make>::Prim>;
178 type BitsType: UInt
179 + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float>::BitsType>
180 + ConvertTo<Self::SignedBitsType>
181 + Compare<Bool = Self::Bool>;
182 type SignedBitsType: SInt
183 + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float>::SignedBitsType>
184 + ConvertTo<Self::BitsType>
185 + Compare<Bool = Self::Bool>;
186 fn abs(self) -> Self;
187 fn trunc(self) -> Self;
188 fn ceil(self) -> Self;
189 fn floor(self) -> Self;
190 /// round to nearest integer, unspecified which way half-way cases are rounded
191 fn round(self) -> Self;
192 /// returns `self * a + b` but only rounding once
193 #[cfg(feature = "fma")]
194 fn fma(self, a: Self, b: Self) -> Self;
195 /// returns `self * a + b` either using `fma` or `self * a + b`
196 fn mul_add_fast(self, a: Self, b: Self) -> Self {
197 #[cfg(feature = "fma")]
198 return self.fma(a, b);
199 #[cfg(not(feature = "fma"))]
200 return self * a + b;
201 }
202 fn is_nan(self) -> Self::Bool {
203 self.ne(self)
204 }
205 fn is_infinite(self) -> Self::Bool {
206 self.abs().eq(Self::infinity(self.ctx()))
207 }
208 fn infinity(ctx: Self::Context) -> Self {
209 Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS))
210 }
211 fn nan(ctx: Self::Context) -> Self {
212 Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS))
213 }
214 fn is_finite(self) -> Self::Bool;
215 fn is_zero_or_subnormal(self) -> Self::Bool {
216 self.extract_exponent_field().eq(self
217 .ctx()
218 .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT))
219 }
220 fn from_bits(v: Self::BitsType) -> Self;
221 fn to_bits(self) -> Self::BitsType;
222 fn extract_exponent_field(self) -> Self::BitsType {
223 let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK);
224 let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT);
225 (self.to_bits() & mask) >> shift
226 }
227 fn extract_exponent_unbiased(self) -> Self::SignedBitsType {
228 Self::sub_exponent_bias(self.extract_exponent_field())
229 }
230 fn extract_mantissa_field(self) -> Self::BitsType {
231 let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK);
232 self.to_bits() & mask
233 }
234 fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType {
235 exponent_field.to()
236 - exponent_field
237 .ctx()
238 .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)
239 }
240 fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType {
241 (exponent
242 + exponent
243 .ctx()
244 .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED))
245 .to()
246 }
247 }
248
249 macro_rules! impl_float {
250 ($ty:ty, $bits:ty, $signed_bits:ty) => {
251 impl Float for $ty {
252 type FloatEncoding = $ty;
253 type BitsType = $bits;
254 type SignedBitsType = $signed_bits;
255 fn abs(self) -> Self {
256 #[cfg(feature = "std")]
257 return self.abs();
258 #[cfg(not(feature = "std"))]
259 todo!();
260 }
261 fn trunc(self) -> Self {
262 #[cfg(feature = "std")]
263 return self.trunc();
264 #[cfg(not(feature = "std"))]
265 todo!();
266 }
267 fn ceil(self) -> Self {
268 #[cfg(feature = "std")]
269 return self.ceil();
270 #[cfg(not(feature = "std"))]
271 todo!();
272 }
273 fn floor(self) -> Self {
274 #[cfg(feature = "std")]
275 return self.floor();
276 #[cfg(not(feature = "std"))]
277 todo!();
278 }
279 fn round(self) -> Self {
280 #[cfg(feature = "std")]
281 return self.round();
282 #[cfg(not(feature = "std"))]
283 todo!();
284 }
285 #[cfg(feature = "fma")]
286 fn fma(self, a: Self, b: Self) -> Self {
287 self.mul_add(a, b)
288 }
289 fn is_nan(self) -> Self::Bool {
290 self.is_nan()
291 }
292 fn is_infinite(self) -> Self::Bool {
293 self.is_infinite()
294 }
295 fn is_finite(self) -> Self::Bool {
296 self.is_finite()
297 }
298 fn from_bits(v: Self::BitsType) -> Self {
299 <$ty>::from_bits(v)
300 }
301 fn to_bits(self) -> Self::BitsType {
302 self.to_bits()
303 }
304 }
305 };
306 }
307
308 impl_float!(f32, u32, i32);
309 impl_float!(f64, u64, i64);
310
311 pub trait Bool: Make + BitOps {}
312
313 impl Bool for bool {}
314
315 pub trait Select<T>: Bool {
316 fn select(self, true_v: T, false_v: T) -> T;
317 }
318
319 impl<T> Select<T> for bool {
320 fn select(self, true_v: T, false_v: T) -> T {
321 if self {
322 true_v
323 } else {
324 false_v
325 }
326 }
327 }
328 pub trait Compare: Make {
329 type Bool: Bool + Select<Self>;
330 fn eq(self, rhs: Self) -> Self::Bool;
331 fn ne(self, rhs: Self) -> Self::Bool;
332 fn lt(self, rhs: Self) -> Self::Bool;
333 fn gt(self, rhs: Self) -> Self::Bool;
334 fn le(self, rhs: Self) -> Self::Bool;
335 fn ge(self, rhs: Self) -> Self::Bool;
336 }
337
338 macro_rules! impl_compare_using_partial_cmp {
339 ($($ty:ty),*) => {
340 $(
341 impl Compare for $ty {
342 type Bool = bool;
343 fn eq(self, rhs: Self) -> Self::Bool {
344 self == rhs
345 }
346 fn ne(self, rhs: Self) -> Self::Bool {
347 self != rhs
348 }
349 fn lt(self, rhs: Self) -> Self::Bool {
350 self < rhs
351 }
352 fn gt(self, rhs: Self) -> Self::Bool {
353 self > rhs
354 }
355 fn le(self, rhs: Self) -> Self::Bool {
356 self <= rhs
357 }
358 fn ge(self, rhs: Self) -> Self::Bool {
359 self >= rhs
360 }
361 }
362 )*
363 };
364 }
365
366 impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];