28a641df564d73f46dcd84f20e07374353bea427
[vector-math.git] / src / algorithms / base.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, Select, UInt},
4 };
5
6 pub fn abs<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
9 PrimF: PrimFloat<BitsType = PrimU>,
10 PrimU: PrimUInt,
11 >(
12 ctx: Ctx,
13 x: VecF,
14 ) -> VecF {
15 VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
16 }
17
18 pub fn copy_sign<
19 Ctx: Context,
20 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
21 PrimF: PrimFloat<BitsType = PrimU>,
22 PrimU: PrimUInt,
23 >(
24 ctx: Ctx,
25 mag: VecF,
26 sign: VecF,
27 ) -> VecF {
28 let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
29 let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
30 VecF::from_bits(mag_bits | sign_bit)
31 }
32
33 pub fn trunc<
34 Ctx: Context,
35 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
36 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
37 PrimF: PrimFloat<BitsType = PrimU>,
38 PrimU: PrimUInt,
39 >(
40 ctx: Ctx,
41 v: VecF,
42 ) -> VecF {
43 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
44 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
45 let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
46 let out_of_range = big | small;
47 let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
48 let out_of_range_value = small.select(small_value, v);
49 let exponent_field = v.extract_exponent_field();
50 let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
51 let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
52 mask >>= right_shift_amount;
53 let in_range_value = VecF::from_bits(v.to_bits() & !mask);
54 out_of_range.select(out_of_range_value, in_range_value)
55 }
56
57 pub fn round_to_nearest_ties_to_even<
58 Ctx: Context,
59 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
60 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
61 PrimF: PrimFloat<BitsType = PrimU>,
62 PrimU: PrimUInt,
63 >(
64 ctx: Ctx,
65 v: VecF,
66 ) -> VecF {
67 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
68 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
69 let small = v.abs().le(ctx.make(PrimF::cvt_from(0.5)));
70 let out_of_range = big | small;
71 let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
72 let out_of_range_value = small.select(small_value, v);
73 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
74 let offset_value: VecF = v.abs() + offset;
75 let in_range_value = (offset_value - offset).copy_sign(v);
76 out_of_range.select(out_of_range_value, in_range_value)
77 }
78
79 #[cfg(test)]
80 mod tests {
81 use super::*;
82 use crate::{
83 f16::F16,
84 prim::PrimSInt,
85 scalar::{Scalar, Value},
86 traits::ConvertFrom,
87 };
88
89 #[test]
90 #[cfg_attr(
91 not(feature = "f16"),
92 should_panic(expected = "f16 feature is not enabled")
93 )]
94 fn test_abs_f16() {
95 for bits in 0..=u16::MAX {
96 let v = F16::from_bits(bits);
97 let expected = v.abs();
98 let result = abs(Scalar, Value(v)).0;
99 assert_eq!(expected.to_bits(), result.to_bits());
100 }
101 }
102
103 #[test]
104 fn test_abs_f32() {
105 for bits in (0..=u32::MAX).step_by(10001) {
106 let v = f32::from_bits(bits);
107 let expected = v.abs();
108 let result = abs(Scalar, Value(v)).0;
109 assert_eq!(expected.to_bits(), result.to_bits());
110 }
111 }
112
113 #[test]
114 fn test_abs_f64() {
115 for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
116 let v = f64::from_bits(bits);
117 let expected = v.abs();
118 let result = abs(Scalar, Value(v)).0;
119 assert_eq!(expected.to_bits(), result.to_bits());
120 }
121 }
122
123 #[test]
124 #[cfg_attr(
125 not(feature = "f16"),
126 should_panic(expected = "f16 feature is not enabled")
127 )]
128 fn test_copy_sign_f16() {
129 #[track_caller]
130 fn check(mag_bits: u16, sign_bits: u16) {
131 let mag = F16::from_bits(mag_bits);
132 let sign = F16::from_bits(sign_bits);
133 let expected = mag.copysign(sign);
134 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
135 assert_eq!(expected.to_bits(), result.to_bits());
136 }
137 for mag_low_bits in 0..16 {
138 for mag_high_bits in 0..16 {
139 for sign_low_bits in 0..16 {
140 for sign_high_bits in 0..16 {
141 check(
142 mag_low_bits | (mag_high_bits << (16 - 4)),
143 sign_low_bits | (sign_high_bits << (16 - 4)),
144 );
145 }
146 }
147 }
148 }
149 }
150
151 #[test]
152 fn test_copy_sign_f32() {
153 #[track_caller]
154 fn check(mag_bits: u32, sign_bits: u32) {
155 let mag = f32::from_bits(mag_bits);
156 let sign = f32::from_bits(sign_bits);
157 let expected = mag.copysign(sign);
158 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
159 assert_eq!(expected.to_bits(), result.to_bits());
160 }
161 for mag_low_bits in 0..16 {
162 for mag_high_bits in 0..16 {
163 for sign_low_bits in 0..16 {
164 for sign_high_bits in 0..16 {
165 check(
166 mag_low_bits | (mag_high_bits << (32 - 4)),
167 sign_low_bits | (sign_high_bits << (32 - 4)),
168 );
169 }
170 }
171 }
172 }
173 }
174
175 #[test]
176 fn test_copy_sign_f64() {
177 #[track_caller]
178 fn check(mag_bits: u64, sign_bits: u64) {
179 let mag = f64::from_bits(mag_bits);
180 let sign = f64::from_bits(sign_bits);
181 let expected = mag.copysign(sign);
182 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
183 assert_eq!(expected.to_bits(), result.to_bits());
184 }
185 for mag_low_bits in 0..16 {
186 for mag_high_bits in 0..16 {
187 for sign_low_bits in 0..16 {
188 for sign_high_bits in 0..16 {
189 check(
190 mag_low_bits | (mag_high_bits << (64 - 4)),
191 sign_low_bits | (sign_high_bits << (64 - 4)),
192 );
193 }
194 }
195 }
196 }
197 }
198
199 fn same<F: PrimFloat>(a: F, b: F) -> bool {
200 if a.is_finite() && b.is_finite() {
201 a == b
202 } else {
203 a == b || (a.is_nan() && b.is_nan())
204 }
205 }
206
207 #[test]
208 #[cfg_attr(
209 not(feature = "f16"),
210 should_panic(expected = "f16 feature is not enabled")
211 )]
212 fn test_trunc_f16() {
213 for bits in 0..=u16::MAX {
214 let v = F16::from_bits(bits);
215 let expected = v.trunc();
216 let result = trunc(Scalar, Value(v)).0;
217 assert!(
218 same(expected, result),
219 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
220 v=v,
221 v_bits=v.to_bits(),
222 expected=expected,
223 expected_bits=expected.to_bits(),
224 result=result,
225 result_bits=result.to_bits(),
226 );
227 }
228 }
229
230 #[test]
231 fn test_trunc_f32() {
232 for bits in (0..=u32::MAX).step_by(0x10000) {
233 let v = f32::from_bits(bits);
234 let expected = v.trunc();
235 let result = trunc(Scalar, Value(v)).0;
236 assert!(
237 same(expected, result),
238 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
239 v=v,
240 v_bits=v.to_bits(),
241 expected=expected,
242 expected_bits=expected.to_bits(),
243 result=result,
244 result_bits=result.to_bits(),
245 );
246 }
247 }
248
249 #[test]
250 fn test_trunc_f64() {
251 for bits in (0..=u64::MAX).step_by(1 << 48) {
252 let v = f64::from_bits(bits);
253 let expected = v.trunc();
254 let result = trunc(Scalar, Value(v)).0;
255 assert!(
256 same(expected, result),
257 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
258 v=v,
259 v_bits=v.to_bits(),
260 expected=expected,
261 expected_bits=expected.to_bits(),
262 result=result,
263 result_bits=result.to_bits(),
264 );
265 }
266 }
267
268 fn reference_round_to_nearest_ties_to_even<
269 F: PrimFloat<BitsType = U, SignedBitsType = S>,
270 U: PrimUInt,
271 S: PrimSInt + ConvertFrom<F>,
272 >(
273 v: F,
274 ) -> F {
275 if v.abs() < F::cvt_from(S::MAX) {
276 let int_value: S = v.to();
277 let int_value_f: F = int_value.to();
278 let remainder: F = v - int_value_f;
279 if remainder.abs() < 0.5.to()
280 || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to())
281 {
282 int_value_f.copy_sign(v)
283 } else if remainder < 0.0.to() {
284 int_value_f - 1.0.to()
285 } else {
286 int_value_f + 1.0.to()
287 }
288 } else {
289 v
290 }
291 }
292
293 #[test]
294 fn test_reference_round_to_nearest_ties_to_even() {
295 #[track_caller]
296 fn case(v: f32, expected: f32) {
297 let result = reference_round_to_nearest_ties_to_even(v);
298 let same = if expected.is_nan() {
299 result.is_nan()
300 } else {
301 expected.to_bits() == result.to_bits()
302 };
303 assert!(
304 same,
305 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
306 v=v,
307 v_bits=v.to_bits(),
308 expected=expected,
309 expected_bits=expected.to_bits(),
310 result=result,
311 result_bits=result.to_bits(),
312 );
313 }
314 case(0.0, 0.0);
315 case(-0.0, -0.0);
316 case(0.499, 0.0);
317 case(-0.499, -0.0);
318 case(0.5, 0.0);
319 case(-0.5, -0.0);
320 case(0.501, 1.0);
321 case(-0.501, -1.0);
322 case(1.0, 1.0);
323 case(-1.0, -1.0);
324 case(1.499, 1.0);
325 case(-1.499, -1.0);
326 case(1.5, 2.0);
327 case(-1.5, -2.0);
328 case(1.501, 2.0);
329 case(-1.501, -2.0);
330 case(2.0, 2.0);
331 case(-2.0, -2.0);
332 case(2.499, 2.0);
333 case(-2.499, -2.0);
334 case(2.5, 2.0);
335 case(-2.5, -2.0);
336 case(2.501, 3.0);
337 case(-2.501, -3.0);
338 case(f32::INFINITY, f32::INFINITY);
339 case(-f32::INFINITY, -f32::INFINITY);
340 case(f32::NAN, f32::NAN);
341 case(1e30, 1e30);
342 case(-1e30, -1e30);
343 let i32_max = i32::MAX as f32;
344 let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1);
345 let i32_max_next = f32::from_bits(i32_max.to_bits() + 1);
346 case(i32_max, i32_max);
347 case(-i32_max, -i32_max);
348 case(i32_max_prev, i32_max_prev);
349 case(-i32_max_prev, -i32_max_prev);
350 case(i32_max_next, i32_max_next);
351 case(-i32_max_next, -i32_max_next);
352 }
353
354 #[test]
355 #[cfg_attr(
356 not(feature = "f16"),
357 should_panic(expected = "f16 feature is not enabled")
358 )]
359 fn test_round_to_nearest_ties_to_even_f16() {
360 for bits in 0..=u16::MAX {
361 let v = F16::from_bits(bits);
362 let expected = reference_round_to_nearest_ties_to_even(v);
363 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
364 assert!(
365 same(expected, result),
366 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
367 v=v,
368 v_bits=v.to_bits(),
369 expected=expected,
370 expected_bits=expected.to_bits(),
371 result=result,
372 result_bits=result.to_bits(),
373 );
374 }
375 }
376
377 #[test]
378 fn test_round_to_nearest_ties_to_even_f32() {
379 for bits in (0..=u32::MAX).step_by(0x10000) {
380 let v = f32::from_bits(bits);
381 let expected = reference_round_to_nearest_ties_to_even(v);
382 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
383 assert!(
384 same(expected, result),
385 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
386 v=v,
387 v_bits=v.to_bits(),
388 expected=expected,
389 expected_bits=expected.to_bits(),
390 result=result,
391 result_bits=result.to_bits(),
392 );
393 }
394 }
395
396 #[test]
397 fn test_round_to_nearest_ties_to_even_f64() {
398 for bits in (0..=u64::MAX).step_by(1 << 48) {
399 let v = f64::from_bits(bits);
400 let expected = reference_round_to_nearest_ties_to_even(v);
401 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
402 assert!(
403 same(expected, result),
404 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
405 v=v,
406 v_bits=v.to_bits(),
407 expected=expected,
408 expected_bits=expected.to_bits(),
409 result=result,
410 result_bits=result.to_bits(),
411 );
412 }
413 }
414 }