use round_to_nearest_ties_to_even to implement round
[vector-math.git] / src / prim.rs
1 use crate::{
2 f16::F16,
3 scalar::{Scalar, Value},
4 traits::{ConvertFrom, ConvertTo},
5 };
6 use core::{fmt, hash, ops};
7
8 mod sealed {
9 use crate::f16::F16;
10
11 pub trait Sealed {}
12 impl Sealed for F16 {}
13 impl Sealed for f32 {}
14 impl Sealed for f64 {}
15 impl Sealed for u8 {}
16 impl Sealed for u16 {}
17 impl Sealed for u32 {}
18 impl Sealed for u64 {}
19 impl Sealed for i8 {}
20 impl Sealed for i16 {}
21 impl Sealed for i32 {}
22 impl Sealed for i64 {}
23 }
24
25 pub trait PrimBase:
26 sealed::Sealed
27 + Copy
28 + 'static
29 + Send
30 + Sync
31 + PartialOrd
32 + fmt::Debug
33 + fmt::Display
34 + ops::Add<Output = Self>
35 + ops::Sub<Output = Self>
36 + ops::Mul<Output = Self>
37 + ops::Div<Output = Self>
38 + ops::Rem<Output = Self>
39 + ops::AddAssign
40 + ops::SubAssign
41 + ops::MulAssign
42 + ops::DivAssign
43 + ops::RemAssign
44 + ConvertFrom<i8>
45 + ConvertFrom<u8>
46 + ConvertFrom<i16>
47 + ConvertFrom<u16>
48 + ConvertFrom<F16>
49 + ConvertFrom<i32>
50 + ConvertFrom<u32>
51 + ConvertFrom<f32>
52 + ConvertFrom<i64>
53 + ConvertFrom<u64>
54 + ConvertFrom<f64>
55 + ConvertTo<i8>
56 + ConvertTo<u8>
57 + ConvertTo<i16>
58 + ConvertTo<u16>
59 + ConvertTo<F16>
60 + ConvertTo<i32>
61 + ConvertTo<u32>
62 + ConvertTo<f32>
63 + ConvertTo<i64>
64 + ConvertTo<u64>
65 + ConvertTo<f64>
66 {
67 }
68
69 pub trait PrimInt:
70 PrimBase
71 + Ord
72 + hash::Hash
73 + fmt::Binary
74 + fmt::LowerHex
75 + fmt::Octal
76 + fmt::UpperHex
77 + ops::BitAnd<Output = Self>
78 + ops::BitOr<Output = Self>
79 + ops::BitXor<Output = Self>
80 + ops::Shl<Output = Self>
81 + ops::Shr<Output = Self>
82 + ops::Not<Output = Self>
83 + ops::BitAndAssign
84 + ops::BitOrAssign
85 + ops::BitXorAssign
86 + ops::ShlAssign
87 + ops::ShrAssign
88 {
89 const ZERO: Self;
90 const ONE: Self;
91 const MIN: Self;
92 const MAX: Self;
93 }
94
95 pub trait PrimUInt: PrimInt + ConvertFrom<Self::SignedType> {
96 type SignedType: PrimSInt<UnsignedType = Self> + ConvertFrom<Self>;
97 }
98
99 pub trait PrimSInt: PrimInt + ops::Neg<Output = Self> + ConvertFrom<Self::UnsignedType> {
100 type UnsignedType: PrimUInt<SignedType = Self> + ConvertFrom<Self>;
101 }
102
103 macro_rules! impl_int {
104 ($uint:ident, $sint:ident) => {
105 impl PrimBase for $uint {}
106 impl PrimBase for $sint {}
107 impl PrimInt for $uint {
108 const ZERO: Self = 0;
109 const ONE: Self = 1;
110 const MIN: Self = 0;
111 const MAX: Self = !0;
112 }
113 impl PrimInt for $sint {
114 const ZERO: Self = 0;
115 const ONE: Self = 1;
116 const MIN: Self = $sint::MIN;
117 const MAX: Self = $sint::MAX;
118 }
119 impl PrimUInt for $uint {
120 type SignedType = $sint;
121 }
122 impl PrimSInt for $sint {
123 type UnsignedType = $uint;
124 }
125 };
126 }
127
128 impl_int!(u8, i8);
129 impl_int!(u16, i16);
130 impl_int!(u32, i32);
131 impl_int!(u64, i64);
132
133 pub trait PrimFloat:
134 PrimBase + ops::Neg<Output = Self> + ConvertFrom<Self::BitsType> + ConvertFrom<Self::SignedBitsType>
135 {
136 type BitsType: PrimUInt<SignedType = Self::SignedBitsType> + ConvertFrom<Self>;
137 type SignedBitsType: PrimSInt<UnsignedType = Self::BitsType> + ConvertFrom<Self>;
138 const EXPONENT_BIAS_UNSIGNED: Self::BitsType;
139 const EXPONENT_BIAS_SIGNED: Self::SignedBitsType;
140 const SIGN_FIELD_WIDTH: Self::BitsType;
141 const EXPONENT_FIELD_WIDTH: Self::BitsType;
142 const MANTISSA_FIELD_WIDTH: Self::BitsType;
143 const SIGN_FIELD_SHIFT: Self::BitsType;
144 const EXPONENT_FIELD_SHIFT: Self::BitsType;
145 const MANTISSA_FIELD_SHIFT: Self::BitsType;
146 const SIGN_FIELD_MASK: Self::BitsType;
147 const EXPONENT_FIELD_MASK: Self::BitsType;
148 const MANTISSA_FIELD_MASK: Self::BitsType;
149 const IMPLICIT_MANTISSA_BIT: Self::BitsType;
150 const ZERO_SUBNORMAL_EXPONENT: Self::BitsType;
151 const NAN_INFINITY_EXPONENT: Self::BitsType;
152 const INFINITY_BITS: Self::BitsType;
153 const NAN_BITS: Self::BitsType;
154 fn is_nan(self) -> bool;
155 fn from_bits(bits: Self::BitsType) -> Self;
156 fn to_bits(self) -> Self::BitsType;
157 fn abs(self) -> Self;
158 fn max_contiguous_integer() -> Self {
159 (Self::BitsType::cvt_from(1) << (Self::MANTISSA_FIELD_WIDTH + 1.to())).to()
160 }
161 fn is_finite(self) -> bool;
162 fn trunc(self) -> Self;
163 /// round to nearest, ties to unspecified
164 fn round(self) -> Self;
165 fn copy_sign(self, sign: Self) -> Self;
166 }
167
168 macro_rules! impl_float {
169 (
170 impl PrimFloat for $float:ident {
171 type BitsType = $bits_type:ident;
172 type SignedBitsType = $signed_bits_type:ident;
173 const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width:literal;
174 const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width:literal;
175 }
176 ) => {
177 impl PrimBase for $float {}
178
179 impl PrimFloat for $float {
180 type BitsType = $bits_type;
181 type SignedBitsType = $signed_bits_type;
182 const EXPONENT_BIAS_UNSIGNED: Self::BitsType =
183 (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1;
184 const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _;
185 const SIGN_FIELD_WIDTH: Self::BitsType = 1;
186 const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width;
187 const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width;
188 const SIGN_FIELD_SHIFT: Self::BitsType =
189 Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
190 const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH;
191 const MANTISSA_FIELD_SHIFT: Self::BitsType = 0;
192 const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT;
193 const EXPONENT_FIELD_MASK: Self::BitsType =
194 ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT;
195 const MANTISSA_FIELD_MASK: Self::BitsType = (1 << Self::MANTISSA_FIELD_WIDTH) - 1;
196 const IMPLICIT_MANTISSA_BIT: Self::BitsType = 1 << Self::MANTISSA_FIELD_WIDTH;
197 const ZERO_SUBNORMAL_EXPONENT: Self::BitsType = 0;
198 const NAN_INFINITY_EXPONENT: Self::BitsType = (1 << Self::EXPONENT_FIELD_WIDTH) - 1;
199 const INFINITY_BITS: Self::BitsType =
200 Self::NAN_INFINITY_EXPONENT << Self::EXPONENT_FIELD_SHIFT;
201 const NAN_BITS: Self::BitsType =
202 Self::INFINITY_BITS | (1 << (Self::MANTISSA_FIELD_WIDTH - 1));
203 fn is_nan(self) -> bool {
204 $float::is_nan(self)
205 }
206 fn from_bits(bits: Self::BitsType) -> Self {
207 $float::from_bits(bits)
208 }
209 fn to_bits(self) -> Self::BitsType {
210 self.to_bits()
211 }
212 fn abs(self) -> Self {
213 #[cfg(feature = "std")]
214 return $float::abs(self);
215 #[cfg(not(feature = "std"))]
216 return crate::algorithms::base::abs(Scalar, Value(self)).0;
217 }
218 fn is_finite(self) -> bool {
219 $float::is_finite(self)
220 }
221 fn trunc(self) -> Self {
222 #[cfg(feature = "std")]
223 return $float::trunc(self);
224 #[cfg(not(feature = "std"))]
225 return crate::algorithms::base::trunc(Scalar, Value(self)).0;
226 }
227 fn round(self) -> Self {
228 #[cfg(feature = "std")]
229 return $float::round(self);
230 #[cfg(not(feature = "std"))]
231 return crate::algorithms::base::round_to_nearest_ties_to_even(Scalar, Value(self))
232 .0;
233 }
234 fn copy_sign(self, sign: Self) -> Self {
235 #[cfg(feature = "std")]
236 return $float::copysign(self);
237 #[cfg(not(feature = "std"))]
238 return crate::algorithms::base::copy_sign(Scalar, Value(self), Value(sign)).0;
239 }
240 }
241 };
242 }
243
244 impl_float! {
245 impl PrimFloat for F16 {
246 type BitsType = u16;
247 type SignedBitsType = i16;
248 const EXPONENT_FIELD_WIDTH: u32 = 5;
249 const MANTISSA_FIELD_WIDTH: u32 = 10;
250 }
251 }
252
253 impl_float! {
254 impl PrimFloat for f32 {
255 type BitsType = u32;
256 type SignedBitsType = i32;
257 const EXPONENT_FIELD_WIDTH: u32 = 8;
258 const MANTISSA_FIELD_WIDTH: u32 = 23;
259 }
260 }
261
262 impl_float! {
263 impl PrimFloat for f64 {
264 type BitsType = u64;
265 type SignedBitsType = i64;
266 const EXPONENT_FIELD_WIDTH: u32 = 11;
267 const MANTISSA_FIELD_WIDTH: u32 = 52;
268 }
269 }