add count_leading_zeros, count_trailing_zeros, and count_ones implementations
[vector-math.git] / src / stdsimd.rs
index 046a3379baa99d7d8cf60226b01c76ceab0d61ba..35692c9246c2bd562cd007ddde475f90f8ce4682 100644 (file)
@@ -2,7 +2,10 @@
 use crate::f16::panic_f16_feature_disabled;
 use crate::{
     f16::F16,
-    traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt},
+    prim::PrimFloat,
+    traits::{
+        Bool, Compare, Context, ConvertFrom, ConvertTo, Float, Int, Make, SInt, Select, UInt,
+    },
 };
 use core::{
     marker::PhantomData,
@@ -293,7 +296,11 @@ where
     Mask64<LANES>: Mask,
 {
     fn select(self, true_v: V, false_v: V) -> V {
-        self.0.select(true_v, false_v)
+        if self.0 {
+            true_v
+        } else {
+            false_v
+        }
     }
 }
 
@@ -319,27 +326,27 @@ macro_rules! impl_scalar_compare {
             type Bool = Wrapper<bool, LANES>;
 
             fn eq(self, rhs: Self) -> Self::Bool {
-                self.0.eq(rhs.0).into()
+                self.0.eq(&rhs.0).into()
             }
 
             fn ne(self, rhs: Self) -> Self::Bool {
-                self.0.ne(rhs.0).into()
+                self.0.ne(&rhs.0).into()
             }
 
             fn lt(self, rhs: Self) -> Self::Bool {
-                self.0.lt(rhs.0).into()
+                self.0.lt(&rhs.0).into()
             }
 
             fn gt(self, rhs: Self) -> Self::Bool {
-                self.0.gt(rhs.0).into()
+                self.0.gt(&rhs.0).into()
             }
 
             fn le(self, rhs: Self) -> Self::Bool {
-                self.0.le(rhs.0).into()
+                self.0.le(&rhs.0).into()
             }
 
             fn ge(self, rhs: Self) -> Self::Bool {
-                self.0.ge(rhs.0).into()
+                self.0.ge(&rhs.0).into()
             }
         }
     };
