5bc1119443466e949640876b70e54c6bf15f24d6
[vector-math.git] / src / f16.rs
1 use core::ops::{
2 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
3 };
4
5 use crate::traits::{ConvertTo, Float};
6
7 #[cfg(feature = "f16")]
8 use half::f16 as F16Impl;
9
10 #[cfg(not(feature = "f16"))]
11 type F16Impl = u16;
12
13 #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
14 #[repr(transparent)]
15 pub struct F16(F16Impl);
16
17 #[cfg(not(feature = "f16"))]
18 #[track_caller]
19 pub(crate) fn panic_f16_feature_disabled() -> ! {
20 panic!("f16 feature is not enabled")
21 }
22
23 #[cfg(feature = "f16")]
24 macro_rules! f16_impl {
25 ($v:expr, [$($vars:ident),*]) => {
26 $v
27 };
28 }
29
30 #[cfg(not(feature = "f16"))]
31 macro_rules! f16_impl {
32 ($v:expr, [$($vars:ident),*]) => {
33 {
34 $(let _ = $vars;)*
35 panic_f16_feature_disabled()
36 }
37 };
38 }
39
40 impl Default for F16 {
41 fn default() -> Self {
42 f16_impl!(F16(F16Impl::default()), [])
43 }
44 }
45
46 impl From<F16Impl> for F16 {
47 fn from(v: F16Impl) -> Self {
48 F16(v)
49 }
50 }
51
52 impl From<F16> for F16Impl {
53 fn from(v: F16) -> Self {
54 v.0
55 }
56 }
57
58 macro_rules! impl_f16_from {
59 ($($ty:ident,)*) => {
60 $(
61 impl From<$ty> for F16 {
62 fn from(v: $ty) -> Self {
63 f16_impl!(F16(F16Impl::from(v)), [v])
64 }
65 }
66
67 impl ConvertTo<F16> for $ty {
68 fn to(self) -> F16 {
69 self.into()
70 }
71 }
72 )*
73 };
74 }
75
76 macro_rules! impl_from_f16 {
77 ($($ty:ident,)*) => {
78 $(
79 impl From<F16> for $ty {
80 fn from(v: F16) -> Self {
81 f16_impl!(v.0.into(), [v])
82 }
83 }
84
85 impl ConvertTo<$ty> for F16 {
86 fn to(self) -> $ty {
87 self.into()
88 }
89 }
90 )*
91 };
92 }
93
94 impl_f16_from![i8, u8,];
95
96 impl_from_f16![f32, f64,];
97
98 macro_rules! impl_int_to_f16 {
99 ($($int:ident),*) => {
100 $(
101 impl ConvertTo<F16> for $int {
102 fn to(self) -> F16 {
103 // f32 has enough mantissa bits such that f16 overflows to
104 // infinity before f32 stops being able to properly
105 // represent integer values, making the below conversion correct.
106 (self as f32).to()
107 }
108 }
109 )*
110 };
111 }
112
113 macro_rules! impl_f16_to_int {
114 ($($int:ident),*) => {
115 $(
116 impl ConvertTo<$int> for F16 {
117 fn to(self) -> $int {
118 f32::from(self) as $int
119 }
120 }
121 )*
122 };
123 }
124
125 impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
126 impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
127
128 impl ConvertTo<F16> for f32 {
129 fn to(self) -> F16 {
130 f16_impl!(F16(F16Impl::from_f32(self)), [])
131 }
132 }
133
134 impl ConvertTo<F16> for f64 {
135 fn to(self) -> F16 {
136 f16_impl!(F16(F16Impl::from_f64(self)), [])
137 }
138 }
139
140 impl Neg for F16 {
141 type Output = Self;
142
143 fn neg(self) -> Self::Output {
144 f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), [])
145 }
146 }
147
148 macro_rules! impl_bin_op_using_f32 {
149 ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
150 $(
151 impl $op for F16 {
152 type Output = Self;
153
154 fn $op_fn(self, rhs: Self) -> Self::Output {
155 f32::from(self).$op_fn(f32::from(rhs)).to()
156 }
157 }
158
159 impl $op_assign for F16 {
160 fn $op_assign_fn(&mut self, rhs: Self) {
161 *self = (*self).$op_fn(rhs);
162 }
163 }
164 )*
165 };
166 }
167
168 impl_bin_op_using_f32! {
169 Add, add, AddAssign, add_assign;
170 Sub, sub, SubAssign, sub_assign;
171 Mul, mul, MulAssign, mul_assign;
172 Div, div, DivAssign, div_assign;
173 Rem, rem, RemAssign, rem_assign;
174 }
175
176 impl Float for F16 {
177 type FloatEncoding = F16;
178 type BitsType = u16;
179 type SignedBitsType = i16;
180
181 fn abs(self) -> Self {
182 f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
183 }
184
185 fn trunc(self) -> Self {
186 f32::from(self).trunc().to()
187 }
188
189 fn ceil(self) -> Self {
190 f32::from(self).ceil().to()
191 }
192
193 fn floor(self) -> Self {
194 f32::from(self).floor().to()
195 }
196
197 fn round(self) -> Self {
198 f32::from(self).round().to()
199 }
200
201 #[cfg(feature = "fma")]
202 fn fma(self, a: Self, b: Self) -> Self {
203 (f64::from(self) * f64::from(a) + f64::from(b)).to()
204 }
205
206 fn is_nan(self) -> Self::Bool {
207 f16_impl!(self.0.is_nan(), [])
208 }
209
210 fn is_infinite(self) -> Self::Bool {
211 f16_impl!(self.0.is_infinite(), [])
212 }
213
214 fn is_finite(self) -> Self::Bool {
215 f16_impl!(self.0.is_finite(), [])
216 }
217
218 fn from_bits(v: Self::BitsType) -> Self {
219 #[cfg(feature = "f16")]
220 return F16(F16Impl::from_bits(v));
221 #[cfg(not(feature = "f16"))]
222 return F16(v);
223 }
224
225 fn to_bits(self) -> Self::BitsType {
226 #[cfg(feature = "f16")]
227 return self.0.to_bits();
228 #[cfg(not(feature = "f16"))]
229 return self.0;
230 }
231 }
232
233 #[cfg(test)]
234 mod tests {
235 use super::*;
236 use core::cmp::Ordering;
237
238 #[test]
239 #[cfg_attr(
240 not(feature = "f16"),
241 should_panic(expected = "f16 feature is not enabled")
242 )]
243 fn test_abs() {
244 assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
245 assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
246 assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
247 assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
248 assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
249 }
250
251 #[test]
252 #[cfg_attr(
253 not(feature = "f16"),
254 should_panic(expected = "f16 feature is not enabled")
255 )]
256 fn test_neg() {
257 assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
258 assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
259 assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
260 assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
261 assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
262 }
263
264 #[test]
265 #[cfg_attr(
266 not(feature = "f16"),
267 should_panic(expected = "f16 feature is not enabled")
268 )]
269 fn test_int_to_f16() {
270 assert_eq!(F16::to_bits(0u32.to()), 0);
271 for v in 1..0x20000u32 {
272 let leading_zeros = u32::leading_zeros(v);
273 let shifted_v = v << leading_zeros;
274 // round to nearest, ties to even
275 let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
276 Ordering::Less => false,
277 Ordering::Equal => (shifted_v & 0x200000) != 0,
278 Ordering::Greater => true,
279 };
280 let (rounded, carry) =
281 (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
282 let mantissa;
283 if carry {
284 mantissa = (rounded >> 22) as u16 + 0x400;
285 } else {
286 mantissa = (rounded >> 21) as u16;
287 }
288 assert_eq!((mantissa & !0x3FF), 0x400);
289 let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
290 let expected = if exponent < 0x1F {
291 (mantissa & 0x3FF) + (exponent << 10)
292 } else {
293 0x7C00
294 };
295 let actual = F16::to_bits(v.to());
296 assert_eq!(
297 actual, expected,
298 "actual = {:#X}, expected = {:#X}, v = {:#X}",
299 actual, expected, v
300 );
301 }
302 }
303 }