add count_leading_zeros, count_trailing_zeros, and count_ones implementations
[vector-math.git] / src / algorithms / base.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, Select, UInt},
4 };
5
6 pub fn abs<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
9 PrimF: PrimFloat<BitsType = PrimU>,
10 PrimU: PrimUInt,
11 >(
12 ctx: Ctx,
13 x: VecF,
14 ) -> VecF {
15 VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
16 }
17
18 pub fn copy_sign<
19 Ctx: Context,
20 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
21 PrimF: PrimFloat<BitsType = PrimU>,
22 PrimU: PrimUInt,
23 >(
24 ctx: Ctx,
25 mag: VecF,
26 sign: VecF,
27 ) -> VecF {
28 let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
29 let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
30 VecF::from_bits(mag_bits | sign_bit)
31 }
32
33 pub fn trunc<
34 Ctx: Context,
35 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
36 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
37 PrimF: PrimFloat<BitsType = PrimU>,
38 PrimU: PrimUInt,
39 >(
40 ctx: Ctx,
41 v: VecF,
42 ) -> VecF {
43 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
44 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
45 let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
46 let out_of_range = big | small;
47 let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
48 let out_of_range_value = small.select(small_value, v);
49 let exponent_field = v.extract_exponent_field();
50 let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
51 let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
52 mask >>= right_shift_amount;
53 let in_range_value = VecF::from_bits(v.to_bits() & !mask);
54 out_of_range.select(out_of_range_value, in_range_value)
55 }
56
57 pub fn round_to_nearest_ties_to_even<
58 Ctx: Context,
59 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
60 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
61 PrimF: PrimFloat<BitsType = PrimU>,
62 PrimU: PrimUInt,
63 >(
64 ctx: Ctx,
65 v: VecF,
66 ) -> VecF {
67 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
68 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
69 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
70 let offset_value: VecF = v.abs() + offset;
71 let in_range_value = (offset_value - offset).copy_sign(v);
72 big.select(v, in_range_value)
73 }
74
75 pub fn floor<
76 Ctx: Context,
77 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
78 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
79 PrimF: PrimFloat<BitsType = PrimU>,
80 PrimU: PrimUInt,
81 >(
82 ctx: Ctx,
83 v: VecF,
84 ) -> VecF {
85 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
86 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
87 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
88 let offset_value: VecF = v.abs() + offset;
89 let rounded = (offset_value - offset).copy_sign(v);
90 let need_round_down = v.lt(rounded);
91 let in_range_value = need_round_down
92 .select(rounded - ctx.make(1.to()), rounded)
93 .copy_sign(v);
94 big.select(v, in_range_value)
95 }
96
97 pub fn ceil<
98 Ctx: Context,
99 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
100 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
101 PrimF: PrimFloat<BitsType = PrimU>,
102 PrimU: PrimUInt,
103 >(
104 ctx: Ctx,
105 v: VecF,
106 ) -> VecF {
107 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
108 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
109 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
110 let offset_value: VecF = v.abs() + offset;
111 let rounded = (offset_value - offset).copy_sign(v);
112 let need_round_up = v.gt(rounded);
113 let in_range_value = need_round_up
114 .select(rounded + ctx.make(1.to()), rounded)
115 .copy_sign(v);
116 big.select(v, in_range_value)
117 }
118
119 #[cfg(test)]
120 mod tests {
121 use super::*;
122 use crate::{
123 f16::F16,
124 prim::PrimSInt,
125 scalar::{Scalar, Value},
126 traits::ConvertFrom,
127 };
128
129 #[test]
130 #[cfg_attr(
131 not(feature = "f16"),
132 should_panic(expected = "f16 feature is not enabled")
133 )]
134 fn test_abs_f16() {
135 for bits in 0..=u16::MAX {
136 let v = F16::from_bits(bits);
137 let expected = v.abs();
138 let result = abs(Scalar, Value(v)).0;
139 assert_eq!(expected.to_bits(), result.to_bits());
140 }
141 }
142
143 #[test]
144 fn test_abs_f32() {
145 for bits in (0..=u32::MAX).step_by(10001) {
146 let v = f32::from_bits(bits);
147 let expected = v.abs();
148 let result = abs(Scalar, Value(v)).0;
149 assert_eq!(expected.to_bits(), result.to_bits());
150 }
151 }
152
153 #[test]
154 fn test_abs_f64() {
155 for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
156 let v = f64::from_bits(bits);
157 let expected = v.abs();
158 let result = abs(Scalar, Value(v)).0;
159 assert_eq!(expected.to_bits(), result.to_bits());
160 }
161 }
162
163 #[test]
164 #[cfg_attr(
165 not(feature = "f16"),
166 should_panic(expected = "f16 feature is not enabled")
167 )]
168 fn test_copy_sign_f16() {
169 #[track_caller]
170 fn check(mag_bits: u16, sign_bits: u16) {
171 let mag = F16::from_bits(mag_bits);
172 let sign = F16::from_bits(sign_bits);
173 let expected = mag.copysign(sign);
174 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
175 assert_eq!(expected.to_bits(), result.to_bits());
176 }
177 for mag_low_bits in 0..16 {
178 for mag_high_bits in 0..16 {
179 for sign_low_bits in 0..16 {
180 for sign_high_bits in 0..16 {
181 check(
182 mag_low_bits | (mag_high_bits << (16 - 4)),
183 sign_low_bits | (sign_high_bits << (16 - 4)),
184 );
185 }
186 }
187 }
188 }
189 }
190
191 #[test]
192 fn test_copy_sign_f32() {
193 #[track_caller]
194 fn check(mag_bits: u32, sign_bits: u32) {
195 let mag = f32::from_bits(mag_bits);
196 let sign = f32::from_bits(sign_bits);
197 let expected = mag.copysign(sign);
198 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
199 assert_eq!(expected.to_bits(), result.to_bits());
200 }
201 for mag_low_bits in 0..16 {
202 for mag_high_bits in 0..16 {
203 for sign_low_bits in 0..16 {
204 for sign_high_bits in 0..16 {
205 check(
206 mag_low_bits | (mag_high_bits << (32 - 4)),
207 sign_low_bits | (sign_high_bits << (32 - 4)),
208 );
209 }
210 }
211 }
212 }
213 }
214
215 #[test]
216 fn test_copy_sign_f64() {
217 #[track_caller]
218 fn check(mag_bits: u64, sign_bits: u64) {
219 let mag = f64::from_bits(mag_bits);
220 let sign = f64::from_bits(sign_bits);
221 let expected = mag.copysign(sign);
222 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
223 assert_eq!(expected.to_bits(), result.to_bits());
224 }
225 for mag_low_bits in 0..16 {
226 for mag_high_bits in 0..16 {
227 for sign_low_bits in 0..16 {
228 for sign_high_bits in 0..16 {
229 check(
230 mag_low_bits | (mag_high_bits << (64 - 4)),
231 sign_low_bits | (sign_high_bits << (64 - 4)),
232 );
233 }
234 }
235 }
236 }
237 }
238
239 fn same<F: PrimFloat>(a: F, b: F) -> bool {
240 if a.is_finite() && b.is_finite() {
241 a.to_bits() == b.to_bits()
242 } else {
243 a == b || (a.is_nan() && b.is_nan())
244 }
245 }
246
247 #[test]
248 #[cfg_attr(
249 not(feature = "f16"),
250 should_panic(expected = "f16 feature is not enabled")
251 )]
252 fn test_trunc_f16() {
253 for bits in 0..=u16::MAX {
254 let v = F16::from_bits(bits);
255 let expected = v.trunc();
256 let result = trunc(Scalar, Value(v)).0;
257 assert!(
258 same(expected, result),
259 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
260 v=v,
261 v_bits=v.to_bits(),
262 expected=expected,
263 expected_bits=expected.to_bits(),
264 result=result,
265 result_bits=result.to_bits(),
266 );
267 }
268 }
269
270 #[test]
271 fn test_trunc_f32() {
272 for bits in (0..=u32::MAX).step_by(0x10000) {
273 let v = f32::from_bits(bits);
274 let expected = v.trunc();
275 let result = trunc(Scalar, Value(v)).0;
276 assert!(
277 same(expected, result),
278 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
279 v=v,
280 v_bits=v.to_bits(),
281 expected=expected,
282 expected_bits=expected.to_bits(),
283 result=result,
284 result_bits=result.to_bits(),
285 );
286 }
287 }
288
289 #[test]
290 fn test_trunc_f64() {
291 for bits in (0..=u64::MAX).step_by(1 << 48) {
292 let v = f64::from_bits(bits);
293 let expected = v.trunc();
294 let result = trunc(Scalar, Value(v)).0;
295 assert!(
296 same(expected, result),
297 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
298 v=v,
299 v_bits=v.to_bits(),
300 expected=expected,
301 expected_bits=expected.to_bits(),
302 result=result,
303 result_bits=result.to_bits(),
304 );
305 }
306 }
307
308 fn reference_round_to_nearest_ties_to_even<
309 F: PrimFloat<BitsType = U, SignedBitsType = S>,
310 U: PrimUInt,
311 S: PrimSInt + ConvertFrom<F>,
312 >(
313 v: F,
314 ) -> F {
315 if v.abs() < F::cvt_from(S::MAX) {
316 let int_value: S = v.to();
317 let int_value_f: F = int_value.to();
318 let remainder: F = v - int_value_f;
319 if remainder.abs() < 0.5.to()
320 || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to())
321 {
322 int_value_f.copy_sign(v)
323 } else if remainder < 0.0.to() {
324 int_value_f - 1.0.to()
325 } else {
326 int_value_f + 1.0.to()
327 }
328 } else {
329 v
330 }
331 }
332
333 #[test]
334 fn test_reference_round_to_nearest_ties_to_even() {
335 #[track_caller]
336 fn case(v: f32, expected: f32) {
337 let result = reference_round_to_nearest_ties_to_even(v);
338 assert!(
339 same(result, expected),
340 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
341 v=v,
342 v_bits=v.to_bits(),
343 expected=expected,
344 expected_bits=expected.to_bits(),
345 result=result,
346 result_bits=result.to_bits(),
347 );
348 }
349 case(0.0, 0.0);
350 case(-0.0, -0.0);
351 case(0.499, 0.0);
352 case(-0.499, -0.0);
353 case(0.5, 0.0);
354 case(-0.5, -0.0);
355 case(0.501, 1.0);
356 case(-0.501, -1.0);
357 case(1.0, 1.0);
358 case(-1.0, -1.0);
359 case(1.499, 1.0);
360 case(-1.499, -1.0);
361 case(1.5, 2.0);
362 case(-1.5, -2.0);
363 case(1.501, 2.0);
364 case(-1.501, -2.0);
365 case(2.0, 2.0);
366 case(-2.0, -2.0);
367 case(2.499, 2.0);
368 case(-2.499, -2.0);
369 case(2.5, 2.0);
370 case(-2.5, -2.0);
371 case(2.501, 3.0);
372 case(-2.501, -3.0);
373 case(f32::INFINITY, f32::INFINITY);
374 case(-f32::INFINITY, -f32::INFINITY);
375 case(f32::NAN, f32::NAN);
376 case(1e30, 1e30);
377 case(-1e30, -1e30);
378 let i32_max = i32::MAX as f32;
379 let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1);
380 let i32_max_next = f32::from_bits(i32_max.to_bits() + 1);
381 case(i32_max, i32_max);
382 case(-i32_max, -i32_max);
383 case(i32_max_prev, i32_max_prev);
384 case(-i32_max_prev, -i32_max_prev);
385 case(i32_max_next, i32_max_next);
386 case(-i32_max_next, -i32_max_next);
387 }
388
389 #[test]
390 #[cfg_attr(
391 not(feature = "f16"),
392 should_panic(expected = "f16 feature is not enabled")
393 )]
394 fn test_round_to_nearest_ties_to_even_f16() {
395 for bits in 0..=u16::MAX {
396 let v = F16::from_bits(bits);
397 let expected = reference_round_to_nearest_ties_to_even(v);
398 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
399 assert!(
400 same(result, expected),
401 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
402 v=v,
403 v_bits=v.to_bits(),
404 expected=expected,
405 expected_bits=expected.to_bits(),
406 result=result,
407 result_bits=result.to_bits(),
408 );
409 }
410 }
411
412 #[test]
413 fn test_round_to_nearest_ties_to_even_f32() {
414 for bits in (0..=u32::MAX).step_by(0x10000) {
415 let v = f32::from_bits(bits);
416 let expected = reference_round_to_nearest_ties_to_even(v);
417 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
418 assert!(
419 same(result, expected),
420 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
421 v=v,
422 v_bits=v.to_bits(),
423 expected=expected,
424 expected_bits=expected.to_bits(),
425 result=result,
426 result_bits=result.to_bits(),
427 );
428 }
429 }
430
431 #[test]
432 fn test_round_to_nearest_ties_to_even_f64() {
433 for bits in (0..=u64::MAX).step_by(1 << 48) {
434 let v = f64::from_bits(bits);
435 let expected = reference_round_to_nearest_ties_to_even(v);
436 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
437 assert!(
438 same(result, expected),
439 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
440 v=v,
441 v_bits=v.to_bits(),
442 expected=expected,
443 expected_bits=expected.to_bits(),
444 result=result,
445 result_bits=result.to_bits(),
446 );
447 }
448 }
449
450 #[test]
451 #[cfg_attr(
452 not(feature = "f16"),
453 should_panic(expected = "f16 feature is not enabled")
454 )]
455 fn test_floor_f16() {
456 for bits in 0..=u16::MAX {
457 let v = F16::from_bits(bits);
458 let expected = v.floor();
459 let result = floor(Scalar, Value(v)).0;
460 assert!(
461 same(expected, result),
462 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
463 v=v,
464 v_bits=v.to_bits(),
465 expected=expected,
466 expected_bits=expected.to_bits(),
467 result=result,
468 result_bits=result.to_bits(),
469 );
470 }
471 }
472
473 #[test]
474 fn test_floor_f32() {
475 for bits in (0..=u32::MAX).step_by(0x10000) {
476 let v = f32::from_bits(bits);
477 let expected = v.floor();
478 let result = floor(Scalar, Value(v)).0;
479 assert!(
480 same(expected, result),
481 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
482 v=v,
483 v_bits=v.to_bits(),
484 expected=expected,
485 expected_bits=expected.to_bits(),
486 result=result,
487 result_bits=result.to_bits(),
488 );
489 }
490 }
491
492 #[test]
493 fn test_floor_f64() {
494 for bits in (0..=u64::MAX).step_by(1 << 48) {
495 let v = f64::from_bits(bits);
496 let expected = v.floor();
497 let result = floor(Scalar, Value(v)).0;
498 assert!(
499 same(expected, result),
500 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
501 v=v,
502 v_bits=v.to_bits(),
503 expected=expected,
504 expected_bits=expected.to_bits(),
505 result=result,
506 result_bits=result.to_bits(),
507 );
508 }
509 }
510
511 #[test]
512 #[cfg_attr(
513 not(feature = "f16"),
514 should_panic(expected = "f16 feature is not enabled")
515 )]
516 fn test_ceil_f16() {
517 for bits in 0..=u16::MAX {
518 let v = F16::from_bits(bits);
519 let expected = v.ceil();
520 let result = ceil(Scalar, Value(v)).0;
521 assert!(
522 same(expected, result),
523 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
524 v=v,
525 v_bits=v.to_bits(),
526 expected=expected,
527 expected_bits=expected.to_bits(),
528 result=result,
529 result_bits=result.to_bits(),
530 );
531 }
532 }
533
534 #[test]
535 fn test_ceil_f32() {
536 for bits in (0..=u32::MAX).step_by(0x10000) {
537 let v = f32::from_bits(bits);
538 let expected = v.ceil();
539 let result = ceil(Scalar, Value(v)).0;
540 assert!(
541 same(expected, result),
542 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
543 v=v,
544 v_bits=v.to_bits(),
545 expected=expected,
546 expected_bits=expected.to_bits(),
547 result=result,
548 result_bits=result.to_bits(),
549 );
550 }
551 }
552
553 #[test]
554 fn test_ceil_f64() {
555 for bits in (0..=u64::MAX).step_by(1 << 48) {
556 let v = f64::from_bits(bits);
557 let expected = v.ceil();
558 let result = ceil(Scalar, Value(v)).0;
559 assert!(
560 same(expected, result),
561 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
562 v=v,
563 v_bits=v.to_bits(),
564 expected=expected,
565 expected_bits=expected.to_bits(),
566 result=result,
567 result_bits=result.to_bits(),
568 );
569 }
570 }
571 }