start adding stdsimd support
[vector-math.git] / src / traits.rs
1 use core::ops::{
2 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
3 Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
4 };
5
6 use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar};
7
8 #[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
9 macro_rules! make_float_type {
10 (
11 #[u32 = $u32:ident]
12 #[bool = $bool:ident]
13 [
14 $({
15 #[uint]
16 $uint_smaller:ident;
17 #[int]
18 $int_smaller:ident;
19 $(
20 #[float]
21 $float_smaller:ident;
22 )?
23 },)*
24 ],
25 {
26 #[uint]
27 $uint:ident;
28 #[int]
29 $int:ident;
30 #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)]
31 $float:ident;
32 },
33 [
34 $({
35 #[uint]
36 $uint_larger:ident;
37 #[int]
38 $int_larger:ident;
39 $(
40 #[float]
41 $float_larger:ident;
42 )?
43 },)*
44 ]
45 ) => {
46 type $float: Float<Self::$u32, BitsType = Self::$uint, SignedBitsType = Self::$int, FloatEncoding = $float_prim>
47 $(+ From<Self::$float_scalar>)?
48 + Compare<Bool = Self::$bool>
49 + Make<Context = Self, Prim = $float_prim>
50 $(+ ConvertTo<Self::$uint_smaller>)*
51 $(+ ConvertTo<Self::$int_smaller>)*
52 $($(+ ConvertTo<Self::$float_smaller>)?)*
53 + ConvertTo<Self::$uint>
54 + ConvertTo<Self::$int>
55 $(+ ConvertTo<Self::$uint_larger>)*
56 $(+ ConvertTo<Self::$int_larger>)*
57 $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
58 };
59 (
60 #[u32 = $u32:ident]
61 #[bool = $bool:ident]
62 [$($smaller:tt,)*],
63 {
64 #[uint]
65 $uint:ident;
66 #[int]
67 $int:ident;
68 },
69 [$($larger:tt,)*]
70 ) => {};
71 }
72
73 #[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
74 macro_rules! make_uint_int_float_type {
75 (
76 #[u32 = $u32:ident]
77 #[bool = $bool:ident]
78 [
79 $({
80 #[uint($($uint_smaller_traits:tt)*)]
81 $uint_smaller:ident;
82 #[int($($int_smaller_traits:tt)*)]
83 $int_smaller:ident;
84 $(
85 #[float($($float_smaller_traits:tt)*)]
86 $float_smaller:ident;
87 )?
88 },)*
89 ],
90 {
91 #[uint(prim = $uint_prim:ident $(, scalar = $uint_scalar:ident)?)]
92 $uint:ident;
93 #[int(prim = $int_prim:ident $(, scalar = $int_scalar:ident)?)]
94 $int:ident;
95 $(
96 #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)]
97 $float:ident;
98 )?
99 },
100 [
101 $({
102 #[uint($($uint_larger_traits:tt)*)]
103 $uint_larger:ident;
104 #[int($($int_larger_traits:tt)*)]
105 $int_larger:ident;
106 $(
107 #[float($($float_larger_traits:tt)*)]
108 $float_larger:ident;
109 )?
110 },)*
111 ]
112 ) => {
113 type $uint: UInt<Self::$u32>
114 $(+ From<Self::$uint_scalar>)?
115 + Compare<Bool = Self::$bool>
116 + Make<Context = Self, Prim = $uint_prim>
117 $(+ ConvertTo<Self::$uint_smaller>)*
118 $(+ ConvertTo<Self::$int_smaller>)*
119 $($(+ ConvertTo<Self::$float_smaller>)?)*
120 + ConvertTo<Self::$int>
121 $(+ ConvertTo<Self::$float>)?
122 $(+ Into<Self::$uint_larger> + ConvertTo<Self::$uint_larger>)*
123 $(+ Into<Self::$int_larger> + ConvertTo<Self::$int_larger>)*
124 $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
125 type $int: SInt<Self::$u32>
126 $(+ From<Self::$int_scalar>)?
127 + Compare<Bool = Self::$bool>
128 + Make<Context = Self, Prim = $int_prim>
129 $(+ ConvertTo<Self::$uint_smaller>)*
130 $(+ ConvertTo<Self::$int_smaller>)*
131 $($(+ ConvertTo<Self::$float_smaller>)?)*
132 + ConvertTo<Self::$uint>
133 $(+ ConvertTo<Self::$float>)?
134 $(+ ConvertTo<Self::$uint_larger>)*
135 $(+ Into<Self::$int_larger> + ConvertTo<Self::$int_larger>)*
136 $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
137 make_float_type! {
138 #[u32 = $u32]
139 #[bool = $bool]
140 [
141 $({
142 #[uint]
143 $uint_smaller;
144 #[int]
145 $int_smaller;
146 $(
147 #[float]
148 $float_smaller;
149 )?
150 },)*
151 ],
152 {
153 #[uint]
154 $uint;
155 #[int]
156 $int;
157 $(
158 #[float(prim = $float_prim $(, scalar = $float_scalar)?)]
159 $float;
160 )?
161 },
162 [
163 $({
164 #[uint]
165 $uint_larger;
166 #[int]
167 $int_larger;
168 $(
169 #[float]
170 $float_larger;
171 )?
172 },)*
173 ]
174 }
175 };
176 }
177
178 macro_rules! make_uint_int_float_types {
179 (
180 #[u32 = $u32:ident]
181 #[bool = $bool:ident]
182 [$($smaller:tt,)*],
183 $current:tt,
184 [$first_larger:tt, $($larger:tt,)*]
185 ) => {
186 make_uint_int_float_type! {
187 #[u32 = $u32]
188 #[bool = $bool]
189 [$($smaller,)*],
190 $current,
191 [$first_larger, $($larger,)*]
192 }
193 make_uint_int_float_types! {
194 #[u32 = $u32]
195 #[bool = $bool]
196 [$($smaller,)* $current,],
197 $first_larger,
198 [$($larger,)*]
199 }
200 };
201 (
202 #[u32 = $u32:ident]
203 #[bool = $bool:ident]
204 [$($smaller:tt,)*],
205 $current:tt,
206 []
207 ) => {
208 make_uint_int_float_type! {
209 #[u32 = $u32]
210 #[bool = $bool]
211 [$($smaller,)*],
212 $current,
213 []
214 }
215 };
216 }
217
218 #[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
219 macro_rules! make_types {
220 (
221 #[bool]
222 $(#[scalar = $ScalarBool:ident])?
223 type $Bool:ident;
224
225 #[u8]
226 $(#[scalar = $ScalarU8:ident])?
227 type $U8:ident;
228
229 #[u16]
230 $(#[scalar = $ScalarU16:ident])?
231 type $U16:ident;
232
233 #[u32]
234 $(#[scalar = $ScalarU32:ident])?
235 type $U32:ident;
236
237 #[u64]
238 $(#[scalar = $ScalarU64:ident])?
239 type $U64:ident;
240
241 #[i8]
242 $(#[scalar = $ScalarI8:ident])?
243 type $I8:ident;
244
245 #[i16]
246 $(#[scalar = $ScalarI16:ident])?
247 type $I16:ident;
248
249 #[i32]
250 $(#[scalar = $ScalarI32:ident])?
251 type $I32:ident;
252
253 #[i64]
254 $(#[scalar = $ScalarI64:ident])?
255 type $I64:ident;
256
257 #[f16]
258 $(#[scalar = $ScalarF16:ident])?
259 type $F16:ident;
260
261 #[f32]
262 $(#[scalar = $ScalarF32:ident])?
263 type $F32:ident;
264
265 #[f64]
266 $(#[scalar = $ScalarF64:ident])?
267 type $F64:ident;
268 ) => {
269 type $Bool: Bool
270 $(+ From<Self::$ScalarBool>)?
271 + Make<Context = Self, Prim = bool>
272 + Select<Self::$Bool>
273 + Select<Self::$U8>
274 + Select<Self::$U16>
275 + Select<Self::$U32>
276 + Select<Self::$U64>
277 + Select<Self::$I8>
278 + Select<Self::$I16>
279 + Select<Self::$I32>
280 + Select<Self::$I64>
281 + Select<Self::$F16>
282 + Select<Self::$F32>
283 + Select<Self::$F64>;
284 make_uint_int_float_types! {
285 #[u32 = $U32]
286 #[bool = $Bool]
287 [],
288 {
289 #[uint(prim = u8 $(, scalar = $ScalarU8)?)]
290 $U8;
291 #[int(prim = i8 $(, scalar = $ScalarI8)?)]
292 $I8;
293 },
294 [
295 {
296 #[uint(prim = u16 $(, scalar = $ScalarU16)?)]
297 $U16;
298 #[int(prim = i16 $(, scalar = $ScalarI16)?)]
299 $I16;
300 #[float(prim = F16 $(, scalar = $ScalarF16)?)]
301 $F16;
302 },
303 {
304 #[uint(prim = u32 $(, scalar = $ScalarU32)?)]
305 $U32;
306 #[int(prim = i32 $(, scalar = $ScalarI32)?)]
307 $I32;
308 #[float(prim = f32 $(, scalar = $ScalarF32)?)]
309 $F32;
310 },
311 {
312 #[uint(prim = u64 $(, scalar = $ScalarU64)?)]
313 $U64;
314 #[int(prim = i64 $(, scalar = $ScalarI64)?)]
315 $I64;
316 #[float(prim = f64 $(, scalar = $ScalarF64)?)]
317 $F64;
318 },
319 ]
320 }
321 };
322 }
323
324 /// reference used to build IR for Kazan; an empty type for `core::simd`
325 pub trait Context: Copy {
326 make_types! {
327 #[bool]
328 type Bool;
329
330 #[u8]
331 type U8;
332
333 #[u16]
334 type U16;
335
336 #[u32]
337 type U32;
338
339 #[u64]
340 type U64;
341
342 #[i8]
343 type I8;
344
345 #[i16]
346 type I16;
347
348 #[i32]
349 type I32;
350
351 #[i64]
352 type I64;
353
354 #[f16]
355 type F16;
356
357 #[f32]
358 type F32;
359
360 #[f64]
361 type F64;
362 }
363 make_types! {
364 #[bool]
365 #[scalar = Bool]
366 type VecBool;
367
368 #[u8]
369 #[scalar = U8]
370 type VecU8;
371
372 #[u16]
373 #[scalar = U16]
374 type VecU16;
375
376 #[u32]
377 #[scalar = U32]
378 type VecU32;
379
380 #[u64]
381 #[scalar = U64]
382 type VecU64;
383
384 #[i8]
385 #[scalar = I8]
386 type VecI8;
387
388 #[i16]
389 #[scalar = I16]
390 type VecI16;
391
392 #[i32]
393 #[scalar = I32]
394 type VecI32;
395
396 #[i64]
397 #[scalar = I64]
398 type VecI64;
399
400 #[f16]
401 #[scalar = F16]
402 type VecF16;
403
404 #[f32]
405 #[scalar = F32]
406 type VecF32;
407
408 #[f64]
409 #[scalar = F64]
410 type VecF64;
411 }
412 fn make<T: Make<Context = Self>>(self, v: T::Prim) -> T {
413 T::make(self, v)
414 }
415 }
416
417 pub trait Make: Copy {
418 type Prim: Copy;
419 type Context: Context;
420 fn ctx(self) -> Self::Context;
421 fn make(ctx: Self::Context, v: Self::Prim) -> Self;
422 }
423
424 pub trait ConvertTo<T> {
425 fn to(self) -> T;
426 }
427
428 macro_rules! impl_convert_to_using_as {
429 ($($src:ident -> [$($dest:ident),*];)*) => {
430 $($(
431 impl ConvertTo<$dest> for $src {
432 fn to(self) -> $dest {
433 self as $dest
434 }
435 }
436 )*)*
437 };
438 ([$($src:ident),*] -> $dest:tt;) => {
439 impl_convert_to_using_as! {
440 $(
441 $src -> $dest;
442 )*
443 }
444 };
445 ([$($src:ident),*];) => {
446 impl_convert_to_using_as! {
447 [$($src),*] -> [$($src),*];
448 }
449 };
450 }
451
452 impl_convert_to_using_as! {
453 [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
454 }
455
456 pub trait Number:
457 Compare
458 + Add<Output = Self>
459 + Sub<Output = Self>
460 + Mul<Output = Self>
461 + Div<Output = Self>
462 + Rem<Output = Self>
463 + AddAssign
464 + SubAssign
465 + MulAssign
466 + DivAssign
467 + RemAssign
468 {
469 }
470
471 impl<T> Number for T where
472 T: Compare
473 + Add<Output = Self>
474 + Sub<Output = Self>
475 + Mul<Output = Self>
476 + Div<Output = Self>
477 + Rem<Output = Self>
478 + AddAssign
479 + SubAssign
480 + MulAssign
481 + DivAssign
482 + RemAssign
483 {
484 }
485
486 pub trait BitOps:
487 Copy
488 + BitAnd<Output = Self>
489 + BitOr<Output = Self>
490 + BitXor<Output = Self>
491 + Not<Output = Self>
492 + BitAndAssign
493 + BitOrAssign
494 + BitXorAssign
495 {
496 }
497
498 impl<T> BitOps for T where
499 T: Copy
500 + BitAnd<Output = Self>
501 + BitOr<Output = Self>
502 + BitXor<Output = Self>
503 + Not<Output = Self>
504 + BitAndAssign
505 + BitOrAssign
506 + BitXorAssign
507 {
508 }
509
510 pub trait Int<ShiftRhs>:
511 Number
512 + BitOps
513 + Shl<ShiftRhs, Output = Self>
514 + Shr<ShiftRhs, Output = Self>
515 + ShlAssign<ShiftRhs>
516 + ShrAssign<ShiftRhs>
517 {
518 fn leading_zeros(self) -> Self;
519 fn leading_ones(self) -> Self {
520 self.not().leading_zeros()
521 }
522 fn trailing_zeros(self) -> Self;
523 fn trailing_ones(self) -> Self {
524 self.not().trailing_zeros()
525 }
526 fn count_zeros(self) -> Self {
527 self.not().count_ones()
528 }
529 fn count_ones(self) -> Self;
530 }
531
532 pub trait UInt<ShiftRhs>: Int<ShiftRhs> {}
533
534 pub trait SInt<ShiftRhs>: Int<ShiftRhs> + Neg<Output = Self> {}
535
536 macro_rules! impl_int {
537 ($ty:ident) => {
538 impl Int<u32> for $ty {
539 fn leading_zeros(self) -> Self {
540 self.leading_zeros() as Self
541 }
542 fn leading_ones(self) -> Self {
543 self.leading_ones() as Self
544 }
545 fn trailing_zeros(self) -> Self {
546 self.trailing_zeros() as Self
547 }
548 fn trailing_ones(self) -> Self {
549 self.trailing_ones() as Self
550 }
551 fn count_zeros(self) -> Self {
552 self.count_zeros() as Self
553 }
554 fn count_ones(self) -> Self {
555 self.count_ones() as Self
556 }
557 }
558 };
559 }
560
561 macro_rules! impl_uint {
562 ($($ty:ident),*) => {
563 $(
564 impl_int!($ty);
565 impl UInt<u32> for $ty {}
566 )*
567 };
568 }
569
570 impl_uint![u8, u16, u32, u64];
571
572 macro_rules! impl_sint {
573 ($($ty:ident),*) => {
574 $(
575 impl_int!($ty);
576 impl SInt<u32> for $ty {}
577 )*
578 };
579 }
580
581 impl_sint![i8, i16, i32, i64];
582
583 pub trait Float<BitsShiftRhs: Make<Context = Self::Context, Prim = u32>>:
584 Number + Neg<Output = Self>
585 {
586 type FloatEncoding: FloatEncoding + Make<Context = Scalar, Prim = <Self as Make>::Prim>;
587 type BitsType: UInt<BitsShiftRhs>
588 + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::BitsType>
589 + ConvertTo<Self::SignedBitsType>
590 + Compare<Bool = Self::Bool>;
591 type SignedBitsType: SInt<BitsShiftRhs>
592 + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::SignedBitsType>
593 + ConvertTo<Self::BitsType>
594 + Compare<Bool = Self::Bool>;
595 fn abs(self) -> Self;
596 fn trunc(self) -> Self;
597 fn ceil(self) -> Self;
598 fn floor(self) -> Self;
599 fn round(self) -> Self;
600 #[cfg(feature = "fma")]
601 fn fma(self, a: Self, b: Self) -> Self;
602 fn is_nan(self) -> Self::Bool {
603 self.ne(self)
604 }
605 fn is_infinite(self) -> Self::Bool {
606 self.abs().eq(Self::infinity(self.ctx()))
607 }
608 fn infinity(ctx: Self::Context) -> Self {
609 Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS))
610 }
611 fn nan(ctx: Self::Context) -> Self {
612 Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS))
613 }
614 fn is_finite(self) -> Self::Bool;
615 fn is_zero_or_subnormal(self) -> Self::Bool {
616 self.extract_exponent_field().eq(self
617 .ctx()
618 .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT))
619 }
620 fn from_bits(v: Self::BitsType) -> Self;
621 fn to_bits(self) -> Self::BitsType;
622 fn extract_exponent_field(self) -> Self::BitsType {
623 let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK);
624 let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT);
625 (self.to_bits() & mask) >> shift
626 }
627 fn extract_exponent_unbiased(self) -> Self::SignedBitsType {
628 Self::sub_exponent_bias(self.extract_exponent_field())
629 }
630 fn extract_mantissa_field(self) -> Self::BitsType {
631 let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK);
632 self.to_bits() & mask
633 }
634 fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType {
635 exponent_field.to()
636 - exponent_field
637 .ctx()
638 .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)
639 }
640 fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType {
641 (exponent
642 + exponent
643 .ctx()
644 .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED))
645 .to()
646 }
647 }
648
649 macro_rules! impl_float {
650 ($ty:ty, $bits:ty, $signed_bits:ty) => {
651 impl Float<u32> for $ty {
652 type FloatEncoding = $ty;
653 type BitsType = $bits;
654 type SignedBitsType = $signed_bits;
655 fn abs(self) -> Self {
656 #[cfg(feature = "std")]
657 return self.abs();
658 #[cfg(not(feature = "std"))]
659 todo!();
660 }
661 fn trunc(self) -> Self {
662 #[cfg(feature = "std")]
663 return self.trunc();
664 #[cfg(not(feature = "std"))]
665 todo!();
666 }
667 fn ceil(self) -> Self {
668 #[cfg(feature = "std")]
669 return self.ceil();
670 #[cfg(not(feature = "std"))]
671 todo!();
672 }
673 fn floor(self) -> Self {
674 #[cfg(feature = "std")]
675 return self.floor();
676 #[cfg(not(feature = "std"))]
677 todo!();
678 }
679 fn round(self) -> Self {
680 #[cfg(feature = "std")]
681 return self.round();
682 #[cfg(not(feature = "std"))]
683 todo!();
684 }
685 #[cfg(feature = "fma")]
686 fn fma(self, a: Self, b: Self) -> Self {
687 self.mul_add(a, b)
688 }
689 fn is_nan(self) -> Self::Bool {
690 self.is_nan()
691 }
692 fn is_infinite(self) -> Self::Bool {
693 self.is_infinite()
694 }
695 fn is_finite(self) -> Self::Bool {
696 self.is_finite()
697 }
698 fn from_bits(v: Self::BitsType) -> Self {
699 <$ty>::from_bits(v)
700 }
701 fn to_bits(self) -> Self::BitsType {
702 self.to_bits()
703 }
704 }
705 };
706 }
707
708 impl_float!(f32, u32, i32);
709 impl_float!(f64, u64, i64);
710
711 pub trait Bool: Make + BitOps {}
712
713 impl Bool for bool {}
714
715 pub trait Select<T>: Bool {
716 fn select(self, true_v: T, false_v: T) -> T;
717 }
718
719 impl<T> Select<T> for bool {
720 fn select(self, true_v: T, false_v: T) -> T {
721 if self {
722 true_v
723 } else {
724 false_v
725 }
726 }
727 }
728 pub trait Compare: Make {
729 type Bool: Bool + Select<Self>;
730 fn eq(self, rhs: Self) -> Self::Bool;
731 fn ne(self, rhs: Self) -> Self::Bool;
732 fn lt(self, rhs: Self) -> Self::Bool;
733 fn gt(self, rhs: Self) -> Self::Bool;
734 fn le(self, rhs: Self) -> Self::Bool;
735 fn ge(self, rhs: Self) -> Self::Bool;
736 }
737
738 macro_rules! impl_compare_using_partial_cmp {
739 ($($ty:ty),*) => {
740 $(
741 impl Compare for $ty {
742 type Bool = bool;
743 fn eq(self, rhs: Self) -> Self::Bool {
744 self == rhs
745 }
746 fn ne(self, rhs: Self) -> Self::Bool {
747 self != rhs
748 }
749 fn lt(self, rhs: Self) -> Self::Bool {
750 self < rhs
751 }
752 fn gt(self, rhs: Self) -> Self::Bool {
753 self > rhs
754 }
755 fn le(self, rhs: Self) -> Self::Bool {
756 self <= rhs
757 }
758 fn ge(self, rhs: Self) -> Self::Bool {
759 self >= rhs
760 }
761 }
762 )*
763 };
764 }
765
766 impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];