b6e342604962cf666677a7454a47e450fd8f2f29
[vector-math.git] / src / algorithms / base.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, Select, UInt},
4 };
5
6 pub fn abs<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
9 PrimF: PrimFloat<BitsType = PrimU>,
10 PrimU: PrimUInt,
11 >(
12 ctx: Ctx,
13 x: VecF,
14 ) -> VecF {
15 VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
16 }
17
18 pub fn copy_sign<
19 Ctx: Context,
20 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
21 PrimF: PrimFloat<BitsType = PrimU>,
22 PrimU: PrimUInt,
23 >(
24 ctx: Ctx,
25 mag: VecF,
26 sign: VecF,
27 ) -> VecF {
28 let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
29 let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
30 VecF::from_bits(mag_bits | sign_bit)
31 }
32
33 pub fn trunc<
34 Ctx: Context,
35 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
36 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
37 PrimF: PrimFloat<BitsType = PrimU>,
38 PrimU: PrimUInt,
39 >(
40 ctx: Ctx,
41 v: VecF,
42 ) -> VecF {
43 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
44 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
45 let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
46 let out_of_range = big | small;
47 let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
48 let out_of_range_value = small.select(small_value, v);
49 let exponent_field = v.extract_exponent_field();
50 let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
51 let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
52 mask >>= right_shift_amount;
53 let in_range_value = VecF::from_bits(v.to_bits() & !mask);
54 out_of_range.select(out_of_range_value, in_range_value)
55 }
56
57 pub fn round_to_nearest_ties_to_even<
58 Ctx: Context,
59 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
60 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
61 PrimF: PrimFloat<BitsType = PrimU>,
62 PrimU: PrimUInt,
63 >(
64 ctx: Ctx,
65 v: VecF,
66 ) -> VecF {
67 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
68 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
69 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
70 let offset_value: VecF = v.abs() + offset;
71 let in_range_value = (offset_value - offset).copy_sign(v);
72 big.select(v, in_range_value)
73 }
74
75 pub fn floor<
76 Ctx: Context,
77 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
78 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
79 PrimF: PrimFloat<BitsType = PrimU>,
80 PrimU: PrimUInt,
81 >(
82 ctx: Ctx,
83 v: VecF,
84 ) -> VecF {
85 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
86 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
87 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
88 let offset_value: VecF = v.abs() + offset;
89 let rounded = (offset_value - offset).copy_sign(v);
90 let need_round_down = v.lt(rounded);
91 let in_range_value = need_round_down
92 .select(rounded - ctx.make(1.to()), rounded)
93 .copy_sign(v);
94 big.select(v, in_range_value)
95 }
96
97 pub fn ceil<
98 Ctx: Context,
99 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
100 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
101 PrimF: PrimFloat<BitsType = PrimU>,
102 PrimU: PrimUInt,
103 >(
104 ctx: Ctx,
105 v: VecF,
106 ) -> VecF {
107 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
108 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
109 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
110 let offset_value: VecF = v.abs() + offset;
111 let rounded = (offset_value - offset).copy_sign(v);
112 let need_round_up = v.gt(rounded);
113 let in_range_value = need_round_up
114 .select(rounded + ctx.make(1.to()), rounded)
115 .copy_sign(v);
116 big.select(v, in_range_value)
117 }
118
119 pub fn initial_rsqrt_approximation<
120 Ctx: Context,
121 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
122 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
123 PrimF: PrimFloat<BitsType = PrimU>,
124 PrimU: PrimUInt,
125 >(
126 ctx: Ctx,
127 v: VecF,
128 ) -> VecF {
129 // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
130 // where `CONST` is optimized for use for Goldschmidt's algorithm.
131 // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
132 // but using different constants.
133 const FACTOR: f64 = -0.5;
134 const TERM: f64 = 1.6;
135 v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
136 }
137
138 /// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
139 pub fn sqrt_rsqrt_kernel<
140 Ctx: Context,
141 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
142 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
143 PrimF: PrimFloat<BitsType = PrimU>,
144 PrimU: PrimUInt,
145 >(
146 ctx: Ctx,
147 v: VecF,
148 iteration_count: usize,
149 ) -> (VecF, VecF) {
150 // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
151 let y = initial_rsqrt_approximation(ctx, v);
152 let mut x = v * y;
153 let one_half: VecF = ctx.make(0.5.to());
154 let mut neg_h = y * -one_half;
155 for _ in 0..iteration_count {
156 let r = x.mul_add_fast(neg_h, one_half);
157 x = x.mul_add_fast(r, x);
158 neg_h = neg_h.mul_add_fast(r, neg_h);
159 }
160 (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
161 }
162
163 #[cfg(test)]
164 mod tests {
165 use super::*;
166 use crate::{
167 f16::F16,
168 prim::PrimSInt,
169 scalar::{Scalar, Value},
170 traits::ConvertFrom,
171 };
172
173 #[test]
174 #[cfg_attr(
175 not(feature = "f16"),
176 should_panic(expected = "f16 feature is not enabled")
177 )]
178 fn test_abs_f16() {
179 for bits in 0..=u16::MAX {
180 let v = F16::from_bits(bits);
181 let expected = v.abs();
182 let result = abs(Scalar, Value(v)).0;
183 assert_eq!(expected.to_bits(), result.to_bits());
184 }
185 }
186
187 #[test]
188 fn test_abs_f32() {
189 for bits in (0..=u32::MAX).step_by(10001) {
190 let v = f32::from_bits(bits);
191 let expected = v.abs();
192 let result = abs(Scalar, Value(v)).0;
193 assert_eq!(expected.to_bits(), result.to_bits());
194 }
195 }
196
197 #[test]
198 fn test_abs_f64() {
199 for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
200 let v = f64::from_bits(bits);
201 let expected = v.abs();
202 let result = abs(Scalar, Value(v)).0;
203 assert_eq!(expected.to_bits(), result.to_bits());
204 }
205 }
206
207 #[test]
208 #[cfg_attr(
209 not(feature = "f16"),
210 should_panic(expected = "f16 feature is not enabled")
211 )]
212 fn test_copy_sign_f16() {
213 #[track_caller]
214 fn check(mag_bits: u16, sign_bits: u16) {
215 let mag = F16::from_bits(mag_bits);
216 let sign = F16::from_bits(sign_bits);
217 let expected = mag.copysign(sign);
218 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
219 assert_eq!(expected.to_bits(), result.to_bits());
220 }
221 for mag_low_bits in 0..16 {
222 for mag_high_bits in 0..16 {
223 for sign_low_bits in 0..16 {
224 for sign_high_bits in 0..16 {
225 check(
226 mag_low_bits | (mag_high_bits << (16 - 4)),
227 sign_low_bits | (sign_high_bits << (16 - 4)),
228 );
229 }
230 }
231 }
232 }
233 }
234
235 #[test]
236 fn test_copy_sign_f32() {
237 #[track_caller]
238 fn check(mag_bits: u32, sign_bits: u32) {
239 let mag = f32::from_bits(mag_bits);
240 let sign = f32::from_bits(sign_bits);
241 let expected = mag.copysign(sign);
242 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
243 assert_eq!(expected.to_bits(), result.to_bits());
244 }
245 for mag_low_bits in 0..16 {
246 for mag_high_bits in 0..16 {
247 for sign_low_bits in 0..16 {
248 for sign_high_bits in 0..16 {
249 check(
250 mag_low_bits | (mag_high_bits << (32 - 4)),
251 sign_low_bits | (sign_high_bits << (32 - 4)),
252 );
253 }
254 }
255 }
256 }
257 }
258
259 #[test]
260 fn test_copy_sign_f64() {
261 #[track_caller]
262 fn check(mag_bits: u64, sign_bits: u64) {
263 let mag = f64::from_bits(mag_bits);
264 let sign = f64::from_bits(sign_bits);
265 let expected = mag.copysign(sign);
266 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
267 assert_eq!(expected.to_bits(), result.to_bits());
268 }
269 for mag_low_bits in 0..16 {
270 for mag_high_bits in 0..16 {
271 for sign_low_bits in 0..16 {
272 for sign_high_bits in 0..16 {
273 check(
274 mag_low_bits | (mag_high_bits << (64 - 4)),
275 sign_low_bits | (sign_high_bits << (64 - 4)),
276 );
277 }
278 }
279 }
280 }
281 }
282
283 fn same<F: PrimFloat>(a: F, b: F) -> bool {
284 if a.is_finite() && b.is_finite() {
285 a.to_bits() == b.to_bits()
286 } else {
287 a == b || (a.is_nan() && b.is_nan())
288 }
289 }
290
291 #[test]
292 #[cfg_attr(
293 not(feature = "f16"),
294 should_panic(expected = "f16 feature is not enabled")
295 )]
296 fn test_trunc_f16() {
297 for bits in 0..=u16::MAX {
298 let v = F16::from_bits(bits);
299 let expected = v.trunc();
300 let result = trunc(Scalar, Value(v)).0;
301 assert!(
302 same(expected, result),
303 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
304 v=v,
305 v_bits=v.to_bits(),
306 expected=expected,
307 expected_bits=expected.to_bits(),
308 result=result,
309 result_bits=result.to_bits(),
310 );
311 }
312 }
313
314 #[test]
315 fn test_trunc_f32() {
316 for bits in (0..=u32::MAX).step_by(0x10000) {
317 let v = f32::from_bits(bits);
318 let expected = v.trunc();
319 let result = trunc(Scalar, Value(v)).0;
320 assert!(
321 same(expected, result),
322 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
323 v=v,
324 v_bits=v.to_bits(),
325 expected=expected,
326 expected_bits=expected.to_bits(),
327 result=result,
328 result_bits=result.to_bits(),
329 );
330 }
331 }
332
333 #[test]
334 fn test_trunc_f64() {
335 for bits in (0..=u64::MAX).step_by(1 << 48) {
336 let v = f64::from_bits(bits);
337 let expected = v.trunc();
338 let result = trunc(Scalar, Value(v)).0;
339 assert!(
340 same(expected, result),
341 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
342 v=v,
343 v_bits=v.to_bits(),
344 expected=expected,
345 expected_bits=expected.to_bits(),
346 result=result,
347 result_bits=result.to_bits(),
348 );
349 }
350 }
351
352 fn reference_round_to_nearest_ties_to_even<
353 F: PrimFloat<BitsType = U, SignedBitsType = S>,
354 U: PrimUInt,
355 S: PrimSInt + ConvertFrom<F>,
356 >(
357 v: F,
358 ) -> F {
359 if v.abs() < F::cvt_from(S::MAX) {
360 let int_value: S = v.to();
361 let int_value_f: F = int_value.to();
362 let remainder: F = v - int_value_f;
363 if remainder.abs() < 0.5.to()
364 || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to())
365 {
366 int_value_f.copy_sign(v)
367 } else if remainder < 0.0.to() {
368 int_value_f - 1.0.to()
369 } else {
370 int_value_f + 1.0.to()
371 }
372 } else {
373 v
374 }
375 }
376
377 #[test]
378 fn test_reference_round_to_nearest_ties_to_even() {
379 #[track_caller]
380 fn case(v: f32, expected: f32) {
381 let result = reference_round_to_nearest_ties_to_even(v);
382 assert!(
383 same(result, expected),
384 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
385 v=v,
386 v_bits=v.to_bits(),
387 expected=expected,
388 expected_bits=expected.to_bits(),
389 result=result,
390 result_bits=result.to_bits(),
391 );
392 }
393 case(0.0, 0.0);
394 case(-0.0, -0.0);
395 case(0.499, 0.0);
396 case(-0.499, -0.0);
397 case(0.5, 0.0);
398 case(-0.5, -0.0);
399 case(0.501, 1.0);
400 case(-0.501, -1.0);
401 case(1.0, 1.0);
402 case(-1.0, -1.0);
403 case(1.499, 1.0);
404 case(-1.499, -1.0);
405 case(1.5, 2.0);
406 case(-1.5, -2.0);
407 case(1.501, 2.0);
408 case(-1.501, -2.0);
409 case(2.0, 2.0);
410 case(-2.0, -2.0);
411 case(2.499, 2.0);
412 case(-2.499, -2.0);
413 case(2.5, 2.0);
414 case(-2.5, -2.0);
415 case(2.501, 3.0);
416 case(-2.501, -3.0);
417 case(f32::INFINITY, f32::INFINITY);
418 case(-f32::INFINITY, -f32::INFINITY);
419 case(f32::NAN, f32::NAN);
420 case(1e30, 1e30);
421 case(-1e30, -1e30);
422 let i32_max = i32::MAX as f32;
423 let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1);
424 let i32_max_next = f32::from_bits(i32_max.to_bits() + 1);
425 case(i32_max, i32_max);
426 case(-i32_max, -i32_max);
427 case(i32_max_prev, i32_max_prev);
428 case(-i32_max_prev, -i32_max_prev);
429 case(i32_max_next, i32_max_next);
430 case(-i32_max_next, -i32_max_next);
431 }
432
433 #[test]
434 #[cfg_attr(
435 not(feature = "f16"),
436 should_panic(expected = "f16 feature is not enabled")
437 )]
438 fn test_round_to_nearest_ties_to_even_f16() {
439 for bits in 0..=u16::MAX {
440 let v = F16::from_bits(bits);
441 let expected = reference_round_to_nearest_ties_to_even(v);
442 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
443 assert!(
444 same(result, expected),
445 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
446 v=v,
447 v_bits=v.to_bits(),
448 expected=expected,
449 expected_bits=expected.to_bits(),
450 result=result,
451 result_bits=result.to_bits(),
452 );
453 }
454 }
455
456 #[test]
457 fn test_round_to_nearest_ties_to_even_f32() {
458 for bits in (0..=u32::MAX).step_by(0x10000) {
459 let v = f32::from_bits(bits);
460 let expected = reference_round_to_nearest_ties_to_even(v);
461 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
462 assert!(
463 same(result, expected),
464 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
465 v=v,
466 v_bits=v.to_bits(),
467 expected=expected,
468 expected_bits=expected.to_bits(),
469 result=result,
470 result_bits=result.to_bits(),
471 );
472 }
473 }
474
475 #[test]
476 fn test_round_to_nearest_ties_to_even_f64() {
477 for bits in (0..=u64::MAX).step_by(1 << 48) {
478 let v = f64::from_bits(bits);
479 let expected = reference_round_to_nearest_ties_to_even(v);
480 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
481 assert!(
482 same(result, expected),
483 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
484 v=v,
485 v_bits=v.to_bits(),
486 expected=expected,
487 expected_bits=expected.to_bits(),
488 result=result,
489 result_bits=result.to_bits(),
490 );
491 }
492 }
493
494 #[test]
495 #[cfg_attr(
496 not(feature = "f16"),
497 should_panic(expected = "f16 feature is not enabled")
498 )]
499 fn test_floor_f16() {
500 for bits in 0..=u16::MAX {
501 let v = F16::from_bits(bits);
502 let expected = v.floor();
503 let result = floor(Scalar, Value(v)).0;
504 assert!(
505 same(expected, result),
506 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
507 v=v,
508 v_bits=v.to_bits(),
509 expected=expected,
510 expected_bits=expected.to_bits(),
511 result=result,
512 result_bits=result.to_bits(),
513 );
514 }
515 }
516
517 #[test]
518 fn test_floor_f32() {
519 for bits in (0..=u32::MAX).step_by(0x10000) {
520 let v = f32::from_bits(bits);
521 let expected = v.floor();
522 let result = floor(Scalar, Value(v)).0;
523 assert!(
524 same(expected, result),
525 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
526 v=v,
527 v_bits=v.to_bits(),
528 expected=expected,
529 expected_bits=expected.to_bits(),
530 result=result,
531 result_bits=result.to_bits(),
532 );
533 }
534 }
535
536 #[test]
537 fn test_floor_f64() {
538 for bits in (0..=u64::MAX).step_by(1 << 48) {
539 let v = f64::from_bits(bits);
540 let expected = v.floor();
541 let result = floor(Scalar, Value(v)).0;
542 assert!(
543 same(expected, result),
544 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
545 v=v,
546 v_bits=v.to_bits(),
547 expected=expected,
548 expected_bits=expected.to_bits(),
549 result=result,
550 result_bits=result.to_bits(),
551 );
552 }
553 }
554
555 #[test]
556 #[cfg_attr(
557 not(feature = "f16"),
558 should_panic(expected = "f16 feature is not enabled")
559 )]
560 fn test_ceil_f16() {
561 for bits in 0..=u16::MAX {
562 let v = F16::from_bits(bits);
563 let expected = v.ceil();
564 let result = ceil(Scalar, Value(v)).0;
565 assert!(
566 same(expected, result),
567 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
568 v=v,
569 v_bits=v.to_bits(),
570 expected=expected,
571 expected_bits=expected.to_bits(),
572 result=result,
573 result_bits=result.to_bits(),
574 );
575 }
576 }
577
578 #[test]
579 fn test_ceil_f32() {
580 for bits in (0..=u32::MAX).step_by(0x10000) {
581 let v = f32::from_bits(bits);
582 let expected = v.ceil();
583 let result = ceil(Scalar, Value(v)).0;
584 assert!(
585 same(expected, result),
586 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
587 v=v,
588 v_bits=v.to_bits(),
589 expected=expected,
590 expected_bits=expected.to_bits(),
591 result=result,
592 result_bits=result.to_bits(),
593 );
594 }
595 }
596
597 #[test]
598 fn test_ceil_f64() {
599 for bits in (0..=u64::MAX).step_by(1 << 48) {
600 let v = f64::from_bits(bits);
601 let expected = v.ceil();
602 let result = ceil(Scalar, Value(v)).0;
603 assert!(
604 same(expected, result),
605 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
606 v=v,
607 v_bits=v.to_bits(),
608 expected=expected,
609 expected_bits=expected.to_bits(),
610 result=result,
611 result_bits=result.to_bits(),
612 );
613 }
614 }
615
616 // TODO: add tests for `sqrt_rsqrt_kernel`
617 }