From 734beaf5e96afe7abd0bee1913f4057519f604de Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 4 May 2021 22:55:36 -0700 Subject: [PATCH 1/1] switch to using separate VecBool8/16/32/64 --- Cargo.toml | 4 + src/f16.rs | 2 +- src/ieee754.rs | 27 +- src/ir.rs | 158 ++++++---- src/scalar.rs | 29 +- src/traits.rs | 481 +++--------------------------- vector-math-proc-macro/Cargo.toml | 14 + vector-math-proc-macro/src/lib.rs | 382 ++++++++++++++++++++++++ 8 files changed, 555 insertions(+), 542 deletions(-) create mode 100644 vector-math-proc-macro/Cargo.toml create mode 100644 vector-math-proc-macro/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 5f82460..858dbd9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ license = "MIT OR Apache-2.0" half = { version = "1.7.1", optional = true } typed-arena = { version = "2.0.1", optional = true } core_simd = { version = "0.1.0", git = "https://github.com/rust-lang/stdsimd", optional = true } +vector-math-proc-macro = { version = "=0.1.0", path = "vector-math-proc-macro" } [features] default = ["f16", "fma"] @@ -17,3 +18,6 @@ fma = ["std"] std = [] ir = ["std", "typed-arena"] stdsimd = ["core_simd"] + +[workspace] +members = [".", "vector-math-proc-macro"] diff --git a/src/f16.rs b/src/f16.rs index bc40c78..ee13d99 100644 --- a/src/f16.rs +++ b/src/f16.rs @@ -161,7 +161,7 @@ impl_bin_op_using_f32! { Rem, rem, RemAssign, rem_assign; } -impl Float for F16 { +impl Float for F16 { type FloatEncoding = F16; type BitsType = u16; type SignedBitsType = i16; diff --git a/src/ieee754.rs b/src/ieee754.rs index 0da587d..6f9fea7 100644 --- a/src/ieee754.rs +++ b/src/ieee754.rs @@ -14,16 +14,16 @@ mod sealed { } pub trait FloatEncoding: - sealed::Sealed + Copy + 'static + Send + Sync + Float + Make + sealed::Sealed + Copy + 'static + Send + Sync + Float + Make { const EXPONENT_BIAS_UNSIGNED: Self::BitsType; const EXPONENT_BIAS_SIGNED: Self::SignedBitsType; - const SIGN_FIELD_WIDTH: u32; - const EXPONENT_FIELD_WIDTH: u32; - const MANTISSA_FIELD_WIDTH: u32; - const SIGN_FIELD_SHIFT: u32; - const EXPONENT_FIELD_SHIFT: u32; - const MANTISSA_FIELD_SHIFT: u32; + const SIGN_FIELD_WIDTH: Self::BitsType; + const EXPONENT_FIELD_WIDTH: Self::BitsType; + const MANTISSA_FIELD_WIDTH: Self::BitsType; + const SIGN_FIELD_SHIFT: Self::BitsType; + const EXPONENT_FIELD_SHIFT: Self::BitsType; + const MANTISSA_FIELD_SHIFT: Self::BitsType; const SIGN_FIELD_MASK: Self::BitsType; const EXPONENT_FIELD_MASK: Self::BitsType; const MANTISSA_FIELD_MASK: Self::BitsType; @@ -45,12 +45,13 @@ macro_rules! impl_float_encoding { const EXPONENT_BIAS_UNSIGNED: Self::BitsType = (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1; const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _; - const SIGN_FIELD_WIDTH: u32 = 1; - const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width; - const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width; - const SIGN_FIELD_SHIFT: u32 = Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH; - const EXPONENT_FIELD_SHIFT: u32 = Self::MANTISSA_FIELD_WIDTH; - const MANTISSA_FIELD_SHIFT: u32 = 0; + const SIGN_FIELD_WIDTH: Self::BitsType = 1; + const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width; + const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width; + const SIGN_FIELD_SHIFT: Self::BitsType = + Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH; + const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH; + const MANTISSA_FIELD_SHIFT: Self::BitsType = 0; const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT; const EXPONENT_FIELD_MASK: Self::BitsType = ((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT; diff --git a/src/ir.rs b/src/ir.rs index f1de05a..a31bece 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -809,6 +809,23 @@ macro_rules! ir_value { } } + impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> { + fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> { + let value = self + .ctx + .make_operation( + Opcode::Select, + [self.value, true_v.value, false_v.value], + $vec_name::TYPE, + ) + .into(); + $vec_name { + value, + ctx: self.ctx, + } + } + } + impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> { fn from(v: $name<'ctx>) -> Self { let value = v @@ -1060,12 +1077,41 @@ macro_rules! impl_number_ops { }; } +macro_rules! impl_bool_compare { + ($ty:ident) => { + impl<'ctx> Compare for $ty<'ctx> { + type Bool = Self; + fn eq(self, rhs: Self) -> Self::Bool { + !(self ^ rhs) + } + fn ne(self, rhs: Self) -> Self::Bool { + self ^ rhs + } + fn lt(self, rhs: Self) -> Self::Bool { + !self & rhs + } + fn gt(self, rhs: Self) -> Self::Bool { + self & !rhs + } + fn le(self, rhs: Self) -> Self::Bool { + !self | rhs + } + fn ge(self, rhs: Self) -> Self::Bool { + self | !rhs + } + } + }; +} + +impl_bool_compare!(IrBool); +impl_bool_compare!(IrVecBool); + macro_rules! impl_shift_ops { - ($ty:ident, $rhs:ident) => { - impl<'ctx> Shl<$rhs<'ctx>> for $ty<'ctx> { + ($ty:ident) => { + impl<'ctx> Shl for $ty<'ctx> { type Output = Self; - fn shl(self, rhs: $rhs<'ctx>) -> Self::Output { + fn shl(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE) @@ -1076,10 +1122,10 @@ macro_rules! impl_shift_ops { } } } - impl<'ctx> Shr<$rhs<'ctx>> for $ty<'ctx> { + impl<'ctx> Shr for $ty<'ctx> { type Output = Self; - fn shr(self, rhs: $rhs<'ctx>) -> Self::Output { + fn shr(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE) @@ -1090,13 +1136,13 @@ macro_rules! impl_shift_ops { } } } - impl<'ctx> ShlAssign<$rhs<'ctx>> for $ty<'ctx> { - fn shl_assign(&mut self, rhs: $rhs<'ctx>) { + impl<'ctx> ShlAssign for $ty<'ctx> { + fn shl_assign(&mut self, rhs: Self) { *self = *self << rhs; } } - impl<'ctx> ShrAssign<$rhs<'ctx>> for $ty<'ctx> { - fn shr_assign(&mut self, rhs: $rhs<'ctx>) { + impl<'ctx> ShrAssign for $ty<'ctx> { + fn shr_assign(&mut self, rhs: Self) { *self = *self >> rhs; } } @@ -1123,8 +1169,8 @@ macro_rules! impl_neg { } macro_rules! impl_int_trait { - ($ty:ident, $u32:ident) => { - impl<'ctx> Int<$u32<'ctx>> for $ty<'ctx> { + ($ty:ident) => { + impl<'ctx> Int for $ty<'ctx> { fn leading_zeros(self) -> Self { let value = self .ctx @@ -1163,12 +1209,12 @@ macro_rules! impl_integer_ops { ($scalar:ident, $vec:ident) => { impl_bit_ops!($scalar); impl_number_ops!($scalar, IrBool); - impl_shift_ops!($scalar, IrU32); + impl_shift_ops!($scalar); impl_bit_ops!($vec); impl_number_ops!($vec, IrVecBool); - impl_shift_ops!($vec, IrVecU32); - impl_int_trait!($scalar, IrU32); - impl_int_trait!($vec, IrVecU32); + impl_shift_ops!($vec); + impl_int_trait!($scalar); + impl_int_trait!($vec); }; } @@ -1176,8 +1222,8 @@ macro_rules! impl_uint_ops { ($scalar:ident, $vec:ident) => { impl_integer_ops!($scalar, $vec); - impl<'ctx> UInt> for $scalar<'ctx> {} - impl<'ctx> UInt> for $vec<'ctx> {} + impl<'ctx> UInt for $scalar<'ctx> {} + impl<'ctx> UInt for $vec<'ctx> {} }; } @@ -1192,8 +1238,8 @@ macro_rules! impl_sint_ops { impl_neg!($scalar); impl_neg!($vec); - impl<'ctx> SInt> for $scalar<'ctx> {} - impl<'ctx> SInt> for $vec<'ctx> {} + impl<'ctx> SInt for $scalar<'ctx> {} + impl<'ctx> SInt for $vec<'ctx> {} }; } @@ -1203,8 +1249,8 @@ impl_sint_ops!(IrI32, IrVecI32); impl_sint_ops!(IrI64, IrVecI64); macro_rules! impl_float { - ($float:ident, $bits:ident, $signed_bits:ident, $u32:ident) => { - impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> { + ($float:ident, $bits:ident, $signed_bits:ident) => { + impl<'ctx> Float for $float<'ctx> { type FloatEncoding = <$float<'ctx> as Make>::Prim; type BitsType = $bits<'ctx>; type SignedBitsType = $signed_bits<'ctx>; @@ -1330,8 +1376,8 @@ macro_rules! impl_float_ops { impl_number_ops!($vec, IrVecBool); impl_neg!($scalar); impl_neg!($vec); - impl_float!($scalar, $scalar_bits, $scalar_signed_bits, IrU32); - impl_float!($vec, $vec_bits, $vec_signed_bits, IrVecU32); + impl_float!($scalar, $scalar_bits, $scalar_signed_bits); + impl_float!($vec, $vec_bits, $vec_signed_bits); }; } @@ -1444,47 +1490,40 @@ ir_value!( ); macro_rules! impl_convert_to { - ($($src:ident -> [$($dest:ident),*];)*) => { - $($( - impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> { - fn to(self) -> $dest<'ctx> { - let value = if $src::TYPE == $dest::TYPE { - self.value - } else { - self - .ctx - .make_operation(Opcode::Cast, [self.value], $dest::TYPE) - .into() - }; - $dest { - value, - ctx: self.ctx, - } + ($src:ident -> $dest:ident) => { + impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> { + fn to(self) -> $dest<'ctx> { + let value = if $src::TYPE == $dest::TYPE { + self.value + } else { + self + .ctx + .make_operation(Opcode::Cast, [self.value], $dest::TYPE) + .into() + }; + $dest { + value, + ctx: self.ctx, } } - )*)* - }; - ([$($src:ident),*] -> $dest:tt;) => { - impl_convert_to! { - $( - $src -> $dest; - )* } }; - ([$($src:ident),*];) => { - impl_convert_to! { - [$($src),*] -> [$($src),*]; - } + ($first:ident $(, $ty:ident)*) => { + $( + impl_convert_to!($first -> $ty); + impl_convert_to!($ty -> $first); + )* + impl_convert_to![$($ty),*]; + }; + () => { }; } +impl_convert_to![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64]; -impl_convert_to! { - [IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64]; -} - -impl_convert_to! { - [IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64]; -} +impl_convert_to![ + IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, + IrVecF32, IrVecF64 +]; macro_rules! impl_from { ($src:ident => [$($dest:ident),*]) => { @@ -1564,15 +1603,18 @@ impl<'ctx> Context for &'ctx IrContext<'ctx> { type U64 = IrU64<'ctx>; type I64 = IrI64<'ctx>; type F64 = IrF64<'ctx>; - type VecBool = IrVecBool<'ctx>; + type VecBool8 = IrVecBool<'ctx>; type VecU8 = IrVecU8<'ctx>; type VecI8 = IrVecI8<'ctx>; + type VecBool16 = IrVecBool<'ctx>; type VecU16 = IrVecU16<'ctx>; type VecI16 = IrVecI16<'ctx>; type VecF16 = IrVecF16<'ctx>; + type VecBool32 = IrVecBool<'ctx>; type VecU32 = IrVecU32<'ctx>; type VecI32 = IrVecI32<'ctx>; type VecF32 = IrVecF32<'ctx>; + type VecBool64 = IrVecBool<'ctx>; type VecU64 = IrVecU64<'ctx>; type VecI64 = IrVecI64<'ctx>; type VecF64 = IrVecF64<'ctx>; diff --git a/src/scalar.rs b/src/scalar.rs index d1e137d..fb83af6 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -34,53 +34,32 @@ macro_rules! impl_context { impl_context! { impl Context for Scalar { type Bool = bool; - type U8 = u8; - type I8 = i8; - type U16 = u16; - type I16 = i16; - type F16 = crate::f16::F16; - type U32 = u32; - type I32 = i32; - type F32 = f32; - type U64 = u64; - type I64 = i64; - type F64 = f64; - #[vec] - - type VecBool = bool; - + type VecBool8 = bool; type VecU8 = u8; - type VecI8 = i8; - + type VecBool16 = bool; type VecU16 = u16; - type VecI16 = i16; - type VecF16 = crate::f16::F16; - + type VecBool32 = bool; type VecU32 = u32; - type VecI32 = i32; - type VecF32 = f32; - + type VecBool64 = bool; type VecU64 = u64; - type VecI64 = i64; - type VecF64 = f64; } } diff --git a/src/traits.rs b/src/traits.rs index 942c67f..ffc766b 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,414 +1,12 @@ +use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar}; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; -use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar}; - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_float_type { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [ - $({ - #[uint] - $uint_smaller:ident; - #[int] - $int_smaller:ident; - $( - #[float] - $float_smaller:ident; - )? - },)* - ], - { - #[uint] - $uint:ident; - #[int] - $int:ident; - #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)] - $float:ident; - }, - [ - $({ - #[uint] - $uint_larger:ident; - #[int] - $int_larger:ident; - $( - #[float] - $float_larger:ident; - )? - },)* - ] - ) => { - type $float: Float - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - + ConvertTo - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ Into + ConvertTo)?)*; - }; - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - { - #[uint] - $uint:ident; - #[int] - $int:ident; - }, - [$($larger:tt,)*] - ) => {}; -} - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_uint_int_float_type { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [ - $({ - #[uint($($uint_smaller_traits:tt)*)] - $uint_smaller:ident; - #[int($($int_smaller_traits:tt)*)] - $int_smaller:ident; - $( - #[float($($float_smaller_traits:tt)*)] - $float_smaller:ident; - )? - },)* - ], - { - #[uint(prim = $uint_prim:ident $(, scalar = $uint_scalar:ident)?)] - $uint:ident; - #[int(prim = $int_prim:ident $(, scalar = $int_scalar:ident)?)] - $int:ident; - $( - #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)] - $float:ident; - )? - }, - [ - $({ - #[uint($($uint_larger_traits:tt)*)] - $uint_larger:ident; - #[int($($int_larger_traits:tt)*)] - $int_larger:ident; - $( - #[float($($float_larger_traits:tt)*)] - $float_larger:ident; - )? - },)* - ] - ) => { - type $uint: UInt - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - $(+ ConvertTo)? - $(+ Into + ConvertTo)* - $(+ Into + ConvertTo)* - $($(+ Into + ConvertTo)?)*; - type $int: SInt - $(+ From)? - + Compare - + Make - $(+ ConvertTo)* - $(+ ConvertTo)* - $($(+ ConvertTo)?)* - + ConvertTo - $(+ ConvertTo)? - $(+ ConvertTo)* - $(+ Into + ConvertTo)* - $($(+ Into + ConvertTo)?)*; - make_float_type! { - #[u32 = $u32] - #[bool = $bool] - [ - $({ - #[uint] - $uint_smaller; - #[int] - $int_smaller; - $( - #[float] - $float_smaller; - )? - },)* - ], - { - #[uint] - $uint; - #[int] - $int; - $( - #[float(prim = $float_prim $(, scalar = $float_scalar)?)] - $float; - )? - }, - [ - $({ - #[uint] - $uint_larger; - #[int] - $int_larger; - $( - #[float] - $float_larger; - )? - },)* - ] - } - }; -} - -macro_rules! make_uint_int_float_types { - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - $current:tt, - [$first_larger:tt, $($larger:tt,)*] - ) => { - make_uint_int_float_type! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)*], - $current, - [$first_larger, $($larger,)*] - } - make_uint_int_float_types! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)* $current,], - $first_larger, - [$($larger,)*] - } - }; - ( - #[u32 = $u32:ident] - #[bool = $bool:ident] - [$($smaller:tt,)*], - $current:tt, - [] - ) => { - make_uint_int_float_type! { - #[u32 = $u32] - #[bool = $bool] - [$($smaller,)*], - $current, - [] - } - }; -} - -#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823 -macro_rules! make_types { - ( - #[bool] - $(#[scalar = $ScalarBool:ident])? - type $Bool:ident; - - #[u8] - $(#[scalar = $ScalarU8:ident])? - type $U8:ident; - - #[u16] - $(#[scalar = $ScalarU16:ident])? - type $U16:ident; - - #[u32] - $(#[scalar = $ScalarU32:ident])? - type $U32:ident; - - #[u64] - $(#[scalar = $ScalarU64:ident])? - type $U64:ident; - - #[i8] - $(#[scalar = $ScalarI8:ident])? - type $I8:ident; - - #[i16] - $(#[scalar = $ScalarI16:ident])? - type $I16:ident; - - #[i32] - $(#[scalar = $ScalarI32:ident])? - type $I32:ident; - - #[i64] - $(#[scalar = $ScalarI64:ident])? - type $I64:ident; - - #[f16] - $(#[scalar = $ScalarF16:ident])? - type $F16:ident; - - #[f32] - $(#[scalar = $ScalarF32:ident])? - type $F32:ident; - - #[f64] - $(#[scalar = $ScalarF64:ident])? - type $F64:ident; - ) => { - type $Bool: Bool - $(+ From)? - + Make - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select - + Select; - make_uint_int_float_types! { - #[u32 = $U32] - #[bool = $Bool] - [], - { - #[uint(prim = u8 $(, scalar = $ScalarU8)?)] - $U8; - #[int(prim = i8 $(, scalar = $ScalarI8)?)] - $I8; - }, - [ - { - #[uint(prim = u16 $(, scalar = $ScalarU16)?)] - $U16; - #[int(prim = i16 $(, scalar = $ScalarI16)?)] - $I16; - #[float(prim = F16 $(, scalar = $ScalarF16)?)] - $F16; - }, - { - #[uint(prim = u32 $(, scalar = $ScalarU32)?)] - $U32; - #[int(prim = i32 $(, scalar = $ScalarI32)?)] - $I32; - #[float(prim = f32 $(, scalar = $ScalarF32)?)] - $F32; - }, - { - #[uint(prim = u64 $(, scalar = $ScalarU64)?)] - $U64; - #[int(prim = i64 $(, scalar = $ScalarI64)?)] - $I64; - #[float(prim = f64 $(, scalar = $ScalarF64)?)] - $F64; - }, - ] - } - }; -} - /// reference used to build IR for Kazan; an empty type for `core::simd` pub trait Context: Copy { - make_types! { - #[bool] - type Bool; - - #[u8] - type U8; - - #[u16] - type U16; - - #[u32] - type U32; - - #[u64] - type U64; - - #[i8] - type I8; - - #[i16] - type I16; - - #[i32] - type I32; - - #[i64] - type I64; - - #[f16] - type F16; - - #[f32] - type F32; - - #[f64] - type F64; - } - make_types! { - #[bool] - #[scalar = Bool] - type VecBool; - - #[u8] - #[scalar = U8] - type VecU8; - - #[u16] - #[scalar = U16] - type VecU16; - - #[u32] - #[scalar = U32] - type VecU32; - - #[u64] - #[scalar = U64] - type VecU64; - - #[i8] - #[scalar = I8] - type VecI8; - - #[i16] - #[scalar = I16] - type VecI16; - - #[i32] - #[scalar = I32] - type VecI32; - - #[i64] - #[scalar = I64] - type VecI64; - - #[f16] - #[scalar = F16] - type VecF16; - - #[f32] - #[scalar = F32] - type VecF32; - - #[f64] - #[scalar = F64] - type VecF64; - } + vector_math_proc_macro::make_context_types!(); fn make>(self, v: T::Prim) -> T { T::make(self, v) } @@ -425,33 +23,33 @@ pub trait ConvertTo { fn to(self) -> T; } +impl ConvertTo for T { + fn to(self) -> T { + self + } +} + macro_rules! impl_convert_to_using_as { - ($($src:ident -> [$($dest:ident),*];)*) => { - $($( - impl ConvertTo<$dest> for $src { - fn to(self) -> $dest { - self as $dest + ($first:ident $(, $ty:ident)*) => { + $( + impl ConvertTo<$first> for $ty { + fn to(self) -> $first { + self as $first } } - )*)* - }; - ([$($src:ident),*] -> $dest:tt;) => { - impl_convert_to_using_as! { - $( - $src -> $dest; - )* - } + impl ConvertTo<$ty> for $first { + fn to(self) -> $ty { + self as $ty + } + } + )* + impl_convert_to_using_as![$($ty),*]; }; - ([$($src:ident),*];) => { - impl_convert_to_using_as! { - [$($src),*] -> [$($src),*]; - } + () => { }; } -impl_convert_to_using_as! { - [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; -} +impl_convert_to_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; pub trait Number: Compare @@ -507,13 +105,8 @@ impl BitOps for T where { } -pub trait Int: - Number - + BitOps - + Shl - + Shr - + ShlAssign - + ShrAssign +pub trait Int: + Number + BitOps + Shl + Shr + ShlAssign + ShrAssign { fn leading_zeros(self) -> Self; fn leading_ones(self) -> Self { @@ -529,13 +122,13 @@ pub trait Int: fn count_ones(self) -> Self; } -pub trait UInt: Int {} +pub trait UInt: Int {} -pub trait SInt: Int + Neg {} +pub trait SInt: Int + Neg {} macro_rules! impl_int { ($ty:ident) => { - impl Int for $ty { + impl Int for $ty { fn leading_zeros(self) -> Self { self.leading_zeros() as Self } @@ -562,7 +155,7 @@ macro_rules! impl_uint { ($($ty:ident),*) => { $( impl_int!($ty); - impl UInt for $ty {} + impl UInt for $ty {} )* }; } @@ -573,23 +166,21 @@ macro_rules! impl_sint { ($($ty:ident),*) => { $( impl_int!($ty); - impl SInt for $ty {} + impl SInt for $ty {} )* }; } impl_sint![i8, i16, i32, i64]; -pub trait Float>: - Number + Neg -{ +pub trait Float: Number + Neg { type FloatEncoding: FloatEncoding + Make::Prim>; - type BitsType: UInt - + Make>::BitsType> + type BitsType: UInt + + Make::BitsType> + ConvertTo + Compare; - type SignedBitsType: SInt - + Make>::SignedBitsType> + type SignedBitsType: SInt + + Make::SignedBitsType> + ConvertTo + Compare; fn abs(self) -> Self; @@ -648,7 +239,7 @@ pub trait Float>: macro_rules! impl_float { ($ty:ty, $bits:ty, $signed_bits:ty) => { - impl Float for $ty { + impl Float for $ty { type FloatEncoding = $ty; type BitsType = $bits; type SignedBitsType = $signed_bits; @@ -763,4 +354,4 @@ macro_rules! impl_compare_using_partial_cmp { }; } -impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64]; +impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64]; diff --git a/vector-math-proc-macro/Cargo.toml b/vector-math-proc-macro/Cargo.toml new file mode 100644 index 0000000..f65c085 --- /dev/null +++ b/vector-math-proc-macro/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "vector-math-proc-macro" +version = "0.1.0" +authors = ["Jacob Lifshay "] +edition = "2018" +license = "MIT OR Apache-2.0" + +[lib] +proc-macro = true + +[dependencies] +quote = "1.0" +proc-macro2 = "1.0" +syn = { version = "1.0", features = [] } diff --git a/vector-math-proc-macro/src/lib.rs b/vector-math-proc-macro/src/lib.rs new file mode 100644 index 0000000..1f5797a --- /dev/null +++ b/vector-math-proc-macro/src/lib.rs @@ -0,0 +1,382 @@ +use std::{ + cmp::Ordering, + collections::{BTreeSet, HashMap}, + hash::Hash, +}; + +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, +}; + +struct Input {} + +impl Parse for Input { + fn parse(_input: ParseStream) -> syn::Result { + Ok(Input {}) + } +} + +macro_rules! make_enum { + ( + $vis:vis enum $ty:ident { + $( + $field:ident $(= $value:expr)?, + )* + } + ) => { + #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + #[repr(u8)] + $vis enum $ty { + $( + $field $(= $value)?, + )* + } + + impl $ty { + #[allow(dead_code)] + $vis const VALUES: &'static [Self] = &[ + $( + Self::$field, + )* + ]; + } + }; +} + +make_enum! { + enum TypeKind { + Bool, + UInt, + SInt, + Float, + } +} + +make_enum! { + enum VectorScalar { + Scalar, + Vector, + } +} + +make_enum! { + enum TypeBits { + Bits8 = 8, + Bits16 = 16, + Bits32 = 32, + Bits64 = 64, + } +} + +impl TypeBits { + const fn bits(self) -> u32 { + self as u8 as u32 + } +} + +make_enum! { + enum Convertibility { + Impossible, + Lossy, + Lossless, + } +} + +impl Convertibility { + const fn make_possible(lossless: bool) -> Self { + if lossless { + Self::Lossless + } else { + Self::Lossy + } + } + const fn make_non_lossy(possible: bool) -> Self { + if possible { + Self::Lossless + } else { + Self::Impossible + } + } + const fn possible(self) -> bool { + match self { + Convertibility::Impossible => false, + Convertibility::Lossy | Convertibility::Lossless => true, + } + } +} + +impl TypeKind { + fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool { + match self { + TypeKind::Float => bits >= TypeBits::Bits16, + TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector, + TypeKind::UInt | TypeKind::SInt => true, + } + } + fn prim_ty(self, bits: TypeBits) -> Ident { + Ident::new( + &match self { + TypeKind::Bool => "bool".into(), + TypeKind::UInt => format!("u{}", bits.bits()), + TypeKind::SInt => format!("i{}", bits.bits()), + TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(), + TypeKind::Float => format!("f{}", bits.bits()), + }, + Span::call_site(), + ) + } + fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident { + let vec_prefix = match vector_scalar { + VectorScalar::Scalar => "", + VectorScalar::Vector => "Vec", + }; + Ident::new( + &match self { + TypeKind::Bool => match vector_scalar { + VectorScalar::Scalar => "Bool".into(), + VectorScalar::Vector => format!("VecBool{}", bits.bits()), + }, + TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()), + TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()), + TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()), + }, + Span::call_site(), + ) + } + fn convertibility_to( + self, + src_bits: TypeBits, + dest_type_kind: TypeKind, + dest_bits: TypeBits, + ) -> Convertibility { + Convertibility::make_possible(match (self, dest_type_kind) { + (TypeKind::Bool, _) | (_, TypeKind::Bool) => { + return Convertibility::make_non_lossy(self == dest_type_kind); + } + (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits, + (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits, + (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits, + (TypeKind::SInt, TypeKind::UInt) => false, + (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits, + (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits, + (TypeKind::Float, TypeKind::UInt) => false, + (TypeKind::Float, TypeKind::SInt) => false, + (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits, + }) + } +} + +#[derive(Default, Debug)] +struct TokenStreamSetElement { + token_stream: TokenStream, + text: String, +} + +impl Ord for TokenStreamSetElement { + fn cmp(&self, other: &Self) -> Ordering { + self.text.cmp(&other.text) + } +} + +impl PartialOrd for TokenStreamSetElement { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for TokenStreamSetElement { + fn eq(&self, other: &Self) -> bool { + self.text == other.text + } +} + +impl Eq for TokenStreamSetElement {} + +impl From for TokenStreamSetElement { + fn from(token_stream: TokenStream) -> Self { + let text = token_stream.to_string(); + Self { token_stream, text } + } +} + +impl ToTokens for TokenStreamSetElement { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.token_stream.to_tokens(tokens) + } + + fn to_token_stream(&self) -> TokenStream { + self.token_stream.to_token_stream() + } + + fn into_token_stream(self) -> TokenStream { + self.token_stream + } +} + +type TokenStreamSet = BTreeSet; + +#[derive(Debug, Default)] +struct TraitSets { + trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>, +} + +impl TraitSets { + fn get( + &mut self, + type_kind: TypeKind, + mut bits: TypeBits, + vector_scalar: VectorScalar, + ) -> &mut TokenStreamSet { + if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar { + bits = TypeBits::Bits8; + } + self.trait_sets_map + .entry((type_kind, bits, vector_scalar)) + .or_default() + } + fn add_trait( + &mut self, + type_kind: TypeKind, + bits: TypeBits, + vector_scalar: VectorScalar, + v: impl Into, + ) { + self.get(type_kind, bits, vector_scalar).insert(v.into()); + } + fn fill(&mut self) { + for &bits in TypeBits::VALUES { + for &type_kind in TypeKind::VALUES { + for &vector_scalar in VectorScalar::VALUES { + if !type_kind.is_valid(bits, vector_scalar) { + continue; + } + let prim_ty = type_kind.prim_ty(bits); + let ty = type_kind.ty(bits, vector_scalar); + if vector_scalar == VectorScalar::Vector { + let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar); + self.add_trait( + type_kind, + bits, + vector_scalar, + quote! { From }, + ); + } + let bool_ty = TypeKind::Bool.ty(bits, vector_scalar); + let uint_ty = TypeKind::UInt.ty(bits, vector_scalar); + let sint_ty = TypeKind::SInt.ty(bits, vector_scalar); + let type_trait = match type_kind { + TypeKind::Bool => quote! { Bool }, + TypeKind::UInt => quote! { UInt }, + TypeKind::SInt => quote! { SInt }, + TypeKind::Float => quote! { Float< + BitsType = Self::#uint_ty, + SignedBitsType = Self::#sint_ty, + FloatEncoding = #prim_ty, + > }, + }; + self.add_trait(type_kind, bits, vector_scalar, type_trait); + self.add_trait( + type_kind, + bits, + vector_scalar, + quote! { Compare }, + ); + self.add_trait( + TypeKind::Bool, + bits, + vector_scalar, + quote! { Select }, + ); + self.add_trait( + TypeKind::Bool, + TypeBits::Bits8, + VectorScalar::Scalar, + quote! { Select }, + ); + for &other_bits in TypeBits::VALUES { + for &other_type_kind in TypeKind::VALUES { + if !other_type_kind.is_valid(other_bits, vector_scalar) { + continue; + } + if other_bits == bits && other_type_kind == type_kind { + continue; + } + let other_ty = other_type_kind.ty(other_bits, vector_scalar); + let convertibility = + other_type_kind.convertibility_to(other_bits, type_kind, bits); + if convertibility == Convertibility::Lossless { + self.add_trait( + type_kind, + bits, + vector_scalar, + quote! { From }, + ); + } + if convertibility.possible() { + self.add_trait( + other_type_kind, + other_bits, + vector_scalar, + quote! { ConvertTo }, + ); + } + } + } + self.add_trait( + type_kind, + bits, + vector_scalar, + quote! { Make }, + ); + } + } + } + } +} + +impl Input { + fn to_tokens(&self) -> syn::Result { + let mut types = Vec::new(); + let mut trait_sets = TraitSets::default(); + trait_sets.fill(); + for &bits in TypeBits::VALUES { + for &type_kind in TypeKind::VALUES { + for &vector_scalar in VectorScalar::VALUES { + if !type_kind.is_valid(bits, vector_scalar) { + continue; + } + let ty = type_kind.ty(bits, vector_scalar); + let traits = trait_sets.get(type_kind, bits, vector_scalar); + types.push(quote! { + type #ty: #(#traits)+*; + }); + } + } + } + Ok(quote! {#(#types)*}) + } +} + +#[proc_macro] +pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as Input); + match input.to_tokens() { + Ok(retval) => retval, + Err(err) => err.to_compile_error(), + } + .into() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() -> syn::Result<()> { + Input {}.to_tokens()?; + Ok(()) + } +} -- 2.30.2