From 8a170330691c442c16cf6a7c6606fc19493e9e81 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 17 May 2021 21:13:12 -0700 Subject: [PATCH] add count_leading_zeros, count_trailing_zeros, and count_ones implementations --- src/algorithms.rs | 1 + src/algorithms/base.rs | 8 +- src/algorithms/integer.rs | 341 ++++++++++++++++++++++++++++++++++++++ src/prim.rs | 3 + src/stdsimd.rs | 34 ++-- 5 files changed, 367 insertions(+), 20 deletions(-) create mode 100644 src/algorithms/integer.rs diff --git a/src/algorithms.rs b/src/algorithms.rs index cfa78b8..4278ac2 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -1,3 +1,4 @@ pub mod base; pub mod ilogb; +pub mod integer; pub mod trig_pi; diff --git a/src/algorithms/base.rs b/src/algorithms/base.rs index b4ec103..4ebd849 100644 --- a/src/algorithms/base.rs +++ b/src/algorithms/base.rs @@ -88,7 +88,9 @@ pub fn floor< let offset_value: VecF = v.abs() + offset; let rounded = (offset_value - offset).copy_sign(v); let need_round_down = v.lt(rounded); - let in_range_value = need_round_down.select(rounded - ctx.make(1.to()), rounded).copy_sign(v); + let in_range_value = need_round_down + .select(rounded - ctx.make(1.to()), rounded) + .copy_sign(v); big.select(v, in_range_value) } @@ -108,7 +110,9 @@ pub fn ceil< let offset_value: VecF = v.abs() + offset; let rounded = (offset_value - offset).copy_sign(v); let need_round_up = v.gt(rounded); - let in_range_value = need_round_up.select(rounded + ctx.make(1.to()), rounded).copy_sign(v); + let in_range_value = need_round_up + .select(rounded + ctx.make(1.to()), rounded) + .copy_sign(v); big.select(v, in_range_value) } diff --git a/src/algorithms/integer.rs b/src/algorithms/integer.rs new file mode 100644 index 0000000..1091723 --- /dev/null +++ b/src/algorithms/integer.rs @@ -0,0 +1,341 @@ +use crate::{ + prim::PrimUInt, + traits::{Context, ConvertFrom, ConvertTo, Make, SInt, Select, UInt}, +}; + +pub fn count_leading_zeros_uint< + Ctx: Context, + VecU: UInt + Make, + PrimU: PrimUInt, +>( + ctx: Ctx, + mut v: VecU, +) -> VecU { + let mut retval: VecU = ctx.make(PrimU::BITS); + let mut bits = PrimU::BITS; + while bits > 1.to() { + bits /= 2.to(); + let limit = PrimU::ONE << bits; + let found = v.ge(ctx.make(limit)); + let shift: VecU = found.select(ctx.make(bits), ctx.make(0.to())); + retval -= shift; + v >>= shift; + } + let nonzero = v.ne(ctx.make(0.to())); + retval - nonzero.select(ctx.make(1.to()), ctx.make(0.to())) +} + +pub fn count_leading_zeros_sint< + Ctx: Context, + VecU: UInt + Make + ConvertFrom, + VecS: SInt + ConvertFrom, +>( + ctx: Ctx, + v: VecS, +) -> VecS { + count_leading_zeros_uint(ctx, VecU::cvt_from(v)).to() +} + +pub fn count_trailing_zeros_uint< + Ctx: Context, + VecU: UInt + Make, + PrimU: PrimUInt, +>( + ctx: Ctx, + mut v: VecU, +) -> VecU { + let mut retval: VecU = ctx.make(PrimU::ZERO); + let mut bits = PrimU::BITS; + while bits > 1.to() { + bits /= 2.to(); + let mask = (PrimU::ONE << bits) - 1.to(); + let zero = (v & ctx.make(mask)).eq(ctx.make(0.to())); + let shift: VecU = zero.select(ctx.make(bits), ctx.make(0.to())); + retval += shift; + v >>= shift; + } + let zero = v.eq(ctx.make(0.to())); + retval + zero.select(ctx.make(1.to()), ctx.make(0.to())) +} + +pub fn count_trailing_zeros_sint< + Ctx: Context, + VecU: UInt + Make + ConvertFrom, + VecS: SInt + ConvertFrom, +>( + ctx: Ctx, + v: VecS, +) -> VecS { + count_trailing_zeros_uint(ctx, VecU::cvt_from(v)).to() +} + +pub fn count_ones_uint< + Ctx: Context, + VecU: UInt + Make, + PrimU: PrimUInt, +>( + ctx: Ctx, + mut v: VecU, +) -> VecU { + assert!(PrimU::BITS <= 64.to()); + assert!(PrimU::BITS >= 8.to()); + const SPLAT_BYTES_MULTIPLIER: u64 = u64::from_le_bytes([1; 8]); + const EVERY_OTHER_BIT_MASK: u64 = 0x55 * SPLAT_BYTES_MULTIPLIER; + const TWO_OUT_OF_FOUR_BITS_MASK: u64 = 0x33 * SPLAT_BYTES_MULTIPLIER; + const FOUR_OUT_OF_EIGHT_BITS_MASK: u64 = 0x0F * SPLAT_BYTES_MULTIPLIER; + // algorithm derived from popcount64c at https://en.wikipedia.org/wiki/Hamming_weight + v -= (v >> ctx.make(1.to())) & ctx.make(EVERY_OTHER_BIT_MASK.to()); + v = (v & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to())) + + ((v >> ctx.make(2.to())) & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to())); + v = (v & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to())) + + ((v >> ctx.make(4.to())) & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to())); + if PrimU::BITS > 8.to() { + v * ctx.make(SPLAT_BYTES_MULTIPLIER.to()) >> ctx.make(PrimU::BITS - 8.to()) + } else { + v + } +} + +pub fn count_ones_sint< + Ctx: Context, + VecU: UInt + Make + ConvertFrom, + VecS: SInt + ConvertFrom, +>( + ctx: Ctx, + v: VecS, +) -> VecS { + count_ones_uint(ctx, VecU::cvt_from(v)).to() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::scalar::{Scalar, Value}; + + #[test] + fn test_count_leading_zeros_u16() { + for v in 0..=u16::MAX { + assert_eq!( + v.leading_zeros() as u16, + count_leading_zeros_uint(Scalar, Value(v)).0, + "v = {:#X}", + v, + ); + } + } + + #[test] + fn test_count_trailing_zeros_u16() { + for v in 0..=u16::MAX { + assert_eq!( + v.trailing_zeros() as u16, + count_trailing_zeros_uint(Scalar, Value(v)).0, + "v = {:#X}", + v, + ); + } + } + + #[test] + fn test_count_ones_u16() { + for v in 0..=u16::MAX { + assert_eq!( + v.count_ones() as u16, + count_ones_uint(Scalar, Value(v)).0, + "v = {:#X}", + v, + ); + } + } +} + +#[cfg(all(feature = "ir", test))] +mod ir_tests { + use super::*; + use crate::ir::{IrContext, IrFunction, IrVecI64, IrVecU64, IrVecU8}; + use std::{format, println}; + + #[test] + fn test_display_count_leading_zeros_i64() { + let ctx = IrContext::new(); + fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { + let f: fn(&'ctx IrContext<'ctx>, IrVecI64<'ctx>) -> IrVecI64<'ctx> = + count_leading_zeros_sint; + IrFunction::make(ctx, f) + } + let text = format!("\n{}", make_it(&ctx)); + println!("{}", text); + assert_eq!( + text, + r" +function(in: vec) -> vec { + op_0: vec = Cast in + op_1: vec = CompareGe op_0, splat(0x100000000_u64) + op_2: vec = Select op_1, splat(0x20_u64), splat(0x0_u64) + op_3: vec = Sub splat(0x40_u64), op_2 + op_4: vec = Shr op_0, op_2 + op_5: vec = CompareGe op_4, splat(0x10000_u64) + op_6: vec = Select op_5, splat(0x10_u64), splat(0x0_u64) + op_7: vec = Sub op_3, op_6 + op_8: vec = Shr op_4, op_6 + op_9: vec = CompareGe op_8, splat(0x100_u64) + op_10: vec = Select op_9, splat(0x8_u64), splat(0x0_u64) + op_11: vec = Sub op_7, op_10 + op_12: vec = Shr op_8, op_10 + op_13: vec = CompareGe op_12, splat(0x10_u64) + op_14: vec = Select op_13, splat(0x4_u64), splat(0x0_u64) + op_15: vec = Sub op_11, op_14 + op_16: vec = Shr op_12, op_14 + op_17: vec = CompareGe op_16, splat(0x4_u64) + op_18: vec = Select op_17, splat(0x2_u64), splat(0x0_u64) + op_19: vec = Sub op_15, op_18 + op_20: vec = Shr op_16, op_18 + op_21: vec = CompareGe op_20, splat(0x2_u64) + op_22: vec = Select op_21, splat(0x1_u64), splat(0x0_u64) + op_23: vec = Sub op_19, op_22 + op_24: vec = Shr op_20, op_22 + op_25: vec = CompareNe op_24, splat(0x0_u64) + op_26: vec = Select op_25, splat(0x1_u64), splat(0x0_u64) + op_27: vec = Sub op_23, op_26 + op_28: vec = Cast op_27 + Return op_28 +} +" + ); + } + + #[test] + fn test_display_count_leading_zeros_u8() { + let ctx = IrContext::new(); + fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { + let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = + count_leading_zeros_uint; + IrFunction::make(ctx, f) + } + let text = format!("\n{}", make_it(&ctx)); + println!("{}", text); + assert_eq!( + text, + r" +function(in: vec) -> vec { + op_0: vec = CompareGe in, splat(0x10_u8) + op_1: vec = Select op_0, splat(0x4_u8), splat(0x0_u8) + op_2: vec = Sub splat(0x8_u8), op_1 + op_3: vec = Shr in, op_1 + op_4: vec = CompareGe op_3, splat(0x4_u8) + op_5: vec = Select op_4, splat(0x2_u8), splat(0x0_u8) + op_6: vec = Sub op_2, op_5 + op_7: vec = Shr op_3, op_5 + op_8: vec = CompareGe op_7, splat(0x2_u8) + op_9: vec = Select op_8, splat(0x1_u8), splat(0x0_u8) + op_10: vec = Sub op_6, op_9 + op_11: vec = Shr op_7, op_9 + op_12: vec = CompareNe op_11, splat(0x0_u8) + op_13: vec = Select op_12, splat(0x1_u8), splat(0x0_u8) + op_14: vec = Sub op_10, op_13 + Return op_14 +} +" + ); + } + + #[test] + fn test_display_count_trailing_zeros_u8() { + let ctx = IrContext::new(); + fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { + let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = + count_trailing_zeros_uint; + IrFunction::make(ctx, f) + } + let text = format!("\n{}", make_it(&ctx)); + println!("{}", text); + assert_eq!( + text, + r" +function(in: vec) -> vec { + op_0: vec = And in, splat(0xF_u8) + op_1: vec = CompareEq op_0, splat(0x0_u8) + op_2: vec = Select op_1, splat(0x4_u8), splat(0x0_u8) + op_3: vec = Add splat(0x0_u8), op_2 + op_4: vec = Shr in, op_2 + op_5: vec = And op_4, splat(0x3_u8) + op_6: vec = CompareEq op_5, splat(0x0_u8) + op_7: vec = Select op_6, splat(0x2_u8), splat(0x0_u8) + op_8: vec = Add op_3, op_7 + op_9: vec = Shr op_4, op_7 + op_10: vec = And op_9, splat(0x1_u8) + op_11: vec = CompareEq op_10, splat(0x0_u8) + op_12: vec = Select op_11, splat(0x1_u8), splat(0x0_u8) + op_13: vec = Add op_8, op_12 + op_14: vec = Shr op_9, op_12 + op_15: vec = CompareEq op_14, splat(0x0_u8) + op_16: vec = Select op_15, splat(0x1_u8), splat(0x0_u8) + op_17: vec = Add op_13, op_16 + Return op_17 +} +" + ); + } + + #[test] + fn test_display_count_ones_u8() { + let ctx = IrContext::new(); + fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { + let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = count_ones_uint; + IrFunction::make(ctx, f) + } + let text = format!("\n{}", make_it(&ctx)); + println!("{}", text); + assert_eq!( + text, + r" +function(in: vec) -> vec { + op_0: vec = Shr in, splat(0x1_u8) + op_1: vec = And op_0, splat(0x55_u8) + op_2: vec = Sub in, op_1 + op_3: vec = And op_2, splat(0x33_u8) + op_4: vec = Shr op_2, splat(0x2_u8) + op_5: vec = And op_4, splat(0x33_u8) + op_6: vec = Add op_3, op_5 + op_7: vec = And op_6, splat(0xF_u8) + op_8: vec = Shr op_6, splat(0x4_u8) + op_9: vec = And op_8, splat(0xF_u8) + op_10: vec = Add op_7, op_9 + Return op_10 +} +" + ); + } + + #[test] + fn test_display_count_ones_u64() { + let ctx = IrContext::new(); + fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { + let f: fn(&'ctx IrContext<'ctx>, IrVecU64<'ctx>) -> IrVecU64<'ctx> = count_ones_uint; + IrFunction::make(ctx, f) + } + let text = format!("\n{}", make_it(&ctx)); + println!("{}", text); + assert_eq!( + text, + r" +function(in: vec) -> vec { + op_0: vec = Shr in, splat(0x1_u64) + op_1: vec = And op_0, splat(0x5555555555555555_u64) + op_2: vec = Sub in, op_1 + op_3: vec = And op_2, splat(0x3333333333333333_u64) + op_4: vec = Shr op_2, splat(0x2_u64) + op_5: vec = And op_4, splat(0x3333333333333333_u64) + op_6: vec = Add op_3, op_5 + op_7: vec = And op_6, splat(0xF0F0F0F0F0F0F0F_u64) + op_8: vec = Shr op_6, splat(0x4_u64) + op_9: vec = And op_8, splat(0xF0F0F0F0F0F0F0F_u64) + op_10: vec = Add op_7, op_9 + op_11: vec = Mul op_10, splat(0x101010101010101_u64) + op_12: vec = Shr op_11, splat(0x38_u64) + Return op_12 +} +" + ); + } +} diff --git a/src/prim.rs b/src/prim.rs index ea7b6b9..7ba23e5 100644 --- a/src/prim.rs +++ b/src/prim.rs @@ -91,6 +91,7 @@ pub trait PrimInt: const ONE: Self; const MIN: Self; const MAX: Self; + const BITS: Self; } pub trait PrimUInt: PrimInt + ConvertFrom { @@ -110,12 +111,14 @@ macro_rules! impl_int { const ONE: Self = 1; const MIN: Self = 0; const MAX: Self = !0; + const BITS: Self = (0 as $uint).count_zeros() as $uint; } impl PrimInt for $sint { const ZERO: Self = 0; const ONE: Self = 1; const MIN: Self = $sint::MIN; const MAX: Self = $sint::MAX; + const BITS: Self = (0 as $sint).count_zeros() as $sint; } impl PrimUInt for $uint { type SignedType = $sint; diff --git a/src/stdsimd.rs b/src/stdsimd.rs index 3d757cc..35692c9 100644 --- a/src/stdsimd.rs +++ b/src/stdsimd.rs @@ -520,7 +520,7 @@ macro_rules! impl_int_scalar { } macro_rules! impl_int_vector { - ($ty:ident) => { + ($ty:ident, $count_leading_zeros:ident, $count_trailing_zeros:ident, $count_ones:ident) => { impl Int for Wrapper<$ty, LANES> where SimdI8: LanesAtMost32, @@ -539,27 +539,15 @@ macro_rules! impl_int_vector { Mask64: Mask, { fn leading_zeros(self) -> Self { - todo!() + crate::algorithms::integer::$count_leading_zeros(self.ctx(), self) } fn trailing_zeros(self) -> Self { - todo!() + crate::algorithms::integer::$count_trailing_zeros(self.ctx(), self) } fn count_ones(self) -> Self { - todo!() - } - - fn leading_ones(self) -> Self { - todo!() - } - - fn trailing_ones(self) -> Self { - todo!() - } - - fn count_zeros(self) -> Self { - todo!() + crate::algorithms::integer::$count_ones(self.ctx(), self) } } }; @@ -567,8 +555,18 @@ macro_rules! impl_int_vector { macro_rules! impl_uint_sint_vector { ($uint:ident, $sint:ident) => { - impl_int_vector!($uint); - impl_int_vector!($sint); + impl_int_vector!( + $uint, + count_leading_zeros_uint, + count_trailing_zeros_uint, + count_ones_uint + ); + impl_int_vector!( + $sint, + count_leading_zeros_sint, + count_trailing_zeros_sint, + count_ones_sint + ); impl UInt for Wrapper<$uint, LANES> where SimdI8: LanesAtMost32, -- 2.30.2