add trunc implementation
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 13 May 2021 04:49:52 +0000 (21:49 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 13 May 2021 04:49:52 +0000 (21:49 -0700)
src/algorithms/base.rs
src/algorithms/trig_pi.rs
src/f16.rs
src/prim.rs
src/scalar.rs
src/traits.rs

index 0b6dcb60993100b966d73f308ceb9084826cbfd1..d38734091a90b41ebabcb6d24c4ebbe4e49212d6 100644 (file)
@@ -1,6 +1,6 @@
 use crate::{
     prim::{PrimFloat, PrimUInt},
-    traits::{Context, Float, Make},
+    traits::{Context, ConvertTo, Float, Make, Select, UInt},
 };
 
 pub fn abs<
@@ -30,6 +30,30 @@ pub fn copy_sign<
     VecF::from_bits(mag_bits | sign_bit)
 }
 
+pub fn trunc<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
+    let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
+    let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
+    let out_of_range = big | small;
+    let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
+    let out_of_range_value = small.select(small_value, v);
+    let exponent_field = v.extract_exponent_field();
+    let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
+    let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
+    mask >>= right_shift_amount;
+    let in_range_value = VecF::from_bits(v.to_bits() & !mask);
+    out_of_range.select(out_of_range_value, in_range_value)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -147,4 +171,73 @@ mod tests {
             }
         }
     }
+
+    fn same<F: PrimFloat>(a: F, b: F) -> bool {
+        if a.is_finite() && b.is_finite() {
+            a == b
+        } else {
+            a == b || (a.is_nan() && b.is_nan())
+        }
+    }
+
+    #[test]
+    #[cfg_attr(
+        not(feature = "f16"),
+        should_panic(expected = "f16 feature is not enabled")
+    )]
+    fn test_trunc_f16() {
+        for bits in 0..=u16::MAX {
+            let v = F16::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_trunc_f32() {
+        for bits in (0..=u32::MAX).step_by(0x10000) {
+            let v = f32::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
+
+    #[test]
+    fn test_trunc_f64() {
+        for bits in (0..=u64::MAX).step_by(1 << 48) {
+            let v = f64::from_bits(bits);
+            let expected = v.trunc();
+            let result = trunc(Scalar, Value(v)).0;
+            assert!(
+                same(expected, result),
+                "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
+                v=v,
+                v_bits=v.to_bits(),
+                expected=expected,
+                expected_bits=expected.to_bits(),
+                result=result,
+                result_bits=result.to_bits(),
+            );
+        }
+    }
 }
index 1dca80a33522bc333f85e89909b96e37b9577ab3..e7763787910685fda9a765093a6d1ef4764dca1e 100644 (file)
@@ -158,8 +158,7 @@ pub fn sin_cos_pi_impl<
 ) -> (VecF, VecF) {
     let two_f: VecF = ctx.make(2.0.to());
     let one_half: VecF = ctx.make(0.5.to());
-    let max_contiguous_integer: VecF =
-        ctx.make((PrimU::cvt_from(1) << (PrimF::MANTISSA_FIELD_WIDTH + 1.to())).to());
+    let max_contiguous_integer: VecF = ctx.make(PrimF::max_contiguous_integer());
     // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer
     let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range
     let is_finite = x.is_finite();
index 280d00d4231f751192b12a5d715cd26734318fa6..4609c58264a7979bd4aa01f4b76501de553c874b 100644 (file)
@@ -1,4 +1,5 @@
 use crate::{
+    prim::PrimFloat,
     scalar::Value,
     traits::{ConvertFrom, ConvertTo, Float},
 };
@@ -211,10 +212,7 @@ impl F16 {
         )
     }
     pub fn trunc(self) -> Self {
-        #[cfg(feature = "std")]
-        return f32::from(self).trunc().to();
-        #[cfg(not(feature = "std"))]
-        todo!();
+        return PrimFloat::trunc(f32::from(self)).to();
     }
     pub fn ceil(self) -> Self {
         #[cfg(feature = "std")]
index 19f4270b2185d9099d3741e23be9f8b88a8326f5..7f9fa30fd163464a29185c494292cb8d5787611e 100644 (file)
@@ -1,5 +1,6 @@
 use crate::{
     f16::F16,
+    scalar::{Scalar, Value},
     traits::{ConvertFrom, ConvertTo},
 };
 use core::{fmt, hash, ops};
@@ -140,6 +141,11 @@ pub trait PrimFloat:
     fn from_bits(bits: Self::BitsType) -> Self;
     fn to_bits(self) -> Self::BitsType;
     fn abs(self) -> Self;
+    fn max_contiguous_integer() -> Self {
+        (Self::BitsType::cvt_from(1) << (Self::MANTISSA_FIELD_WIDTH + 1.to())).to()
+    }
+    fn is_finite(self) -> bool;
+    fn trunc(self) -> Self;
 }
 
 macro_rules! impl_float {
@@ -190,7 +196,16 @@ macro_rules! impl_float {
                 #[cfg(feature = "std")]
                 return $float::abs(self);
                 #[cfg(not(feature = "std"))]
-                todo!();
+                return crate::algorithms::base::abs(Scalar, Value(self)).0;
+            }
+            fn is_finite(self) -> bool {
+                $float::is_finite(self)
+            }
+            fn trunc(self) -> Self {
+                #[cfg(feature = "std")]
+                return $float::trunc(self);
+                #[cfg(not(feature = "std"))]
+                return crate::algorithms::base::trunc(Scalar, Value(self)).0;
             }
         }
     };
index 4e5009598aff43900ae158d35d4918c8e0d87140..c1a1ec94c2b59608d8bd40476670185aa5734434 100644 (file)
@@ -352,11 +352,17 @@ macro_rules! impl_float {
                 #[cfg(not(feature = "std"))]
                 return crate::algorithms::base::abs(Scalar, self);
             }
+            fn copy_sign(self, sign: Self) -> Self {
+                #[cfg(feature = "std")]
+                return Value(self.0.copysign(sign.0));
+                #[cfg(not(feature = "std"))]
+                return crate::algorithms::base::copy_sign(Scalar, self, sign);
+            }
             fn trunc(self) -> Self {
                 #[cfg(feature = "std")]
                 return Value(self.0.trunc());
                 #[cfg(not(feature = "std"))]
-                todo!();
+                return crate::algorithms::base::trunc(Scalar, self);
             }
             fn ceil(self) -> Self {
                 #[cfg(feature = "std")]
index 2ec48157c3b691883936c6d2e485e8b042c3e1b7..c02b60919cddc332266ce23b36318f2a45ec9e76 100644 (file)
@@ -172,6 +172,9 @@ pub trait Float:
         + Compare<Bool = Self::Bool>
         + ConvertFrom<Self>;
     fn abs(self) -> Self;
+    fn copy_sign(self, sign: Self) -> Self {
+        crate::algorithms::base::copy_sign(self.ctx(), self, sign)
+    }
     fn trunc(self) -> Self;
     fn ceil(self) -> Self;
     fn floor(self) -> Self;
@@ -218,6 +221,33 @@ pub trait Float:
         let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK);
         self.to_bits() & mask
     }
