use crate::{ prim::PrimUInt, traits::{Context, ConvertFrom, ConvertTo, Make, SInt, Select, UInt}, }; pub fn count_leading_zeros_uint< Ctx: Context, VecU: UInt + Make, PrimU: PrimUInt, >( ctx: Ctx, mut v: VecU, ) -> VecU { let mut retval: VecU = ctx.make(PrimU::BITS); let mut bits = PrimU::BITS; while bits > 1.to() { bits /= 2.to(); let limit = PrimU::ONE << bits; let found = v.ge(ctx.make(limit)); let shift: VecU = found.select(ctx.make(bits), ctx.make(0.to())); retval -= shift; v >>= shift; } let nonzero = v.ne(ctx.make(0.to())); retval - nonzero.select(ctx.make(1.to()), ctx.make(0.to())) } pub fn count_leading_zeros_sint< Ctx: Context, VecU: UInt + Make + ConvertFrom, VecS: SInt + ConvertFrom, >( ctx: Ctx, v: VecS, ) -> VecS { count_leading_zeros_uint(ctx, VecU::cvt_from(v)).to() } pub fn count_trailing_zeros_uint< Ctx: Context, VecU: UInt + Make, PrimU: PrimUInt, >( ctx: Ctx, mut v: VecU, ) -> VecU { let mut retval: VecU = ctx.make(PrimU::ZERO); let mut bits = PrimU::BITS; while bits > 1.to() { bits /= 2.to(); let mask = (PrimU::ONE << bits) - 1.to(); let zero = (v & ctx.make(mask)).eq(ctx.make(0.to())); let shift: VecU = zero.select(ctx.make(bits), ctx.make(0.to())); retval += shift; v >>= shift; } let zero = v.eq(ctx.make(0.to())); retval + zero.select(ctx.make(1.to()), ctx.make(0.to())) } pub fn count_trailing_zeros_sint< Ctx: Context, VecU: UInt + Make + ConvertFrom, VecS: SInt + ConvertFrom, >( ctx: Ctx, v: VecS, ) -> VecS { count_trailing_zeros_uint(ctx, VecU::cvt_from(v)).to() } pub fn count_ones_uint< Ctx: Context, VecU: UInt + Make, PrimU: PrimUInt, >( ctx: Ctx, mut v: VecU, ) -> VecU { assert!(PrimU::BITS <= 64.to()); assert!(PrimU::BITS >= 8.to()); const SPLAT_BYTES_MULTIPLIER: u64 = u64::from_le_bytes([1; 8]); const EVERY_OTHER_BIT_MASK: u64 = 0x55 * SPLAT_BYTES_MULTIPLIER; const TWO_OUT_OF_FOUR_BITS_MASK: u64 = 0x33 * SPLAT_BYTES_MULTIPLIER; const FOUR_OUT_OF_EIGHT_BITS_MASK: u64 = 0x0F * SPLAT_BYTES_MULTIPLIER; // algorithm derived from popcount64c at https://en.wikipedia.org/wiki/Hamming_weight v -= (v >> ctx.make(1.to())) & ctx.make(EVERY_OTHER_BIT_MASK.to()); v = (v & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to())) + ((v >> ctx.make(2.to())) & ctx.make(TWO_OUT_OF_FOUR_BITS_MASK.to())); v = (v & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to())) + ((v >> ctx.make(4.to())) & ctx.make(FOUR_OUT_OF_EIGHT_BITS_MASK.to())); if PrimU::BITS > 8.to() { v * ctx.make(SPLAT_BYTES_MULTIPLIER.to()) >> ctx.make(PrimU::BITS - 8.to()) } else { v } } pub fn count_ones_sint< Ctx: Context, VecU: UInt + Make + ConvertFrom, VecS: SInt + ConvertFrom, >( ctx: Ctx, v: VecS, ) -> VecS { count_ones_uint(ctx, VecU::cvt_from(v)).to() } #[cfg(test)] mod tests { use super::*; use crate::scalar::{Scalar, Value}; #[test] fn test_count_leading_zeros_u16() { for v in 0..=u16::MAX { assert_eq!( v.leading_zeros() as u16, count_leading_zeros_uint(Scalar, Value(v)).0, "v = {:#X}", v, ); } } #[test] fn test_count_trailing_zeros_u16() { for v in 0..=u16::MAX { assert_eq!( v.trailing_zeros() as u16, count_trailing_zeros_uint(Scalar, Value(v)).0, "v = {:#X}", v, ); } } #[test] fn test_count_ones_u16() { for v in 0..=u16::MAX { assert_eq!( v.count_ones() as u16, count_ones_uint(Scalar, Value(v)).0, "v = {:#X}", v, ); } } } #[cfg(all(feature = "ir", test))] mod ir_tests { use super::*; use crate::ir::{IrContext, IrFunction, IrVecI64, IrVecU64, IrVecU8}; use std::{format, println}; #[test] fn test_display_count_leading_zeros_i64() { let ctx = IrContext::new(); fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { let f: fn(&'ctx IrContext<'ctx>, IrVecI64<'ctx>) -> IrVecI64<'ctx> = count_leading_zeros_sint; IrFunction::make(ctx, f) } let text = format!("\n{}", make_it(&ctx)); println!("{}", text); assert_eq!( text, r" function(in: vec) -> vec { op_0: vec = Cast in op_1: vec = CompareGe op_0, splat(0x100000000_u64) op_2: vec = Select op_1, splat(0x20_u64), splat(0x0_u64) op_3: vec = Sub splat(0x40_u64), op_2 op_4: vec = Shr op_0, op_2 op_5: vec = CompareGe op_4, splat(0x10000_u64) op_6: vec = Select op_5, splat(0x10_u64), splat(0x0_u64) op_7: vec = Sub op_3, op_6 op_8: vec = Shr op_4, op_6 op_9: vec = CompareGe op_8, splat(0x100_u64) op_10: vec = Select op_9, splat(0x8_u64), splat(0x0_u64) op_11: vec = Sub op_7, op_10 op_12: vec = Shr op_8, op_10 op_13: vec = CompareGe op_12, splat(0x10_u64) op_14: vec = Select op_13, splat(0x4_u64), splat(0x0_u64) op_15: vec = Sub op_11, op_14 op_16: vec = Shr op_12, op_14 op_17: vec = CompareGe op_16, splat(0x4_u64) op_18: vec = Select op_17, splat(0x2_u64), splat(0x0_u64) op_19: vec = Sub op_15, op_18 op_20: vec = Shr op_16, op_18 op_21: vec = CompareGe op_20, splat(0x2_u64) op_22: vec = Select op_21, splat(0x1_u64), splat(0x0_u64) op_23: vec = Sub op_19, op_22 op_24: vec = Shr op_20, op_22 op_25: vec = CompareNe op_24, splat(0x0_u64) op_26: vec = Select op_25, splat(0x1_u64), splat(0x0_u64) op_27: vec = Sub op_23, op_26 op_28: vec = Cast op_27 Return op_28 } " ); } #[test] fn test_display_count_leading_zeros_u8() { let ctx = IrContext::new(); fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = count_leading_zeros_uint; IrFunction::make(ctx, f) } let text = format!("\n{}", make_it(&ctx)); println!("{}", text); assert_eq!( text, r" function(in: vec) -> vec { op_0: vec = CompareGe in, splat(0x10_u8) op_1: vec = Select op_0, splat(0x4_u8), splat(0x0_u8) op_2: vec = Sub splat(0x8_u8), op_1 op_3: vec = Shr in, op_1 op_4: vec = CompareGe op_3, splat(0x4_u8) op_5: vec = Select op_4, splat(0x2_u8), splat(0x0_u8) op_6: vec = Sub op_2, op_5 op_7: vec = Shr op_3, op_5 op_8: vec = CompareGe op_7, splat(0x2_u8) op_9: vec = Select op_8, splat(0x1_u8), splat(0x0_u8) op_10: vec = Sub op_6, op_9 op_11: vec = Shr op_7, op_9 op_12: vec = CompareNe op_11, splat(0x0_u8) op_13: vec = Select op_12, splat(0x1_u8), splat(0x0_u8) op_14: vec = Sub op_10, op_13 Return op_14 } " ); } #[test] fn test_display_count_trailing_zeros_u8() { let ctx = IrContext::new(); fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = count_trailing_zeros_uint; IrFunction::make(ctx, f) } let text = format!("\n{}", make_it(&ctx)); println!("{}", text); assert_eq!( text, r" function(in: vec) -> vec { op_0: vec = And in, splat(0xF_u8) op_1: vec = CompareEq op_0, splat(0x0_u8) op_2: vec = Select op_1, splat(0x4_u8), splat(0x0_u8) op_3: vec = Add splat(0x0_u8), op_2 op_4: vec = Shr in, op_2 op_5: vec = And op_4, splat(0x3_u8) op_6: vec = CompareEq op_5, splat(0x0_u8) op_7: vec = Select op_6, splat(0x2_u8), splat(0x0_u8) op_8: vec = Add op_3, op_7 op_9: vec = Shr op_4, op_7 op_10: vec = And op_9, splat(0x1_u8) op_11: vec = CompareEq op_10, splat(0x0_u8) op_12: vec = Select op_11, splat(0x1_u8), splat(0x0_u8) op_13: vec = Add op_8, op_12 op_14: vec = Shr op_9, op_12 op_15: vec = CompareEq op_14, splat(0x0_u8) op_16: vec = Select op_15, splat(0x1_u8), splat(0x0_u8) op_17: vec = Add op_13, op_16 Return op_17 } " ); } #[test] fn test_display_count_ones_u8() { let ctx = IrContext::new(); fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>) -> IrVecU8<'ctx> = count_ones_uint; IrFunction::make(ctx, f) } let text = format!("\n{}", make_it(&ctx)); println!("{}", text); assert_eq!( text, r" function(in: vec) -> vec { op_0: vec = Shr in, splat(0x1_u8) op_1: vec = And op_0, splat(0x55_u8) op_2: vec = Sub in, op_1 op_3: vec = And op_2, splat(0x33_u8) op_4: vec = Shr op_2, splat(0x2_u8) op_5: vec = And op_4, splat(0x33_u8) op_6: vec = Add op_3, op_5 op_7: vec = And op_6, splat(0xF_u8) op_8: vec = Shr op_6, splat(0x4_u8) op_9: vec = And op_8, splat(0xF_u8) op_10: vec = Add op_7, op_9 Return op_10 } " ); } #[test] fn test_display_count_ones_u64() { let ctx = IrContext::new(); fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { let f: fn(&'ctx IrContext<'ctx>, IrVecU64<'ctx>) -> IrVecU64<'ctx> = count_ones_uint; IrFunction::make(ctx, f) } let text = format!("\n{}", make_it(&ctx)); println!("{}", text); assert_eq!( text, r" function(in: vec) -> vec { op_0: vec = Shr in, splat(0x1_u64) op_1: vec = And op_0, splat(0x5555555555555555_u64) op_2: vec = Sub in, op_1 op_3: vec = And op_2, splat(0x3333333333333333_u64) op_4: vec = Shr op_2, splat(0x2_u64) op_5: vec = And op_4, splat(0x3333333333333333_u64) op_6: vec = Add op_3, op_5 op_7: vec = And op_6, splat(0xF0F0F0F0F0F0F0F_u64) op_8: vec = Shr op_6, splat(0x4_u64) op_9: vec = And op_8, splat(0xF0F0F0F0F0F0F0F_u64) op_10: vec = Add op_7, op_9 op_11: vec = Mul op_10, splat(0x101010101010101_u64) op_12: vec = Shr op_11, splat(0x38_u64) Return op_12 } " ); } }