refactor to easily allow algorithms generic over f16/32/64
[vector-math.git] / src / scalar.rs
1 use crate::{
2 f16::F16,
3 prim::{PrimSInt, PrimUInt},
4 traits::{Bool, Compare, Context, ConvertFrom, Float, Int, Make, SInt, Select, UInt},
5 };
6 use core::ops::{
7 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
8 Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
9 };
10
11 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)]
12 pub struct Scalar;
13
14 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Default)]
15 #[repr(transparent)]
16 pub struct Value<T>(pub T);
17
18 macro_rules! impl_convert_from {
19 ($first:ident $(, $ty:ident)*) => {
20 $(
21 impl ConvertFrom<Value<$first>> for Value<$ty> {
22 fn cvt_from(v: Value<$first>) -> Self {
23 Value(ConvertFrom::cvt_from(v.0))
24 }
25 }
26 impl ConvertFrom<Value<$ty>> for Value<$first> {
27 fn cvt_from(v: Value<$ty>) -> Self {
28 Value(ConvertFrom::cvt_from(v.0))
29 }
30 }
31 )*
32 impl_convert_from![$($ty),*];
33 };
34 () => {
35 };
36 }
37
38 impl_convert_from![u8, i8, u16, i16, F16, u32, i32, u64, i64, f32, f64];
39
40 macro_rules! impl_bit_ops {
41 ($ty:ident) => {
42 impl BitAnd for Value<$ty> {
43 type Output = Self;
44
45 fn bitand(self, rhs: Self) -> Self {
46 Value(self.0 & rhs.0)
47 }
48 }
49
50 impl BitOr for Value<$ty> {
51 type Output = Self;
52
53 fn bitor(self, rhs: Self) -> Self {
54 Value(self.0 | rhs.0)
55 }
56 }
57
58 impl BitXor for Value<$ty> {
59 type Output = Self;
60
61 fn bitxor(self, rhs: Self) -> Self {
62 Value(self.0 ^ rhs.0)
63 }
64 }
65
66 impl Not for Value<$ty> {
67 type Output = Self;
68
69 fn not(self) -> Self {
70 Value(!self.0)
71 }
72 }
73
74 impl BitAndAssign for Value<$ty> {
75 fn bitand_assign(&mut self, rhs: Self) {
76 self.0 &= rhs.0;
77 }
78 }
79
80 impl BitOrAssign for Value<$ty> {
81 fn bitor_assign(&mut self, rhs: Self) {
82 self.0 |= rhs.0;
83 }
84 }
85
86 impl BitXorAssign for Value<$ty> {
87 fn bitxor_assign(&mut self, rhs: Self) {
88 self.0 ^= rhs.0;
89 }
90 }
91 };
92 }
93
94 macro_rules! impl_wrapping_int_ops {
95 ($ty:ident) => {
96 impl Add for Value<$ty> {
97 type Output = Self;
98
99 fn add(self, rhs: Self) -> Self {
100 Value(self.0.wrapping_add(rhs.0))
101 }
102 }
103
104 impl Sub for Value<$ty> {
105 type Output = Self;
106
107 fn sub(self, rhs: Self) -> Self {
108 Value(self.0.wrapping_sub(rhs.0))
109 }
110 }
111
112 impl Mul for Value<$ty> {
113 type Output = Self;
114
115 fn mul(self, rhs: Self) -> Self {
116 Value(self.0.wrapping_mul(rhs.0))
117 }
118 }
119
120 impl Div for Value<$ty> {
121 type Output = Self;
122
123 fn div(self, rhs: Self) -> Self {
124 Value(self.0.wrapping_div(rhs.0))
125 }
126 }
127
128 impl Rem for Value<$ty> {
129 type Output = Self;
130
131 fn rem(self, rhs: Self) -> Self {
132 Value(self.0.wrapping_rem(rhs.0))
133 }
134 }
135
136 impl Shl for Value<$ty> {
137 type Output = Self;
138
139 fn shl(self, rhs: Self) -> Self {
140 Value(self.0.wrapping_shl(rhs.0 as u32))
141 }
142 }
143
144 impl Shr for Value<$ty> {
145 type Output = Self;
146
147 fn shr(self, rhs: Self) -> Self {
148 Value(self.0.wrapping_shr(rhs.0 as u32))
149 }
150 }
151
152 impl Neg for Value<$ty> {
153 type Output = Self;
154
155 fn neg(self) -> Self {
156 Value(self.0.wrapping_neg())
157 }
158 }
159
160 impl AddAssign for Value<$ty> {
161 fn add_assign(&mut self, rhs: Self) {
162 *self = self.add(rhs);
163 }
164 }
165
166 impl SubAssign for Value<$ty> {
167 fn sub_assign(&mut self, rhs: Self) {
168 *self = self.sub(rhs);
169 }
170 }
171
172 impl MulAssign for Value<$ty> {
173 fn mul_assign(&mut self, rhs: Self) {
174 *self = self.mul(rhs);
175 }
176 }
177
178 impl DivAssign for Value<$ty> {
179 fn div_assign(&mut self, rhs: Self) {
180 *self = self.div(rhs);
181 }
182 }
183
184 impl RemAssign for Value<$ty> {
185 fn rem_assign(&mut self, rhs: Self) {
186 *self = self.rem(rhs);
187 }
188 }
189
190 impl ShlAssign for Value<$ty> {
191 fn shl_assign(&mut self, rhs: Self) {
192 *self = self.shl(rhs);
193 }
194 }
195
196 impl ShrAssign for Value<$ty> {
197 fn shr_assign(&mut self, rhs: Self) {
198 *self = self.shr(rhs);
199 }
200 }
201 };
202 }
203 macro_rules! impl_int {
204 ($ty:ident) => {
205 impl_bit_ops!($ty);
206 impl_wrapping_int_ops!($ty);
207 impl Int for Value<$ty> {
208 fn leading_zeros(self) -> Self {
209 Value(self.0.leading_zeros() as $ty)
210 }
211 fn leading_ones(self) -> Self {
212 Value(self.0.leading_ones() as $ty)
213 }
214 fn trailing_zeros(self) -> Self {
215 Value(self.0.trailing_zeros() as $ty)
216 }
217 fn trailing_ones(self) -> Self {
218 Value(self.0.trailing_ones() as $ty)
219 }
220 fn count_zeros(self) -> Self {
221 Value(self.0.count_zeros() as $ty)
222 }
223 fn count_ones(self) -> Self {
224 Value(self.0.count_ones() as $ty)
225 }
226 }
227 };
228 }
229
230 macro_rules! impl_uint {
231 ($($ty:ident),*) => {
232 $(
233 impl_int!($ty);
234 impl UInt for Value<$ty> {
235 type PrimUInt = $ty;
236 type SignedType = Value<<$ty as PrimUInt>::SignedType>;
237 }
238 )*
239 };
240 }
241
242 impl_uint![u8, u16, u32, u64];
243
244 macro_rules! impl_sint {
245 ($($ty:ident),*) => {
246 $(
247 impl_int!($ty);
248 impl SInt for Value<$ty> {
249 type PrimSInt = $ty;
250 type UnsignedType = Value<<$ty as PrimSInt>::UnsignedType>;
251 }
252 )*
253 };
254 }
255
256 impl_sint![i8, i16, i32, i64];
257
258 macro_rules! impl_float_ops {
259 ($ty:ident) => {
260 impl Add for Value<$ty> {
261 type Output = Self;
262
263 fn add(self, rhs: Self) -> Self {
264 Value(self.0.add(rhs.0))
265 }
266 }
267
268 impl Sub for Value<$ty> {
269 type Output = Self;
270
271 fn sub(self, rhs: Self) -> Self {
272 Value(self.0.sub(rhs.0))
273 }
274 }
275
276 impl Mul for Value<$ty> {
277 type Output = Self;
278
279 fn mul(self, rhs: Self) -> Self {
280 Value(self.0.mul(rhs.0))
281 }
282 }
283
284 impl Div for Value<$ty> {
285 type Output = Self;
286
287 fn div(self, rhs: Self) -> Self {
288 Value(self.0.div(rhs.0))
289 }
290 }
291
292 impl Rem for Value<$ty> {
293 type Output = Self;
294
295 fn rem(self, rhs: Self) -> Self {
296 Value(self.0.rem(rhs.0))
297 }
298 }
299
300 impl Neg for Value<$ty> {
301 type Output = Self;
302
303 fn neg(self) -> Self {
304 Value(self.0.neg())
305 }
306 }
307
308 impl AddAssign for Value<$ty> {
309 fn add_assign(&mut self, rhs: Self) {
310 *self = self.add(rhs);
311 }
312 }
313
314 impl SubAssign for Value<$ty> {
315 fn sub_assign(&mut self, rhs: Self) {
316 *self = self.sub(rhs);
317 }
318 }
319
320 impl MulAssign for Value<$ty> {
321 fn mul_assign(&mut self, rhs: Self) {
322 *self = self.mul(rhs);
323 }
324 }
325
326 impl DivAssign for Value<$ty> {
327 fn div_assign(&mut self, rhs: Self) {
328 *self = self.div(rhs);
329 }
330 }
331
332 impl RemAssign for Value<$ty> {
333 fn rem_assign(&mut self, rhs: Self) {
334 *self = self.rem(rhs);
335 }
336 }
337 };
338 }
339
340 impl_float_ops!(F16);
341
342 macro_rules! impl_float {
343 ($ty:ident, $bits:ty, $signed_bits:ty) => {
344 impl_float_ops!($ty);
345 impl Float for Value<$ty> {
346 type PrimFloat = $ty;
347 type BitsType = Value<$bits>;
348 type SignedBitsType = Value<$signed_bits>;
349 fn abs(self) -> Self {
350 #[cfg(feature = "std")]
351 return Value(self.0.abs());
352 #[cfg(not(feature = "std"))]
353 todo!();
354 }
355 fn trunc(self) -> Self {
356 #[cfg(feature = "std")]
357 return Value(self.0.trunc());
358 #[cfg(not(feature = "std"))]
359 todo!();
360 }
361 fn ceil(self) -> Self {
362 #[cfg(feature = "std")]
363 return Value(self.0.ceil());
364 #[cfg(not(feature = "std"))]
365 todo!();
366 }
367 fn floor(self) -> Self {
368 #[cfg(feature = "std")]
369 return Value(self.0.floor());
370 #[cfg(not(feature = "std"))]
371 todo!();
372 }
373 fn round(self) -> Self {
374 #[cfg(feature = "std")]
375 return Value(self.0.round());
376 #[cfg(not(feature = "std"))]
377 todo!();
378 }
379 #[cfg(feature = "fma")]
380 fn fma(self, a: Self, b: Self) -> Self {
381 Value(self.0.mul_add(a.0, b.0))
382 }
383 fn is_nan(self) -> Self::Bool {
384 Value(self.0.is_nan())
385 }
386 fn is_infinite(self) -> Self::Bool {
387 Value(self.0.is_infinite())
388 }
389 fn is_finite(self) -> Self::Bool {
390 Value(self.0.is_finite())
391 }
392 fn from_bits(v: Self::BitsType) -> Self {
393 Value(<$ty>::from_bits(v.0))
394 }
395 fn to_bits(self) -> Self::BitsType {
396 Value(self.0.to_bits())
397 }
398 }
399 };
400 }
401
402 impl_float!(f32, u32, i32);
403 impl_float!(f64, u64, i64);
404
405 macro_rules! impl_compare_using_partial_cmp {
406 ($($ty:ty),*) => {
407 $(
408 impl Compare for Value<$ty> {
409 type Bool = Value<bool>;
410 fn eq(self, rhs: Self) -> Self::Bool {
411 Value(self == rhs)
412 }
413 fn ne(self, rhs: Self) -> Self::Bool {
414 Value(self != rhs)
415 }
416 fn lt(self, rhs: Self) -> Self::Bool {
417 Value(self < rhs)
418 }
419 fn gt(self, rhs: Self) -> Self::Bool {
420 Value(self > rhs)
421 }
422 fn le(self, rhs: Self) -> Self::Bool {
423 Value(self <= rhs)
424 }
425 fn ge(self, rhs: Self) -> Self::Bool {
426 Value(self >= rhs)
427 }
428 }
429 )*
430 };
431 }
432
433 impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];
434
435 impl Bool for Value<bool> {}
436
437 impl_bit_ops!(bool);
438
439 impl<T> Select<Value<T>> for Value<bool> {
440 fn select(self, true_v: Value<T>, false_v: Value<T>) -> Value<T> {
441 if self.0 {
442 true_v
443 } else {
444 false_v
445 }
446 }
447 }
448
449 macro_rules! impl_from {
450 ($src:ident => [$($dest:ident),*]) => {
451 $(
452 impl From<Value<$src>> for Value<$dest> {
453 fn from(v: Value<$src>) -> Self {
454 Value(v.0.into())
455 }
456 }
457 )*
458 };
459 }
460
461 impl_from!(u8 => [u16, i16, F16, u32, i32, f32, u64, i64, f64]);
462 impl_from!(u16 => [u32, i32, f32, u64, i64, f64]);
463 impl_from!(u32 => [u64, i64, f64]);
464 impl_from!(i8 => [i16, F16, i32, f32, i64, f64]);
465 impl_from!(i16 => [i32, f32, i64, f64]);
466 impl_from!(i32 => [i64, f64]);
467 impl_from!(F16 => [f32, f64]);
468 impl_from!(f32 => [f64]);
469
470 macro_rules! impl_context {
471 (
472 impl Context for Scalar {
473 $(type $name:ident = Value<$ty:ident>;)*
474 #[vec]
475 $(type $vec_name:ident = Value<$vec_ty:ident>;)*
476 }
477 ) => {
478 impl Context for Scalar {
479 $(type $name = Value<$ty>;)*
480 $(type $vec_name = Value<$vec_ty>;)*
481 }
482
483 $(
484 impl Make for Value<$ty> {
485 type Prim = $ty;
486 type Context = Scalar;
487 fn ctx(self) -> Self::Context {
488 Scalar
489 }
490 fn make(_ctx: Self::Context, v: Self::Prim) -> Self {
491 Value(v)
492 }
493 }
494 )*
495 };
496 }
497
498 impl_context! {
499 impl Context for Scalar {
500 type Bool = Value<bool>;
501 type U8 = Value<u8>;
502 type I8 = Value<i8>;
503 type U16 = Value<u16>;
504 type I16 = Value<i16>;
505 type F16 = Value<F16>;
506 type U32 = Value<u32>;
507 type I32 = Value<i32>;
508 type F32 = Value<f32>;
509 type U64 = Value<u64>;
510 type I64 = Value<i64>;
511 type F64 = Value<f64>;
512 #[vec]
513 type VecBool8 = Value<bool>;
514 type VecU8 = Value<u8>;
515 type VecI8 = Value<i8>;
516 type VecBool16 = Value<bool>;
517 type VecU16 = Value<u16>;
518 type VecI16 = Value<i16>;
519 type VecF16 = Value<F16>;
520 type VecBool32 = Value<bool>;
521 type VecU32 = Value<u32>;
522 type VecI32 = Value<i32>;
523 type VecF32 = Value<f32>;
524 type VecBool64 = Value<bool>;
525 type VecU64 = Value<u64>;
526 type VecI64 = Value<i64>;
527 type VecF64 = Value<f64>;
528 }
529 }