@@ -513,7 +520,7 @@ macro_rules! impl_int_scalar {
 }
 
 macro_rules! impl_int_vector {
-    ($ty:ident) => {
+    ($ty:ident, $count_leading_zeros:ident, $count_trailing_zeros:ident, $count_ones:ident) => {
         impl<const LANES: usize> Int for Wrapper<$ty<LANES>, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
@@ -532,36 +539,35 @@ macro_rules! impl_int_vector {
             Mask64<LANES>: Mask,
         {
             fn leading_zeros(self) -> Self {
-                todo!()
+                crate::algorithms::integer::$count_leading_zeros(self.ctx(), self)
             }
 
             fn trailing_zeros(self) -> Self {
-                todo!()
+                crate::algorithms::integer::$count_trailing_zeros(self.ctx(), self)
             }
 
             fn count_ones(self) -> Self {
-                todo!()
-            }
-
-            fn leading_ones(self) -> Self {
-                todo!()
-            }
-
-            fn trailing_ones(self) -> Self {
-                todo!()
-            }
-
-            fn count_zeros(self) -> Self {
-                todo!()
+                crate::algorithms::integer::$count_ones(self.ctx(), self)
             }
         }
     };
 }
 
-macro_rules! impl_uint_vector {
-    ($ty:ident) => {
-        impl_int_vector!($ty);
-        impl<const LANES: usize> UInt for Wrapper<$ty<LANES>, LANES>
+macro_rules! impl_uint_sint_vector {
+    ($uint:ident, $sint:ident) => {
+        impl_int_vector!(
+            $uint,
+            count_leading_zeros_uint,
+            count_trailing_zeros_uint,
+            count_ones_uint
+        );
+        impl_int_vector!(
+            $sint,
+            count_leading_zeros_sint,
+            count_trailing_zeros_sint,
+            count_ones_sint
+        );
+        impl<const LANES: usize> UInt for Wrapper<$uint<LANES>, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
             SimdU8<LANES>: LanesAtMost32,
@@ -578,19 +584,11 @@ macro_rules! impl_uint_vector {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
+            type PrimUInt = Self::Prim;
+            type SignedType = Wrapper<$sint<LANES>, LANES>;
         }
-    };
-}
 
-impl_uint_vector!(SimdU8);
-impl_uint_vector!(SimdU16);
-impl_uint_vector!(SimdU32);
-impl_uint_vector!(SimdU64);
-
-macro_rules! impl_uint_scalar {
-    ($ty:ident) => {
-        impl_int_scalar!($ty);
-        impl<const LANES: usize> UInt for Wrapper<$ty, LANES>
+        impl<const LANES: usize> SInt for Wrapper<$sint<LANES>, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
             SimdU8<LANES>: LanesAtMost32,
@@ -607,19 +605,22 @@ macro_rules! impl_uint_scalar {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
+            type PrimSInt = Self::Prim;
+            type UnsignedType = Wrapper<$uint<LANES>, LANES>;
         }
     };
 }
 
-impl_uint_scalar!(u8);
-impl_uint_scalar!(u16);
-impl_uint_scalar!(u32);
-impl_uint_scalar!(u64);
+impl_uint_sint_vector!(SimdU8, SimdI8);
+impl_uint_sint_vector!(SimdU16, SimdI16);
+impl_uint_sint_vector!(SimdU32, SimdI32);
+impl_uint_sint_vector!(SimdU64, SimdI64);
 
-macro_rules! impl_sint_vector {
-    ($ty:ident) => {
-        impl_int_vector!($ty);
-        impl<const LANES: usize> SInt for Wrapper<$ty<LANES>, LANES>
+macro_rules! impl_uint_sint_scalar {
+    ($uint:ident, $sint:ident) => {
+        impl_int_scalar!($uint);
+        impl_int_scalar!($sint);
+        impl<const LANES: usize> UInt for Wrapper<$uint, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
             SimdU8<LANES>: LanesAtMost32,
@@ -636,19 +637,11 @@ macro_rules! impl_sint_vector {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
+            type PrimUInt = Self::Prim;
+            type SignedType = Wrapper<$sint, LANES>;
         }
-    };
-}
-
-impl_sint_vector!(SimdI8);
-impl_sint_vector!(SimdI16);
-impl_sint_vector!(SimdI32);
-impl_sint_vector!(SimdI64);
 
-macro_rules! impl_sint_scalar {
-    ($ty:ident) => {
-        impl_int_scalar!($ty);
-        impl<const LANES: usize> SInt for Wrapper<$ty, LANES>
+        impl<const LANES: usize> SInt for Wrapper<$sint, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
             SimdU8<LANES>: LanesAtMost32,
@@ -665,14 +658,16 @@ macro_rules! impl_sint_scalar {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
+            type PrimSInt = Self::Prim;
+            type UnsignedType = Wrapper<$uint, LANES>;
         }
     };
 }
 
-impl_sint_scalar!(i8);
-impl_sint_scalar!(i16);
-impl_sint_scalar!(i32);
-impl_sint_scalar!(i64);
+impl_uint_sint_scalar!(u8, i8);
+impl_uint_sint_scalar!(u16, i16);
+impl_uint_sint_scalar!(u32, i32);
+impl_uint_sint_scalar!(u64, i64);
 
 macro_rules! impl_float {
     ($ty:ident, $prim:ident, $uint:ident, $sint:ident) => {
@@ -693,11 +688,11 @@ macro_rules! impl_float {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
-            type FloatEncoding = $prim;
+            type PrimFloat = $prim;
 
-            type BitsType = Wrapper<<$prim as Float>::BitsType, LANES>;
+            type BitsType = Wrapper<<$prim as PrimFloat>::BitsType, LANES>;
 
-            type SignedBitsType = Wrapper<<$prim as Float>::SignedBitsType, LANES>;
+            type SignedBitsType = Wrapper<<$prim as PrimFloat>::SignedBitsType, LANES>;
 
             fn abs(self) -> Self {
                 self.0.abs().into()
@@ -719,8 +714,12 @@ macro_rules! impl_float {
                 self.0.round().into()
             }
 
+            #[cfg(feature = "fma")]
             fn fma(self, a: Self, b: Self) -> Self {
-                self.0.fma(a.0, b.0).into()
+                use crate::scalar::Value;
+                let a = Value(a.0);
+                let b = Value(b.0);
+                Value(self.0).fma(a, b).0.into()
             }
 
             fn is_finite(self) -> Self::Bool {
@@ -753,7 +752,7 @@ macro_rules! impl_float {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
-            type FloatEncoding = $prim;
+            type PrimFloat = $prim;
 
             type BitsType = Wrapper<$uint<LANES>, LANES>;
 
@@ -779,6 +778,7 @@ macro_rules! impl_float {
                 self.0.round().into()
             }
 
+            #[cfg(feature = "fma")]
             fn fma(self, _a: Self, _b: Self) -> Self {
                 // FIXME(programmerjake): implement once core_simd gains support:
                 // https://github.com/rust-lang/stdsimd/issues/102
@@ -804,9 +804,9 @@ impl_float!(SimdF16, F16, SimdU16, SimdI16);
 impl_float!(SimdF32, f32, SimdU32, SimdI32);
 impl_float!(SimdF64, f64, SimdU64, SimdI64);
 
-macro_rules! impl_scalar_convert_to_helper {
+macro_rules! impl_vector_convert_from_helper {
     ($src:ty => $dest:ty) => {
-        impl<const LANES: usize> ConvertTo<Wrapper<$dest, LANES>> for Wrapper<$src, LANES>
+        impl<const LANES: usize> ConvertFrom<Wrapper<$src, LANES>> for Wrapper<$dest, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
             SimdU8<LANES>: LanesAtMost32,
@@ -823,31 +823,31 @@ macro_rules! impl_scalar_convert_to_helper {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
-            fn to(self) -> Wrapper<$dest, LANES> {
-                let v: $dest = self.0.to();
+            fn cvt_from(v: Wrapper<$src, LANES>) -> Self {
+                let v: $dest = v.0.to();
                 v.into()
             }
         }
     };
 }
 
-macro_rules! impl_scalar_convert_to {
+macro_rules! impl_vector_convert_from {
     ($first:ty $(, $ty:ty)*) => {
         $(
-            impl_scalar_convert_to_helper!($first => $ty);
-            impl_scalar_convert_to_helper!($ty => $first);
+            impl_vector_convert_from_helper!($first => $ty);
+            impl_vector_convert_from_helper!($ty => $first);
         )*
-        impl_scalar_convert_to![$($ty),*];
+        impl_vector_convert_from![$($ty),*];
     };
     () => {};
 }
 
-impl_scalar_convert_to![u8, i8, u16, i16, F16, u32, i32, u64, i64, f32, f64];
+impl_vector_convert_from![u8, i8, u16, i16, F16, u32, i32, u64, i64, f32, f64];
 
-macro_rules! impl_vector_convert_to_helper {
+macro_rules! impl_vector_convert_from_helper {
     (($(#[From = $From:ident])? $src:ident, $src_prim:ident) => ($(#[From = $From2:ident])? $dest:ident, $dest_prim:ident)) => {
-        impl<const LANES: usize> ConvertTo<Wrapper<$dest<LANES>, LANES>>
-            for Wrapper<$src<LANES>, LANES>
+        impl<const LANES: usize> ConvertFrom<Wrapper<$src<LANES>, LANES>>
+            for Wrapper<$dest<LANES>, LANES>
         where
             SimdI8<LANES>: LanesAtMost32,
             SimdU8<LANES>: LanesAtMost32,
@@ -864,9 +864,9 @@ macro_rules! impl_vector_convert_to_helper {
             SimdF64<LANES>: LanesAtMost32,
             Mask64<LANES>: Mask,
         {
-            fn to(self) -> Wrapper<$dest<LANES>, LANES> {
+            fn cvt_from(v: Wrapper<$src<LANES>, LANES>) -> Self {
                 // FIXME(programmerjake): workaround https://github.com/rust-lang/stdsimd/issues/116
-                let src: [$src_prim; LANES] = self.0.into();
+                let src: [$src_prim; LANES] = v.0.into();
                 let mut dest: [$dest_prim; LANES] = [Default::default(); LANES];
                 for i in 0..LANES {
                     dest[i] = src[i].to();
@@ -899,18 +899,18 @@ macro_rules! impl_vector_convert_to_helper {
     };
 }
 
-macro_rules! impl_vector_convert_to {
+macro_rules! impl_vector_convert_from {
     ($first:tt $(, $ty:tt)*) => {
         $(
-            impl_vector_convert_to_helper!($first => $ty);
-            impl_vector_convert_to_helper!($ty => $first);
+            impl_vector_convert_from_helper!($first => $ty);
+            impl_vector_convert_from_helper!($ty => $first);
         )*
-        impl_vector_convert_to![$($ty),*];
+        impl_vector_convert_from![$($ty),*];
     };
     () => {};
 }
 
-impl_vector_convert_to![
+impl_vector_convert_from![
     (SimdU8, u8),
     (SimdI8, i8),
     (SimdU16, u16),
@@ -924,7 +924,7 @@ impl_vector_convert_to![
     (SimdF64, f64)
 ];
 
-impl_vector_convert_to![
+impl_vector_convert_from![
     (
         #[From = From]
         Mask8,
@@ -967,7 +967,7 @@ macro_rules! impl_from_helper {
             Mask64<LANES>: Mask,
         {
             fn from(v: $src) -> Self {
-                <$src as ConvertTo<$dest>>::to(v)
+                <$dest>::cvt_from(v)
             }
         }
     };