add ilogb
[vector-math.git] / src / f16.rs
1 use core::ops::{
2 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
3 };
4
5 use crate::traits::{ConvertTo, Float};
6
7 #[cfg(feature = "f16")]
8 use half::f16 as F16Impl;
9
10 #[cfg(not(feature = "f16"))]
11 type F16Impl = u16;
12
13 #[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
14 #[repr(transparent)]
15 pub struct F16(F16Impl);
16
17 #[cfg(feature = "f16")]
18 macro_rules! f16_impl {
19 ($v:expr, [$($vars:ident),*]) => {
20 $v
21 };
22 }
23
24 #[cfg(not(feature = "f16"))]
25 macro_rules! f16_impl {
26 ($v:expr, [$($vars:ident),*]) => {
27 {
28 $(let _ = $vars;)*
29 panic!("f16 feature is not enabled")
30 }
31 };
32 }
33
34 impl From<F16Impl> for F16 {
35 fn from(v: F16Impl) -> Self {
36 F16(v)
37 }
38 }
39
40 impl From<F16> for F16Impl {
41 fn from(v: F16) -> Self {
42 v.0
43 }
44 }
45
46 macro_rules! impl_f16_from {
47 ($($ty:ident,)*) => {
48 $(
49 impl From<$ty> for F16 {
50 fn from(v: $ty) -> Self {
51 f16_impl!(F16(F16Impl::from(v)), [v])
52 }
53 }
54
55 impl ConvertTo<F16> for $ty {
56 fn to(self) -> F16 {
57 self.into()
58 }
59 }
60 )*
61 };
62 }
63
64 macro_rules! impl_from_f16 {
65 ($($ty:ident,)*) => {
66 $(
67 impl From<F16> for $ty {
68 fn from(v: F16) -> Self {
69 f16_impl!(v.0.into(), [v])
70 }
71 }
72
73 impl ConvertTo<$ty> for F16 {
74 fn to(self) -> $ty {
75 self.into()
76 }
77 }
78 )*
79 };
80 }
81
82 impl_f16_from![i8, u8,];
83
84 impl_from_f16![f32, f64,];
85
86 macro_rules! impl_int_to_f16 {
87 ($($int:ident),*) => {
88 $(
89 impl ConvertTo<F16> for $int {
90 fn to(self) -> F16 {
91 // f32 has enough mantissa bits such that f16 overflows to
92 // infinity before f32 stops being able to properly
93 // represent integer values, making the below conversion correct.
94 (self as f32).to()
95 }
96 }
97 )*
98 };
99 }
100
101 macro_rules! impl_f16_to_int {
102 ($($int:ident),*) => {
103 $(
104 impl ConvertTo<$int> for F16 {
105 fn to(self) -> $int {
106 f32::from(self) as $int
107 }
108 }
109 )*
110 };
111 }
112
113 impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128];
114 impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128];
115
116 impl ConvertTo<F16> for f32 {
117 fn to(self) -> F16 {
118 f16_impl!(F16(F16Impl::from_f32(self)), [])
119 }
120 }
121
122 impl ConvertTo<F16> for f64 {
123 fn to(self) -> F16 {
124 f16_impl!(F16(F16Impl::from_f64(self)), [])
125 }
126 }
127
128 impl Neg for F16 {
129 type Output = Self;
130
131 fn neg(self) -> Self::Output {
132 f16_impl!(Self::from_bits(self.to_bits() ^ 0x8000), [])
133 }
134 }
135
136 macro_rules! impl_bin_op_using_f32 {
137 ($($op:ident, $op_fn:ident, $op_assign:ident, $op_assign_fn:ident;)*) => {
138 $(
139 impl $op for F16 {
140 type Output = Self;
141
142 fn $op_fn(self, rhs: Self) -> Self::Output {
143 f32::from(self).$op_fn(f32::from(rhs)).to()
144 }
145 }
146
147 impl $op_assign for F16 {
148 fn $op_assign_fn(&mut self, rhs: Self) {
149 *self = (*self).$op_fn(rhs);
150 }
151 }
152 )*
153 };
154 }
155
156 impl_bin_op_using_f32! {
157 Add, add, AddAssign, add_assign;
158 Sub, sub, SubAssign, sub_assign;
159 Mul, mul, MulAssign, mul_assign;
160 Div, div, DivAssign, div_assign;
161 Rem, rem, RemAssign, rem_assign;
162 }
163
164 impl Float<u32> for F16 {
165 type FloatEncoding = F16;
166 type BitsType = u16;
167 type SignedBitsType = i16;
168
169 fn abs(self) -> Self {
170 f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
171 }
172
173 fn trunc(self) -> Self {
174 f32::from(self).trunc().to()
175 }
176
177 fn ceil(self) -> Self {
178 f32::from(self).ceil().to()
179 }
180
181 fn floor(self) -> Self {
182 f32::from(self).floor().to()
183 }
184
185 fn round(self) -> Self {
186 f32::from(self).round().to()
187 }
188
189 #[cfg(feature = "fma")]
190 fn fma(self, a: Self, b: Self) -> Self {
191 (f64::from(self) * f64::from(a) + f64::from(b)).to()
192 }
193
194 fn is_nan(self) -> Self::Bool {
195 f16_impl!(self.0.is_nan(), [])
196 }
197
198 fn is_infinite(self) -> Self::Bool {
199 f16_impl!(self.0.is_infinite(), [])
200 }
201
202 fn is_finite(self) -> Self::Bool {
203 f16_impl!(self.0.is_finite(), [])
204 }
205
206 fn from_bits(v: Self::BitsType) -> Self {
207 #[cfg(feature = "f16")]
208 return F16(F16Impl::from_bits(v));
209 #[cfg(not(feature = "f16"))]
210 return F16(v);
211 }
212
213 fn to_bits(self) -> Self::BitsType {
214 #[cfg(feature = "f16")]
215 return self.0.to_bits();
216 #[cfg(not(feature = "f16"))]
217 return self.0;
218 }
219 }
220
221 #[cfg(test)]
222 mod tests {
223 use super::*;
224 use core::cmp::Ordering;
225
226 #[test]
227 #[cfg_attr(
228 not(feature = "f16"),
229 should_panic(expected = "f16 feature is not enabled")
230 )]
231 fn test_abs() {
232 assert_eq!(F16::from_bits(0x8000).abs().to_bits(), 0);
233 assert_eq!(F16::from_bits(0).abs().to_bits(), 0);
234 assert_eq!(F16::from_bits(0x8ABC).abs().to_bits(), 0xABC);
235 assert_eq!(F16::from_bits(0xFE00).abs().to_bits(), 0x7E00);
236 assert_eq!(F16::from_bits(0x7E00).abs().to_bits(), 0x7E00);
237 }
238
239 #[test]
240 #[cfg_attr(
241 not(feature = "f16"),
242 should_panic(expected = "f16 feature is not enabled")
243 )]
244 fn test_neg() {
245 assert_eq!(F16::from_bits(0x8000).neg().to_bits(), 0);
246 assert_eq!(F16::from_bits(0).neg().to_bits(), 0x8000);
247 assert_eq!(F16::from_bits(0x8ABC).neg().to_bits(), 0xABC);
248 assert_eq!(F16::from_bits(0xFE00).neg().to_bits(), 0x7E00);
249 assert_eq!(F16::from_bits(0x7E00).neg().to_bits(), 0xFE00);
250 }
251
252 #[test]
253 #[cfg_attr(
254 not(feature = "f16"),
255 should_panic(expected = "f16 feature is not enabled")
256 )]
257 fn test_int_to_f16() {
258 assert_eq!(F16::to_bits(0u32.to()), 0);
259 for v in 1..0x20000u32 {
260 let leading_zeros = u32::leading_zeros(v);
261 let shifted_v = v << leading_zeros;
262 // round to nearest, ties to even
263 let round_up = match (shifted_v & 0x1FFFFF).cmp(&0x100000) {
264 Ordering::Less => false,
265 Ordering::Equal => (shifted_v & 0x200000) != 0,
266 Ordering::Greater => true,
267 };
268 let (rounded, carry) =
269 (shifted_v & !0x1FFFFF).overflowing_add(round_up.then(|| 0x200000).unwrap_or(0));
270 let mantissa;
271 if carry {
272 mantissa = (rounded >> 22) as u16 + 0x400;
273 } else {
274 mantissa = (rounded >> 21) as u16;
275 }
276 assert_eq!((mantissa & !0x3FF), 0x400);
277 let exponent = 31 - leading_zeros as u16 + 15 + carry as u16;
278 let expected = if exponent < 0x1F {
279 (mantissa & 0x3FF) + (exponent << 10)
280 } else {
281 0x7C00
282 };
283 let actual = F16::to_bits(v.to());
284 assert_eq!(
285 actual, expected,
286 "actual = {:#X}, expected = {:#X}, v = {:#X}",
287 actual, expected, v
288 );
289 }
290 }
291 }