refactor to easily allow algorithms generic over f16/32/64
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 10 May 2021 00:41:54 +0000 (17:41 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 10 May 2021 00:41:54 +0000 (17:41 -0700)
12 files changed:
.gitlab-ci.yml
Cargo.toml
src/algorithms/ilogb.rs
src/algorithms/trig_pi.rs
src/f16.rs
src/ieee754.rs [deleted file]
src/ir.rs
src/lib.rs
src/prim.rs [new file with mode: 0644]
src/scalar.rs
src/traits.rs
vector-math-proc-macro/src/lib.rs

index c7adb6b8e72ebd0870664c1cec84bf85cf9e7705..4d9c5410119abdf027dcb1e39e48cd05482c8159 100644 (file)
@@ -11,6 +11,26 @@ rust-latest:
         matrix:
             - FEATURES: ["", "fma,ir", "f16,ir", "fma,f16,ir"]
 
+rust-latest-release:
+    stage: build
+    image: rust:latest
+    script:
+        - cargo build --verbose --release --no-default-features --features="$FEATURES"
+        - cargo test --verbose --release --no-default-features --features="$FEATURES"
+    parallel:
+        matrix:
+            - FEATURES:
+                  [
+                      "",
+                      "fma,ir",
+                      "f16,ir",
+                      "fma,f16,ir",
+                      "full_tests",
+                      "full_tests,fma",
+                      "full_tests,fma,f16",
+                      "full_tests,f16",
+                  ]
+
 rust-nightly:
     stage: build
     image: rustlang/rust:nightly
index d83abb80cfca0f04b88fd26385efe25086c5e0a7..e1b82f35bafb78515defd7b66891324666357205 100644 (file)
@@ -18,6 +18,7 @@ fma = ["std"]
 std = []
 ir = ["std", "typed-arena"]
 stdsimd = ["core_simd"]
+# enable slow tests
 full_tests = []
 
 [workspace]
index 7da2733bb0e5f6fc71b9078562a90b67b62f2f12..36d9d54ff975275f38d44087875a2eeaa18086c2 100644 (file)
@@ -1,6 +1,6 @@
 use crate::{
     f16::F16,
-    ieee754::FloatEncoding,
+    prim::PrimFloat,
     traits::{Compare, Context, ConvertTo, Float, Select},
 };
 
index 38104a655ee23870c5c04cfc7c1dc259d6cd32f9..5b07c2a3994593cc35f61badd9d866c4f8fc3af3 100644 (file)
@@ -1,7 +1,7 @@
 use crate::{
     f16::F16,
-    ieee754::FloatEncoding,
-    traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Select},
+    prim::PrimFloat,prim::PrimSInt,prim::PrimUInt,
+    traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Make, Select},
 };
 
 mod consts {
@@ -87,39 +87,58 @@ pub fn cos_pi_kernel_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16
 
 /// computes `(sin(pi * x), cos(pi * x))`
 /// not guaranteed to give correct sign for zero results
-/// has an error of up to 2ULP
-pub fn sin_cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) {
-    let two_f16: Ctx::VecF16 = ctx.make(2.0.to());
-    let one_half: Ctx::VecF16 = ctx.make(0.5.to());
-    let max_contiguous_integer: Ctx::VecF16 =
-        ctx.make((1u16 << (F16::MANTISSA_FIELD_WIDTH + 1)).to());
+/// inherits error from `sin_pi_kernel` and `cos_pi_kernel`
+pub fn sin_cos_pi_impl<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+    SinPiKernel: FnOnce(Ctx, VecF) -> VecF,
+    CosPiKernel: FnOnce(Ctx, VecF) -> VecF,
+>(
+    ctx: Ctx,
+    x: VecF,
+    sin_pi_kernel: SinPiKernel,
+    cos_pi_kernel: CosPiKernel,
+) -> (VecF, VecF) {
+    let two_f: VecF = ctx.make(2.0.to());
+    let one_half: VecF = ctx.make(0.5.to());
+    let max_contiguous_integer: VecF =
+        ctx.make((PrimU::cvt_from(1) << (PrimF::MANTISSA_FIELD_WIDTH + 1.to())).to());
     // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer
     let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range
     let is_finite = x.is_finite();
-    let nan: Ctx::VecF16 = ctx.make(f32::NAN.to());
-    let zero_f16: Ctx::VecF16 = ctx.make(0.to());
-    let one_f16: Ctx::VecF16 = ctx.make(1.to());
-    let zero_i16: Ctx::VecI16 = ctx.make(0.to());
-    let one_i16: Ctx::VecI16 = ctx.make(1.to());
-    let two_i16: Ctx::VecI16 = ctx.make(2.to());
-    let out_of_range_sin = is_finite.select(zero_f16, nan);
-    let out_of_range_cos = is_finite.select(one_f16, nan);
-    let xi = (x * two_f16).round();
+    let nan: VecF = ctx.make(f32::NAN.to());
+    let zero_f: VecF = ctx.make(0.to());
+    let one_f: VecF = ctx.make(1.to());
+    let zero_i: VecF::SignedBitsType = ctx.make(0.to());
+    let one_i: VecF::SignedBitsType = ctx.make(1.to());
+    let two_i: VecF::SignedBitsType = ctx.make(2.to());
+    let out_of_range_sin = is_finite.select(zero_f, nan);
+    let out_of_range_cos = is_finite.select(one_f, nan);
+    let xi = (x * two_f).round();
     let xk = x - xi * one_half;
-    let sk = sin_pi_kernel_f16(ctx, xk);
-    let ck = cos_pi_kernel_f16(ctx, xk);
-    let xi = Ctx::VecI16::cvt_from(xi);
-    let bit_0_clear = (xi & one_i16).eq(zero_i16);
+    let sk = sin_pi_kernel(ctx, xk);
+    let ck = cos_pi_kernel(ctx, xk);
+    let xi = VecF::SignedBitsType::cvt_from(xi);
+    let bit_0_clear = (xi & one_i).eq(zero_i);
     let st = bit_0_clear.select(sk, ck);
     let ct = bit_0_clear.select(ck, sk);
-    let s = (xi & two_i16).eq(zero_i16).select(st, -st);
-    let c = ((xi + one_i16) & two_i16).eq(zero_i16).select(ct, -ct);
+    let s = (xi & two_i).eq(zero_i).select(st, -st);
+    let c = ((xi + one_i) & two_i).eq(zero_i).select(ct, -ct);
     (
         in_range.select(s, out_of_range_sin),
         in_range.select(c, out_of_range_cos),
     )
 }
 
+/// computes `(sin(pi * x), cos(pi * x))`
+/// not guaranteed to give correct sign for zero results
+/// has an error of up to 2ULP
+pub fn sin_cos_pi_f16<Ctx: Context>(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) {
+    sin_cos_pi_impl(ctx, x, sin_pi_kernel_f16, cos_pi_kernel_f16)
+}
+
 /// computes `sin(pi * x)`
 /// not guaranteed to give correct sign for zero results
 /// has an error of up to 2ULP
index e9541b416f1546e5f7195baa01f3a528f7e6bdf5..5253fef282e136c966950c2d3c4f536bc120fa0c 100644 (file)
@@ -1,11 +1,11 @@
-use core::ops::{
-    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign,
-};
-
 use crate::{
     scalar::Value,
     traits::{ConvertFrom, ConvertTo, Float},
 };
+use core::{
+    fmt,
+    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign},
+};
 
 #[cfg(feature = "f16")]
 use half::f16 as F16Impl;
