add count_leading_zeros, count_trailing_zeros, and count_ones implementations
[vector-math.git] / src / ir.rs
index 3eaccb987d6fa0f239d208470350f92046e84de1..279902041a39bfcdee4f2385cd550f8891d16374 100644 (file)
--- a/src/ir.rs
+++ b/src/ir.rs
@@ -1,6 +1,8 @@
 use crate::{
     f16::F16,
-    traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt},
+    traits::{
+        Bool, Compare, Context, ConvertFrom, ConvertTo, Float, Int, Make, SInt, Select, UInt,
+    },
 };
 use std::{
     borrow::Borrow,
@@ -229,14 +231,14 @@ impl fmt::Display for ScalarConstant {
         match self {
             ScalarConstant::Bool(false) => write!(f, "false"),
             ScalarConstant::Bool(true) => write!(f, "true"),
-            ScalarConstant::U8(v) => write!(f, "{}_u8", v),
-            ScalarConstant::U16(v) => write!(f, "{}_u16", v),
-            ScalarConstant::U32(v) => write!(f, "{}_u32", v),
-            ScalarConstant::U64(v) => write!(f, "{}_u64", v),
-            ScalarConstant::I8(v) => write!(f, "{}_i8", v),
-            ScalarConstant::I16(v) => write!(f, "{}_i16", v),
-            ScalarConstant::I32(v) => write!(f, "{}_i32", v),
-            ScalarConstant::I64(v) => write!(f, "{}_i64", v),
+            ScalarConstant::U8(v) => write!(f, "{:#X}_u8", v),
+            ScalarConstant::U16(v) => write!(f, "{:#X}_u16", v),
+            ScalarConstant::U32(v) => write!(f, "{:#X}_u32", v),
+            ScalarConstant::U64(v) => write!(f, "{:#X}_u64", v),
+            ScalarConstant::I8(v) => write!(f, "{:#X}_i8", v),
+            ScalarConstant::I16(v) => write!(f, "{:#X}_i16", v),
+            ScalarConstant::I32(v) => write!(f, "{:#X}_i32", v),
+            ScalarConstant::I64(v) => write!(f, "{:#X}_i64", v),
             ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits),
             ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits),
             ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits),
@@ -399,6 +401,9 @@ make_enum! {
         Not,
         Shl,
         Shr,
+        CountSetBits,
+        CountLeadingZeros,
+        CountTrailingZeros,
         Neg,
         Abs,
         Trunc,
@@ -685,7 +690,7 @@ impl_ir_function_maker_io!(
     in11: In11,
 );
 
-pub trait IrValue<'ctx>: Copy {
+pub trait IrValue<'ctx>: Copy + Make<Context = &'ctx IrContext<'ctx>> {
     const TYPE: Type;
     fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self;
     fn make_input<N: Borrow<str> + Into<String>>(
@@ -695,7 +700,6 @@ pub trait IrValue<'ctx>: Copy {
         let input = ctx.make_input(name, Self::TYPE);
         (Self::new(ctx, input.into()), input)
     }
-    fn ctx(self) -> &'ctx IrContext<'ctx>;
     fn value(self) -> Value<'ctx>;
 }
 
@@ -713,9 +717,6 @@ macro_rules! ir_value {
                 assert_eq!(value.ty(), Self::TYPE);
                 Self { ctx, value }
             }
-            fn ctx(self) -> &'ctx IrContext<'ctx> {
-                self.ctx
-            }
             fn value(self) -> Value<'ctx> {
                 self.value
             }
@@ -725,10 +726,13 @@ macro_rules! ir_value {
             pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type;
         }
 
