From: Jacob Lifshay Date: Thu, 13 May 2021 04:49:52 +0000 (-0700) Subject: add trunc implementation X-Git-Url: https://git.libre-soc.org/?p=vector-math.git;a=commitdiff_plain;h=ebb72b1eb96f918ebc34fb74221562c4cf65365b add trunc implementation --- diff --git a/src/algorithms/base.rs b/src/algorithms/base.rs index 0b6dcb6..d387340 100644 --- a/src/algorithms/base.rs +++ b/src/algorithms/base.rs @@ -1,6 +1,6 @@ use crate::{ prim::{PrimFloat, PrimUInt}, - traits::{Context, Float, Make}, + traits::{Context, ConvertTo, Float, Make, Select, UInt}, }; pub fn abs< @@ -30,6 +30,30 @@ pub fn copy_sign< VecF::from_bits(mag_bits | sign_bit) } +pub fn trunc< + Ctx: Context, + VecF: Float + Make, + VecU: UInt + Make, + PrimF: PrimFloat, + PrimU: PrimUInt, +>( + ctx: Ctx, + v: VecF, +) -> VecF { + let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to()); + let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big + let small = v.abs().lt(ctx.make(PrimF::cvt_from(1))); + let out_of_range = big | small; + let small_value = ctx.make::(0.to()).copy_sign(v); + let out_of_range_value = small.select(small_value, v); + let exponent_field = v.extract_exponent_field(); + let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED); + let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK); + mask >>= right_shift_amount; + let in_range_value = VecF::from_bits(v.to_bits() & !mask); + out_of_range.select(out_of_range_value, in_range_value) +} + #[cfg(test)] mod tests { use super::*; @@ -147,4 +171,73 @@ mod tests { } } } + + fn same(a: F, b: F) -> bool { + if a.is_finite() && b.is_finite() { + a == b + } else { + a == b || (a.is_nan() && b.is_nan()) + } + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_trunc_f16() { + for bits in 0..=u16::MAX { + let v = F16::from_bits(bits); + let expected = v.trunc(); + let result = trunc(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_trunc_f32() { + for bits in (0..=u32::MAX).step_by(0x10000) { + let v = f32::from_bits(bits); + let expected = v.trunc(); + let result = trunc(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } + + #[test] + fn test_trunc_f64() { + for bits in (0..=u64::MAX).step_by(1 << 48) { + let v = f64::from_bits(bits); + let expected = v.trunc(); + let result = trunc(Scalar, Value(v)).0; + assert!( + same(expected, result), + "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}", + v=v, + v_bits=v.to_bits(), + expected=expected, + expected_bits=expected.to_bits(), + result=result, + result_bits=result.to_bits(), + ); + } + } } diff --git a/src/algorithms/trig_pi.rs b/src/algorithms/trig_pi.rs index 1dca80a..e776378 100644 --- a/src/algorithms/trig_pi.rs +++ b/src/algorithms/trig_pi.rs @@ -158,8 +158,7 @@ pub fn sin_cos_pi_impl< ) -> (VecF, VecF) { let two_f: VecF = ctx.make(2.0.to()); let one_half: VecF = ctx.make(0.5.to()); - let max_contiguous_integer: VecF = - ctx.make((PrimU::cvt_from(1) << (PrimF::MANTISSA_FIELD_WIDTH + 1.to())).to()); + let max_contiguous_integer: VecF = ctx.make(PrimF::max_contiguous_integer()); // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range let is_finite = x.is_finite(); diff --git a/src/f16.rs b/src/f16.rs index 280d00d..4609c58 100644 --- a/src/f16.rs +++ b/src/f16.rs @@ -1,4 +1,5 @@ use crate::{ + prim::PrimFloat, scalar::Value, traits::{ConvertFrom, ConvertTo, Float}, }; @@ -211,10 +212,7 @@ impl F16 { ) } pub fn trunc(self) -> Self { - #[cfg(feature = "std")] - return f32::from(self).trunc().to(); - #[cfg(not(feature = "std"))] - todo!(); + return PrimFloat::trunc(f32::from(self)).to(); } pub fn ceil(self) -> Self { #[cfg(feature = "std")] diff --git a/src/prim.rs b/src/prim.rs index 19f4270..7f9fa30 100644 --- a/src/prim.rs +++ b/src/prim.rs @@ -1,5 +1,6 @@ use crate::{ f16::F16, + scalar::{Scalar, Value}, traits::{ConvertFrom, ConvertTo}, }; use core::{fmt, hash, ops}; @@ -140,6 +141,11 @@ pub trait PrimFloat: fn from_bits(bits: Self::BitsType) -> Self; fn to_bits(self) -> Self::BitsType; fn abs(self) -> Self; + fn max_contiguous_integer() -> Self { + (Self::BitsType::cvt_from(1) << (Self::MANTISSA_FIELD_WIDTH + 1.to())).to() + } + fn is_finite(self) -> bool; + fn trunc(self) -> Self; } macro_rules! impl_float { @@ -190,7 +196,16 @@ macro_rules! impl_float { #[cfg(feature = "std")] return $float::abs(self); #[cfg(not(feature = "std"))] - todo!(); + return crate::algorithms::base::abs(Scalar, Value(self)).0; + } + fn is_finite(self) -> bool { + $float::is_finite(self) + } + fn trunc(self) -> Self { + #[cfg(feature = "std")] + return $float::trunc(self); + #[cfg(not(feature = "std"))] + return crate::algorithms::base::trunc(Scalar, Value(self)).0; } } }; diff --git a/src/scalar.rs b/src/scalar.rs index 4e50095..c1a1ec9 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -352,11 +352,17 @@ macro_rules! impl_float { #[cfg(not(feature = "std"))] return crate::algorithms::base::abs(Scalar, self); } + fn copy_sign(self, sign: Self) -> Self { + #[cfg(feature = "std")] + return Value(self.0.copysign(sign.0)); + #[cfg(not(feature = "std"))] + return crate::algorithms::base::copy_sign(Scalar, self, sign); + } fn trunc(self) -> Self { #[cfg(feature = "std")] return Value(self.0.trunc()); #[cfg(not(feature = "std"))] - todo!(); + return crate::algorithms::base::trunc(Scalar, self); } fn ceil(self) -> Self { #[cfg(feature = "std")] diff --git a/src/traits.rs b/src/traits.rs index 2ec4815..c02b609 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -172,6 +172,9 @@ pub trait Float: + Compare + ConvertFrom; fn abs(self) -> Self; + fn copy_sign(self, sign: Self) -> Self { + crate::algorithms::base::copy_sign(self.ctx(), self, sign) + } fn trunc(self) -> Self; fn ceil(self) -> Self; fn floor(self) -> Self; @@ -218,6 +221,33 @@ pub trait Float: let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK); self.to_bits() & mask } + fn is_sign_negative(self) -> Self::Bool { + let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK); + self.ctx() + .make::(0.to()) + .ne(self.to_bits() & mask) + } + fn is_sign_positive(self) -> Self::Bool { + let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK); + self.ctx() + .make::(0.to()) + .eq(self.to_bits() & mask) + } + fn extract_sign_field(self) -> Self::BitsType { + let shift = self.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT); + self.to_bits() >> shift + } + fn from_fields( + sign_field: Self::BitsType, + exponent_field: Self::BitsType, + mantissa_field: Self::BitsType, + ) -> Self { + let sign_shift = sign_field.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT); + let exponent_shift = sign_field.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT); + Self::from_bits( + (sign_field << sign_shift) | (exponent_field << exponent_shift) | mantissa_field, + ) + } fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType { Self::SignedBitsType::cvt_from(exponent_field) - exponent_field @@ -229,14 +259,14 @@ pub trait Float: } } -pub trait Bool: Make + BitOps {} +pub trait Bool: Make + BitOps + Select {} -pub trait Select: Bool { +pub trait Select { fn select(self, true_v: T, false_v: T) -> T; } pub trait Compare: Make { - type Bool: Bool + Select; + type Bool: Bool + Select + Make; fn eq(self, rhs: Self) -> Self::Bool; fn ne(self, rhs: Self) -> Self::Bool; fn lt(self, rhs: Self) -> Self::Bool;