e9541b416f1546e5f7195baa01f3a528f7e6bdf5
[vector-math.git] / src / f16.rs
1 use core::ops::{
2 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
3 };
4
5 use crate::{
6 scalar::Value,
7 traits::{ConvertFrom, ConvertTo, Float},
8 };
9
10 #[cfg(feature = "f16")]
11 use half::f16 as F16Impl;
12
13 #[cfg(not(feature = "f16"))]
14 type F16Impl = u16;
15
16 #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
17 #[repr(transparent)]
18 pub struct F16(F16Impl);
19
20 #[cfg(not(feature = "f16"))]
21 #[track_caller]
22 pub(crate) fn panic_f16_feature_disabled() -> ! {
23 panic!("f16 feature is not enabled")
24 }
25
26 #[cfg(feature = "f16")]
27 macro_rules! f16_impl {
28 ($v:expr, [$($vars:ident),*]) => {
29 $v
30 };
31 }
32
33 #[cfg(not(feature = "f16"))]
34 macro_rules! f16_impl {
35 ($v:expr, [$($vars:ident),*]) => {
36 {
37 $(let _ = $vars;)*
38 panic_f16_feature_disabled()
39 }
40 };
41 }
42
43 impl Default for F16 {
44 fn default() -> Self {
45 f16_impl!(F16(F16Impl::default()), [])
46 }
47 }
48
49 impl From<F16Impl> for F16 {
50 fn from(v: F16Impl) -> Self {
51 F16(v)
52 }
53 }
54
55 impl From<F16> for F16Impl {
56 fn from(v: F16) -> Self {
57 v.0
58 }
59 }
60
61 macro_rules! impl_f16_from {
62 ($($ty:ident,)*) => {
63 $(
64 impl From<$ty> for F16 {
65 fn from(v: $ty) -> Self {
66 f16_impl!(F16(F16Impl::from(v)), [v])
67 }
68 }
69
70 impl ConvertFrom<$ty> for F16 {
71 fn cvt_from(v: $ty) -> F16 {
72 v.into()
73 }
74 }
75 )*
76 };
77 }
78
79 macro_rules! impl_from_f16 {
80 ($($ty:ident,)*) => {
81 $(
82 impl From<F16> for $ty {
83 fn from(v: F16) -> Self {
84 f16_impl!(v.0.into(), [v])
85 }
86 }
87
88 impl ConvertFrom<F16> for $ty {
89 fn cvt_from(v: F16) -> Self {
90 v.into()
91 }
92 }
93 )*
94 };
95 }
96
97 impl_f16_from![i8, u8,];
98
99 impl_from_f16![f32, f64,];
100
101 macro_rules! impl_int_to_f16 {
102 ($($int:ident),*) => {
103 $(
104 impl ConvertFrom<$int> for F16 {
105 fn cvt_from(v: $int) -> Self {
106 // f32 has enough mantissa bits such that f16 overflows to
107 // infinity before f32 stops being able to properly
108 // represent integer values, making the below conversion correct.
109 F16::cvt_from(v as f32)
110 }
111 }
112 )*
113 };
114 }
115
116 macro_rules! impl_f16_to_int {
117 ($($int:ident),*) => {
118 $(
119 impl ConvertFrom<F16> for $int {
120 fn cvt_from(v: F16) -> Self {
121 f32::from(v) as $int
122 }
123 }
124 )*
125 };
126 }
127
128 impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
129 impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
130
131 impl ConvertFrom<f32> for F16 {
132 fn cvt_from(v: f32) -> Self {
133 f16_impl!(F16(F16Impl::from_f32(v)), [v])
134 }
135 }
136
137 impl ConvertFrom<f64> for F16 {
138 fn cvt_from(v: f64) -> Self {
139 f16_impl!(F16(F16Impl::from_f64(v)), [v])
140 }
141 }
142
143 impl Neg for F16 {
144 type Output = Self;
145
146 fn neg(self) -> Self::Output {
147 f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), [])
148 }
149 }
150
151 macro_rules! impl_bin_op_using_f32 {
152 ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
153 $(
154 impl $op for F16 {
155 type Output = Self;
156
157 fn $op_fn(self, rhs: Self) -> Self::Output {
158 f32::from(self).$op_fn(f32::from(rhs)).to()
159 }
160 }
161
162 impl $op_assign for F16 {
163 fn $op_assign_fn(&mut self, rhs: Self) {
164 *self = (*self).$op_fn(rhs);
165 }
166 }
167 )*
168 };
169 }
170
171 impl_bin_op_using_f32! {
172 Add, add, AddAssign, add_assign;
173 Sub, sub, SubAssign, sub_assign;
174 Mul, mul, MulAssign, mul_assign;
175 Div, div, DivAssign, div_assign;
176 Rem, rem, RemAssign, rem_assign;
177 }
178
179 impl F16 {
180 pub fn from_bits(v: u16) -> Self {
181 #[cfg(feature = "f16")]
182 return F16(F16Impl::from_bits(v));
183 #[cfg(not(feature = "f16"))]
184 return F16(v);
185 }
186 pub fn to_bits(self) -> u16 {
187 #[cfg(feature = "f16")]
188 return self.0.to_bits();
189 #[cfg(not(feature = "f16"))]
190 return self.0;
191 }
192 pub fn abs(self) -> Self {
193 f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
194 }
195 pub fn trunc(self) -> Self {
196 f32::from(self).trunc().to()
197 }
198
199 pub fn ceil(self) -> Self {
200 f32::from(self).ceil().to()
201 }
202
203 pub fn floor(self) -> Self {
204 f32::from(self).floor().to()
205 }
206
207 pub fn round(self) -> Self {
208 f32::from(self).round().to()
209 }
210
211 #[cfg(feature = "fma")]
212 pub fn fma(self, a: Self, b: Self) -> Self {
213 (f64::from(self) * f64::from(a) + f64::from(b)).to()
214 }
215
216 pub fn is_nan(self) -> bool {
217 f16_impl!(self.0.is_nan(), [])
218 }
219
220 pub fn is_infinite(self) -> bool {
221 f16_impl!(self.0.is_infinite(), [])
222 }
223
224 pub fn is_finite(self) -> bool {
225 f16_impl!(self.0.is_finite(), [])
226 }
227 }
228
229 impl Float for Value<F16> {
230 type FloatEncoding = F16;
231 type BitsType = Value<u16>;
232 type SignedBitsType = Value<i16>;
233
234 fn abs(self) -> Self {
235 Value(self.0.abs())
236 }
237
238 fn trunc(self) -> Self {
239 Value(self.0.trunc())
240 }
241
242 fn ceil(self) -> Self {
243 Value(self.0.ceil())
244 }
245
246 fn floor(self) -> Self {
247 Value(self.0.floor())
248 }
249
250 fn round(self) -> Self {
251 Value(self.0.round())
252 }
253
254 #[cfg(feature = "fma")]
255 fn fma(self, a: Self, b: Self) -> Self {
256 Value(self.0.fma(a.0, b.0))
257 }
258
259 fn is_nan(self) -> Self::Bool {
260 Value(self.0.is_nan())
261 }
262
263 fn is_infinite(self) -> Self::Bool {
264 Value(self.0.is_infinite())
265 }
266
267 fn is_finite(self) -> Self::Bool {
268 Value(self.0.is_finite())
269 }
270
271 fn from_bits(v: Self::BitsType) -> Self {
272 Value(F16::from_bits(v.0))
273 }
274
275 fn to_bits(self) -> Self::BitsType {
276 Value(self.0.to_bits())
277 }
278 }
279
280 #[cfg(test)]
281 mod tests {
282 use super::*;
283 use core::cmp::Ordering;
284
285 #[test]
286 #[cfg_attr(
287 not(feature = "f16"),
288 should_panic(expected = "f16 feature is not enabled")
289 )]
290 fn test_abs() {
291 assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
292 assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
293 assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
294 assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
295 assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
296 }
297
298 #[test]
299 #[cfg_attr(
300 not(feature = "f16"),
301 should_panic(expected = "f16 feature is not enabled")
302 )]
303 fn test_neg() {
304 assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
305 assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
306 assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
307 assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
308 assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
309 }
310
311 #[test]
312 #[cfg_attr(
313 not(feature = "f16"),
314 should_panic(expected = "f16 feature is not enabled")
315 )]
316 fn test_int_to_f16() {
317 assert_eq!(F16::to_bits(0u32.to()), 0);
318 for v in 1..0x20000u32 {
319 let leading_zeros = u32::leading_zeros(v);
320 let shifted_v = v << leading_zeros;
321 // round to nearest, ties to even
322 let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
323 Ordering::Less => false,
324 Ordering::Equal => (shifted_v & 0x200000) != 0,
325 Ordering::Greater => true,
326 };
327 let (rounded, carry) =
328 (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
329 let mantissa;
330 if carry {
331 mantissa = (rounded >> 22) as u16 + 0x400;
332 } else {
333 mantissa = (rounded >> 21) as u16;
334 }
335 assert_eq!((mantissa & !0x3FF), 0x400);
336 let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
337 let expected = if exponent < 0x1F {
338 (mantissa & 0x3FF) + (exponent << 10)
339 } else {
340 0x7C00
341 };
342 let actual = F16::to_bits(v.to());
343 assert_eq!(
344 actual, expected,
345 "actual = {:#X}, expected = {:#X}, v = {:#X}",
346 actual, expected, v
347 );
348 }
349 }
350 }