+    fn is_sign_negative(self) -> Self::Bool {
+        let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK);
+        self.ctx()
+            .make::<Self::BitsType>(0.to())
+            .ne(self.to_bits() & mask)
+    }
+    fn is_sign_positive(self) -> Self::Bool {
+        let mask = self.ctx().make(Self::PrimFloat::SIGN_FIELD_MASK);
+        self.ctx()
+            .make::<Self::BitsType>(0.to())
+            .eq(self.to_bits() & mask)
+    }
+    fn extract_sign_field(self) -> Self::BitsType {
+        let shift = self.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT);
+        self.to_bits() >> shift
+    }
+    fn from_fields(
+        sign_field: Self::BitsType,
+        exponent_field: Self::BitsType,
+        mantissa_field: Self::BitsType,
+    ) -> Self {
+        let sign_shift = sign_field.ctx().make(Self::PrimFloat::SIGN_FIELD_SHIFT);
+        let exponent_shift = sign_field.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT);
+        Self::from_bits(
+            (sign_field << sign_shift) | (exponent_field << exponent_shift) | mantissa_field,
+        )
+    }
     fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType {
         Self::SignedBitsType::cvt_from(exponent_field)
             - exponent_field
@@ -229,14 +259,14 @@ pub trait Float:
     }
 }
 
-pub trait Bool: Make + BitOps {}
+pub trait Bool: Make<Prim = bool> + BitOps + Select<Self> {}
 
-pub trait Select<T>: Bool {
+pub trait Select<T> {
     fn select(self, true_v: T, false_v: T) -> T;
 }
 
 pub trait Compare: Make {
-    type Bool: Bool + Select<Self>;
+    type Bool: Bool + Select<Self> + Make<Context = Self::Context>;
     fn eq(self, rhs: Self) -> Self::Bool;
     fn ne(self, rhs: Self) -> Self::Bool;
     fn lt(self, rhs: Self) -> Self::Bool;