use std::{ cmp::Ordering, collections::{BTreeSet, HashMap}, hash::Hash, }; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, parse_macro_input, }; struct Input {} impl Parse for Input { fn parse(_input: ParseStream) -> syn::Result { Ok(Input {}) } } macro_rules! make_enum { ( $vis:vis enum $ty:ident { $( $field:ident $(= $value:expr)?, )* } ) => { #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] #[repr(u8)] $vis enum $ty { $( $field $(= $value)?, )* } impl $ty { #[allow(dead_code)] $vis const VALUES: &'static [Self] = &[ $( Self::$field, )* ]; } }; } make_enum! { enum TypeKind { Bool, UInt, SInt, Float, } } make_enum! { enum VectorScalar { Scalar, Vector, } } make_enum! { enum TypeBits { Bits8 = 8, Bits16 = 16, Bits32 = 32, Bits64 = 64, } } impl TypeBits { const fn bits(self) -> u32 { self as u8 as u32 } } make_enum! { enum Convertibility { Impossible, Lossy, Lossless, } } impl Convertibility { const fn make_possible(lossless: bool) -> Self { if lossless { Self::Lossless } else { Self::Lossy } } const fn make_non_lossy(possible: bool) -> Self { if possible { Self::Lossless } else { Self::Impossible } } const fn possible(self) -> bool { match self { Convertibility::Impossible => false, Convertibility::Lossy | Convertibility::Lossless => true, } } } impl TypeKind { fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool { match self { TypeKind::Float => bits >= TypeBits::Bits16, TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector, TypeKind::UInt | TypeKind::SInt => true, } } fn prim_ty(self, bits: TypeBits) -> Ident { Ident::new( &match self { TypeKind::Bool => "bool".into(), TypeKind::UInt => format!("u{}", bits.bits()), TypeKind::SInt => format!("i{}", bits.bits()), TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(), TypeKind::Float => format!("f{}", bits.bits()), }, Span::call_site(), ) } fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident { let vec_prefix = match vector_scalar { VectorScalar::Scalar => "", VectorScalar::Vector => "Vec", }; Ident::new( &match self { TypeKind::Bool => match vector_scalar { VectorScalar::Scalar => "Bool".into(), VectorScalar::Vector => format!("VecBool{}", bits.bits()), }, TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()), TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()), TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()), }, Span::call_site(), ) } fn convertibility_to( self, src_bits: TypeBits, dest_type_kind: TypeKind, dest_bits: TypeBits, ) -> Convertibility { Convertibility::make_possible(match (self, dest_type_kind) { (TypeKind::Bool, _) | (_, TypeKind::Bool) => { return Convertibility::make_non_lossy(self == dest_type_kind); } (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits, (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits, (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits, (TypeKind::SInt, TypeKind::UInt) => false, (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits, (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits, (TypeKind::Float, TypeKind::UInt) => false, (TypeKind::Float, TypeKind::SInt) => false, (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits, }) } } #[derive(Default, Debug)] struct TokenStreamSetElement { token_stream: TokenStream, text: String, } impl Ord for TokenStreamSetElement { fn cmp(&self, other: &Self) -> Ordering { self.text.cmp(&other.text) } } impl PartialOrd for TokenStreamSetElement { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl PartialEq for TokenStreamSetElement { fn eq(&self, other: &Self) -> bool { self.text == other.text } } impl Eq for TokenStreamSetElement {} impl From for TokenStreamSetElement { fn from(token_stream: TokenStream) -> Self { let text = token_stream.to_string(); Self { token_stream, text } } } impl ToTokens for TokenStreamSetElement { fn to_tokens(&self, tokens: &mut TokenStream) { self.token_stream.to_tokens(tokens) } fn to_token_stream(&self) -> TokenStream { self.token_stream.to_token_stream() } fn into_token_stream(self) -> TokenStream { self.token_stream } } type TokenStreamSet = BTreeSet; #[derive(Debug, Default)] struct TraitSets { trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>, } impl TraitSets { fn get( &mut self, type_kind: TypeKind, mut bits: TypeBits, vector_scalar: VectorScalar, ) -> &mut TokenStreamSet { if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar { bits = TypeBits::Bits8; } self.trait_sets_map .entry((type_kind, bits, vector_scalar)) .or_default() } fn add_trait( &mut self, type_kind: TypeKind, bits: TypeBits, vector_scalar: VectorScalar, v: impl Into, ) { self.get(type_kind, bits, vector_scalar).insert(v.into()); } fn fill(&mut self) { for &bits in TypeBits::VALUES { for &type_kind in TypeKind::VALUES { for &vector_scalar in VectorScalar::VALUES { if !type_kind.is_valid(bits, vector_scalar) { continue; } let prim_ty = type_kind.prim_ty(bits); let ty = type_kind.ty(bits, vector_scalar); if vector_scalar == VectorScalar::Vector { let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar); self.add_trait( type_kind, bits, vector_scalar, quote! { From }, ); } let bool_ty = TypeKind::Bool.ty(bits, vector_scalar); let uint_ty = TypeKind::UInt.ty(bits, vector_scalar); let sint_ty = TypeKind::SInt.ty(bits, vector_scalar); let type_trait = match type_kind { TypeKind::Bool => quote! { Bool }, TypeKind::UInt => quote! { UInt }, TypeKind::SInt => quote! { SInt }, TypeKind::Float => quote! { Float< BitsType = Self::#uint_ty, SignedBitsType = Self::#sint_ty, FloatEncoding = #prim_ty, > }, }; self.add_trait(type_kind, bits, vector_scalar, type_trait); self.add_trait( type_kind, bits, vector_scalar, quote! { Compare }, ); self.add_trait( TypeKind::Bool, bits, vector_scalar, quote! { Select }, ); self.add_trait( TypeKind::Bool, TypeBits::Bits8, VectorScalar::Scalar, quote! { Select }, ); for &other_bits in TypeBits::VALUES { for &other_type_kind in TypeKind::VALUES { if !other_type_kind.is_valid(other_bits, vector_scalar) { continue; } if other_bits == bits && other_type_kind == type_kind { continue; } let other_ty = other_type_kind.ty(other_bits, vector_scalar); let convertibility = other_type_kind.convertibility_to(other_bits, type_kind, bits); if convertibility == Convertibility::Lossless { self.add_trait( type_kind, bits, vector_scalar, quote! { From }, ); } if convertibility.possible() { self.add_trait( type_kind, bits, vector_scalar, quote! { ConvertFrom }, ); } } } self.add_trait( type_kind, bits, vector_scalar, quote! { Make }, ); } } } } } impl Input { fn to_tokens(&self) -> syn::Result { let mut types = Vec::new(); let mut trait_sets = TraitSets::default(); trait_sets.fill(); for &bits in TypeBits::VALUES { for &type_kind in TypeKind::VALUES { for &vector_scalar in VectorScalar::VALUES { if !type_kind.is_valid(bits, vector_scalar) { continue; } let ty = type_kind.ty(bits, vector_scalar); let traits = trait_sets.get(type_kind, bits, vector_scalar); types.push(quote! { type #ty: #(#traits)+*; }); } } } Ok(quote! {#(#types)*}) } } #[proc_macro] pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as Input); match input.to_tokens() { Ok(retval) => retval, Err(err) => err.to_compile_error(), } .into() } #[cfg(test)] mod tests { use super::*; #[test] fn test() -> syn::Result<()> { Input {}.to_tokens()?; Ok(()) } }