X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Falgorithms%2Fbase.rs;h=4ebd8493ee3d7dc12e0753e469db896253e6f14f;hb=8a170330691c442c16cf6a7c6606fc19493e9e81;hp=28a641df564d73f46dcd84f20e07374353bea427;hpb=aa592610b0802e56c6b5ea2ea70c3d742f36e797;p=vector-math.git diff --git a/src/algorithms/base.rs b/src/algorithms/base.rs index 28a641d..4ebd849 100644 --- a/src/algorithms/base.rs +++ b/src/algorithms/base.rs @@ -66,14 +66,54 @@ pub fn round_to_nearest_ties_to_even< ) -> VecF { let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to()); let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big - let small = v.abs().le(ctx.make(PrimF::cvt_from(0.5))); - let out_of_range = big | small; - let small_value = ctx.make::(0.to()).copy_sign(v); - let out_of_range_value = small.select(small_value, v); let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to()); let offset_value: VecF = v.abs() + offset; let in_range_value = (offset_value - offset).copy_sign(v); - out_of_range.select(out_of_range_value, in_range_value) + big.select(v, in_range_value) +} + +pub fn floor< + Ctx: Context, + VecF: Float + Make, + VecU: UInt + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, +>( + ctx: Ctx, + v: VecF, +) -> VecF { + let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to()); + let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big + let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to()); + let offset_value: VecF = v.abs() + offset; + let rounded = (offset_value - offset).copy_sign(v); + let need_round_down = v.lt(rounded); + let in_range_value = need_round_down + .select(rounded - ctx.make(1.to()), rounded) + .copy_sign(v); + big.select(v, in_range_value) +} + +pub fn ceil< + Ctx: Context, + VecF: Float + Make, + VecU: UInt + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, +>( + ctx: Ctx, + v: VecF, +) -> VecF { + let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to()); + let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big + let offset = ctx.make((PrimU::cvt_from(1) << PrimF::MANTISSA_FIELD_WIDTH).to()); + let offset_value: VecF = v.abs() + offset; + let rounded = (offset_value - offset).copy_sign(v); + let need_round_up = v.gt(rounded); + let in_range_value = need_round_up + .select(rounded + ctx.make(1.to()), rounded) + .copy_sign(v); + big.select(v, in_range_value) } #[cfg(test)] @@ -198,7 +238,7 @@ mod tests { fn same(a: F, b: F) -> bool { if a.is_finite() && b.is_finite() { - a == b + a.to_bits() == b.to_bits() } else { a == b || (a.is_nan() && b.is_nan()) } @@ -295,13 +335,8 @@ mod tests { #[track_caller] fn case(v: f32, expected: f32) { let result = reference_round_to_nearest_ties_to_even(v); - let same = if expected.is_nan() { - result.is_nan() - } else { - expected.to_bits() == result.to_bits() - }; assert!( - same, + same(result, expected), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), @@ -362,7 +397,7 @@ mod tests { let expected = reference_round_to_nearest_ties_to_even(v); let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; assert!( - same(expected, result), + same(result, expected), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), @@ -381,7 +416,7 @@ mod tests { let expected = reference_round_to_nearest_ties_to_even(v); let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; assert!( - same(expected, result), + same(result, expected), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", v=v, v_bits=v.to_bits(), @@ -399,6 +434,128 @@ mod tests { let v = f64::from_bits(bits); let expected = reference_round_to_nearest_ties_to_even(v); let result = round_to_nearest_ties_to_even(Scalar, Value(v)).0; + assert!( + same(result, expected), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_floor_f16() { + for bits in 0..=u16::MAX { + let v = F16::from_bits(bits); + let expected = v.floor(); + let result = floor(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_floor_f32() { + for bits in (0..=u32::MAX).step_by(0x10000) { + let v = f32::from_bits(bits); + let expected = v.floor(); + let result = floor(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_floor_f64() { + for bits in (0..=u64::MAX).step_by(1 << 48) { + let v = f64::from_bits(bits); + let expected = v.floor(); + let result = floor(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_ceil_f16() { + for bits in 0..=u16::MAX { + let v = F16::from_bits(bits); + let expected = v.ceil(); + let result = ceil(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_ceil_f32() { + for bits in (0..=u32::MAX).step_by(0x10000) { + let v = f32::from_bits(bits); + let expected = v.ceil(); + let result = ceil(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_ceil_f64() { + for bits in (0..=u64::MAX).step_by(1 << 48) { + let v = f64::from_bits(bits); + let expected = v.ceil(); + let result = ceil(Scalar, Value(v)).0; assert!( same(expected, result), "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",