@@ -13,7 +13,7 @@ use half::f16 as F16Impl;
 #[cfg(not(feature = "f16"))]
 type F16Impl = u16;
 
-#[derive(Clone, Copy, PartialEq, PartialOrd, Debug)]
+#[derive(Clone, Copy, PartialEq, PartialOrd)]
 #[repr(transparent)]
 pub struct F16(F16Impl);
 
@@ -40,6 +40,18 @@ macro_rules! f16_impl {
     };
 }
 
+impl fmt::Display for F16 {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f16_impl!(self.0.fmt(f), [f])
+    }
+}
+
+impl fmt::Debug for F16 {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f16_impl!(self.0.fmt(f), [f])
+    }
+}
+
 impl Default for F16 {
     fn default() -> Self {
         f16_impl!(F16(F16Impl::default()), [])
@@ -193,21 +205,29 @@ impl F16 {
         f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), [])
     }
     pub fn trunc(self) -> Self {
-        f32::from(self).trunc().to()
+        #[cfg(feature = "std")]
+        return f32::from(self).trunc().to();
+        #[cfg(not(feature = "std"))]
+        todo!();
     }
-
     pub fn ceil(self) -> Self {
-        f32::from(self).ceil().to()
+        #[cfg(feature = "std")]
+        return f32::from(self).ceil().to();
+        #[cfg(not(feature = "std"))]
+        todo!();
     }
