add ilogb
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 4 May 2021 04:27:18 +0000 (21:27 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 4 May 2021 04:27:18 +0000 (21:27 -0700)
src/algorithms.rs [new file with mode: 0644]
src/algorithms/ilogb.rs [new file with mode: 0644]
src/f16.rs
src/ieee754.rs [new file with mode: 0644]
src/ir.rs
src/lib.rs
src/scalar.rs
src/traits.rs

diff --git a/src/algorithms.rs b/src/algorithms.rs
new file mode 100644 (file)
index 0000000..61191a6
--- /dev/null
@@ -0,0 +1 @@
+pub mod ilogb;
diff --git a/src/algorithms/ilogb.rs b/src/algorithms/ilogb.rs
new file mode 100644 (file)
index 0000000..2956c8f
--- /dev/null
@@ -0,0 +1,122 @@
+use crate::{
+    f16::F16,
+    ieee754::FloatEncoding,
+    traits::{Compare, Context, ConvertTo, Float, Select},
+};
+
+pub const DEFAULT_NAN_RESULT: i16 = i16::MIN + 1;
+pub const DEFAULT_OVERFLOW_RESULT: i16 = i16::MAX;
+pub const DEFAULT_UNDERFLOW_RESULT: i16 = i16::MIN;
+
+macro_rules! impl_ilogb {
+    (
+        #[prim = $prim:ident]
+        #[prim_signed_bits = $prim_signed_bits:ident]
+        #[ilogb = $ilogb:ident]
+        #[nan = $NAN_RESULT:ident]
+        #[overflow = $OVERFLOW_RESULT:ident]
+        #[underflow = $UNDERFLOW_RESULT:ident]
+        fn $ilogb_extended:ident($vector_float:ident, $scalar_signed_bits:ident) -> $vector_signed_bits:ident;
+    ) => {
+        pub const $NAN_RESULT: $prim_signed_bits = $prim_signed_bits::MIN + 1;
+        pub const $OVERFLOW_RESULT: $prim_signed_bits = $prim_signed_bits::MAX;
+        pub const $UNDERFLOW_RESULT: $prim_signed_bits = $prim_signed_bits::MIN;
+
+        pub fn $ilogb_extended<Ctx: Context>(
+            ctx: Ctx,
+            arg: Ctx::$vector_float,
+            nan_result: Ctx::$scalar_signed_bits,
+            overflow_result: Ctx::$scalar_signed_bits,
+            underflow_result: Ctx::$scalar_signed_bits,
+        ) -> Ctx::$vector_signed_bits {
+            let is_finite = arg.is_finite();
+            let is_zero_subnormal = arg.is_zero_or_subnormal();
+            let is_nan = arg.is_nan();
+            let inf_nan_result: Ctx::$vector_signed_bits =
+                is_nan.select(nan_result.into(), overflow_result.into());
+            let scale_factor: $prim = (1u64 << $prim::MANTISSA_FIELD_WIDTH).to();
+            let scaled = arg * ctx.make(scale_factor);
+            let scaled_exponent = scaled.extract_exponent_unbiased();
+            let exponent = arg.extract_exponent_unbiased();
+            let normal_inf_nan_result = is_finite.select(exponent, inf_nan_result);
+            let is_zero = arg.eq(ctx.make($prim::from(0u8)));
+            let zero_subnormal_result = is_zero.select(
+                underflow_result.into(),
+                scaled_exponent - ctx.make($prim::MANTISSA_FIELD_WIDTH.to()),
+            );
+            is_zero_subnormal.select(zero_subnormal_result, normal_inf_nan_result)
+        }
+
+        pub fn $ilogb<Ctx: Context>(ctx: Ctx, arg: Ctx::$vector_float) -> Ctx::$vector_signed_bits {
+            $ilogb_extended(
+                ctx,
+                arg,
+                ctx.make($NAN_RESULT),
+                ctx.make($OVERFLOW_RESULT),
+                ctx.make($UNDERFLOW_RESULT),
+            )
+        }
+    };
+}
+
+impl_ilogb! {
+    #[prim = F16]
+    #[prim_signed_bits = i16]
+    #[ilogb = ilogb_f16]
+    #[nan = ILOGB_NAN_RESULT_F16]
+    #[overflow = ILOGB_OVERFLOW_RESULT_F16]
+    #[underflow = ILOGB_UNDERFLOW_RESULT_F16]
+    fn ilogb_f16_extended(VecF16, I16) -> VecI16;
+}
+
+impl_ilogb! {
+    #[prim = f32]
+    #[prim_signed_bits = i32]
+    #[ilogb = ilogb_f32]
+    #[nan = ILOGB_NAN_RESULT_F32]
+    #[overflow = ILOGB_OVERFLOW_RESULT_F32]
+    #[underflow = ILOGB_UNDERFLOW_RESULT_F32]
+    fn ilogb_f32_extended(VecF32, I32) -> VecI32;
+}
+
+impl_ilogb! {
+    #[prim = f64]
+    #[prim_signed_bits = i64]
+    #[ilogb = ilogb_f64]
+    #[nan = ILOGB_NAN_RESULT_F64]
+    #[overflow = ILOGB_OVERFLOW_RESULT_F64]
+    #[underflow = ILOGB_UNDERFLOW_RESULT_F64]
+    fn ilogb_f64_extended(VecF64, I64) -> VecI64;
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::scalar::Scalar;
+
+    #[test]
+    fn test_ilogb_f32() {
+        assert_eq!(ilogb_f32(Scalar, 0f32), ILOGB_UNDERFLOW_RESULT_F32);
+        assert_eq!(ilogb_f32(Scalar, 1f32), 0);
+        assert_eq!(ilogb_f32(Scalar, 2f32), 1);
+        assert_eq!(ilogb_f32(Scalar, 3f32), 1);
+        assert_eq!(ilogb_f32(Scalar, 3.99999f32), 1);
+        assert_eq!(ilogb_f32(Scalar, 0.5f32), -1);
+        assert_eq!(ilogb_f32(Scalar, 0.5f32.powi(130)), -130);
+        assert_eq!(ilogb_f32(Scalar, f32::INFINITY), ILOGB_OVERFLOW_RESULT_F32);
+        assert_eq!(ilogb_f32(Scalar, f32::NAN), ILOGB_NAN_RESULT_F32);
+    }
+
+    #[test]
+    fn test_ilogb_f64() {
+        assert_eq!(ilogb_f64(Scalar, 0f64), ILOGB_UNDERFLOW_RESULT_F64);
+        assert_eq!(ilogb_f64(Scalar, 1f64), 0);
+        assert_eq!(ilogb_f64(Scalar, 2f64), 1);
+        assert_eq!(ilogb_f64(Scalar, 3f64), 1);
+        assert_eq!(ilogb_f64(Scalar, 3.99999f64), 1);
+        assert_eq!(ilogb_f64(Scalar, 0.5f64), -1);
+        assert_eq!(ilogb_f64(Scalar, 0.5f64.powi(1030)), -1030);
+        assert_eq!(ilogb_f64(Scalar, f64::INFINITY), ILOGB_OVERFLOW_RESULT_F64);
+        assert_eq!(ilogb_f64(Scalar, f64::NAN), ILOGB_NAN_RESULT_F64);
+    }
+}
index 20ed902ff6e0afc542f93a869bb40707c2d6f47f..bc40c782a20f514d5658463c6f5b098d80a72d56 100644 (file)
@@ -162,7 +162,9 @@ impl_bin_op_using_f32! {
 }
 
 impl Float<u32> for F16 {
+    type FloatEncoding = F16;
     type BitsType = u16;
+    type SignedBitsType = i16;
 
     fn abs(self) -> Self {
         f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
diff --git a/src/ieee754.rs b/src/ieee754.rs
new file mode 100644 (file)
index 0000000..0da587d
--- /dev/null
@@ -0,0 +1,88 @@
+use crate::{
+    f16::F16,
+    scalar::Scalar,
+    traits::{Float, Make},
+};
+
+mod sealed {
+    use crate::f16::F16;
+
+    pub trait Sealed {}
+    impl Sealed for F16 {}
+    impl Sealed for f32 {}
+    impl Sealed for f64 {}
+}
+
+pub trait FloatEncoding:
+    sealed::Sealed + Copy + 'static + Send + Sync + Float<u32> + Make<Context = Scalar>
+{
+    const EXPONENT_BIAS_UNSIGNED: Self::BitsType;
+    const EXPONENT_BIAS_SIGNED: Self::SignedBitsType;
+    const SIGN_FIELD_WIDTH: u32;
+    const EXPONENT_FIELD_WIDTH: u32;
+    const MANTISSA_FIELD_WIDTH: u32;
+    const SIGN_FIELD_SHIFT: u32;
+    const EXPONENT_FIELD_SHIFT: u32;
+    const MANTISSA_FIELD_SHIFT: u32;
+    const SIGN_FIELD_MASK: Self::BitsType;
+    const EXPONENT_FIELD_MASK: Self::BitsType;
+    const MANTISSA_FIELD_MASK: Self::BitsType;
+    const IMPLICIT_MANTISSA_BIT: Self::BitsType;
+    const ZERO_SUBNORMAL_EXPONENT: Self::BitsType;
+    const NAN_INFINITY_EXPONENT: Self::BitsType;
+    const INFINITY_BITS: Self::BitsType;
+    const NAN_BITS: Self::BitsType;
+}
+
+macro_rules! impl_float_encoding {
+    (
+        impl FloatEncoding for $float:ident {
+            const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width:literal;
+            const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width:literal;
+        }
+    ) => {
+        impl FloatEncoding for $float {
+            const EXPONENT_BIAS_UNSIGNED: Self::BitsType =
+                (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1;
+            const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _;
+            const SIGN_FIELD_WIDTH: u32 = 1;
+            const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width;
+            const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width;
+            const SIGN_FIELD_SHIFT: u32 = Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
+            const EXPONENT_FIELD_SHIFT: u32 = Self::MANTISSA_FIELD_WIDTH;
+            const MANTISSA_FIELD_SHIFT: u32 = 0;
+            const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT;
+            const EXPONENT_FIELD_MASK: Self::BitsType =
+                ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT;
+            const MANTISSA_FIELD_MASK: Self::BitsType = (1 << Self::MANTISSA_FIELD_WIDTH) - 1;
+            const IMPLICIT_MANTISSA_BIT: Self::BitsType = 1 << Self::MANTISSA_FIELD_WIDTH;
+            const ZERO_SUBNORMAL_EXPONENT: Self::BitsType = 0;
+            const NAN_INFINITY_EXPONENT: Self::BitsType = (1 << Self::EXPONENT_FIELD_WIDTH) - 1;
+            const INFINITY_BITS: Self::BitsType =
+                Self::NAN_INFINITY_EXPONENT << Self::EXPONENT_FIELD_SHIFT;
+            const NAN_BITS: Self::BitsType =
+                Self::INFINITY_BITS | (1 << (Self::MANTISSA_FIELD_WIDTH - 1));
+        }
+    };
+}
+
+impl_float_encoding! {
+    impl FloatEncoding for F16 {
+        const EXPONENT_FIELD_WIDTH: u32 = 5;
+        const MANTISSA_FIELD_WIDTH: u32 = 10;
+    }
+}
+
+impl_float_encoding! {
+    impl FloatEncoding for f32 {
+        const EXPONENT_FIELD_WIDTH: u32 = 8;
+        const MANTISSA_FIELD_WIDTH: u32 = 23;
+    }
+}
+
+impl_float_encoding! {
+    impl FloatEncoding for f64 {
+        const EXPONENT_FIELD_WIDTH: u32 = 11;
+        const MANTISSA_FIELD_WIDTH: u32 = 52;
+    }
+}
index 3eaccb987d6fa0f239d208470350f92046e84de1..f1de05aa7d6bad699471d3afcbd4a5bd86724ac2 100644 (file)
--- a/src/ir.rs
+++ b/src/ir.rs
@@ -229,14 +229,14 @@ impl fmt::Display for ScalarConstant {
         match self {
             ScalarConstant::Bool(false) => write!(f, "false"),
             ScalarConstant::Bool(true) => write!(f, "true"),
-            ScalarConstant::U8(v) => write!(f, "{}_u8", v),
-            ScalarConstant::U16(v) => write!(f, "{}_u16", v),
-            ScalarConstant::U32(v) => write!(f, "{}_u32", v),
-            ScalarConstant::U64(v) => write!(f, "{}_u64", v),
-            ScalarConstant::I8(v) => write!(f, "{}_i8", v),
-            ScalarConstant::I16(v) => write!(f, "{}_i16", v),
-            ScalarConstant::I32(v) => write!(f, "{}_i32", v),
-            ScalarConstant::I64(v) => write!(f, "{}_i64", v),
+            ScalarConstant::U8(v) => write!(f, "{:#X}_u8", v),
+            ScalarConstant::U16(v) => write!(f, "{:#X}_u16", v),
+            ScalarConstant::U32(v) => write!(f, "{:#X}_u32", v),
+            ScalarConstant::U64(v) => write!(f, "{:#X}_u64", v),
+            ScalarConstant::I8(v) => write!(f, "{:#X}_i8", v),
+            ScalarConstant::I16(v) => write!(f, "{:#X}_i16", v),
+            ScalarConstant::I32(v) => write!(f, "{:#X}_i32", v),
+            ScalarConstant::I64(v) => write!(f, "{:#X}_i64", v),
             ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits),
             ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits),
             ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits),
@@ -399,6 +399,9 @@ make_enum! {
         Not,
         Shl,
         Shr,
+        CountSetBits,
+        CountLeadingZeros,
+        CountTrailingZeros,
         Neg,
         Abs,
         Trunc,
@@ -685,7 +688,7 @@ impl_ir_function_maker_io!(
     in11: In11,
 );
 
-pub trait IrValue<'ctx>: Copy {
+pub trait IrValue<'ctx>: Copy + Make<Context = &'ctx IrContext<'ctx>> {
     const TYPE: Type;
     fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self;
     fn make_input<N: Borrow<str> + Into<String>>(
@@ -695,7 +698,6 @@ pub trait IrValue<'ctx>: Copy {
         let input = ctx.make_input(name, Self::TYPE);
         (Self::new(ctx, input.into()), input)
     }
-    fn ctx(self) -> &'ctx IrContext<'ctx>;
     fn value(self) -> Value<'ctx>;
 }
 
@@ -713,9 +715,6 @@ macro_rules! ir_value {
                 assert_eq!(value.ty(), Self::TYPE);
                 Self { ctx, value }
             }
-            fn ctx(self) -> &'ctx IrContext<'ctx> {
-                self.ctx
-            }
             fn value(self) -> Value<'ctx> {
                 self.value
             }
@@ -725,10 +724,13 @@ macro_rules! ir_value {
             pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type;
         }
 
-        impl<'ctx> Make<&'ctx IrContext<'ctx>> for $name<'ctx> {
+        impl<'ctx> Make for $name<'ctx> {
             type Prim = $prim;
-
-            fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self {
+            type Context = &'ctx IrContext<'ctx>;
+            fn ctx(self) -> Self::Context {
+                self.ctx
+            }
+            fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
                 let value: ScalarConstant = $make;
                 let value = value.into();
                 Self { value, ctx }
@@ -747,9 +749,6 @@ macro_rules! ir_value {
                 assert_eq!(value.ty(), Self::TYPE);
                 Self { ctx, value }
             }
-            fn ctx(self) -> &'ctx IrContext<'ctx> {
-                self.ctx
-            }
             fn value(self) -> Value<'ctx> {
                 self.value
             }
@@ -761,10 +760,13 @@ macro_rules! ir_value {
             };
         }
 
-        impl<'ctx> Make<&'ctx IrContext<'ctx>> for $vec_name<'ctx> {
+        impl<'ctx> Make for $vec_name<'ctx> {
             type Prim = $prim;
-
-            fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self {
+            type Context = &'ctx IrContext<'ctx>;
+            fn ctx(self) -> Self::Context {
+                self.ctx
+            }
+            fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
                 let element = $make;
                 Self {
                     value: VectorSplatConstant { element }.into(),
@@ -1120,6 +1122,43 @@ macro_rules! impl_neg {
     };
 }
 
+macro_rules! impl_int_trait {
+    ($ty:ident, $u32:ident) => {
+        impl<'ctx> Int<$u32<'ctx>> for $ty<'ctx> {
+            fn leading_zeros(self) -> Self {
+                let value = self
+                    .ctx
+                    .make_operation(Opcode::CountLeadingZeros, [self.value], Self::TYPE)
+                    .into();
+                Self {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+            fn trailing_zeros(self) -> Self {
+                let value = self
+                    .ctx
+                    .make_operation(Opcode::CountTrailingZeros, [self.value], Self::TYPE)
+                    .into();
+                Self {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+            fn count_ones(self) -> Self {
+                let value = self
+                    .ctx
+                    .make_operation(Opcode::CountSetBits, [self.value], Self::TYPE)
+                    .into();
+                Self {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+        }
+    };
+}
+
 macro_rules! impl_integer_ops {
     ($scalar:ident, $vec:ident) => {
         impl_bit_ops!($scalar);
@@ -1128,9 +1167,8 @@ macro_rules! impl_integer_ops {
         impl_bit_ops!($vec);
         impl_number_ops!($vec, IrVecBool);
         impl_shift_ops!($vec, IrVecU32);
-
-        impl<'ctx> Int<IrU32<'ctx>> for $scalar<'ctx> {}
-        impl<'ctx> Int<IrVecU32<'ctx>> for $vec<'ctx> {}
+        impl_int_trait!($scalar, IrU32);
+        impl_int_trait!($vec, IrVecU32);
     };
 }
 
@@ -1165,9 +1203,11 @@ impl_sint_ops!(IrI32, IrVecI32);
 impl_sint_ops!(IrI64, IrVecI64);
 
 macro_rules! impl_float {
-    ($float:ident, $bits:ident, $u32:ident) => {
+    ($float:ident, $bits:ident, $signed_bits:ident, $u32:ident) => {
         impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> {
+            type FloatEncoding = <$float<'ctx> as Make>::Prim;
             type BitsType = $bits<'ctx>;
+            type SignedBitsType = $signed_bits<'ctx>;
             fn abs(self) -> Self {
                 let value = self
                     .ctx
@@ -1285,19 +1325,19 @@ macro_rules! impl_float {
 }
 
 macro_rules! impl_float_ops {
-    ($scalar:ident, $scalar_bits:ident, $vec:ident, $vec_bits:ident) => {
+    ($scalar:ident, $scalar_bits:ident, $scalar_signed_bits:ident, $vec:ident, $vec_bits:ident, $vec_signed_bits:ident) => {
         impl_number_ops!($scalar, IrBool);
         impl_number_ops!($vec, IrVecBool);
         impl_neg!($scalar);
         impl_neg!($vec);
-        impl_float!($scalar, $scalar_bits, IrU32);
-        impl_float!($vec, $vec_bits, IrVecU32);
+        impl_float!($scalar, $scalar_bits, $scalar_signed_bits, IrU32);
+        impl_float!($vec, $vec_bits, $vec_signed_bits, IrVecU32);
     };
 }
 
-impl_float_ops!(IrF16, IrU16, IrVecF16, IrVecU16);
-impl_float_ops!(IrF32, IrU32, IrVecF32, IrVecU32);
-impl_float_ops!(IrF64, IrU64, IrVecF64, IrVecU64);
+impl_float_ops!(IrF16, IrU16, IrI16, IrVecF16, IrVecU16, IrVecI16);
+impl_float_ops!(IrF32, IrU32, IrI32, IrVecF32, IrVecU32, IrVecI32);
+impl_float_ops!(IrF64, IrU64, IrI64, IrVecF64, IrVecU64, IrVecI64);
 
 ir_value!(
     IrBool,
@@ -1540,6 +1580,8 @@ impl<'ctx> Context for &'ctx IrContext<'ctx> {
 
 #[cfg(test)]
 mod tests {
+    use crate::algorithms;
+
     use super::*;
     use std::println;
 
@@ -1568,6 +1610,52 @@ function(in<arg_0>: vec<U8>, in<arg_1>: vec<F32>) -> vec<F64> {
     op_5: vec<F64> = Cast op_4
     Return op_5
 }
+"
+        );
+    }
+
+    #[test]
+    fn test_display_ilogb_f32() {
+        let ctx = IrContext::new();
+        fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+            let f: fn(&'ctx IrContext<'ctx>, IrVecF32<'ctx>) -> IrVecI32<'ctx> =
+                algorithms::ilogb::ilogb_f32;
+            IrFunction::make(ctx, f)
+        }
+        let text = format!("\n{}", make_it(&ctx));
+        println!("{}", text);
+        assert_eq!(
+            text,
+            r"
+function(in<arg_0>: vec<F32>) -> vec<I32> {
+    op_0: vec<Bool> = IsFinite in<arg_0>
+    op_1: vec<U32> = ToBits in<arg_0>
+    op_2: vec<U32> = And op_1, splat(0x7F800000_u32)
+    op_3: vec<U32> = Shr op_2, splat(0x17_u32)
+    op_4: vec<Bool> = CompareEq op_3, splat(0x0_u32)
+    op_5: vec<Bool> = CompareNe in<arg_0>, in<arg_0>
+    op_6: vec<I32> = Splat 0x80000001_i32
+    op_7: vec<I32> = Splat 0x7FFFFFFF_i32
+    op_8: vec<I32> = Select op_5, op_6, op_7
+    op_9: vec<F32> = Mul in<arg_0>, splat(0x4B000000_f32)
+    op_10: vec<U32> = ToBits op_9
+    op_11: vec<U32> = And op_10, splat(0x7F800000_u32)
+    op_12: vec<U32> = Shr op_11, splat(0x17_u32)
+    op_13: vec<I32> = Cast op_12
+    op_14: vec<I32> = Sub op_13, splat(0x7F_i32)
+    op_15: vec<U32> = ToBits in<arg_0>
+    op_16: vec<U32> = And op_15, splat(0x7F800000_u32)
+    op_17: vec<U32> = Shr op_16, splat(0x17_u32)
+    op_18: vec<I32> = Cast op_17
+    op_19: vec<I32> = Sub op_18, splat(0x7F_i32)
+    op_20: vec<I32> = Select op_0, op_19, op_8
+    op_21: vec<Bool> = CompareEq in<arg_0>, splat(0x0_f32)
+    op_22: vec<I32> = Splat 0x80000000_i32
+    op_23: vec<I32> = Sub op_14, splat(0x17_i32)
+    op_24: vec<I32> = Select op_21, op_22, op_23
+    op_25: vec<I32> = Select op_4, op_24, op_20
+    Return op_25
+}
 "
         );
     }
index 06bfb803ff47beaff17d4ca1d73b90febcd59e59..1b3add31dc4ccf23ecdc99f184ba12cf39500ba8 100644 (file)
@@ -4,7 +4,9 @@
 #[cfg(any(feature = "std", test))]
 extern crate std;
 
+pub mod algorithms;
 pub mod f16;
+pub mod ieee754;
 #[cfg(feature = "ir")]
 pub mod ir;
 pub mod scalar;
index c6794e260c04359502dca30f3e7552bca34a6172..d1e137d0a747c8fc4fe586ba960923675ab179c0 100644 (file)
@@ -3,60 +3,84 @@ use crate::traits::{Context, Make};
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)]
 pub struct Scalar;
 
-impl Context for Scalar {
-    type Bool = bool;
+macro_rules! impl_context {
+    (
+        impl Context for Scalar {
+            $(type $name:ident = $ty:ty;)*
+            #[vec]
+            $(type $vec_name:ident = $vec_ty:ty;)*
+        }
+    ) => {
+        impl Context for Scalar {
+            $(type $name = $ty;)*
+            $(type $vec_name = $vec_ty;)*
+        }
+
+        $(
+            impl Make for $ty {
+                type Prim = $ty;
+                type Context = Scalar;
+                fn ctx(self) -> Self::Context {
+                    Scalar
+                }
+                fn make(_ctx: Self::Context, v: Self::Prim) -> Self {
+                    v
+                }
+            }
+        )*
+    };
+}
 
-    type U8 = u8;
+impl_context! {
+    impl Context for Scalar {
+        type Bool = bool;
 
-    type I8 = i8;
+        type U8 = u8;
 
-    type U16 = u16;
+        type I8 = i8;
 
-    type I16 = i16;
+        type U16 = u16;
 
-    type F16 = crate::f16::F16;
+        type I16 = i16;
 
-    type U32 = u32;
+        type F16 = crate::f16::F16;
 
-    type I32 = i32;
+        type U32 = u32;
 
-    type F32 = f32;
+        type I32 = i32;
 
-    type U64 = u64;
+        type F32 = f32;
 
-    type I64 = i64;
+        type U64 = u64;
 
-    type F64 = f64;
+        type I64 = i64;
 
-    type VecBool = bool;
+        type F64 = f64;
 
-    type VecU8 = u8;
+        #[vec]
 
-    type VecI8 = i8;
+        type VecBool = bool;
 
-    type VecU16 = u16;
+        type VecU8 = u8;
 
-    type VecI16 = i16;
+        type VecI8 = i8;
 
-    type VecF16 = crate::f16::F16;
+        type VecU16 = u16;
 
-    type VecU32 = u32;
+        type VecI16 = i16;
 
-    type VecI32 = i32;
+        type VecF16 = crate::f16::F16;
 
-    type VecF32 = f32;
+        type VecU32 = u32;
 
-    type VecU64 = u64;
+        type VecI32 = i32;
 
-    type VecI64 = i64;
+        type VecF32 = f32;
 
-    type VecF64 = f64;
-}
+        type VecU64 = u64;
 
-impl<T> Make<Scalar> for T {
-    type Prim = T;
+        type VecI64 = i64;
 
-    fn make(_ctx: Scalar, v: Self::Prim) -> Self {
-        v
+        type VecF64 = f64;
     }
 }
index 7837bfe1a7a03f255e8bfae5bb6f64885f1fb555..942c67f827734139aa610d17c9ef9400e729fdf8 100644 (file)
@@ -3,7 +3,7 @@ use core::ops::{
     Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
 };
 
-use crate::f16::F16;
+use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar};
 
 #[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
 macro_rules! make_float_type {
@@ -43,10 +43,10 @@ macro_rules! make_float_type {
             },)*
         ]
     ) => {
-        type $float: Float<Self::$u32, BitsType = Self::$uint>
+        type $float: Float<Self::$u32, BitsType = Self::$uint, SignedBitsType = Self::$int, FloatEncoding = $float_prim>
             $(+ From<Self::$float_scalar>)?
             + Compare<Bool = Self::$bool>
-            + Make<Self, Prim = $float_prim>
+            + Make<Context = Self, Prim = $float_prim>
             $(+ ConvertTo<Self::$uint_smaller>)*
             $(+ ConvertTo<Self::$int_smaller>)*
             $($(+ ConvertTo<Self::$float_smaller>)?)*
@@ -113,7 +113,7 @@ macro_rules! make_uint_int_float_type {
         type $uint: UInt<Self::$u32>
             $(+ From<Self::$uint_scalar>)?
             + Compare<Bool = Self::$bool>
-            + Make<Self, Prim = $uint_prim>
+            + Make<Context = Self, Prim = $uint_prim>
             $(+ ConvertTo<Self::$uint_smaller>)*
             $(+ ConvertTo<Self::$int_smaller>)*
             $($(+ ConvertTo<Self::$float_smaller>)?)*
@@ -125,7 +125,7 @@ macro_rules! make_uint_int_float_type {
         type $int: SInt<Self::$u32>
             $(+ From<Self::$int_scalar>)?
             + Compare<Bool = Self::$bool>
-            + Make<Self, Prim = $int_prim>
+            + Make<Context = Self, Prim = $int_prim>
             $(+ ConvertTo<Self::$uint_smaller>)*
             $(+ ConvertTo<Self::$int_smaller>)*
             $($(+ ConvertTo<Self::$float_smaller>)?)*
@@ -268,8 +268,19 @@ macro_rules! make_types {
     ) => {
         type $Bool: Bool
             $(+ From<Self::$ScalarBool>)?
-            + Make<Self, Prim = bool>
-            + Select<Self::$Bool>;
+            + Make<Context = Self, Prim = bool>
+            + Select<Self::$Bool>
+            + Select<Self::$U8>
+            + Select<Self::$U16>
+            + Select<Self::$U32>
+            + Select<Self::$U64>
+            + Select<Self::$I8>
+            + Select<Self::$I16>
+            + Select<Self::$I32>
+            + Select<Self::$I64>
+            + Select<Self::$F16>
+            + Select<Self::$F32>
+            + Select<Self::$F64>;
         make_uint_int_float_types! {
             #[u32 = $U32]
             #[bool = $Bool]
@@ -398,14 +409,16 @@ pub trait Context: Copy {
         #[scalar = F64]
         type VecF64;
     }
-    fn make<T: Make<Self>>(self, v: T::Prim) -> T {
+    fn make<T: Make<Context = Self>>(self, v: T::Prim) -> T {
         T::make(self, v)
     }
 }
 
-pub trait Make<Context>: Sized {
-    type Prim;
-    fn make(ctx: Context, v: Self::Prim) -> Self;
+pub trait Make: Copy {
+    type Prim: Copy;
+    type Context: Context;
+    fn ctx(self) -> Self::Context;
+    fn make(ctx: Self::Context, v: Self::Prim) -> Self;
 }
 
 pub trait ConvertTo<T> {
@@ -502,16 +515,53 @@ pub trait Int<ShiftRhs>:
     + ShlAssign<ShiftRhs>
     + ShrAssign<ShiftRhs>
 {
+    fn leading_zeros(self) -> Self;
+    fn leading_ones(self) -> Self {
+        self.not().leading_zeros()
+    }
+    fn trailing_zeros(self) -> Self;
+    fn trailing_ones(self) -> Self {
+        self.not().trailing_zeros()
+    }
+    fn count_zeros(self) -> Self {
+        self.not().count_ones()
+    }
+    fn count_ones(self) -> Self;
 }
 
 pub trait UInt<ShiftRhs>: Int<ShiftRhs> {}
 
 pub trait SInt<ShiftRhs>: Int<ShiftRhs> + Neg<Output = Self> {}
 
+macro_rules! impl_int {
+    ($ty:ident) => {
+        impl Int<u32> for $ty {
+            fn leading_zeros(self) -> Self {
+                self.leading_zeros() as Self
+            }
+            fn leading_ones(self) -> Self {
+                self.leading_ones() as Self
+            }
+            fn trailing_zeros(self) -> Self {
+                self.trailing_zeros() as Self
+            }
+            fn trailing_ones(self) -> Self {
+                self.trailing_ones() as Self
+            }
+            fn count_zeros(self) -> Self {
+                self.count_zeros() as Self
+            }
+            fn count_ones(self) -> Self {
+                self.count_ones() as Self
+            }
+        }
+    };
+}
+
 macro_rules! impl_uint {
     ($($ty:ident),*) => {
         $(
-            impl Int<u32> for $ty {}
+            impl_int!($ty);
             impl UInt<u32> for $ty {}
         )*
     };
@@ -519,19 +569,29 @@ macro_rules! impl_uint {
 
 impl_uint![u8, u16, u32, u64];
 
-macro_rules! impl_int {
+macro_rules! impl_sint {
     ($($ty:ident),*) => {
         $(
-            impl Int<u32> for $ty {}
+            impl_int!($ty);
             impl SInt<u32> for $ty {}
         )*
     };
 }
 
-impl_int![i8, i16, i32, i64];
+impl_sint![i8, i16, i32, i64];
 
-pub trait Float<BitsShiftRhs>: Number + Neg<Output = Self> {
-    type BitsType: UInt<BitsShiftRhs>;
+pub trait Float<BitsShiftRhs: Make<Context = Self::Context, Prim = u32>>:
+    Number + Neg<Output = Self>
+{
+    type FloatEncoding: FloatEncoding + Make<Context = Scalar, Prim = <Self as Make>::Prim>;
+    type BitsType: UInt<BitsShiftRhs>
+        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::BitsType>
+        + ConvertTo<Self::SignedBitsType>
+        + Compare<Bool = Self::Bool>;
+    type SignedBitsType: SInt<BitsShiftRhs>
+        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::SignedBitsType>
+        + ConvertTo<Self::BitsType>
+        + Compare<Bool = Self::Bool>;
     fn abs(self) -> Self;
     fn trunc(self) -> Self;
     fn ceil(self) -> Self;
@@ -539,17 +599,59 @@ pub trait Float<BitsShiftRhs>: Number + Neg<Output = Self> {
     fn round(self) -> Self;
     #[cfg(feature = "fma")]
     fn fma(self, a: Self, b: Self) -> Self;
-    fn is_nan(self) -> Self::Bool;
-    fn is_infinite(self) -> Self::Bool;
+    fn is_nan(self) -> Self::Bool {
+        self.ne(self)
+    }
+    fn is_infinite(self) -> Self::Bool {
+        self.abs().eq(Self::infinity(self.ctx()))
+    }
+    fn infinity(ctx: Self::Context) -> Self {
+        Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS))
+    }
+    fn nan(ctx: Self::Context) -> Self {
+        Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS))
+    }
     fn is_finite(self) -> Self::Bool;
+    fn is_zero_or_subnormal(self) -> Self::Bool {
+        self.extract_exponent_field().eq(self
+            .ctx()
+            .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT))
+    }
     fn from_bits(v: Self::BitsType) -> Self;
     fn to_bits(self) -> Self::BitsType;
+    fn extract_exponent_field(self) -> Self::BitsType {
+        let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK);
+        let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT);
+        (self.to_bits() & mask) >> shift
+    }
+    fn extract_exponent_unbiased(self) -> Self::SignedBitsType {
+        Self::sub_exponent_bias(self.extract_exponent_field())
+    }
+    fn extract_mantissa_field(self) -> Self::BitsType {
+        let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK);
+        self.to_bits() & mask
+    }
+    fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType {
+        exponent_field.to()
+            - exponent_field
+                .ctx()
+                .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)
+    }
+    fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType {
+        (exponent
+            + exponent
+                .ctx()
+                .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED))
+        .to()
+    }
 }
 
 macro_rules! impl_float {
-    ($ty:ty, $bits:ty) => {
+    ($ty:ty, $bits:ty, $signed_bits:ty) => {
         impl Float<u32> for $ty {
+            type FloatEncoding = $ty;
             type BitsType = $bits;
+            type SignedBitsType = $signed_bits;
             fn abs(self) -> Self {
                 #[cfg(feature = "std")]
                 return self.abs();
@@ -603,10 +705,10 @@ macro_rules! impl_float {
     };
 }
 
-impl_float!(f32, u32);
-impl_float!(f64, u64);
+impl_float!(f32, u32, i32);
+impl_float!(f64, u64, i64);
 
-pub trait Bool: BitOps {}
+pub trait Bool: Make + BitOps {}
 
 impl Bool for bool {}
 
@@ -623,7 +725,7 @@ impl<T> Select<T> for bool {
         }
     }
 }
-pub trait Compare: Copy {
+pub trait Compare: Make {
     type Bool: Bool + Select<Self>;
     fn eq(self, rhs: Self) -> Self::Bool;
     fn ne(self, rhs: Self) -> Self::Bool;