half = { version = "1.7.1", optional = true }
typed-arena = { version = "2.0.1", optional = true }
core_simd = { version = "0.1.0", git = "https://github.com/rust-lang/stdsimd", optional = true }
+vector-math-proc-macro = { version = "=0.1.0", path = "vector-math-proc-macro" }
[features]
default = ["f16", "fma"]
std = []
ir = ["std", "typed-arena"]
stdsimd = ["core_simd"]
+
+[workspace]
+members = [".", "vector-math-proc-macro"]
Rem, rem, RemAssign, rem_assign;
}
-impl Float<u32> for F16 {
+impl Float for F16 {
type FloatEncoding = F16;
type BitsType = u16;
type SignedBitsType = i16;
}
pub trait FloatEncoding:
- sealed::Sealed + Copy + 'static + Send + Sync + Float<u32> + Make<Context = Scalar>
+ sealed::Sealed + Copy + 'static + Send + Sync + Float + Make<Context = Scalar>
{
const EXPONENT_BIAS_UNSIGNED: Self::BitsType;
const EXPONENT_BIAS_SIGNED: Self::SignedBitsType;
- const SIGN_FIELD_WIDTH: u32;
- const EXPONENT_FIELD_WIDTH: u32;
- const MANTISSA_FIELD_WIDTH: u32;
- const SIGN_FIELD_SHIFT: u32;
- const EXPONENT_FIELD_SHIFT: u32;
- const MANTISSA_FIELD_SHIFT: u32;
+ const SIGN_FIELD_WIDTH: Self::BitsType;
+ const EXPONENT_FIELD_WIDTH: Self::BitsType;
+ const MANTISSA_FIELD_WIDTH: Self::BitsType;
+ const SIGN_FIELD_SHIFT: Self::BitsType;
+ const EXPONENT_FIELD_SHIFT: Self::BitsType;
+ const MANTISSA_FIELD_SHIFT: Self::BitsType;
const SIGN_FIELD_MASK: Self::BitsType;
const EXPONENT_FIELD_MASK: Self::BitsType;
const MANTISSA_FIELD_MASK: Self::BitsType;
const EXPONENT_BIAS_UNSIGNED: Self::BitsType =
(1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1;
const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _;
- const SIGN_FIELD_WIDTH: u32 = 1;
- const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width;
- const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width;
- const SIGN_FIELD_SHIFT: u32 = Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
- const EXPONENT_FIELD_SHIFT: u32 = Self::MANTISSA_FIELD_WIDTH;
- const MANTISSA_FIELD_SHIFT: u32 = 0;
+ const SIGN_FIELD_WIDTH: Self::BitsType = 1;
+ const EXPONENT_FIELD_WIDTH: Self::BitsType = $exponent_field_width;
+ const MANTISSA_FIELD_WIDTH: Self::BitsType = $mantissa_field_width;
+ const SIGN_FIELD_SHIFT: Self::BitsType =
+ Self::EXPONENT_FIELD_SHIFT + Self::EXPONENT_FIELD_WIDTH;
+ const EXPONENT_FIELD_SHIFT: Self::BitsType = Self::MANTISSA_FIELD_WIDTH;
+ const MANTISSA_FIELD_SHIFT: Self::BitsType = 0;
const SIGN_FIELD_MASK: Self::BitsType = 1 << Self::SIGN_FIELD_SHIFT;
const EXPONENT_FIELD_MASK: Self::BitsType =
((1 << Self::EXPONENT_FIELD_WIDTH) - 1) << Self::EXPONENT_FIELD_SHIFT;
}
}
+ impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> {
+ fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
+ let value = self
+ .ctx
+ .make_operation(
+ Opcode::Select,
+ [self.value, true_v.value, false_v.value],
+ $vec_name::TYPE,
+ )
+ .into();
+ $vec_name {
+ value,
+ ctx: self.ctx,
+ }
+ }
+ }
+
impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
fn from(v: $name<'ctx>) -> Self {
let value = v
};
}
+macro_rules! impl_bool_compare {
+ ($ty:ident) => {
+ impl<'ctx> Compare for $ty<'ctx> {
+ type Bool = Self;
+ fn eq(self, rhs: Self) -> Self::Bool {
+ !(self ^ rhs)
+ }
+ fn ne(self, rhs: Self) -> Self::Bool {
+ self ^ rhs
+ }
+ fn lt(self, rhs: Self) -> Self::Bool {
+ !self & rhs
+ }
+ fn gt(self, rhs: Self) -> Self::Bool {
+ self & !rhs
+ }
+ fn le(self, rhs: Self) -> Self::Bool {
+ !self | rhs
+ }
+ fn ge(self, rhs: Self) -> Self::Bool {
+ self | !rhs
+ }
+ }
+ };
+}
+
+impl_bool_compare!(IrBool);
+impl_bool_compare!(IrVecBool);
+
macro_rules! impl_shift_ops {
- ($ty:ident, $rhs:ident) => {
- impl<'ctx> Shl<$rhs<'ctx>> for $ty<'ctx> {
+ ($ty:ident) => {
+ impl<'ctx> Shl for $ty<'ctx> {
type Output = Self;
- fn shl(self, rhs: $rhs<'ctx>) -> Self::Output {
+ fn shl(self, rhs: Self) -> Self::Output {
let value = self
.ctx
.make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
}
}
}
- impl<'ctx> Shr<$rhs<'ctx>> for $ty<'ctx> {
+ impl<'ctx> Shr for $ty<'ctx> {
type Output = Self;
- fn shr(self, rhs: $rhs<'ctx>) -> Self::Output {
+ fn shr(self, rhs: Self) -> Self::Output {
let value = self
.ctx
.make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
}
}
}
- impl<'ctx> ShlAssign<$rhs<'ctx>> for $ty<'ctx> {
- fn shl_assign(&mut self, rhs: $rhs<'ctx>) {
+ impl<'ctx> ShlAssign for $ty<'ctx> {
+ fn shl_assign(&mut self, rhs: Self) {
*self = *self << rhs;
}
}
- impl<'ctx> ShrAssign<$rhs<'ctx>> for $ty<'ctx> {
- fn shr_assign(&mut self, rhs: $rhs<'ctx>) {
+ impl<'ctx> ShrAssign for $ty<'ctx> {
+ fn shr_assign(&mut self, rhs: Self) {
*self = *self >> rhs;
}
}
}
macro_rules! impl_int_trait {
- ($ty:ident, $u32:ident) => {
- impl<'ctx> Int<$u32<'ctx>> for $ty<'ctx> {
+ ($ty:ident) => {
+ impl<'ctx> Int for $ty<'ctx> {
fn leading_zeros(self) -> Self {
let value = self
.ctx
($scalar:ident, $vec:ident) => {
impl_bit_ops!($scalar);
impl_number_ops!($scalar, IrBool);
- impl_shift_ops!($scalar, IrU32);
+ impl_shift_ops!($scalar);
impl_bit_ops!($vec);
impl_number_ops!($vec, IrVecBool);
- impl_shift_ops!($vec, IrVecU32);
- impl_int_trait!($scalar, IrU32);
- impl_int_trait!($vec, IrVecU32);
+ impl_shift_ops!($vec);
+ impl_int_trait!($scalar);
+ impl_int_trait!($vec);
};
}
($scalar:ident, $vec:ident) => {
impl_integer_ops!($scalar, $vec);
- impl<'ctx> UInt<IrU32<'ctx>> for $scalar<'ctx> {}
- impl<'ctx> UInt<IrVecU32<'ctx>> for $vec<'ctx> {}
+ impl<'ctx> UInt for $scalar<'ctx> {}
+ impl<'ctx> UInt for $vec<'ctx> {}
};
}
impl_neg!($scalar);
impl_neg!($vec);
- impl<'ctx> SInt<IrU32<'ctx>> for $scalar<'ctx> {}
- impl<'ctx> SInt<IrVecU32<'ctx>> for $vec<'ctx> {}
+ impl<'ctx> SInt for $scalar<'ctx> {}
+ impl<'ctx> SInt for $vec<'ctx> {}
};
}
impl_sint_ops!(IrI64, IrVecI64);
macro_rules! impl_float {
- ($float:ident, $bits:ident, $signed_bits:ident, $u32:ident) => {
- impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> {
+ ($float:ident, $bits:ident, $signed_bits:ident) => {
+ impl<'ctx> Float for $float<'ctx> {
type FloatEncoding = <$float<'ctx> as Make>::Prim;
type BitsType = $bits<'ctx>;
type SignedBitsType = $signed_bits<'ctx>;
impl_number_ops!($vec, IrVecBool);
impl_neg!($scalar);
impl_neg!($vec);
- impl_float!($scalar, $scalar_bits, $scalar_signed_bits, IrU32);
- impl_float!($vec, $vec_bits, $vec_signed_bits, IrVecU32);
+ impl_float!($scalar, $scalar_bits, $scalar_signed_bits);
+ impl_float!($vec, $vec_bits, $vec_signed_bits);
};
}
);
macro_rules! impl_convert_to {
- ($($src:ident -> [$($dest:ident),*];)*) => {
- $($(
- impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
- fn to(self) -> $dest<'ctx> {
- let value = if $src::TYPE == $dest::TYPE {
- self.value
- } else {
- self
- .ctx
- .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
- .into()
- };
- $dest {
- value,
- ctx: self.ctx,
- }
+ ($src:ident -> $dest:ident) => {
+ impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
+ fn to(self) -> $dest<'ctx> {
+ let value = if $src::TYPE == $dest::TYPE {
+ self.value
+ } else {
+ self
+ .ctx
+ .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
+ .into()
+ };
+ $dest {
+ value,
+ ctx: self.ctx,
}
}
- )*)*
- };
- ([$($src:ident),*] -> $dest:tt;) => {
- impl_convert_to! {
- $(
- $src -> $dest;
- )*
}
};
- ([$($src:ident),*];) => {
- impl_convert_to! {
- [$($src),*] -> [$($src),*];
- }
+ ($first:ident $(, $ty:ident)*) => {
+ $(
+ impl_convert_to!($first -> $ty);
+ impl_convert_to!($ty -> $first);
+ )*
+ impl_convert_to![$($ty),*];
+ };
+ () => {
};
}
+impl_convert_to![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
-impl_convert_to! {
- [IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
-}
-
-impl_convert_to! {
- [IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64];
-}
+impl_convert_to![
+ IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64,
+ IrVecF32, IrVecF64
+];
macro_rules! impl_from {
($src:ident => [$($dest:ident),*]) => {
type U64 = IrU64<'ctx>;
type I64 = IrI64<'ctx>;
type F64 = IrF64<'ctx>;
- type VecBool = IrVecBool<'ctx>;
+ type VecBool8 = IrVecBool<'ctx>;
type VecU8 = IrVecU8<'ctx>;
type VecI8 = IrVecI8<'ctx>;
+ type VecBool16 = IrVecBool<'ctx>;
type VecU16 = IrVecU16<'ctx>;
type VecI16 = IrVecI16<'ctx>;
type VecF16 = IrVecF16<'ctx>;
+ type VecBool32 = IrVecBool<'ctx>;
type VecU32 = IrVecU32<'ctx>;
type VecI32 = IrVecI32<'ctx>;
type VecF32 = IrVecF32<'ctx>;
+ type VecBool64 = IrVecBool<'ctx>;
type VecU64 = IrVecU64<'ctx>;
type VecI64 = IrVecI64<'ctx>;
type VecF64 = IrVecF64<'ctx>;
impl_context! {
impl Context for Scalar {
type Bool = bool;
-
type U8 = u8;
-
type I8 = i8;
-
type U16 = u16;
-
type I16 = i16;
-
type F16 = crate::f16::F16;
-
type U32 = u32;
-
type I32 = i32;
-
type F32 = f32;
-
type U64 = u64;
-
type I64 = i64;
-
type F64 = f64;
-
#[vec]
-
- type VecBool = bool;
-
+ type VecBool8 = bool;
type VecU8 = u8;
-
type VecI8 = i8;
-
+ type VecBool16 = bool;
type VecU16 = u16;
-
type VecI16 = i16;
-
type VecF16 = crate::f16::F16;
-
+ type VecBool32 = bool;
type VecU32 = u32;
-
type VecI32 = i32;
-
type VecF32 = f32;
-
+ type VecBool64 = bool;
type VecU64 = u64;
-
type VecI64 = i64;
-
type VecF64 = f64;
}
}
+use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar};
use core::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
-use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar};
-
-#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
-macro_rules! make_float_type {
- (
- #[u32 = $u32:ident]
- #[bool = $bool:ident]
- [
- $({
- #[uint]
- $uint_smaller:ident;
- #[int]
- $int_smaller:ident;
- $(
- #[float]
- $float_smaller:ident;
- )?
- },)*
- ],
- {
- #[uint]
- $uint:ident;
- #[int]
- $int:ident;
- #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)]
- $float:ident;
- },
- [
- $({
- #[uint]
- $uint_larger:ident;
- #[int]
- $int_larger:ident;
- $(
- #[float]
- $float_larger:ident;
- )?
- },)*
- ]
- ) => {
- type $float: Float<Self::$u32, BitsType = Self::$uint, SignedBitsType = Self::$int, FloatEncoding = $float_prim>
- $(+ From<Self::$float_scalar>)?
- + Compare<Bool = Self::$bool>
- + Make<Context = Self, Prim = $float_prim>
- $(+ ConvertTo<Self::$uint_smaller>)*
- $(+ ConvertTo<Self::$int_smaller>)*
- $($(+ ConvertTo<Self::$float_smaller>)?)*
- + ConvertTo<Self::$uint>
- + ConvertTo<Self::$int>
- $(+ ConvertTo<Self::$uint_larger>)*
- $(+ ConvertTo<Self::$int_larger>)*
- $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
- };
- (
- #[u32 = $u32:ident]
- #[bool = $bool:ident]
- [$($smaller:tt,)*],
- {
- #[uint]
- $uint:ident;
- #[int]
- $int:ident;
- },
- [$($larger:tt,)*]
- ) => {};
-}
-
-#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
-macro_rules! make_uint_int_float_type {
- (
- #[u32 = $u32:ident]
- #[bool = $bool:ident]
- [
- $({
- #[uint($($uint_smaller_traits:tt)*)]
- $uint_smaller:ident;
- #[int($($int_smaller_traits:tt)*)]
- $int_smaller:ident;
- $(
- #[float($($float_smaller_traits:tt)*)]
- $float_smaller:ident;
- )?
- },)*
- ],
- {
- #[uint(prim = $uint_prim:ident $(, scalar = $uint_scalar:ident)?)]
- $uint:ident;
- #[int(prim = $int_prim:ident $(, scalar = $int_scalar:ident)?)]
- $int:ident;
- $(
- #[float(prim = $float_prim:ident $(, scalar = $float_scalar:ident)?)]
- $float:ident;
- )?
- },
- [
- $({
- #[uint($($uint_larger_traits:tt)*)]
- $uint_larger:ident;
- #[int($($int_larger_traits:tt)*)]
- $int_larger:ident;
- $(
- #[float($($float_larger_traits:tt)*)]
- $float_larger:ident;
- )?
- },)*
- ]
- ) => {
- type $uint: UInt<Self::$u32>
- $(+ From<Self::$uint_scalar>)?
- + Compare<Bool = Self::$bool>
- + Make<Context = Self, Prim = $uint_prim>
- $(+ ConvertTo<Self::$uint_smaller>)*
- $(+ ConvertTo<Self::$int_smaller>)*
- $($(+ ConvertTo<Self::$float_smaller>)?)*
- + ConvertTo<Self::$int>
- $(+ ConvertTo<Self::$float>)?
- $(+ Into<Self::$uint_larger> + ConvertTo<Self::$uint_larger>)*
- $(+ Into<Self::$int_larger> + ConvertTo<Self::$int_larger>)*
- $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
- type $int: SInt<Self::$u32>
- $(+ From<Self::$int_scalar>)?
- + Compare<Bool = Self::$bool>
- + Make<Context = Self, Prim = $int_prim>
- $(+ ConvertTo<Self::$uint_smaller>)*
- $(+ ConvertTo<Self::$int_smaller>)*
- $($(+ ConvertTo<Self::$float_smaller>)?)*
- + ConvertTo<Self::$uint>
- $(+ ConvertTo<Self::$float>)?
- $(+ ConvertTo<Self::$uint_larger>)*
- $(+ Into<Self::$int_larger> + ConvertTo<Self::$int_larger>)*
- $($(+ Into<Self::$float_larger> + ConvertTo<Self::$float_larger>)?)*;
- make_float_type! {
- #[u32 = $u32]
- #[bool = $bool]
- [
- $({
- #[uint]
- $uint_smaller;
- #[int]
- $int_smaller;
- $(
- #[float]
- $float_smaller;
- )?
- },)*
- ],
- {
- #[uint]
- $uint;
- #[int]
- $int;
- $(
- #[float(prim = $float_prim $(, scalar = $float_scalar)?)]
- $float;
- )?
- },
- [
- $({
- #[uint]
- $uint_larger;
- #[int]
- $int_larger;
- $(
- #[float]
- $float_larger;
- )?
- },)*
- ]
- }
- };
-}
-
-macro_rules! make_uint_int_float_types {
- (
- #[u32 = $u32:ident]
- #[bool = $bool:ident]
- [$($smaller:tt,)*],
- $current:tt,
- [$first_larger:tt, $($larger:tt,)*]
- ) => {
- make_uint_int_float_type! {
- #[u32 = $u32]
- #[bool = $bool]
- [$($smaller,)*],
- $current,
- [$first_larger, $($larger,)*]
- }
- make_uint_int_float_types! {
- #[u32 = $u32]
- #[bool = $bool]
- [$($smaller,)* $current,],
- $first_larger,
- [$($larger,)*]
- }
- };
- (
- #[u32 = $u32:ident]
- #[bool = $bool:ident]
- [$($smaller:tt,)*],
- $current:tt,
- []
- ) => {
- make_uint_int_float_type! {
- #[u32 = $u32]
- #[bool = $bool]
- [$($smaller,)*],
- $current,
- []
- }
- };
-}
-
-#[rustfmt::skip] // work around for https://github.com/rust-lang/rustfmt/issues/4823
-macro_rules! make_types {
- (
- #[bool]
- $(#[scalar = $ScalarBool:ident])?
- type $Bool:ident;
-
- #[u8]
- $(#[scalar = $ScalarU8:ident])?
- type $U8:ident;
-
- #[u16]
- $(#[scalar = $ScalarU16:ident])?
- type $U16:ident;
-
- #[u32]
- $(#[scalar = $ScalarU32:ident])?
- type $U32:ident;
-
- #[u64]
- $(#[scalar = $ScalarU64:ident])?
- type $U64:ident;
-
- #[i8]
- $(#[scalar = $ScalarI8:ident])?
- type $I8:ident;
-
- #[i16]
- $(#[scalar = $ScalarI16:ident])?
- type $I16:ident;
-
- #[i32]
- $(#[scalar = $ScalarI32:ident])?
- type $I32:ident;
-
- #[i64]
- $(#[scalar = $ScalarI64:ident])?
- type $I64:ident;
-
- #[f16]
- $(#[scalar = $ScalarF16:ident])?
- type $F16:ident;
-
- #[f32]
- $(#[scalar = $ScalarF32:ident])?
- type $F32:ident;
-
- #[f64]
- $(#[scalar = $ScalarF64:ident])?
- type $F64:ident;
- ) => {
- type $Bool: Bool
- $(+ From<Self::$ScalarBool>)?
- + Make<Context = Self, Prim = bool>
- + Select<Self::$Bool>
- + Select<Self::$U8>
- + Select<Self::$U16>
- + Select<Self::$U32>
- + Select<Self::$U64>
- + Select<Self::$I8>
- + Select<Self::$I16>
- + Select<Self::$I32>
- + Select<Self::$I64>
- + Select<Self::$F16>
- + Select<Self::$F32>
- + Select<Self::$F64>;
- make_uint_int_float_types! {
- #[u32 = $U32]
- #[bool = $Bool]
- [],
- {
- #[uint(prim = u8 $(, scalar = $ScalarU8)?)]
- $U8;
- #[int(prim = i8 $(, scalar = $ScalarI8)?)]
- $I8;
- },
- [
- {
- #[uint(prim = u16 $(, scalar = $ScalarU16)?)]
- $U16;
- #[int(prim = i16 $(, scalar = $ScalarI16)?)]
- $I16;
- #[float(prim = F16 $(, scalar = $ScalarF16)?)]
- $F16;
- },
- {
- #[uint(prim = u32 $(, scalar = $ScalarU32)?)]
- $U32;
- #[int(prim = i32 $(, scalar = $ScalarI32)?)]
- $I32;
- #[float(prim = f32 $(, scalar = $ScalarF32)?)]
- $F32;
- },
- {
- #[uint(prim = u64 $(, scalar = $ScalarU64)?)]
- $U64;
- #[int(prim = i64 $(, scalar = $ScalarI64)?)]
- $I64;
- #[float(prim = f64 $(, scalar = $ScalarF64)?)]
- $F64;
- },
- ]
- }
- };
-}
-
/// reference used to build IR for Kazan; an empty type for `core::simd`
pub trait Context: Copy {
- make_types! {
- #[bool]
- type Bool;
-
- #[u8]
- type U8;
-
- #[u16]
- type U16;
-
- #[u32]
- type U32;
-
- #[u64]
- type U64;
-
- #[i8]
- type I8;
-
- #[i16]
- type I16;
-
- #[i32]
- type I32;
-
- #[i64]
- type I64;
-
- #[f16]
- type F16;
-
- #[f32]
- type F32;
-
- #[f64]
- type F64;
- }
- make_types! {
- #[bool]
- #[scalar = Bool]
- type VecBool;
-
- #[u8]
- #[scalar = U8]
- type VecU8;
-
- #[u16]
- #[scalar = U16]
- type VecU16;
-
- #[u32]
- #[scalar = U32]
- type VecU32;
-
- #[u64]
- #[scalar = U64]
- type VecU64;
-
- #[i8]
- #[scalar = I8]
- type VecI8;
-
- #[i16]
- #[scalar = I16]
- type VecI16;
-
- #[i32]
- #[scalar = I32]
- type VecI32;
-
- #[i64]
- #[scalar = I64]
- type VecI64;
-
- #[f16]
- #[scalar = F16]
- type VecF16;
-
- #[f32]
- #[scalar = F32]
- type VecF32;
-
- #[f64]
- #[scalar = F64]
- type VecF64;
- }
+ vector_math_proc_macro::make_context_types!();
fn make<T: Make<Context = Self>>(self, v: T::Prim) -> T {
T::make(self, v)
}
fn to(self) -> T;
}
+impl<T> ConvertTo<T> for T {
+ fn to(self) -> T {
+ self
+ }
+}
+
macro_rules! impl_convert_to_using_as {
- ($($src:ident -> [$($dest:ident),*];)*) => {
- $($(
- impl ConvertTo<$dest> for $src {
- fn to(self) -> $dest {
- self as $dest
+ ($first:ident $(, $ty:ident)*) => {
+ $(
+ impl ConvertTo<$first> for $ty {
+ fn to(self) -> $first {
+ self as $first
}
}
- )*)*
- };
- ([$($src:ident),*] -> $dest:tt;) => {
- impl_convert_to_using_as! {
- $(
- $src -> $dest;
- )*
- }
+ impl ConvertTo<$ty> for $first {
+ fn to(self) -> $ty {
+ self as $ty
+ }
+ }
+ )*
+ impl_convert_to_using_as![$($ty),*];
};
- ([$($src:ident),*];) => {
- impl_convert_to_using_as! {
- [$($src),*] -> [$($src),*];
- }
+ () => {
};
}
-impl_convert_to_using_as! {
- [u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
-}
+impl_convert_to_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64];
pub trait Number:
Compare
{
}
-pub trait Int<ShiftRhs>:
- Number
- + BitOps
- + Shl<ShiftRhs, Output = Self>
- + Shr<ShiftRhs, Output = Self>
- + ShlAssign<ShiftRhs>
- + ShrAssign<ShiftRhs>
+pub trait Int:
+ Number + BitOps + Shl<Output = Self> + Shr<Output = Self> + ShlAssign + ShrAssign
{
fn leading_zeros(self) -> Self;
fn leading_ones(self) -> Self {
fn count_ones(self) -> Self;
}
-pub trait UInt<ShiftRhs>: Int<ShiftRhs> {}
+pub trait UInt: Int {}
-pub trait SInt<ShiftRhs>: Int<ShiftRhs> + Neg<Output = Self> {}
+pub trait SInt: Int + Neg<Output = Self> {}
macro_rules! impl_int {
($ty:ident) => {
- impl Int<u32> for $ty {
+ impl Int for $ty {
fn leading_zeros(self) -> Self {
self.leading_zeros() as Self
}
($($ty:ident),*) => {
$(
impl_int!($ty);
- impl UInt<u32> for $ty {}
+ impl UInt for $ty {}
)*
};
}
($($ty:ident),*) => {
$(
impl_int!($ty);
- impl SInt<u32> for $ty {}
+ impl SInt for $ty {}
)*
};
}
impl_sint![i8, i16, i32, i64];
-pub trait Float<BitsShiftRhs: Make<Context = Self::Context, Prim = u32>>:
- Number + Neg<Output = Self>
-{
+pub trait Float: Number + Neg<Output = Self> {
type FloatEncoding: FloatEncoding + Make<Context = Scalar, Prim = <Self as Make>::Prim>;
- type BitsType: UInt<BitsShiftRhs>
- + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::BitsType>
+ type BitsType: UInt
+ + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float>::BitsType>
+ ConvertTo<Self::SignedBitsType>
+ Compare<Bool = Self::Bool>;
- type SignedBitsType: SInt<BitsShiftRhs>
- + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float<u32>>::SignedBitsType>
+ type SignedBitsType: SInt
+ + Make<Context = Self::Context, Prim = <Self::FloatEncoding as Float>::SignedBitsType>
+ ConvertTo<Self::BitsType>
+ Compare<Bool = Self::Bool>;
fn abs(self) -> Self;
macro_rules! impl_float {
($ty:ty, $bits:ty, $signed_bits:ty) => {
- impl Float<u32> for $ty {
+ impl Float for $ty {
type FloatEncoding = $ty;
type BitsType = $bits;
type SignedBitsType = $signed_bits;
};
}
-impl_compare_using_partial_cmp![u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];
+impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64];
--- /dev/null
+[package]
+name = "vector-math-proc-macro"
+version = "0.1.0"
+authors = ["Jacob Lifshay <programmerjake@gmail.com>"]
+edition = "2018"
+license = "MIT OR Apache-2.0"
+
+[lib]
+proc-macro = true
+
+[dependencies]
+quote = "1.0"
+proc-macro2 = "1.0"
+syn = { version = "1.0", features = [] }
--- /dev/null
+use std::{
+ cmp::Ordering,
+ collections::{BTreeSet, HashMap},
+ hash::Hash,
+};
+
+use proc_macro2::{Ident, Span, TokenStream};
+use quote::{quote, ToTokens};
+use syn::{
+ parse::{Parse, ParseStream},
+ parse_macro_input,
+};
+
+struct Input {}
+
+impl Parse for Input {
+ fn parse(_input: ParseStream) -> syn::Result<Self> {
+ Ok(Input {})
+ }
+}
+
+macro_rules! make_enum {
+ (
+ $vis:vis enum $ty:ident {
+ $(
+ $field:ident $(= $value:expr)?,
+ )*
+ }
+ ) => {
+ #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
+ #[repr(u8)]
+ $vis enum $ty {
+ $(
+ $field $(= $value)?,
+ )*
+ }
+
+ impl $ty {
+ #[allow(dead_code)]
+ $vis const VALUES: &'static [Self] = &[
+ $(
+ Self::$field,
+ )*
+ ];
+ }
+ };
+}
+
+make_enum! {
+ enum TypeKind {
+ Bool,
+ UInt,
+ SInt,
+ Float,
+ }
+}
+
+make_enum! {
+ enum VectorScalar {
+ Scalar,
+ Vector,
+ }
+}
+
+make_enum! {
+ enum TypeBits {
+ Bits8 = 8,
+ Bits16 = 16,
+ Bits32 = 32,
+ Bits64 = 64,
+ }
+}
+
+impl TypeBits {
+ const fn bits(self) -> u32 {
+ self as u8 as u32
+ }
+}
+
+make_enum! {
+ enum Convertibility {
+ Impossible,
+ Lossy,
+ Lossless,
+ }
+}
+
+impl Convertibility {
+ const fn make_possible(lossless: bool) -> Self {
+ if lossless {
+ Self::Lossless
+ } else {
+ Self::Lossy
+ }
+ }
+ const fn make_non_lossy(possible: bool) -> Self {
+ if possible {
+ Self::Lossless
+ } else {
+ Self::Impossible
+ }
+ }
+ const fn possible(self) -> bool {
+ match self {
+ Convertibility::Impossible => false,
+ Convertibility::Lossy | Convertibility::Lossless => true,
+ }
+ }
+}
+
+impl TypeKind {
+ fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool {
+ match self {
+ TypeKind::Float => bits >= TypeBits::Bits16,
+ TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector,
+ TypeKind::UInt | TypeKind::SInt => true,
+ }
+ }
+ fn prim_ty(self, bits: TypeBits) -> Ident {
+ Ident::new(
+ &match self {
+ TypeKind::Bool => "bool".into(),
+ TypeKind::UInt => format!("u{}", bits.bits()),
+ TypeKind::SInt => format!("i{}", bits.bits()),
+ TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(),
+ TypeKind::Float => format!("f{}", bits.bits()),
+ },
+ Span::call_site(),
+ )
+ }
+ fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident {
+ let vec_prefix = match vector_scalar {
+ VectorScalar::Scalar => "",
+ VectorScalar::Vector => "Vec",
+ };
+ Ident::new(
+ &match self {
+ TypeKind::Bool => match vector_scalar {
+ VectorScalar::Scalar => "Bool".into(),
+ VectorScalar::Vector => format!("VecBool{}", bits.bits()),
+ },
+ TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()),
+ TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()),
+ TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()),
+ },
+ Span::call_site(),
+ )
+ }
+ fn convertibility_to(
+ self,
+ src_bits: TypeBits,
+ dest_type_kind: TypeKind,
+ dest_bits: TypeBits,
+ ) -> Convertibility {
+ Convertibility::make_possible(match (self, dest_type_kind) {
+ (TypeKind::Bool, _) | (_, TypeKind::Bool) => {
+ return Convertibility::make_non_lossy(self == dest_type_kind);
+ }
+ (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits,
+ (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits,
+ (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits,
+ (TypeKind::SInt, TypeKind::UInt) => false,
+ (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits,
+ (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits,
+ (TypeKind::Float, TypeKind::UInt) => false,
+ (TypeKind::Float, TypeKind::SInt) => false,
+ (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits,
+ })
+ }
+}
+
+#[derive(Default, Debug)]
+struct TokenStreamSetElement {
+ token_stream: TokenStream,
+ text: String,
+}
+
+impl Ord for TokenStreamSetElement {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.text.cmp(&other.text)
+ }
+}
+
+impl PartialOrd for TokenStreamSetElement {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl PartialEq for TokenStreamSetElement {
+ fn eq(&self, other: &Self) -> bool {
+ self.text == other.text
+ }
+}
+
+impl Eq for TokenStreamSetElement {}
+
+impl From<TokenStream> for TokenStreamSetElement {
+ fn from(token_stream: TokenStream) -> Self {
+ let text = token_stream.to_string();
+ Self { token_stream, text }
+ }
+}
+
+impl ToTokens for TokenStreamSetElement {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ self.token_stream.to_tokens(tokens)
+ }
+
+ fn to_token_stream(&self) -> TokenStream {
+ self.token_stream.to_token_stream()
+ }
+
+ fn into_token_stream(self) -> TokenStream {
+ self.token_stream
+ }
+}
+
+type TokenStreamSet = BTreeSet<TokenStreamSetElement>;
+
+#[derive(Debug, Default)]
+struct TraitSets {
+ trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>,
+}
+
+impl TraitSets {
+ fn get(
+ &mut self,
+ type_kind: TypeKind,
+ mut bits: TypeBits,
+ vector_scalar: VectorScalar,
+ ) -> &mut TokenStreamSet {
+ if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar {
+ bits = TypeBits::Bits8;
+ }
+ self.trait_sets_map
+ .entry((type_kind, bits, vector_scalar))
+ .or_default()
+ }
+ fn add_trait(
+ &mut self,
+ type_kind: TypeKind,
+ bits: TypeBits,
+ vector_scalar: VectorScalar,
+ v: impl Into<TokenStreamSetElement>,
+ ) {
+ self.get(type_kind, bits, vector_scalar).insert(v.into());
+ }
+ fn fill(&mut self) {
+ for &bits in TypeBits::VALUES {
+ for &type_kind in TypeKind::VALUES {
+ for &vector_scalar in VectorScalar::VALUES {
+ if !type_kind.is_valid(bits, vector_scalar) {
+ continue;
+ }
+ let prim_ty = type_kind.prim_ty(bits);
+ let ty = type_kind.ty(bits, vector_scalar);
+ if vector_scalar == VectorScalar::Vector {
+ let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar);
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { From<Self::#scalar_ty> },
+ );
+ }
+ let bool_ty = TypeKind::Bool.ty(bits, vector_scalar);
+ let uint_ty = TypeKind::UInt.ty(bits, vector_scalar);
+ let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
+ let type_trait = match type_kind {
+ TypeKind::Bool => quote! { Bool },
+ TypeKind::UInt => quote! { UInt },
+ TypeKind::SInt => quote! { SInt },
+ TypeKind::Float => quote! { Float<
+ BitsType = Self::#uint_ty,
+ SignedBitsType = Self::#sint_ty,
+ FloatEncoding = #prim_ty,
+ > },
+ };
+ self.add_trait(type_kind, bits, vector_scalar, type_trait);
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { Compare<Bool = Self::#bool_ty> },
+ );
+ self.add_trait(
+ TypeKind::Bool,
+ bits,
+ vector_scalar,
+ quote! { Select<Self::#ty> },
+ );
+ self.add_trait(
+ TypeKind::Bool,
+ TypeBits::Bits8,
+ VectorScalar::Scalar,
+ quote! { Select<Self::#ty> },
+ );
+ for &other_bits in TypeBits::VALUES {
+ for &other_type_kind in TypeKind::VALUES {
+ if !other_type_kind.is_valid(other_bits, vector_scalar) {
+ continue;
+ }
+ if other_bits == bits && other_type_kind == type_kind {
+ continue;
+ }
+ let other_ty = other_type_kind.ty(other_bits, vector_scalar);
+ let convertibility =
+ other_type_kind.convertibility_to(other_bits, type_kind, bits);
+ if convertibility == Convertibility::Lossless {
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { From<Self::#other_ty> },
+ );
+ }
+ if convertibility.possible() {
+ self.add_trait(
+ other_type_kind,
+ other_bits,
+ vector_scalar,
+ quote! { ConvertTo<Self::#ty> },
+ );
+ }
+ }
+ }
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { Make<Context = Self, Prim = #prim_ty> },
+ );
+ }
+ }
+ }
+ }
+}
+
+impl Input {
+ fn to_tokens(&self) -> syn::Result<TokenStream> {
+ let mut types = Vec::new();
+ let mut trait_sets = TraitSets::default();
+ trait_sets.fill();
+ for &bits in TypeBits::VALUES {
+ for &type_kind in TypeKind::VALUES {
+ for &vector_scalar in VectorScalar::VALUES {
+ if !type_kind.is_valid(bits, vector_scalar) {
+ continue;
+ }
+ let ty = type_kind.ty(bits, vector_scalar);
+ let traits = trait_sets.get(type_kind, bits, vector_scalar);
+ types.push(quote! {
+ type #ty: #(#traits)+*;
+ });
+ }
+ }
+ }
+ Ok(quote! {#(#types)*})
+ }
+}
+
+#[proc_macro]
+pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let input = parse_macro_input!(input as Input);
+ match input.to_tokens() {
+ Ok(retval) => retval,
+ Err(err) => err.to_compile_error(),
+ }
+ .into()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test() -> syn::Result<()> {
+ Input {}.to_tokens()?;
+ Ok(())
+ }
+}