add ceil and floor
[vector-math.git] / src / algorithms / base.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, Select, UInt},
4 };
5
6 pub fn abs<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
9 PrimF: PrimFloat<BitsType = PrimU>,
10 PrimU: PrimUInt,
11 >(
12 ctx: Ctx,
13 x: VecF,
14 ) -> VecF {
15 VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
16 }
17
18 pub fn copy_sign<
19 Ctx: Context,
20 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
21 PrimF: PrimFloat<BitsType = PrimU>,
22 PrimU: PrimUInt,
23 >(
24 ctx: Ctx,
25 mag: VecF,
26 sign: VecF,
27 ) -> VecF {
28 let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
29 let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
30 VecF::from_bits(mag_bits | sign_bit)
31 }
32
33 pub fn trunc<
34 Ctx: Context,
35 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
36 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
37 PrimF: PrimFloat<BitsType = PrimU>,
38 PrimU: PrimUInt,
39 >(
40 ctx: Ctx,
41 v: VecF,
42 ) -> VecF {
43 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
44 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
45 let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
46 let out_of_range = big | small;
47 let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
48 let out_of_range_value = small.select(small_value, v);
49 let exponent_field = v.extract_exponent_field();
50 let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
51 let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
52 mask >>= right_shift_amount;
53 let in_range_value = VecF::from_bits(v.to_bits() & !mask);
54 out_of_range.select(out_of_range_value, in_range_value)
55 }
56
57 pub fn round_to_nearest_ties_to_even<
58 Ctx: Context,
59 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
60 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
61 PrimF: PrimFloat<BitsType = PrimU>,
62 PrimU: PrimUInt,
63 >(
64 ctx: Ctx,
65 v: VecF,
66 ) -> VecF {
67 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
68 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
69 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
70 let offset_value: VecF = v.abs() + offset;
71 let in_range_value = (offset_value - offset).copy_sign(v);
72 big.select(v, in_range_value)
73 }
74
75 pub fn floor<
76 Ctx: Context,
77 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
78 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
79 PrimF: PrimFloat<BitsType = PrimU>,
80 PrimU: PrimUInt,
81 >(
82 ctx: Ctx,
83 v: VecF,
84 ) -> VecF {
85 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
86 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
87 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
88 let offset_value: VecF = v.abs() + offset;
89 let rounded = (offset_value - offset).copy_sign(v);
90 let need_round_down = v.lt(rounded);
91 let in_range_value = need_round_down.select(rounded - ctx.make(1.to()), rounded).copy_sign(v);
92 big.select(v, in_range_value)
93 }
94
95 pub fn ceil<
96 Ctx: Context,
97 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
98 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
99 PrimF: PrimFloat<BitsType = PrimU>,
100 PrimU: PrimUInt,
101 >(
102 ctx: Ctx,
103 v: VecF,
104 ) -> VecF {
105 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
106 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
107 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
108 let offset_value: VecF = v.abs() + offset;
109 let rounded = (offset_value - offset).copy_sign(v);
110 let need_round_up = v.gt(rounded);
111 let in_range_value = need_round_up.select(rounded + ctx.make(1.to()), rounded).copy_sign(v);
112 big.select(v, in_range_value)
113 }
114
115 #[cfg(test)]
116 mod tests {
117 use super::*;
118 use crate::{
119 f16::F16,
120 prim::PrimSInt,
121 scalar::{Scalar, Value},
122 traits::ConvertFrom,
123 };
124
125 #[test]
126 #[cfg_attr(
127 not(feature = "f16"),
128 should_panic(expected = "f16 feature is not enabled")
129 )]
130 fn test_abs_f16() {
131 for bits in 0..=u16::MAX {
132 let v = F16::from_bits(bits);
133 let expected = v.abs();
134 let result = abs(Scalar, Value(v)).0;
135 assert_eq!(expected.to_bits(), result.to_bits());
136 }
137 }
138
139 #[test]
140 fn test_abs_f32() {
141 for bits in (0..=u32::MAX).step_by(10001) {
142 let v = f32::from_bits(bits);
143 let expected = v.abs();
144 let result = abs(Scalar, Value(v)).0;
145 assert_eq!(expected.to_bits(), result.to_bits());
146 }
147 }
148
149 #[test]
150 fn test_abs_f64() {
151 for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
152 let v = f64::from_bits(bits);
153 let expected = v.abs();
154 let result = abs(Scalar, Value(v)).0;
155 assert_eq!(expected.to_bits(), result.to_bits());
156 }
157 }
158
159 #[test]
160 #[cfg_attr(
161 not(feature = "f16"),
162 should_panic(expected = "f16 feature is not enabled")
163 )]
164 fn test_copy_sign_f16() {
165 #[track_caller]
166 fn check(mag_bits: u16, sign_bits: u16) {
167 let mag = F16::from_bits(mag_bits);
168 let sign = F16::from_bits(sign_bits);
169 let expected = mag.copysign(sign);
170 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
171 assert_eq!(expected.to_bits(), result.to_bits());
172 }
173 for mag_low_bits in 0..16 {
174 for mag_high_bits in 0..16 {
175 for sign_low_bits in 0..16 {
176 for sign_high_bits in 0..16 {
177 check(
178 mag_low_bits | (mag_high_bits << (16 - 4)),
179 sign_low_bits | (sign_high_bits << (16 - 4)),
180 );
181 }
182 }
183 }
184 }
185 }
186
187 #[test]
188 fn test_copy_sign_f32() {
189 #[track_caller]
190 fn check(mag_bits: u32, sign_bits: u32) {
191 let mag = f32::from_bits(mag_bits);
192 let sign = f32::from_bits(sign_bits);
193 let expected = mag.copysign(sign);
194 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
195 assert_eq!(expected.to_bits(), result.to_bits());
196 }
197 for mag_low_bits in 0..16 {
198 for mag_high_bits in 0..16 {
199 for sign_low_bits in 0..16 {
200 for sign_high_bits in 0..16 {
201 check(
202 mag_low_bits | (mag_high_bits << (32 - 4)),
203 sign_low_bits | (sign_high_bits << (32 - 4)),
204 );
205 }
206 }
207 }
208 }
209 }
210
211 #[test]
212 fn test_copy_sign_f64() {
213 #[track_caller]
214 fn check(mag_bits: u64, sign_bits: u64) {
215 let mag = f64::from_bits(mag_bits);
216 let sign = f64::from_bits(sign_bits);
217 let expected = mag.copysign(sign);
218 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
219 assert_eq!(expected.to_bits(), result.to_bits());
220 }
221 for mag_low_bits in 0..16 {
222 for mag_high_bits in 0..16 {
223 for sign_low_bits in 0..16 {
224 for sign_high_bits in 0..16 {
225 check(
226 mag_low_bits | (mag_high_bits << (64 - 4)),
227 sign_low_bits | (sign_high_bits << (64 - 4)),
228 );
229 }
230 }
231 }
232 }
233 }
234
235 fn same<F: PrimFloat>(a: F, b: F) -> bool {
236 if a.is_finite() && b.is_finite() {
237 a.to_bits() == b.to_bits()
238 } else {
239 a == b || (a.is_nan() && b.is_nan())
240 }
241 }
242
243 #[test]
244 #[cfg_attr(
245 not(feature = "f16"),
246 should_panic(expected = "f16 feature is not enabled")
247 )]
248 fn test_trunc_f16() {
249 for bits in 0..=u16::MAX {
250 let v = F16::from_bits(bits);
251 let expected = v.trunc();
252 let result = trunc(Scalar, Value(v)).0;
253 assert!(
254 same(expected, result),
255 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
256 v=v,
257 v_bits=v.to_bits(),
258 expected=expected,
259 expected_bits=expected.to_bits(),
260 result=result,
261 result_bits=result.to_bits(),
262 );
263 }
264 }
265
266 #[test]
267 fn test_trunc_f32() {
268 for bits in (0..=u32::MAX).step_by(0x10000) {
269 let v = f32::from_bits(bits);
270 let expected = v.trunc();
271 let result = trunc(Scalar, Value(v)).0;
272 assert!(
273 same(expected, result),
274 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
275 v=v,
276 v_bits=v.to_bits(),
277 expected=expected,
278 expected_bits=expected.to_bits(),
279 result=result,
280 result_bits=result.to_bits(),
281 );
282 }
283 }
284
285 #[test]
286 fn test_trunc_f64() {
287 for bits in (0..=u64::MAX).step_by(1 << 48) {
288 let v = f64::from_bits(bits);
289 let expected = v.trunc();
290 let result = trunc(Scalar, Value(v)).0;
291 assert!(
292 same(expected, result),
293 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
294 v=v,
295 v_bits=v.to_bits(),
296 expected=expected,
297 expected_bits=expected.to_bits(),
298 result=result,
299 result_bits=result.to_bits(),
300 );
301 }
302 }
303
304 fn reference_round_to_nearest_ties_to_even<
305 F: PrimFloat<BitsType = U, SignedBitsType = S>,
306 U: PrimUInt,
307 S: PrimSInt + ConvertFrom<F>,
308 >(
309 v: F,
310 ) -> F {
311 if v.abs() < F::cvt_from(S::MAX) {
312 let int_value: S = v.to();
313 let int_value_f: F = int_value.to();
314 let remainder: F = v - int_value_f;
315 if remainder.abs() < 0.5.to()
316 || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to())
317 {
318 int_value_f.copy_sign(v)
319 } else if remainder < 0.0.to() {
320 int_value_f - 1.0.to()
321 } else {
322 int_value_f + 1.0.to()
323 }
324 } else {
325 v
326 }
327 }
328
329 #[test]
330 fn test_reference_round_to_nearest_ties_to_even() {
331 #[track_caller]
332 fn case(v: f32, expected: f32) {
333 let result = reference_round_to_nearest_ties_to_even(v);
334 assert!(
335 same(result, expected),
336 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
337 v=v,
338 v_bits=v.to_bits(),
339 expected=expected,
340 expected_bits=expected.to_bits(),
341 result=result,
342 result_bits=result.to_bits(),
343 );
344 }
345 case(0.0, 0.0);
346 case(-0.0, -0.0);
347 case(0.499, 0.0);
348 case(-0.499, -0.0);
349 case(0.5, 0.0);
350 case(-0.5, -0.0);
351 case(0.501, 1.0);
352 case(-0.501, -1.0);
353 case(1.0, 1.0);
354 case(-1.0, -1.0);
355 case(1.499, 1.0);
356 case(-1.499, -1.0);
357 case(1.5, 2.0);
358 case(-1.5, -2.0);
359 case(1.501, 2.0);
360 case(-1.501, -2.0);
361 case(2.0, 2.0);
362 case(-2.0, -2.0);
363 case(2.499, 2.0);
364 case(-2.499, -2.0);
365 case(2.5, 2.0);
366 case(-2.5, -2.0);
367 case(2.501, 3.0);
368 case(-2.501, -3.0);
369 case(f32::INFINITY, f32::INFINITY);
370 case(-f32::INFINITY, -f32::INFINITY);
371 case(f32::NAN, f32::NAN);
372 case(1e30, 1e30);
373 case(-1e30, -1e30);
374 let i32_max = i32::MAX as f32;
375 let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1);
376 let i32_max_next = f32::from_bits(i32_max.to_bits() + 1);
377 case(i32_max, i32_max);
378 case(-i32_max, -i32_max);
379 case(i32_max_prev, i32_max_prev);
380 case(-i32_max_prev, -i32_max_prev);
381 case(i32_max_next, i32_max_next);
382 case(-i32_max_next, -i32_max_next);
383 }
384
385 #[test]
386 #[cfg_attr(
387 not(feature = "f16"),
388 should_panic(expected = "f16 feature is not enabled")
389 )]
390 fn test_round_to_nearest_ties_to_even_f16() {
391 for bits in 0..=u16::MAX {
392 let v = F16::from_bits(bits);
393 let expected = reference_round_to_nearest_ties_to_even(v);
394 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
395 assert!(
396 same(result, expected),
397 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
398 v=v,
399 v_bits=v.to_bits(),
400 expected=expected,
401 expected_bits=expected.to_bits(),
402 result=result,
403 result_bits=result.to_bits(),
404 );
405 }
406 }
407
408 #[test]
409 fn test_round_to_nearest_ties_to_even_f32() {
410 for bits in (0..=u32::MAX).step_by(0x10000) {
411 let v = f32::from_bits(bits);
412 let expected = reference_round_to_nearest_ties_to_even(v);
413 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
414 assert!(
415 same(result, expected),
416 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
417 v=v,
418 v_bits=v.to_bits(),
419 expected=expected,
420 expected_bits=expected.to_bits(),
421 result=result,
422 result_bits=result.to_bits(),
423 );
424 }
425 }
426
427 #[test]
428 fn test_round_to_nearest_ties_to_even_f64() {
429 for bits in (0..=u64::MAX).step_by(1 << 48) {
430 let v = f64::from_bits(bits);
431 let expected = reference_round_to_nearest_ties_to_even(v);
432 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
433 assert!(
434 same(result, expected),
435 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
436 v=v,
437 v_bits=v.to_bits(),
438 expected=expected,
439 expected_bits=expected.to_bits(),
440 result=result,
441 result_bits=result.to_bits(),
442 );
443 }
444 }
445
446 #[test]
447 #[cfg_attr(
448 not(feature = "f16"),
449 should_panic(expected = "f16 feature is not enabled")
450 )]
451 fn test_floor_f16() {
452 for bits in 0..=u16::MAX {
453 let v = F16::from_bits(bits);
454 let expected = v.floor();
455 let result = floor(Scalar, Value(v)).0;
456 assert!(
457 same(expected, result),
458 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
459 v=v,
460 v_bits=v.to_bits(),
461 expected=expected,
462 expected_bits=expected.to_bits(),
463 result=result,
464 result_bits=result.to_bits(),
465 );
466 }
467 }
468
469 #[test]
470 fn test_floor_f32() {
471 for bits in (0..=u32::MAX).step_by(0x10000) {
472 let v = f32::from_bits(bits);
473 let expected = v.floor();
474 let result = floor(Scalar, Value(v)).0;
475 assert!(
476 same(expected, result),
477 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
478 v=v,
479 v_bits=v.to_bits(),
480 expected=expected,
481 expected_bits=expected.to_bits(),
482 result=result,
483 result_bits=result.to_bits(),
484 );
485 }
486 }
487
488 #[test]
489 fn test_floor_f64() {
490 for bits in (0..=u64::MAX).step_by(1 << 48) {
491 let v = f64::from_bits(bits);
492 let expected = v.floor();
493 let result = floor(Scalar, Value(v)).0;
494 assert!(
495 same(expected, result),
496 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
497 v=v,
498 v_bits=v.to_bits(),
499 expected=expected,
500 expected_bits=expected.to_bits(),
501 result=result,
502 result_bits=result.to_bits(),
503 );
504 }
505 }
506
507 #[test]
508 #[cfg_attr(
509 not(feature = "f16"),
510 should_panic(expected = "f16 feature is not enabled")
511 )]
512 fn test_ceil_f16() {
513 for bits in 0..=u16::MAX {
514 let v = F16::from_bits(bits);
515 let expected = v.ceil();
516 let result = ceil(Scalar, Value(v)).0;
517 assert!(
518 same(expected, result),
519 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
520 v=v,
521 v_bits=v.to_bits(),
522 expected=expected,
523 expected_bits=expected.to_bits(),
524 result=result,
525 result_bits=result.to_bits(),
526 );
527 }
528 }
529
530 #[test]
531 fn test_ceil_f32() {
532 for bits in (0..=u32::MAX).step_by(0x10000) {
533 let v = f32::from_bits(bits);
534 let expected = v.ceil();
535 let result = ceil(Scalar, Value(v)).0;
536 assert!(
537 same(expected, result),
538 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
539 v=v,
540 v_bits=v.to_bits(),
541 expected=expected,
542 expected_bits=expected.to_bits(),
543 result=result,
544 result_bits=result.to_bits(),
545 );
546 }
547 }
548
549 #[test]
550 fn test_ceil_f64() {
551 for bits in (0..=u64::MAX).step_by(1 << 48) {
552 let v = f64::from_bits(bits);
553 let expected = v.ceil();
554 let result = ceil(Scalar, Value(v)).0;
555 assert!(
556 same(expected, result),
557 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
558 v=v,
559 v_bits=v.to_bits(),
560 expected=expected,
561 expected_bits=expected.to_bits(),
562 result=result,
563 result_bits=result.to_bits(),
564 );
565 }
566 }
567 }