2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, Select, UInt},
8 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
9 PrimF: PrimFloat<BitsType = PrimU>,
15 VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
20 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
21 PrimF: PrimFloat<BitsType = PrimU>,
28 let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
29 let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
30 VecF::from_bits(mag_bits | sign_bit)
35 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
36 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
37 PrimF: PrimFloat<BitsType = PrimU>,
43 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
44 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
45 let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
46 let out_of_range = big | small;
47 let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
48 let out_of_range_value = small.select(small_value, v);
49 let exponent_field = v.extract_exponent_field();
50 let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
51 let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
52 mask >>= right_shift_amount;
53 let in_range_value = VecF::from_bits(v.to_bits() & !mask);
54 out_of_range.select(out_of_range_value, in_range_value)
57 pub fn round_to_nearest_ties_to_even<
59 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
60 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
61 PrimF: PrimFloat<BitsType = PrimU>,
67 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
68 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
69 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
70 let offset_value: VecF = v.abs() + offset;
71 let in_range_value = (offset_value - offset).copy_sign(v);
72 big.select(v, in_range_value)
77 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
78 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
79 PrimF: PrimFloat<BitsType = PrimU>,
85 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
86 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
87 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
88 let offset_value: VecF = v.abs() + offset;
89 let rounded = (offset_value - offset).copy_sign(v);
90 let need_round_down = v.lt(rounded);
91 let in_range_value = need_round_down
92 .select(rounded - ctx.make(1.to()), rounded)
94 big.select(v, in_range_value)
99 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
100 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
101 PrimF: PrimFloat<BitsType = PrimU>,
107 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
108 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
109 let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to());
110 let offset_value: VecF = v.abs() + offset;
111 let rounded = (offset_value - offset).copy_sign(v);
112 let need_round_up = v.gt(rounded);
113 let in_range_value = need_round_up
114 .select(rounded + ctx.make(1.to()), rounded)
116 big.select(v, in_range_value)
119 pub fn initial_rsqrt_approximation<
121 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
122 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
123 PrimF: PrimFloat<BitsType = PrimU>,
129 // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
130 // where `CONST` is optimized for use for Goldschmidt's algorithm.
131 // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
132 // but using different constants.
133 const FACTOR: f64 = -0.5;
134 const TERM: f64 = 1.6;
135 v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
138 /// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
139 pub fn sqrt_rsqrt_kernel<
141 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
142 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
143 PrimF: PrimFloat<BitsType = PrimU>,
148 iteration_count: usize,
150 /// based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
151 let y = initial_rsqrt_approximation(ctx, v);
153 let one_half: VecF = ctx.make(0.5.to());
154 let mut neg_h = y * -one_half;
155 for _ in 0..iteration_count {
156 let r = x.mul_add_fast(neg_h, one_half);
157 x = x.mul_add_fast(r, x);
158 neg_h = neg_h.mul_add_fast(r, neg_h);
160 (x, neg_h * ctx.make(-2.to()))
169 scalar::{Scalar, Value},
175 not(feature = "f16"),
176 should_panic(expected = "f16 feature is not enabled")
179 for bits in 0..=u16::MAX {
180 let v = F16::from_bits(bits);
181 let expected = v.abs();
182 let result = abs(Scalar, Value(v)).0;
183 assert_eq!(expected.to_bits(), result.to_bits());
189 for bits in (0..=u32::MAX).step_by(10001) {
190 let v = f32::from_bits(bits);
191 let expected = v.abs();
192 let result = abs(Scalar, Value(v)).0;
193 assert_eq!(expected.to_bits(), result.to_bits());
199 for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
200 let v = f64::from_bits(bits);
201 let expected = v.abs();
202 let result = abs(Scalar, Value(v)).0;
203 assert_eq!(expected.to_bits(), result.to_bits());
209 not(feature = "f16"),
210 should_panic(expected = "f16 feature is not enabled")
212 fn test_copy_sign_f16() {
214 fn check(mag_bits: u16, sign_bits: u16) {
215 let mag = F16::from_bits(mag_bits);
216 let sign = F16::from_bits(sign_bits);
217 let expected = mag.copysign(sign);
218 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
219 assert_eq!(expected.to_bits(), result.to_bits());
221 for mag_low_bits in 0..16 {
222 for mag_high_bits in 0..16 {
223 for sign_low_bits in 0..16 {
224 for sign_high_bits in 0..16 {
226 mag_low_bits | (mag_high_bits << (16 - 4)),
227 sign_low_bits | (sign_high_bits << (16 - 4)),
236 fn test_copy_sign_f32() {
238 fn check(mag_bits: u32, sign_bits: u32) {
239 let mag = f32::from_bits(mag_bits);
240 let sign = f32::from_bits(sign_bits);
241 let expected = mag.copysign(sign);
242 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
243 assert_eq!(expected.to_bits(), result.to_bits());
245 for mag_low_bits in 0..16 {
246 for mag_high_bits in 0..16 {
247 for sign_low_bits in 0..16 {
248 for sign_high_bits in 0..16 {
250 mag_low_bits | (mag_high_bits << (32 - 4)),
251 sign_low_bits | (sign_high_bits << (32 - 4)),
260 fn test_copy_sign_f64() {
262 fn check(mag_bits: u64, sign_bits: u64) {
263 let mag = f64::from_bits(mag_bits);
264 let sign = f64::from_bits(sign_bits);
265 let expected = mag.copysign(sign);
266 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
267 assert_eq!(expected.to_bits(), result.to_bits());
269 for mag_low_bits in 0..16 {
270 for mag_high_bits in 0..16 {
271 for sign_low_bits in 0..16 {
272 for sign_high_bits in 0..16 {
274 mag_low_bits | (mag_high_bits << (64 - 4)),
275 sign_low_bits | (sign_high_bits << (64 - 4)),
283 fn same<F: PrimFloat>(a: F, b: F) -> bool {
284 if a.is_finite() && b.is_finite() {
285 a.to_bits() == b.to_bits()
287 a == b || (a.is_nan() && b.is_nan())
293 not(feature = "f16"),
294 should_panic(expected = "f16 feature is not enabled")
296 fn test_trunc_f16() {
297 for bits in 0..=u16::MAX {
298 let v = F16::from_bits(bits);
299 let expected = v.trunc();
300 let result = trunc(Scalar, Value(v)).0;
302 same(expected, result),
303 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
307 expected_bits=expected.to_bits(),
309 result_bits=result.to_bits(),
315 fn test_trunc_f32() {
316 for bits in (0..=u32::MAX).step_by(0x10000) {
317 let v = f32::from_bits(bits);
318 let expected = v.trunc();
319 let result = trunc(Scalar, Value(v)).0;
321 same(expected, result),
322 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
326 expected_bits=expected.to_bits(),
328 result_bits=result.to_bits(),
334 fn test_trunc_f64() {
335 for bits in (0..=u64::MAX).step_by(1 << 48) {
336 let v = f64::from_bits(bits);
337 let expected = v.trunc();
338 let result = trunc(Scalar, Value(v)).0;
340 same(expected, result),
341 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
345 expected_bits=expected.to_bits(),
347 result_bits=result.to_bits(),
352 fn reference_round_to_nearest_ties_to_even<
353 F: PrimFloat<BitsType = U, SignedBitsType = S>,
355 S: PrimSInt + ConvertFrom<F>,
359 if v.abs() < F::cvt_from(S::MAX) {
360 let int_value: S = v.to();
361 let int_value_f: F = int_value.to();
362 let remainder: F = v - int_value_f;
363 if remainder.abs() < 0.5.to()
364 || (int_value % 2.to() == 0.to() && remainder.abs() == 0.5.to())
366 int_value_f.copy_sign(v)
367 } else if remainder < 0.0.to() {
368 int_value_f - 1.0.to()
370 int_value_f + 1.0.to()
378 fn test_reference_round_to_nearest_ties_to_even() {
380 fn case(v: f32, expected: f32) {
381 let result = reference_round_to_nearest_ties_to_even(v);
383 same(result, expected),
384 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
388 expected_bits=expected.to_bits(),
390 result_bits=result.to_bits(),
417 case(f32::INFINITY, f32::INFINITY);
418 case(-f32::INFINITY, -f32::INFINITY);
419 case(f32::NAN, f32::NAN);
422 let i32_max = i32::MAX as f32;
423 let i32_max_prev = f32::from_bits(i32_max.to_bits() - 1);
424 let i32_max_next = f32::from_bits(i32_max.to_bits() + 1);
425 case(i32_max, i32_max);
426 case(-i32_max, -i32_max);
427 case(i32_max_prev, i32_max_prev);
428 case(-i32_max_prev, -i32_max_prev);
429 case(i32_max_next, i32_max_next);
430 case(-i32_max_next, -i32_max_next);
435 not(feature = "f16"),
436 should_panic(expected = "f16 feature is not enabled")
438 fn test_round_to_nearest_ties_to_even_f16() {
439 for bits in 0..=u16::MAX {
440 let v = F16::from_bits(bits);
441 let expected = reference_round_to_nearest_ties_to_even(v);
442 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
444 same(result, expected),
445 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
449 expected_bits=expected.to_bits(),
451 result_bits=result.to_bits(),
457 fn test_round_to_nearest_ties_to_even_f32() {
458 for bits in (0..=u32::MAX).step_by(0x10000) {
459 let v = f32::from_bits(bits);
460 let expected = reference_round_to_nearest_ties_to_even(v);
461 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
463 same(result, expected),
464 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
468 expected_bits=expected.to_bits(),
470 result_bits=result.to_bits(),
476 fn test_round_to_nearest_ties_to_even_f64() {
477 for bits in (0..=u64::MAX).step_by(1 << 48) {
478 let v = f64::from_bits(bits);
479 let expected = reference_round_to_nearest_ties_to_even(v);
480 let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0;
482 same(result, expected),
483 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
487 expected_bits=expected.to_bits(),
489 result_bits=result.to_bits(),
496 not(feature = "f16"),
497 should_panic(expected = "f16 feature is not enabled")
499 fn test_floor_f16() {
500 for bits in 0..=u16::MAX {
501 let v = F16::from_bits(bits);
502 let expected = v.floor();
503 let result = floor(Scalar, Value(v)).0;
505 same(expected, result),
506 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
510 expected_bits=expected.to_bits(),
512 result_bits=result.to_bits(),
518 fn test_floor_f32() {
519 for bits in (0..=u32::MAX).step_by(0x10000) {
520 let v = f32::from_bits(bits);
521 let expected = v.floor();
522 let result = floor(Scalar, Value(v)).0;
524 same(expected, result),
525 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
529 expected_bits=expected.to_bits(),
531 result_bits=result.to_bits(),
537 fn test_floor_f64() {
538 for bits in (0..=u64::MAX).step_by(1 << 48) {
539 let v = f64::from_bits(bits);
540 let expected = v.floor();
541 let result = floor(Scalar, Value(v)).0;
543 same(expected, result),
544 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
548 expected_bits=expected.to_bits(),
550 result_bits=result.to_bits(),
557 not(feature = "f16"),
558 should_panic(expected = "f16 feature is not enabled")
561 for bits in 0..=u16::MAX {
562 let v = F16::from_bits(bits);
563 let expected = v.ceil();
564 let result = ceil(Scalar, Value(v)).0;
566 same(expected, result),
567 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
571 expected_bits=expected.to_bits(),
573 result_bits=result.to_bits(),
580 for bits in (0..=u32::MAX).step_by(0x10000) {
581 let v = f32::from_bits(bits);
582 let expected = v.ceil();
583 let result = ceil(Scalar, Value(v)).0;
585 same(expected, result),
586 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
590 expected_bits=expected.to_bits(),
592 result_bits=result.to_bits(),
599 for bits in (0..=u64::MAX).step_by(1 << 48) {
600 let v = f64::from_bits(bits);
601 let expected = v.ceil();
602 let result = ceil(Scalar, Value(v)).0;
604 same(expected, result),
605 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
609 expected_bits=expected.to_bits(),
611 result_bits=result.to_bits(),
616 // TODO: add tests for `sqrt_rsqrt_kernel`