clarify explanatory comment
[vector-math.git] / src / f16.rs
1 use core::ops::{
2 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
3 };
4
5 use crate::traits::{ConvertTo, Float};
6
7 #[cfg(feature = "f16")]
8 use half::f16 as F16Impl;
9
10 #[cfg(not(feature = "f16"))]
11 #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
12 enum F16Impl {}
13
14 #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
15 #[cfg_attr(feature = "f16", repr(transparent))]
16 pub struct F16(F16Impl);
17
18 #[cfg(feature = "f16")]
19 macro_rules! f16_impl {
20 ($v:expr, [$($vars:ident),*]) => {
21 $v
22 };
23 }
24
25 #[cfg(not(feature = "f16"))]
26 macro_rules! f16_impl {
27 ($v:expr, [$($vars:ident),*]) => {
28 {
29 $(let _ = $vars;)*
30 panic!("f16 feature is not enabled")
31 }
32 };
33 }
34
35 impl From<F16Impl> for F16 {
36 fn from(v: F16Impl) -> Self {
37 F16(v)
38 }
39 }
40
41 impl From<F16> for F16Impl {
42 fn from(v: F16) -> Self {
43 v.0
44 }
45 }
46
47 macro_rules! impl_f16_from {
48 ($($ty:ident,)*) => {
49 $(
50 impl From<$ty> for F16 {
51 fn from(v: $ty) -> Self {
52 f16_impl!(F16(F16Impl::from(v)), [v])
53 }
54 }
55
56 impl ConvertTo<F16> for $ty {
57 fn to(self) -> F16 {
58 self.into()
59 }
60 }
61 )*
62 };
63 }
64
65 macro_rules! impl_from_f16 {
66 ($($ty:ident,)*) => {
67 $(
68 impl From<F16> for $ty {
69 fn from(v: F16) -> Self {
70 #[cfg(feature = "f16")]
71 return v.0.into();
72 #[cfg(not(feature = "f16"))]
73 match v.0 {}
74 }
75 }
76
77 impl ConvertTo<$ty> for F16 {
78 fn to(self) -> $ty {
79 self.into()
80 }
81 }
82 )*
83 };
84 }
85
86 impl_f16_from![i8, u8,];
87
88 impl_from_f16![f32, f64,];
89
90 macro_rules! impl_int_to_f16 {
91 ($($int:ident),*) => {
92 $(
93 impl ConvertTo<F16> for $int {
94 fn to(self) -> F16 {
95 // f32 has enough mantissa bits such that f16 overflows to
96 // infinity before f32 stops being able to properly
97 // represent integer values, making the below conversion correct.
98 (self as f32).to()
99 }
100 }
101 )*
102 };
103 }
104
105 macro_rules! impl_f16_to_int {
106 ($($int:ident),*) => {
107 $(
108 impl ConvertTo<$int> for F16 {
109 fn to(self) -> $int {
110 f32::from(self) as $int
111 }
112 }
113 )*
114 };
115 }
116
117 impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
118 impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
119
120 impl ConvertTo<F16> for f32 {
121 fn to(self) -> F16 {
122 f16_impl!(F16(F16Impl::from_f32(self)), [])
123 }
124 }
125
126 impl ConvertTo<F16> for f64 {
127 fn to(self) -> F16 {
128 f16_impl!(F16(F16Impl::from_f64(self)), [])
129 }
130 }
131
132 impl Neg for F16 {
133 type Output = Self;
134
135 fn neg(self) -> Self::Output {
136 Self::from_bits(self.to_bits() ^ 0x8000)
137 }
138 }
139
140 macro_rules! impl_bin_op_using_f32 {
141 ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
142 $(
143 impl $op for F16 {
144 type Output = Self;
145
146 fn $op_fn(self, rhs: Self) -> Self::Output {
147 f32::from(self).$op_fn(f32::from(rhs)).to()
148 }
149 }
150
151 impl $op_assign for F16 {
152 fn $op_assign_fn(&mut self, rhs: Self) {
153 *self = (*self).$op_fn(rhs);
154 }
155 }
156 )*
157 };
158 }
159
160 impl_bin_op_using_f32! {
161 Add, add, AddAssign, add_assign;
162 Sub, sub, SubAssign, sub_assign;
163 Mul, mul, MulAssign, mul_assign;
164 Div, div, DivAssign, div_assign;
165 Rem, rem, RemAssign, rem_assign;
166 }
167
168 impl Float<u32> for F16 {
169 type BitsType = u16;
170
171 fn abs(self) -> Self {
172 Self::from_bits(self.to_bits() & 0x7FFF)
173 }
174
175 fn trunc(self) -> Self {
176 f32::from(self).trunc().to()
177 }
178
179 fn ceil(self) -> Self {
180 f32::from(self).ceil().to()
181 }
182
183 fn floor(self) -> Self {
184 f32::from(self).floor().to()
185 }
186
187 fn round(self) -> Self {
188 f32::from(self).round().to()
189 }
190
191 #[cfg(feature = "fma")]
192 fn fma(self, a: Self, b: Self) -> Self {
193 (f64::from(self) * f64::from(a) + f64::from(b)).to()
194 }
195
196 fn is_nan(self) -> Self::Bool {
197 f16_impl!(self.0.is_nan(), [])
198 }
199
200 fn is_infinite(self) -> Self::Bool {
201 f16_impl!(self.0.is_infinite(), [])
202 }
203
204 fn is_finite(self) -> Self::Bool {
205 f16_impl!(self.0.is_finite(), [])
206 }
207
208 fn from_bits(v: Self::BitsType) -> Self {
209 f16_impl!(F16(F16Impl::from_bits(v)), [v])
210 }
211
212 fn to_bits(self) -> Self::BitsType {
213 f16_impl!(self.0.to_bits(), [])
214 }
215 }
216
217 #[cfg(test)]
218 mod tests {
219 use super::*;
220 use core::cmp::Ordering;
221
222 #[test]
223 #[cfg_attr(
224 not(feature = "f16"),
225 should_panic(expected = "f16 feature is not enabled")
226 )]
227 fn test_abs() {
228 assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
229 assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
230 assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
231 assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
232 assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
233 }
234
235 #[test]
236 #[cfg_attr(
237 not(feature = "f16"),
238 should_panic(expected = "f16 feature is not enabled")
239 )]
240 fn test_neg() {
241 assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
242 assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
243 assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
244 assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
245 assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
246 }
247
248 #[test]
249 #[cfg_attr(
250 not(feature = "f16"),
251 should_panic(expected = "f16 feature is not enabled")
252 )]
253 fn test_int_to_f16() {
254 assert_eq!(F16::to_bits(0u32.to()), 0);
255 for v in 1..0x20000u32 {
256 let leading_zeros = u32::leading_zeros(v);
257 let shifted_v = v << leading_zeros;
258 // round to nearest, ties to even
259 let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
260 Ordering::Less => false,
261 Ordering::Equal => (shifted_v & 0x200000) != 0,
262 Ordering::Greater => true,
263 };
264 let (rounded, carry) =
265 (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
266 let mantissa;
267 if carry {
268 mantissa = (rounded >> 22) as u16 + 0x400;
269 } else {
270 mantissa = (rounded >> 21) as u16;
271 }
272 assert_eq!((mantissa & !0x3FF), 0x400);
273 let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
274 let expected = if exponent < 0x1F {
275 (mantissa & 0x3FF) + (exponent << 10)
276 } else {
277 0x7C00
278 };
279 let actual = F16::to_bits(v.to());
280 assert_eq!(
281 actual, expected,
282 "actual = {:#X}, expected = {:#X}, v = {:#X}",
283 actual, expected, v
284 );
285 }
286 }
287 }