-        impl<'ctx> Make<&'ctx IrContext<'ctx>> for $name<'ctx> {
+        impl<'ctx> Make for $name<'ctx> {
             type Prim = $prim;
-
-            fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self {
+            type Context = &'ctx IrContext<'ctx>;
+            fn ctx(self) -> Self::Context {
+                self.ctx
+            }
+            fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
                 let value: ScalarConstant = $make;
                 let value = value.into();
                 Self { value, ctx }
@@ -747,9 +751,6 @@ macro_rules! ir_value {
                 assert_eq!(value.ty(), Self::TYPE);
                 Self { ctx, value }
             }
-            fn ctx(self) -> &'ctx IrContext<'ctx> {
-                self.ctx
-            }
             fn value(self) -> Value<'ctx> {
                 self.value
             }
@@ -761,10 +762,13 @@ macro_rules! ir_value {
             };
         }
 
-        impl<'ctx> Make<&'ctx IrContext<'ctx>> for $vec_name<'ctx> {
+        impl<'ctx> Make for $vec_name<'ctx> {
             type Prim = $prim;
-
-            fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self {
+            type Context = &'ctx IrContext<'ctx>;
+            fn ctx(self) -> Self::Context {
+                self.ctx
+            }
+            fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
                 let element = $make;
                 Self {
                     value: VectorSplatConstant { element }.into(),
@@ -807,6 +811,23 @@ macro_rules! ir_value {
             }
         }
 
+        impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> {
+            fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
+                let value = self
+                    .ctx
+                    .make_operation(
+                        Opcode::Select,
+                        [self.value, true_v.value, false_v.value],
+                        $vec_name::TYPE,
+                    )
+                    .into();
+                $vec_name {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+        }
+
         impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
             fn from(v: $name<'ctx>) -> Self {
                 let value = v
@@ -1058,12 +1079,41 @@ macro_rules! impl_number_ops {
     };
 }
 
+macro_rules! impl_bool_compare {
+    ($ty:ident) => {
+        impl<'ctx> Compare for $ty<'ctx> {
+            type Bool = Self;
+            fn eq(self, rhs: Self) -> Self::Bool {
+                !(self ^ rhs)
+            }
+            fn ne(self, rhs: Self) -> Self::Bool {
+                self ^ rhs
+            }
+            fn lt(self, rhs: Self) -> Self::Bool {
+                !self & rhs
+            }
+            fn gt(self, rhs: Self) -> Self::Bool {
+                self & !rhs
+            }
+            fn le(self, rhs: Self) -> Self::Bool {
+                !self | rhs
+            }
+            fn ge(self, rhs: Self) -> Self::Bool {
+                self | !rhs
+            }
+        }
+    };
+}
+
+impl_bool_compare!(IrBool);
+impl_bool_compare!(IrVecBool);
+
 macro_rules! impl_shift_ops {
-    ($ty:ident, $rhs:ident) => {
-        impl<'ctx> Shl<$rhs<'ctx>> for $ty<'ctx> {
+    ($ty:ident) => {
+        impl<'ctx> Shl for $ty<'ctx> {
             type Output = Self;
 
-            fn shl(self, rhs: $rhs<'ctx>) -> Self::Output {
+            fn shl(self, rhs: Self) -> Self::Output {
                 let value = self
                     .ctx
                     .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
@@ -1074,10 +1124,10 @@ macro_rules! impl_shift_ops {
                 }
             }
         }
-        impl<'ctx> Shr<$rhs<'ctx>> for $ty<'ctx> {
+        impl<'ctx> Shr for $ty<'ctx> {
             type Output = Self;
 
-            fn shr(self, rhs: $rhs<'ctx>) -> Self::Output {
+            fn shr(self, rhs: Self) -> Self::Output {
                 let value = self
                     .ctx
                     .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
@@ -1088,13 +1138,13 @@ macro_rules! impl_shift_ops {
                 }
             }
         }
-        impl<'ctx> ShlAssign<$rhs<'ctx>> for $ty<'ctx> {
-            fn shl_assign(&mut self, rhs: $rhs<'ctx>) {
+        impl<'ctx> ShlAssign for $ty<'ctx> {
+            fn shl_assign(&mut self, rhs: Self) {
                 *self = *self << rhs;
             }
         }
-        impl<'ctx> ShrAssign<$rhs<'ctx>> for $ty<'ctx> {
-            fn shr_assign(&mut self, rhs: $rhs<'ctx>) {
+        impl<'ctx> ShrAssign for $ty<'ctx> {
+            fn shr_assign(&mut self, rhs: Self) {
                 *self = *self >> rhs;
             }
         }
@@ -1120,54 +1170,93 @@ macro_rules! impl_neg {
     };
 }
 
+macro_rules! impl_int_trait {
+    ($ty:ident) => {
+        impl<'ctx> Int for $ty<'ctx> {
+            fn leading_zeros(self) -> Self {
+                let value = self
+                    .ctx
+                    .make_operation(Opcode::CountLeadingZeros, [self.value], Self::TYPE)
+                    .into();
+                Self {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+            fn trailing_zeros(self) -> Self {
+                let value = self
+                    .ctx
+                    .make_operation(Opcode::CountTrailingZeros, [self.value], Self::TYPE)
+                    .into();
+                Self {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+            fn count_ones(self) -> Self {
+                let value = self
+                    .ctx
+                    .make_operation(Opcode::CountSetBits, [self.value], Self::TYPE)
+                    .into();
+                Self {
+                    value,
+                    ctx: self.ctx,
+                }
+            }
+        }
+    };
+}
+
 macro_rules! impl_integer_ops {
     ($scalar:ident, $vec:ident) => {
         impl_bit_ops!($scalar);
         impl_number_ops!($scalar, IrBool);
-        impl_shift_ops!($scalar, IrU32);
+        impl_shift_ops!($scalar);
         impl_bit_ops!($vec);
         impl_number_ops!($vec, IrVecBool);
-        impl_shift_ops!($vec, IrVecU32);
-
-        impl<'ctx> Int<IrU32<'ctx>> for $scalar<'ctx> {}
-        impl<'ctx> Int<IrVecU32<'ctx>> for $vec<'ctx> {}
-    };
-}
-
-macro_rules! impl_uint_ops {
-    ($scalar:ident, $vec:ident) => {
-        impl_integer_ops!($scalar, $vec);
-
-        impl<'ctx> UInt<IrU32<'ctx>> for $scalar<'ctx> {}
-        impl<'ctx> UInt<IrVecU32<'ctx>> for $vec<'ctx> {}
+        impl_shift_ops!($vec);
+        impl_int_trait!($scalar);
+        impl_int_trait!($vec);
     };
 }
 
-impl_uint_ops!(IrU8, IrVecU8);
-impl_uint_ops!(IrU16, IrVecU16);
-impl_uint_ops!(IrU32, IrVecU32);
-impl_uint_ops!(IrU64, IrVecU64);
+macro_rules! impl_uint_sint_ops {
+    ($uint_scalar:ident, $uint_vec:ident, $sint_scalar:ident, $sint_vec:ident) => {
+        impl_integer_ops!($uint_scalar, $uint_vec);
+        impl_integer_ops!($sint_scalar, $sint_vec);
+        impl_neg!($sint_scalar);
+        impl_neg!($sint_vec);
 
-macro_rules! impl_sint_ops {
-    ($scalar:ident, $vec:ident) => {
-        impl_integer_ops!($scalar, $vec);
-        impl_neg!($scalar);
-        impl_neg!($vec);
-
-        impl<'ctx> SInt<IrU32<'ctx>> for $scalar<'ctx> {}
-        impl<'ctx> SInt<IrVecU32<'ctx>> for $vec<'ctx> {}
+        impl<'ctx> UInt for $uint_scalar<'ctx> {
+            type PrimUInt = Self::Prim;
+            type SignedType = $sint_scalar<'ctx>;
+        }
+        impl<'ctx> UInt for $uint_vec<'ctx> {
+            type PrimUInt = Self::Prim;
+            type SignedType = $sint_vec<'ctx>;
+        }
+        impl<'ctx> SInt for $sint_scalar<'ctx> {
+            type PrimSInt = Self::Prim;
+            type UnsignedType = $uint_scalar<'ctx>;
+        }
+        impl<'ctx> SInt for $sint_vec<'ctx> {
+            type PrimSInt = Self::Prim;
+            type UnsignedType = $uint_vec<'ctx>;
+        }
     };
 }
 
-impl_sint_ops!(IrI8, IrVecI8);
-impl_sint_ops!(IrI16, IrVecI16);
-impl_sint_ops!(IrI32, IrVecI32);
-impl_sint_ops!(IrI64, IrVecI64);
+impl_uint_sint_ops!(IrU8, IrVecU8, IrI8, IrVecI8);
+impl_uint_sint_ops!(IrU16, IrVecU16, IrI16, IrVecI16);
+impl_uint_sint_ops!(IrU32, IrVecU32, IrI32, IrVecI32);
+impl_uint_sint_ops!(IrU64, IrVecU64, IrI64, IrVecI64);
 
 macro_rules! impl_float {
-    ($float:ident, $bits:ident, $u32:ident) => {
-        impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> {
+    ($float:ident, $bits:ident, $signed_bits:ident) => {
+        impl<'ctx> Float for $float<'ctx> {
+            type PrimFloat = <$float<'ctx> as Make>::Prim;
             type BitsType = $bits<'ctx>;
+            type SignedBitsType = $signed_bits<'ctx>;
             fn abs(self) -> Self {
                 let value = self
                     .ctx
@@ -1285,19 +1374,19 @@ macro_rules! impl_float {
 }
 
 macro_rules! impl_float_ops {
-    ($scalar:ident, $scalar_bits:ident, $vec:ident, $vec_bits:ident) => {
+    ($scalar:ident, $scalar_bits:ident, $scalar_signed_bits:ident, $vec:ident, $vec_bits:ident, $vec_signed_bits:ident) => {
         impl_number_ops!($scalar, IrBool);
         impl_number_ops!($vec, IrVecBool);
         impl_neg!($scalar);
         impl_neg!($vec);
-        impl_float!($scalar, $scalar_bits, IrU32);
-        impl_float!($vec, $vec_bits, IrVecU32);
+        impl_float!($scalar, $scalar_bits, $scalar_signed_bits);
+        impl_float!($vec, $vec_bits, $vec_signed_bits);
     };
 }
 
-impl_float_ops!(IrF16, IrU16, IrVecF16, IrVecU16);
-impl_float_ops!(IrF32, IrU32, IrVecF32, IrVecU32);
-impl_float_ops!(IrF64, IrU64, IrVecF64, IrVecU64);
+impl_float_ops!(IrF16, IrU16, IrI16, IrVecF16, IrVecU16, IrVecI16);
+impl_float_ops!(IrF32, IrU32, IrI32, IrVecF32, IrVecU32, IrVecI32);
+impl_float_ops!(IrF64, IrU64, IrI64, IrVecF64, IrVecU64, IrVecI64);
 
 ir_value!(
     IrBool,
@@ -1403,48 +1492,41 @@ ir_value!(
     }
 );
 
-macro_rules! impl_convert_to {
-    ($($src:ident -> [$($dest:ident),*];)*) => {
-        $($(
-            impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
-                fn to(self) -> $dest<'ctx> {
-                    let value = if $src::TYPE == $dest::TYPE {
-                        self.value
-                    } else {
-                        self
-                            .ctx
-                            .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
-                            .into()
-                    };
-                    $dest {
-                        value,
-                        ctx: self.ctx,
-                    }
+macro_rules! impl_convert_from {
+    ($src:ident -> $dest:ident) => {
+        impl<'ctx> ConvertFrom<$src<'ctx>> for $dest<'ctx> {
+            fn cvt_from(v: $src<'ctx>) -> Self {
+                let value = if $src::TYPE == $dest::TYPE {
+                    v.value
+                } else {
+                    v
+                        .ctx
+                        .make_operation(Opcode::Cast, [v.value], $dest::TYPE)
+                        .into()
+                };
+                $dest {
+                    value,
+                    ctx: v.ctx,
                 }
             }
-        )*)*
-    };
-    ([$($src:ident),*] -> $dest:tt;) => {
-        impl_convert_to! {
-            $(
-                $src -> $dest;
-            )*
         }
     };
-    ([$($src:ident),*];) => {
-        impl_convert_to! {
-            [$($src),*] -> [$($src),*];
-        }
+    ($first:ident $(, $ty:ident)*) => {
+        $(
+            impl_convert_from!($first -> $ty);
+            impl_convert_from!($ty -> $first);
+        )*
+        impl_convert_from![$($ty),*];
+    };
+    () => {
     };
 }
+impl_convert_from![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
 
-impl_convert_to! {
-    [IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
-}
-
-impl_convert_to! {
-    [IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64];
-}
+impl_convert_from![
+    IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64,
+    IrVecF32, IrVecF64
+];
 
 macro_rules! impl_from {
     ($src:ident => [$($dest:ident),*]) => {
@@ -1524,15 +1606,18 @@ impl<'ctx> Context for &'ctx IrContext<'ctx> {
     type U64 = IrU64<'ctx>;
     type I64 = IrI64<'ctx>;
     type F64 = IrF64<'ctx>;
-    type VecBool = IrVecBool<'ctx>;
+    type VecBool8 = IrVecBool<'ctx>;
     type VecU8 = IrVecU8<'ctx>;
     type VecI8 = IrVecI8<'ctx>;
+    type VecBool16 = IrVecBool<'ctx>;
     type VecU16 = IrVecU16<'ctx>;
     type VecI16 = IrVecI16<'ctx>;
     type VecF16 = IrVecF16<'ctx>;
+    type VecBool32 = IrVecBool<'ctx>;
     type VecU32 = IrVecU32<'ctx>;
     type VecI32 = IrVecI32<'ctx>;
     type VecF32 = IrVecF32<'ctx>;
+    type VecBool64 = IrVecBool<'ctx>;
     type VecU64 = IrVecU64<'ctx>;
     type VecI64 = IrVecI64<'ctx>;
     type VecF64 = IrVecF64<'ctx>;
@@ -1540,6 +1625,8 @@ impl<'ctx> Context for &'ctx IrContext<'ctx> {
 
 #[cfg(test)]
 mod tests {
+    use crate::algorithms;
+
     use super::*;
     use std::println;
 
@@ -1568,6 +1655,52 @@ function(in<arg_0>: vec<U8>, in<arg_1>: vec<F32>) -> vec<F64> {
     op_5: vec<F64> = Cast op_4
     Return op_5
 }
+"
+        );
+    }
+
+    #[test]
+    fn test_display_ilogb_f32() {
+        let ctx = IrContext::new();
+        fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
+            let f: fn(&'ctx IrContext<'ctx>, IrVecF32<'ctx>) -> IrVecI32<'ctx> =
+                algorithms::ilogb::ilogb_f32;
+            IrFunction::make(ctx, f)
+        }
+        let text = format!("\n{}", make_it(&ctx));
+        println!("{}", text);
+        assert_eq!(
+            text,
+            r"
+function(in<arg_0>: vec<F32>) -> vec<I32> {
+    op_0: vec<Bool> = IsFinite in<arg_0>
+    op_1: vec<U32> = ToBits in<arg_0>
+    op_2: vec<U32> = And op_1, splat(0x7F800000_u32)
+    op_3: vec<U32> = Shr op_2, splat(0x17_u32)
+    op_4: vec<Bool> = CompareEq op_3, splat(0x0_u32)
+    op_5: vec<Bool> = CompareNe in<arg_0>, in<arg_0>
+    op_6: vec<I32> = Splat 0x80000001_i32
+    op_7: vec<I32> = Splat 0x7FFFFFFF_i32
+    op_8: vec<I32> = Select op_5, op_6, op_7
+    op_9: vec<F32> = Mul in<arg_0>, splat(0x4B000000_f32)
+    op_10: vec<U32> = ToBits op_9
+    op_11: vec<U32> = And op_10, splat(0x7F800000_u32)
+    op_12: vec<U32> = Shr op_11, splat(0x17_u32)
+    op_13: vec<I32> = Cast op_12
+    op_14: vec<I32> = Sub op_13, splat(0x7F_i32)
+    op_15: vec<U32> = ToBits in<arg_0>
+    op_16: vec<U32> = And op_15, splat(0x7F800000_u32)
+    op_17: vec<U32> = Shr op_16, splat(0x17_u32)
+    op_18: vec<I32> = Cast op_17
+    op_19: vec<I32> = Sub op_18, splat(0x7F_i32)
+    op_20: vec<I32> = Select op_0, op_19, op_8
+    op_21: vec<Bool> = CompareEq in<arg_0>, splat(0x0_f32)
+    op_22: vec<I32> = Splat 0x80000000_i32
+    op_23: vec<I32> = Sub op_14, splat(0x17_i32)
+    op_24: vec<I32> = Select op_21, op_22, op_23
+    op_25: vec<I32> = Select op_4, op_24, op_20
+    Return op_25
+}
 "
         );
     }