refactor to easily allow algorithms generic over f16/32/64
[vector-math.git] / src / traits.rs
index 1877c21f1cb7d1fde0c73114f7170a4f0261d1a8..2ec48157c3b691883936c6d2e485e8b042c3e1b7 100644 (file)
@@ -1,4 +1,7 @@
-use crate::{f16::F16, ieee754::FloatEncoding};
+use crate::{
+    f16::F16,
+    prim::{PrimFloat, PrimSInt, PrimUInt},
+};
 use core::ops::{
     Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
     Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
@@ -132,20 +135,42 @@ pub trait Int:
     fn count_ones(self) -> Self;
 }
 
-pub trait UInt: Int {}
-
-pub trait SInt: Int + Neg<Output = Self> {}
-
-pub trait Float: Number + Neg<Output = Self> {
-    type FloatEncoding: FloatEncoding + From<<Self as Make>::Prim> + Into<<Self as Make>::Prim>;
-    type BitsType: UInt
-        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as FloatEncoding>::BitsType>
-        + ConvertTo<Self::SignedBitsType>
+pub trait UInt: Int + Make<Prim = Self::PrimUInt> + ConvertFrom<Self::SignedType> {
+    type PrimUInt: PrimUInt<SignedType = <Self::SignedType as SInt>::PrimSInt>;
+    type SignedType: SInt
+        + ConvertFrom<Self>
+        + Make<Context = Self::Context>
         + Compare<Bool = Self::Bool>;
-    type SignedBitsType: SInt
-        + Make<Context = Self::Context, Prim = <Self::FloatEncoding as FloatEncoding>::SignedBitsType>
-        + ConvertTo<Self::BitsType>
+}
+
+pub trait SInt:
+    Int + Neg<Output = Self> + Make<Prim = Self::PrimSInt> + ConvertFrom<Self::UnsignedType>
+{
+    type PrimSInt: PrimSInt<UnsignedType = <Self::UnsignedType as UInt>::PrimUInt>;
+    type UnsignedType: UInt
+        + ConvertFrom<Self>
+        + Make<Context = Self::Context>
         + Compare<Bool = Self::Bool>;
+}
+
+pub trait Float:
+    Number
+    + Neg<Output = Self>
+    + Make<Prim = Self::PrimFloat>
+    + ConvertFrom<Self::SignedBitsType>
+    + ConvertFrom<Self::BitsType>
+{
+    type PrimFloat: PrimFloat;
+    type BitsType: UInt<PrimUInt = <Self::PrimFloat as PrimFloat>::BitsType, SignedType = Self::SignedBitsType>
+        + Make<Context = Self::Context, Prim = <Self::PrimFloat as PrimFloat>::BitsType>
+        + Compare<Bool = Self::Bool>
+        + ConvertFrom<Self>;
+    type SignedBitsType: SInt<
+            PrimSInt = <Self::PrimFloat as PrimFloat>::SignedBitsType,
+            UnsignedType = Self::BitsType,
+        > + Make<Context = Self::Context, Prim = <Self::PrimFloat as PrimFloat>::SignedBitsType>
+        + Compare<Bool = Self::Bool>
+        + ConvertFrom<Self>;
     fn abs(self) -> Self;
     fn trunc(self) -> Self;
     fn ceil(self) -> Self;
@@ -169,43 +194,38 @@ pub trait Float: Number + Neg<Output = Self> {
         self.abs().eq(Self::infinity(self.ctx()))
     }
     fn infinity(ctx: Self::Context) -> Self {
-        Self::from_bits(ctx.make(Self::FloatEncoding::INFINITY_BITS))
+        Self::from_bits(ctx.make(Self::PrimFloat::INFINITY_BITS))
     }
     fn nan(ctx: Self::Context) -> Self {
-        Self::from_bits(ctx.make(Self::FloatEncoding::NAN_BITS))
+        Self::from_bits(ctx.make(Self::PrimFloat::NAN_BITS))
     }
     fn is_finite(self) -> Self::Bool;
     fn is_zero_or_subnormal(self) -> Self::Bool {
-        self.extract_exponent_field().eq(self
-            .ctx()
-            .make(Self::FloatEncoding::ZERO_SUBNORMAL_EXPONENT))
+        self.extract_exponent_field()
+            .eq(self.ctx().make(Self::PrimFloat::ZERO_SUBNORMAL_EXPONENT))
     }
     fn from_bits(v: Self::BitsType) -> Self;
     fn to_bits(self) -> Self::BitsType;
     fn extract_exponent_field(self) -> Self::BitsType {
-        let mask = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_MASK);
-        let shift = self.ctx().make(Self::FloatEncoding::EXPONENT_FIELD_SHIFT);
+        let mask = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_MASK);
+        let shift = self.ctx().make(Self::PrimFloat::EXPONENT_FIELD_SHIFT);
         (self.to_bits() & mask) >> shift
     }
     fn extract_exponent_unbiased(self) -> Self::SignedBitsType {
         Self::sub_exponent_bias(self.extract_exponent_field())
     }
     fn extract_mantissa_field(self) -> Self::BitsType {
-        let mask = self.ctx().make(Self::FloatEncoding::MANTISSA_FIELD_MASK);
+        let mask = self.ctx().make(Self::PrimFloat::MANTISSA_FIELD_MASK);
         self.to_bits() & mask
     }
     fn sub_exponent_bias(exponent_field: Self::BitsType) -> Self::SignedBitsType {
-        exponent_field.to()
+        Self::SignedBitsType::cvt_from(exponent_field)
             - exponent_field
                 .ctx()
-                .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED)
+                .make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)
     }
     fn add_exponent_bias(exponent: Self::SignedBitsType) -> Self::BitsType {
-        (exponent
-            + exponent
-                .ctx()
-                .make(Self::FloatEncoding::EXPONENT_BIAS_SIGNED))
-        .to()
+        (exponent + exponent.ctx().make(Self::PrimFloat::EXPONENT_BIAS_SIGNED)).to()
     }
 }