b5d84d5a5020a12f7615e2c644a215593779167c
[vector-math.git] / src / f16.rs
1 use crate::{
2 prim::PrimFloat,
3 scalar::Value,
4 traits::{ConvertFrom, ConvertTo, Float},
5 };
6 use core::{
7 fmt,
8 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign},
9 };
10
11 #[cfg(feature = "f16")]
12 use half::f16 as F16Impl;
13
14 #[cfg(not(feature = "f16"))]
15 type F16Impl = u16;
16
17 #[derive(Clone, Copy, PartialEq, PartialOrd)]
18 #[repr(transparent)]
19 pub struct F16(F16Impl);
20
21 #[cfg(not(feature = "f16"))]
22 #[track_caller]
23 pub(crate) fn panic_f16_feature_disabled() -> ! {
24 panic!("f16 feature is not enabled")
25 }
26
27 #[cfg(feature = "f16")]
28 macro_rules! f16_impl {
29 ($v:expr, [$($vars:ident),*]) => {
30 $v
31 };
32 }
33
34 #[cfg(not(feature = "f16"))]
35 macro_rules! f16_impl {
36 ($v:expr, [$($vars:ident),*]) => {
37 {
38 $(let _ = $vars;)*
39 panic_f16_feature_disabled()
40 }
41 };
42 }
43
44 impl fmt::Display for F16 {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f16_impl!(self.0.fmt(f), [f])
47 }
48 }
49
50 impl fmt::Debug for F16 {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f16_impl!(self.0.fmt(f), [f])
53 }
54 }
55
56 impl Default for F16 {
57 fn default() -> Self {
58 f16_impl!(F16(F16Impl::default()), [])
59 }
60 }
61
62 impl From<F16Impl> for F16 {
63 fn from(v: F16Impl) -> Self {
64 F16(v)
65 }
66 }
67
68 impl From<F16> for F16Impl {
69 fn from(v: F16) -> Self {
70 v.0
71 }
72 }
73
74 macro_rules! impl_f16_from {
75 ($($ty:ident,)*) => {
76 $(
77 impl From<$ty> for F16 {
78 fn from(v: $ty) -> Self {
79 f16_impl!(F16(F16Impl::from(v)), [v])
80 }
81 }
82
83 impl ConvertFrom<$ty> for F16 {
84 fn cvt_from(v: $ty) -> F16 {
85 v.into()
86 }
87 }
88 )*
89 };
90 }
91
92 macro_rules! impl_from_f16 {
93 ($($ty:ident,)*) => {
94 $(
95 impl From<F16> for $ty {
96 fn from(v: F16) -> Self {
97 f16_impl!(v.0.into(), [v])
98 }
99 }
100
101 impl ConvertFrom<F16> for $ty {
102 fn cvt_from(v: F16) -> Self {
103 v.into()
104 }
105 }
106 )*
107 };
108 }
109
110 impl_f16_from![i8, u8,];
111
112 impl_from_f16![f32, f64,];
113
114 macro_rules! impl_int_to_f16 {
115 ($($int:ident),*) => {
116 $(
117 impl ConvertFrom<$int> for F16 {
118 fn cvt_from(v: $int) -> Self {
119 // f32 has enough mantissa bits such that f16 overflows to
120 // infinity before f32 stops being able to properly
121 // represent integer values, making the below conversion correct.
122 F16::cvt_from(v as f32)
123 }
124 }
125 )*
126 };
127 }
128
129 macro_rules! impl_f16_to_int {
130 ($($int:ident),*) => {
131 $(
132 impl ConvertFrom<F16> for $int {
133 fn cvt_from(v: F16) -> Self {
134 f32::from(v) as $int
135 }
136 }
137 )*
138 };
139 }
140
141 impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
142 impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
143
144 impl ConvertFrom<f32> for F16 {
145 fn cvt_from(v: f32) -> Self {
146 f16_impl!(F16(F16Impl::from_f32(v)), [v])
147 }
148 }
149
150 impl ConvertFrom<f64> for F16 {
151 fn cvt_from(v: f64) -> Self {
152 f16_impl!(F16(F16Impl::from_f64(v)), [v])
153 }
154 }
155
156 impl Neg for F16 {
157 type Output = Self;
158
159 fn neg(self) -> Self::Output {
160 f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), [])
161 }
162 }
163
164 macro_rules! impl_bin_op_using_f32 {
165 ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
166 $(
167 impl $op for F16 {
168 type Output = Self;
169
170 fn $op_fn(self, rhs: Self) -> Self::Output {
171 f32::from(self).$op_fn(f32::from(rhs)).to()
172 }
173 }
174
175 impl $op_assign for F16 {
176 fn $op_assign_fn(&mut self, rhs: Self) {
177 *self = (*self).$op_fn(rhs);
178 }
179 }
180 )*
181 };
182 }
183
184 impl_bin_op_using_f32! {
185 Add, add, AddAssign, add_assign;
186 Sub, sub, SubAssign, sub_assign;
187 Mul, mul, MulAssign, mul_assign;
188 Div, div, DivAssign, div_assign;
189 Rem, rem, RemAssign, rem_assign;
190 }
191
192 impl F16 {
193 pub fn from_bits(v: u16) -> Self {
194 #[cfg(feature = "f16")]
195 return F16(F16Impl::from_bits(v));
196 #[cfg(not(feature = "f16"))]
197 return F16(v);
198 }
199 pub fn to_bits(self) -> u16 {
200 #[cfg(feature = "f16")]
201 return self.0.to_bits();
202 #[cfg(not(feature = "f16"))]
203 return self.0;
204 }
205 pub fn abs(self) -> Self {
206 f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
207 }
208 pub fn copysign(self, sign: Self) -> Self {
209 f16_impl!(
210 Self::from_bits((self.to_bits() & 0x7FFF) | (sign.to_bits() & 0x8000)),
211 [sign]
212 )
213 }
214 pub fn trunc(self) -> Self {
215 return PrimFloat::trunc(f32::from(self)).to();
216 }
217 pub fn ceil(self) -> Self {
218 return PrimFloat::ceil(f32::from(self)).to();
219 }
220 pub fn floor(self) -> Self {
221 return PrimFloat::floor(f32::from(self)).to();
222 }
223 /// round to nearest, ties to unspecified
224 pub fn round(self) -> Self {
225 return PrimFloat::round(f32::from(self)).to();
226 }
227 #[cfg(feature = "fma")]
228 pub fn fma(self, a: Self, b: Self) -> Self {
229 (f64::from(self) * f64::from(a) + f64::from(b)).to()
230 }
231
232 pub fn is_nan(self) -> bool {
233 f16_impl!(self.0.is_nan(), [])
234 }
235
236 pub fn is_infinite(self) -> bool {
237 f16_impl!(self.0.is_infinite(), [])
238 }
239
240 pub fn is_finite(self) -> bool {
241 f16_impl!(self.0.is_finite(), [])
242 }
243 }
244
245 impl Float for Value<F16> {
246 type PrimFloat = F16;
247 type BitsType = Value<u16>;
248 type SignedBitsType = Value<i16>;
249
250 fn abs(self) -> Self {
251 Value(self.0.abs())
252 }
253
254 fn trunc(self) -> Self {
255 Value(self.0.trunc())
256 }
257
258 fn ceil(self) -> Self {
259 Value(self.0.ceil())
260 }
261
262 fn floor(self) -> Self {
263 Value(self.0.floor())
264 }
265
266 fn round(self) -> Self {
267 Value(self.0.round())
268 }
269
270 #[cfg(feature = "fma")]
271 fn fma(self, a: Self, b: Self) -> Self {
272 Value(self.0.fma(a.0, b.0))
273 }
274
275 fn is_nan(self) -> Self::Bool {
276 Value(self.0.is_nan())
277 }
278
279 fn is_infinite(self) -> Self::Bool {
280 Value(self.0.is_infinite())
281 }
282
283 fn is_finite(self) -> Self::Bool {
284 Value(self.0.is_finite())
285 }
286
287 fn from_bits(v: Self::BitsType) -> Self {
288 Value(F16::from_bits(v.0))
289 }
290
291 fn to_bits(self) -> Self::BitsType {
292 Value(self.0.to_bits())
293 }
294 }
295
296 #[cfg(test)]
297 mod tests {
298 use super::*;
299 use core::cmp::Ordering;
300
301 #[test]
302 #[cfg_attr(
303 not(feature = "f16"),
304 should_panic(expected = "f16 feature is not enabled")
305 )]
306 fn test_abs() {
307 assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
308 assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
309 assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
310 assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
311 assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
312 }
313
314 #[test]
315 #[cfg_attr(
316 not(feature = "f16"),
317 should_panic(expected = "f16 feature is not enabled")
318 )]
319 fn test_neg() {
320 assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
321 assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
322 assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
323 assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
324 assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
325 }
326
327 #[test]
328 #[cfg_attr(
329 not(feature = "f16"),
330 should_panic(expected = "f16 feature is not enabled")
331 )]
332 fn test_int_to_f16() {
333 assert_eq!(F16::to_bits(0u32.to()), 0);
334 for v in 1..0x20000u32 {
335 let leading_zeros = u32::leading_zeros(v);
336 let shifted_v = v << leading_zeros;
337 // round to nearest, ties to even
338 let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
339 Ordering::Less => false,
340 Ordering::Equal => (shifted_v & 0x200000) != 0,
341 Ordering::Greater => true,
342 };
343 let (rounded, carry) =
344 (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
345 let mantissa;
346 if carry {
347 mantissa = (rounded >> 22) as u16 + 0x400;
348 } else {
349 mantissa = (rounded >> 21) as u16;
350 }
351 assert_eq!((mantissa & !0x3FF), 0x400);
352 let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
353 let expected = if exponent < 0x1F {
354 (mantissa & 0x3FF) + (exponent << 10)
355 } else {
356 0x7C00
357 };
358 let actual = F16::to_bits(v.to());
359 assert_eq!(
360 actual, expected,
361 "actual = {:#X}, expected = {:#X}, v = {:#X}",
362 actual, expected, v
363 );
364 }
365 }
366 }