-
     pub fn floor(self) -> Self {
-        f32::from(self).floor().to()
+        #[cfg(feature = "std")]
+        return f32::from(self).floor().to();
+        #[cfg(not(feature = "std"))]
+        todo!();
     }
-
     pub fn round(self) -> Self {
-        f32::from(self).round().to()
+        #[cfg(feature = "std")]
+        return f32::from(self).round().to();
+        #[cfg(not(feature = "std"))]
+        todo!();
     }
-
     #[cfg(feature = "fma")]
     pub fn fma(self, a: Self, b: Self) -> Self {
         (f64::from(self) * f64::from(a) + f64::from(b)).to()
@@ -227,7 +247,7 @@ impl F16 {
 }
 
 impl Float for Value<F16> {
-    type FloatEncoding = F16;
+    type PrimFloat = F16;
     type BitsType = Value<u16>;
     type SignedBitsType = Value<i16>;
 
diff --git a/src/ieee754.rs b/src/ieee754.rs
deleted file mode 100644 (file)
index 3d70468..0000000
+++ /dev/null
@@ -1,95 +0,0 @@
-use crate::f16::F16;
-
-mod sealed {
-    use crate::f16::F16;
-
-    pub trait Sealed {}
-    impl Sealed for F16 {}
-    impl Sealed for f32 {}
-    impl Sealed for f64 {}
-}
-
-pub trait FloatEncoding: sealed::Sealed + Copy + 'static + Send + Sync {
-    type BitsType;
-    type SignedBitsType;
-    const EXPONENT_BIAS_UNSIGNED: Self::BitsType;
-    const EXPONENT_BIAS_SIGNED: Self::SignedBitsType;
-    const SIGN_FIELD_WIDTH: Self::BitsType;
-    const EXPONENT_FIELD_WIDTH: Self::BitsType;
-    const MANTISSA_FIELD_WIDTH: Self::BitsType;
-    const SIGN_FIELD_SHIFT: Self::BitsType;
-    const EXPONENT_FIELD_SHIFT: Self::BitsType;
-    const MANTISSA_FIELD_SHIFT: Self::BitsType;
-    const SIGN_FIELD_MASK: Self::BitsType;
-    const EXPONENT_FIELD_MASK: Self::BitsType;
-    const MANTISSA_FIELD_MASK: Self::BitsType;
-    const IMPLICIT_MANTISSA_BIT: Self::BitsType;
-    const ZERO_SUBNORMAL_EXPONENT: Self::BitsType;
-    const NAN_INFINITY_EXPONENT: Self::BitsType;
-    const INFINITY_BITS: Self::BitsType;
-    const NAN_BITS: Self::BitsType;
-}
-
-macro_rules! impl_float_encoding {
-    (
-        impl FloatEncoding for $float:ident {
-            type BitsType = $bits_type:ident;
-            type SignedBitsType = $signed_bits_type:ident;
-            const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width:literal;
-            const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width:literal;
-        }
-    ) => {
-        impl FloatEncoding for $float {
-            type BitsType = $bits_type;
-            type SignedBitsType = $signed_bits_type;
-            const EXPONENT_BIAS_UNSIGNED: Self::BitsType =
-                (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1;
-            const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _;
-            const SIGN_FIELD_WIDTH: Self::BitsType = 1;
-            const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width;
-            const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width;
-            const SIGN_FIELD_SHIFT: Self::BitsType =
-                Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
-            const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH;
-            const MANTISSA_FIELD_SHIFT: Self::BitsType = 0;
-            const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT;
-            const EXPONENT_FIELD_MASK: Self::BitsType =
-                ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT;
-            const MANTISSA_FIELD_MASK: Self::BitsType = (1 << Self::MANTISSA_FIELD_WIDTH) - 1;
-            const IMPLICIT_MANTISSA_BIT: Self::BitsType = 1 << Self::MANTISSA_FIELD_WIDTH;
-            const ZERO_SUBNORMAL_EXPONENT: Self::BitsType = 0;
-            const NAN_INFINITY_EXPONENT: Self::BitsType = (1 << Self::EXPONENT_FIELD_WIDTH) - 1;
-            const INFINITY_BITS: Self::BitsType =
-                Self::NAN_INFINITY_EXPONENT << Self::EXPONENT_FIELD_SHIFT;
-            const NAN_BITS: Self::BitsType =
-                Self::INFINITY_BITS | (1 << (Self::MANTISSA_FIELD_WIDTH - 1));
-        }
-    };
-}
-
-impl_float_encoding! {
-    impl FloatEncoding for F16 {
-        type BitsType = u16;
-        type SignedBitsType = i16;
-        const EXPONENT_FIELD_WIDTH: u32 = 5;
-        const MANTISSA_FIELD_WIDTH: u32 = 10;
-    }
-}
-
-impl_float_encoding! {
-    impl FloatEncoding for f32 {
-        type BitsType = u32;
-        type SignedBitsType = i32;
-        const EXPONENT_FIELD_WIDTH: u32 = 8;
-        const MANTISSA_FIELD_WIDTH: u32 = 23;
-    }
-}
-
-impl_float_encoding! {
-    impl FloatEncoding for f64 {
-        type BitsType = u64;
-        type SignedBitsType = i64;
-        const EXPONENT_FIELD_WIDTH: u32 = 11;
-        const MANTISSA_FIELD_WIDTH: u32 = 52;
-    }
-}
index e2b4a0ecd72cb58421cf07b060a023e3ef640020..279902041a39bfcdee4f2385cd550f8891d16374 100644 (file)
--- a/src/ir.rs
+++ b/src/ir.rs
@@ -1220,40 +1220,41 @@ macro_rules! impl_integer_ops {
     };
 }
 
-macro_rules! impl_uint_ops {
-    ($scalar:ident, $vec:ident) => {
-        impl_integer_ops!($scalar, $vec);
-
-        impl<'ctx> UInt for $scalar<'ctx> {}
-        impl<'ctx> UInt for $vec<'ctx> {}
-    };
-}
+macro_rules! impl_uint_sint_ops {
+    ($uint_scalar:ident, $uint_vec:ident, $sint_scalar:ident, $sint_vec:ident) => {
+        impl_integer_ops!($uint_scalar, $uint_vec);
+        impl_integer_ops!($sint_scalar, $sint_vec);
+        impl_neg!($sint_scalar);
+        impl_neg!($sint_vec);
 
-impl_uint_ops!(IrU8, IrVecU8);
-impl_uint_ops!(IrU16, IrVecU16);
-impl_uint_ops!(IrU32, IrVecU32);
-impl_uint_ops!(IrU64, IrVecU64);
-
-macro_rules! impl_sint_ops {
-    ($scalar:ident, $vec:ident) => {
-        impl_integer_ops!($scalar, $vec);
-        impl_neg!($scalar);
-        impl_neg!($vec);
-
-        impl<'ctx> SInt for $scalar<'ctx> {}
-        impl<'ctx> SInt for $vec<'ctx> {}
+        impl<'ctx> UInt for $uint_scalar<'ctx> {
+            type PrimUInt = Self::Prim;
+            type SignedType = $sint_scalar<'ctx>;
+        }
+        impl<'ctx> UInt for $uint_vec<'ctx> {
+            type PrimUInt = Self::Prim;
+            type SignedType = $sint_vec<'ctx>;
+        }
+        impl<'ctx> SInt for $sint_scalar<'ctx> {
+            type PrimSInt = Self::Prim;
+            type UnsignedType = $uint_scalar<'ctx>;
+        }
+        impl<'ctx> SInt for $sint_vec<'ctx> {
+            type PrimSInt = Self::Prim;
+            type UnsignedType = $uint_vec<'ctx>;
+        }
     };
 }
 
-impl_sint_ops!(IrI8, IrVecI8);
-impl_sint_ops!(IrI16, IrVecI16);
-impl_sint_ops!(IrI32, IrVecI32);
-impl_sint_ops!(IrI64, IrVecI64);
+impl_uint_sint_ops!(IrU8, IrVecU8, IrI8, IrVecI8);
+impl_uint_sint_ops!(IrU16, IrVecU16, IrI16, IrVecI16);
+impl_uint_sint_ops!(IrU32, IrVecU32, IrI32, IrVecI32);
+impl_uint_sint_ops!(IrU64, IrVecU64, IrI64, IrVecI64);
 
 macro_rules! impl_float {
     ($float:ident, $bits:ident, $signed_bits:ident) => {
         impl<'ctx> Float for $float<'ctx> {
-            type FloatEncoding = <$float<'ctx> as Make>::Prim;
+            type PrimFloat = <$float<'ctx> as Make>::Prim;
             type BitsType = $bits<'ctx>;
             type SignedBitsType = $signed_bits<'ctx>;
             fn abs(self) -> Self {
index cecc2e4566d62787705f20076e5291b873080add..60aa0824fbdb0fd8d3354206e91d5a4f9f335778 100644 (file)
@@ -6,9 +6,9 @@ extern crate std;
 
 pub mod algorithms;
 pub mod f16;
-pub mod ieee754;
 #[cfg(feature = "ir")]
 pub mod ir;
+pub mod prim;
 pub mod scalar;
 #[cfg(feature = "stdsimd")]
 pub mod stdsimd;
diff --git a/src/prim.rs b/src/prim.rs
new file mode 100644 (file)
index 0000000..b2d2ebb
--- /dev/null
@@ -0,0 +1,205 @@
+use crate::{
+    f16::F16,
+    traits::{ConvertFrom, ConvertTo},
+};
+use core::{fmt, hash, ops};
+
+mod sealed {
+    use crate::f16::F16;
+
+    pub trait Sealed {}
+    impl Sealed for F16 {}
+    impl Sealed for f32 {}
+    impl Sealed for f64 {}
+    impl Sealed for u8 {}
+    impl Sealed for u16 {}
+    impl Sealed for u32 {}
+    impl Sealed for u64 {}
+    impl Sealed for i8 {}
+    impl Sealed for i16 {}
+    impl Sealed for i32 {}
+    impl Sealed for i64 {}
+}
+
+pub trait PrimBase:
+    sealed::Sealed
+    + Copy
+    + 'static
+    + Send
+    + Sync
+    + PartialOrd
+    + fmt::Debug
+    + fmt::Display
+    + ops::Add<Output = Self>
+    + ops::Sub<Output = Self>
+    + ops::Mul<Output = Self>
+    + ops::Div<Output = Self>
+    + ops::Rem<Output = Self>
+    + ops::AddAssign
+    + ops::SubAssign
+    + ops::MulAssign
+    + ops::DivAssign
+    + ops::RemAssign
+    + ConvertFrom<i8>
+    + ConvertFrom<u8>
+    + ConvertFrom<i16>
+    + ConvertFrom<u16>
+    + ConvertFrom<F16>
+    + ConvertFrom<i32>
+    + ConvertFrom<u32>
+    + ConvertFrom<f32>
+    + ConvertFrom<i64>
+    + ConvertFrom<u64>
+    + ConvertFrom<f64>
+    + ConvertTo<i8>
+    + ConvertTo<u8>
+    + ConvertTo<i16>
+    + ConvertTo<u16>
+    + ConvertTo<F16>
+    + ConvertTo<i32>
+    + ConvertTo<u32>
+    + ConvertTo<f32>
+    + ConvertTo<i64>
+    + ConvertTo<u64>
+    + ConvertTo<f64>
+{
+}
+
+pub trait PrimInt:
+    PrimBase
+    + Ord
+    + hash::Hash
+    + fmt::Binary
+    + fmt::LowerHex
+    + fmt::Octal
+    + fmt::UpperHex
+    + ops::BitAnd<Output = Self>
+    + ops::BitOr<Output = Self>
+    + ops::BitXor<Output = Self>
+    + ops::Shl<Output = Self>
+    + ops::Shr<Output = Self>
+    + ops::Not<Output = Self>
+    + ops::BitAndAssign
+    + ops::BitOrAssign
+    + ops::BitXorAssign
+    + ops::ShlAssign
+    + ops::ShrAssign
+{
+}
+
+pub trait PrimUInt: PrimInt + ConvertFrom<Self::SignedType> {
+    type SignedType: PrimSInt<UnsignedType = Self> + ConvertFrom<Self>;
+}
+
+pub trait PrimSInt: PrimInt + ops::Neg<Output = Self> + ConvertFrom<Self::UnsignedType> {
+    type UnsignedType: PrimUInt<SignedType = Self> + ConvertFrom<Self>;
+}
+
+macro_rules! impl_int {
+    ($uint:ident, $sint:ident) => {
+        impl PrimBase for $uint {}
+        impl PrimBase for $sint {}
+        impl PrimInt for $uint {}
+        impl PrimInt for $sint {}
+        impl PrimUInt for $uint {
+            type SignedType = $sint;
+        }
+        impl PrimSInt for $sint {
+            type UnsignedType = $uint;
+        }
+    };
+}
+
+impl_int!(u8, i8);
+impl_int!(u16, i16);
+impl_int!(u32, i32);
+impl_int!(u64, i64);
+
+pub trait PrimFloat:
+    PrimBase + ops::Neg<Output = Self> + ConvertFrom<Self::BitsType> + ConvertFrom<Self::SignedBitsType>
+{
+    type BitsType: PrimUInt<SignedType = Self::SignedBitsType> + ConvertFrom<Self>;
+    type SignedBitsType: PrimSInt<UnsignedType = Self::BitsType> + ConvertFrom<Self>;
+    const EXPONENT_BIAS_UNSIGNED: Self::BitsType;
+    const EXPONENT_BIAS_SIGNED: Self::SignedBitsType;
+    const SIGN_FIELD_WIDTH: Self::BitsType;
+    const EXPONENT_FIELD_WIDTH: Self::BitsType;
+    const MANTISSA_FIELD_WIDTH: Self::BitsType;
+    const SIGN_FIELD_SHIFT: Self::BitsType;
+    const EXPONENT_FIELD_SHIFT: Self::BitsType;
+    const MANTISSA_FIELD_SHIFT: Self::BitsType;
+    const SIGN_FIELD_MASK: Self::BitsType;
+    const EXPONENT_FIELD_MASK: Self::BitsType;
+    const MANTISSA_FIELD_MASK: Self::BitsType;
+    const IMPLICIT_MANTISSA_BIT: Self::BitsType;
+    const ZERO_SUBNORMAL_EXPONENT: Self::BitsType;
+    const NAN_INFINITY_EXPONENT: Self::BitsType;
+    const INFINITY_BITS: Self::BitsType;
+    const NAN_BITS: Self::BitsType;
+}
+
+macro_rules! impl_float {
+    (
+        impl PrimFloat for $float:ident {
+            type BitsType = $bits_type:ident;
+            type SignedBitsType = $signed_bits_type:ident;
+            const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width:literal;
+            const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width:literal;
+        }
+    ) => {
+        impl PrimBase for $float {}
+
+        impl PrimFloat for $float {
+            type BitsType = $bits_type;
+            type SignedBitsType = $signed_bits_type;
+            const EXPONENT_BIAS_UNSIGNED: Self::BitsType =
+                (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1;
+            const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _;
+            const SIGN_FIELD_WIDTH: Self::BitsType = 1;
+            const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width;
+            const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width;
+            const SIGN_FIELD_SHIFT: Self::BitsType =
+                Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
+            const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH;
+            const MANTISSA_FIELD_SHIFT: Self::BitsType = 0;
+            const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT;
+            const EXPONENT_FIELD_MASK: Self::BitsType =
+                ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT;
+            const MANTISSA_FIELD_MASK: Self::BitsType = (1 << Self::MANTISSA_FIELD_WIDTH) - 1;
+            const IMPLICIT_MANTISSA_BIT: Self::BitsType = 1 << Self::MANTISSA_FIELD_WIDTH;
+            const ZERO_SUBNORMAL_EXPONENT: Self::BitsType = 0;
+            const NAN_INFINITY_EXPONENT: Self::BitsType = (1 << Self::EXPONENT_FIELD_WIDTH) - 1;
+            const INFINITY_BITS: Self::BitsType =
+                Self::NAN_INFINITY_EXPONENT << Self::EXPONENT_FIELD_SHIFT;
+            const NAN_BITS: Self::BitsType =
+                Self::INFINITY_BITS | (1 << (Self::MANTISSA_FIELD_WIDTH - 1));
+        }
+    };
+}
+
+impl_float! {
+    impl PrimFloat for F16 {
+        type BitsType = u16;
+        type SignedBitsType = i16;
+        const EXPONENT_FIELD_WIDTH: u32 = 5;
+        const MANTISSA_FIELD_WIDTH: u32 = 10;
+    }
+}
+
+impl_float! {
+    impl PrimFloat for f32 {
+        type BitsType = u32;
+        type SignedBitsType = i32;
+        const EXPONENT_FIELD_WIDTH: u32 = 8;
+        const MANTISSA_FIELD_WIDTH: u32 = 23;
+    }
+}
+
+impl_float! {
+    impl PrimFloat for f64 {
+        type BitsType = u64;
+        type SignedBitsType = i64;
+        const EXPONENT_FIELD_WIDTH: u32 = 11;
+        const MANTISSA_FIELD_WIDTH: u32 = 52;
+    }
+}
index 4eb5b985627f03f3e9a75761437e1923b4f279c8..30aaa9e9cfc3eb981da354f0d9bb61aeafc98439 100644 (file)
@@ -1,5 +1,6 @@
 use crate::{
     f16::F16,
+    prim::{PrimSInt, PrimUInt},
     traits::{Bool, Compare, Context, ConvertFrom, Float, Int, Make, SInt, Select, UInt},
 };
 use core::ops::{
@@ -230,7 +231,10 @@ macro_rules! impl_uint {
     ($($ty:ident),*) => {
         $(
             impl_int!($ty);
-            impl UInt for Value<$ty> {}
+            impl UInt for Value<$ty> {
+                type PrimUInt = $ty;
+                type SignedType = Value<<$ty as PrimUInt>::SignedType>;
+            }
         )*
     };
 }
@@ -241,7 +245,10 @@ macro_rules! impl_sint {
     ($($ty:ident),*) => {
         $(
             impl_int!($ty);
-            impl SInt for Value<$ty> {}
+            impl SInt for Value<$ty> {
+                type PrimSInt = $ty;
+                type UnsignedType = Value<<$ty as PrimSInt>::UnsignedType>;
+            }
         )*
     };
 }
@@ -336,7 +343,7 @@ macro_rules! impl_float {
     ($ty:ident, $bits:ty, $signed_bits:ty) => {
         impl_float_ops!($ty);
         impl Float for Value<$ty> {
-            type FloatEncoding = $ty;
+            type PrimFloat = $ty;
             type BitsType = Value<$bits>;
             type SignedBitsType = Value<$signed_bits>;
             fn abs(self) -> Self {
index 1877c21f1cb7d1fde0c73114f7170a4f0261d1a8..2ec48157c3b691883936c6d2e485e8b042c3e1b7 100644 (file)
@@ -1,4 +1,7 @@
-use crate::{f16::F16, ieee754::FloatEncoding};
+use crate::{
+    f16::F16,
+    prim::{PrimFloat, PrimSInt, PrimUInt},
+};
 use core::ops::{
     Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
     Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
@@ -132,20 +135,42 @@ pub trait Int:
     fn count_ones(self) -> Self;
 }
 
-pub trait UInt: Int {}
-
-pub trait SInt: Int + Neg<Output = Self> {}
-
-pub trait Float: Number + Neg<Output = Self> {
-    type FloatEncoding: FloatEncoding + From<<Self as Make>::Prim> + Into<<Self as Make>::Prim>;
-    type BitsType: UInt
-        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as FloatEncoding>::BitsType>
-        + ConvertTo<Self::SignedBitsType>
+pub trait UInt: Int + Make<Prim = Self::PrimUInt> + ConvertFrom<Self::SignedType> {
+    type PrimUInt: PrimUInt<SignedType = <Self::SignedType as SInt>::PrimSInt>;
+    type SignedType: SInt
+        + ConvertFrom<Self>
+        + Make<Context = Self::Context>
         + Compare<Bool = Self::Bool>;
-    type SignedBitsType: SInt
-        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as FloatEncoding>::SignedBitsType>
-        + ConvertTo<Self::BitsType>
+}
+
+pub trait SInt:
+    Int + Neg<Output = Self> + Make<Prim = Self::PrimSInt> + ConvertFrom<Self::UnsignedType>
+{
+    type PrimSInt: PrimSInt<UnsignedType = <Self::UnsignedType as UInt>::PrimUInt>;
+    type UnsignedType: UInt
+        + ConvertFrom<Self>
+        + Make<Context = Self::Context>
         + Compare<Bool = Self::Bool>;
+}
+
+pub trait Float:
+    Number
+    + Neg<Output = Self>
+    + Make<Prim = Self::PrimFloat>
+    + ConvertFrom<Self::SignedBitsType>
+    + ConvertFrom<Self::BitsType>
+{
+    type PrimFloat: PrimFloat;
+    type BitsType: UInt<PrimUInt = <Self::PrimFloat as PrimFloat>::BitsType, SignedType = Self::SignedBitsType>
+        + Make<Context = Self::Context, Prim = <Self::PrimFloat as PrimFloat>::BitsType>
+        + Compare<Bool = Self::Bool>
+        + ConvertFrom<Self>;
+    type SignedBitsType: SInt<
+            PrimSInt = <Self::PrimFloat as PrimFloat>::SignedBitsType,
+            UnsignedType = Self::BitsType,
+        > + Make<Context = Self::Context, Prim = <Self::PrimFloat as PrimFloat>::SignedBitsType>
+        + Compare<Bool = Self::Bool>
+        + ConvertFrom<Self>;
     fn abs(self) -> Self;
     fn trunc(self) -> Self;
     fn ceil(self) -> Self;
@@ -169,43 +194,38 @@ pub trait Float: Number + Neg<Output = Self> {
         self.abs().eq(Self::infinity(self.ctx()))
     }
     fn infinity(ctx: Self::Context) -> Self {
-        Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS))
+        Self::from_bits(ctx.make(Self::PrimFloat::INFINITY_BITS))
     }
     fn nan(ctx: Self::Context) -> Self {
-        Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS))
+        Self::from_bits(ctx.make(Self::PrimFloat::NAN_BITS))
     }
     fn is_finite(self) -> Self::Bool;
     fn is_zero_or_subnormal(self) -> Self::Bool {
-        self.extract_exponent_field().eq(self
-            .ctx()
-            .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT))
+        self.extract_exponent_field()
+            .eq(self.ctx().make(Self::PrimFloat::ZERO_SUBNORMAL_EXPONENT))
     }
     fn from_bits(v: Self::BitsType) -> Self;
     fn to_bits(self) -> Self::BitsType;
     fn extract_exponent_field(self) -> Self::BitsType {
-        let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK);
-        let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT);
+        let mask = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_MASK);
+        let shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT);
         (self.to_bits() & mask) >> shift
     }
     fn extract_exponent_unbiased(self) -> Self::SignedBitsType {
         Self::sub_exponent_bias(self.extract_exponent_field())
     }
     fn extract_mantissa_field(self) -> Self::BitsType {
-        let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK);
+        let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK);
         self.to_bits() & mask
     }
     fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType {
-        exponent_field.to()
+        Self::SignedBitsType::cvt_from(exponent_field)
             - exponent_field
                 .ctx()
-                .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)
+                .make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)
     }
     fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType {
-        (exponent
-            + exponent
-                .ctx()
-                .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED))
-        .to()
+        (exponent + exponent.ctx().make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)).to()
     }
 }
 
index 5c4de02ac0984025a1a530761a25177272cbe1d4..89d2bc43f365c25bcbb4a714e3425a7ca5c84c20 100644 (file)
@@ -269,12 +269,16 @@ impl TraitSets {
                     let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
                     let type_trait = match type_kind {
                         TypeKind::Bool => quote! { Bool },
-                        TypeKind::UInt => quote! { UInt },
-                        TypeKind::SInt => quote! { SInt },
+                        TypeKind::UInt => {
+                            quote! { UInt<PrimUInt = #prim_ty, SignedType = Self::#sint_ty> }
+                        }
+                        TypeKind::SInt => {
+                            quote! { SInt<PrimSInt = #prim_ty, UnsignedType = Self::#uint_ty> }
+                        }
                         TypeKind::Float => quote! { Float<
                             BitsType = Self::#uint_ty,
                             SignedBitsType = Self::#sint_ty,
-                            FloatEncoding = #prim_ty,
+                            PrimFloat = #prim_ty,
                         > },
                     };
                     self.add_trait(type_kind, bits, vector_scalar, type_trait);