-use crate::traits::{Context, Float};
+use crate::{
+ prim::{PrimFloat, PrimUInt},
+ traits::{Context, Float, Make},
+};
-pub fn abs_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 {
- Ctx::VecF16::from_bits(x.to_bits() & ctx.make(0x7FFFu16))
+pub fn abs<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ x: VecF,
+) -> VecF {
+ VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
}
-pub fn abs_f32<Ctx: Context>(ctx: Ctx, x: Ctx::VecF32) -> Ctx::VecF32 {
- Ctx::VecF32::from_bits(x.to_bits() & ctx.make(!(1u32 << 31)))
-}
-
-pub fn abs_f64<Ctx: Context>(ctx: Ctx, x: Ctx::VecF64) -> Ctx::VecF64 {
- Ctx::VecF64::from_bits(x.to_bits() & ctx.make(!(1u64 << 63)))
+pub fn copy_sign<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ mag: VecF,
+ sign: VecF,
+) -> VecF {
+ let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
+ let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
+ VecF::from_bits(mag_bits | sign_bit)
}
#[cfg(test)]
for bits in 0..=u16::MAX {
let v = F16::from_bits(bits);
let expected = v.abs();
- let result = abs_f16(Scalar, Value(v)).0;
+ let result = abs(Scalar, Value(v)).0;
assert_eq!(expected.to_bits(), result.to_bits());
}
}
for bits in (0..=u32::MAX).step_by(10001) {
let v = f32::from_bits(bits);
let expected = v.abs();
- let result = abs_f32(Scalar, Value(v)).0;
+ let result = abs(Scalar, Value(v)).0;
assert_eq!(expected.to_bits(), result.to_bits());
}
}
for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
let v = f64::from_bits(bits);
let expected = v.abs();
- let result = abs_f64(Scalar, Value(v)).0;
+ let result = abs(Scalar, Value(v)).0;
+ assert_eq!(expected.to_bits(), result.to_bits());
+ }
+ }
+
+ #[test]
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_copy_sign_f16() {
+ #[track_caller]
+ fn check(mag_bits: u16, sign_bits: u16) {
+ let mag = F16::from_bits(mag_bits);
+ let sign = F16::from_bits(sign_bits);
+ let expected = mag.copysign(sign);
+ let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
+ assert_eq!(expected.to_bits(), result.to_bits());
+ }
+ for mag_low_bits in 0..16 {
+ for mag_high_bits in 0..16 {
+ for sign_low_bits in 0..16 {
+ for sign_high_bits in 0..16 {
+ check(
+ mag_low_bits | (mag_high_bits << (16 - 4)),
+ sign_low_bits | (sign_high_bits << (16 - 4)),
+ );
+ }
+ }
+ }
+ }
+ }
+
+ #[test]
+ fn test_copy_sign_f32() {
+ #[track_caller]
+ fn check(mag_bits: u32, sign_bits: u32) {
+ let mag = f32::from_bits(mag_bits);
+ let sign = f32::from_bits(sign_bits);
+ let expected = mag.copysign(sign);
+ let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
+ assert_eq!(expected.to_bits(), result.to_bits());
+ }
+ for mag_low_bits in 0..16 {
+ for mag_high_bits in 0..16 {
+ for sign_low_bits in 0..16 {
+ for sign_high_bits in 0..16 {
+ check(
+ mag_low_bits | (mag_high_bits << (32 - 4)),
+ sign_low_bits | (sign_high_bits << (32 - 4)),
+ );
+ }
+ }
+ }
+ }
+ }
+
+ #[test]
+ fn test_copy_sign_f64() {
+ #[track_caller]
+ fn check(mag_bits: u64, sign_bits: u64) {
+ let mag = f64::from_bits(mag_bits);
+ let sign = f64::from_bits(sign_bits);
+ let expected = mag.copysign(sign);
+ let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
assert_eq!(expected.to_bits(), result.to_bits());
}
+ for mag_low_bits in 0..16 {
+ for mag_high_bits in 0..16 {
+ for sign_low_bits in 0..16 {
+ for sign_high_bits in 0..16 {
+ check(
+ mag_low_bits | (mag_high_bits << (64 - 4)),
+ sign_low_bits | (sign_high_bits << (64 - 4)),
+ );
+ }
+ }
+ }
+ }
}
}