-use std::{
- cmp::Ordering,
- collections::{BTreeSet, HashMap},
- hash::Hash,
-};
-
-use proc_macro2::{Ident, Span, TokenStream};
-use quote::{quote, ToTokens};
-use syn::{
- parse::{Parse, ParseStream},
- parse_macro_input,
-};
-
-struct Input {}
-
-impl Parse for Input {
- fn parse(_input: ParseStream) -> syn::Result<Self> {
- Ok(Input {})
- }
-}
-
-macro_rules! make_enum {
- (
- $vis:vis enum $ty:ident {
- $(
- $field:ident $(= $value:expr)?,
- )*
- }
- ) => {
- #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
- #[repr(u8)]
- $vis enum $ty {
- $(
- $field $(= $value)?,
- )*
- }
-
- impl $ty {
- #[allow(dead_code)]
- $vis const VALUES: &'static [Self] = &[
- $(
- Self::$field,
- )*
- ];
- }
- };
-}
-
-make_enum! {
- enum TypeKind {
- Bool,
- UInt,
- SInt,
- Float,
- }
-}
-
-make_enum! {
- enum VectorScalar {
- Scalar,
- Vector,
- }
-}
-
-make_enum! {
- enum TypeBits {
- Bits8 = 8,
- Bits16 = 16,
- Bits32 = 32,
- Bits64 = 64,
- }
-}
-
-impl TypeBits {
- const fn bits(self) -> u32 {
- self as u8 as u32
- }
-}
-
-make_enum! {
- enum Convertibility {
- Impossible,
- Lossy,
- Lossless,
- }
-}
-
-impl Convertibility {
- const fn make_possible(lossless: bool) -> Self {
- if lossless {
- Self::Lossless
- } else {
- Self::Lossy
- }
- }
- const fn make_non_lossy(possible: bool) -> Self {
- if possible {
- Self::Lossless
- } else {
- Self::Impossible
- }
- }
- const fn possible(self) -> bool {
- match self {
- Convertibility::Impossible => false,
- Convertibility::Lossy | Convertibility::Lossless => true,
- }
- }
-}
-
-impl TypeKind {
- fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool {
- match self {
- TypeKind::Float => bits >= TypeBits::Bits16,
- TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector,
- TypeKind::UInt | TypeKind::SInt => true,
- }
- }
- fn prim_ty(self, bits: TypeBits) -> Ident {
- Ident::new(
- &match self {
- TypeKind::Bool => "bool".into(),
- TypeKind::UInt => format!("u{}", bits.bits()),
- TypeKind::SInt => format!("i{}", bits.bits()),
- TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(),
- TypeKind::Float => format!("f{}", bits.bits()),
- },
- Span::call_site(),
- )
- }
- fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident {
- let vec_prefix = match vector_scalar {
- VectorScalar::Scalar => "",
- VectorScalar::Vector => "Vec",
- };
- Ident::new(
- &match self {
- TypeKind::Bool => match vector_scalar {
- VectorScalar::Scalar => "Bool".into(),
- VectorScalar::Vector => format!("VecBool{}", bits.bits()),
- },
- TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()),
- TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()),
- TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()),
- },
- Span::call_site(),
- )
- }
- fn convertibility_to(
- self,
- src_bits: TypeBits,
- dest_type_kind: TypeKind,
- dest_bits: TypeBits,
- ) -> Convertibility {
- Convertibility::make_possible(match (self, dest_type_kind) {
- (TypeKind::Bool, _) | (_, TypeKind::Bool) => {
- return Convertibility::make_non_lossy(self == dest_type_kind);
- }
- (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits,
- (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits,
- (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits,
- (TypeKind::SInt, TypeKind::UInt) => false,
- (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits,
- (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits,
- (TypeKind::Float, TypeKind::UInt) => false,
- (TypeKind::Float, TypeKind::SInt) => false,
- (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits,
- })
- }
-}
-
-#[derive(Default, Debug)]
-struct TokenStreamSetElement {
- token_stream: TokenStream,
- text: String,
-}
-
-impl Ord for TokenStreamSetElement {
- fn cmp(&self, other: &Self) -> Ordering {
- self.text.cmp(&other.text)
- }
-}
-
-impl PartialOrd for TokenStreamSetElement {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- Some(self.cmp(other))
- }
-}
-
-impl PartialEq for TokenStreamSetElement {
- fn eq(&self, other: &Self) -> bool {
- self.text == other.text
- }
-}
-
-impl Eq for TokenStreamSetElement {}
-
-impl From<TokenStream> for TokenStreamSetElement {
- fn from(token_stream: TokenStream) -> Self {
- let text = token_stream.to_string();
- Self { token_stream, text }
- }
-}
-
-impl ToTokens for TokenStreamSetElement {
- fn to_tokens(&self, tokens: &mut TokenStream) {
- self.token_stream.to_tokens(tokens)
- }
-
- fn to_token_stream(&self) -> TokenStream {
- self.token_stream.to_token_stream()
- }
-
- fn into_token_stream(self) -> TokenStream {
- self.token_stream
- }
-}
-
-type TokenStreamSet = BTreeSet<TokenStreamSetElement>;
-
-#[derive(Debug, Default)]
-struct TraitSets {
- trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>,
-}
-
-impl TraitSets {
- fn get(
- &mut self,
- type_kind: TypeKind,
- mut bits: TypeBits,
- vector_scalar: VectorScalar,
- ) -> &mut TokenStreamSet {
- if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar {
- bits = TypeBits::Bits8;
- }
- self.trait_sets_map
- .entry((type_kind, bits, vector_scalar))
- .or_default()
- }
- fn add_trait(
- &mut self,
- type_kind: TypeKind,
- bits: TypeBits,
- vector_scalar: VectorScalar,
- v: impl Into<TokenStreamSetElement>,
- ) {
- self.get(type_kind, bits, vector_scalar).insert(v.into());
- }
- fn fill(&mut self) {
- for &bits in TypeBits::VALUES {
- for &type_kind in TypeKind::VALUES {
- for &vector_scalar in VectorScalar::VALUES {
- if !type_kind.is_valid(bits, vector_scalar) {
- continue;
- }
- let prim_ty = type_kind.prim_ty(bits);
- let ty = type_kind.ty(bits, vector_scalar);
- if vector_scalar == VectorScalar::Vector {
- let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar);
- self.add_trait(
- type_kind,
- bits,
- vector_scalar,
- quote! { From<Self::#scalar_ty> },
- );
- }
- let bool_ty = TypeKind::Bool.ty(bits, vector_scalar);
- let uint_ty = TypeKind::UInt.ty(bits, vector_scalar);
- let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
- let type_trait = match type_kind {
- TypeKind::Bool => quote! { Bool },
- TypeKind::UInt => {
- quote! { UInt<PrimUInt = #prim_ty, SignedType = Self::#sint_ty> }
- }
- TypeKind::SInt => {
- quote! { SInt<PrimSInt = #prim_ty, UnsignedType = Self::#uint_ty> }
- }
- TypeKind::Float => quote! { Float<
- BitsType = Self::#uint_ty,
- SignedBitsType = Self::#sint_ty,
- PrimFloat = #prim_ty,
- > },
- };
- self.add_trait(type_kind, bits, vector_scalar, type_trait);
- self.add_trait(
- type_kind,
- bits,
- vector_scalar,
- quote! { Compare<Bool = Self::#bool_ty> },
- );
- self.add_trait(
- TypeKind::Bool,
- bits,
- vector_scalar,
- quote! { Select<Self::#ty> },
- );
- self.add_trait(
- TypeKind::Bool,
- TypeBits::Bits8,
- VectorScalar::Scalar,
- quote! { Select<Self::#ty> },
- );
- for &other_bits in TypeBits::VALUES {
- for &other_type_kind in TypeKind::VALUES {
- if !other_type_kind.is_valid(other_bits, vector_scalar) {
- continue;
- }
- if other_bits == bits && other_type_kind == type_kind {
- continue;
- }
- let other_ty = other_type_kind.ty(other_bits, vector_scalar);
- let convertibility =
- other_type_kind.convertibility_to(other_bits, type_kind, bits);
- if convertibility == Convertibility::Lossless {
- self.add_trait(
- type_kind,
- bits,
- vector_scalar,
- quote! { From<Self::#other_ty> },
- );
- }
- if convertibility.possible() {
- self.add_trait(
- type_kind,
- bits,
- vector_scalar,
- quote! { ConvertFrom<Self::#other_ty> },
- );
- }
- }
- }
- self.add_trait(
- type_kind,
- bits,
- vector_scalar,
- quote! { Make<Context = Self, Prim = #prim_ty> },
- );
- }
- }
- }
- }
-}
-
-impl Input {
- fn to_tokens(&self) -> syn::Result<TokenStream> {
- let mut types = Vec::new();
- let mut trait_sets = TraitSets::default();
- trait_sets.fill();
- for &bits in TypeBits::VALUES {
- for &type_kind in TypeKind::VALUES {
- for &vector_scalar in VectorScalar::VALUES {
- if !type_kind.is_valid(bits, vector_scalar) {
- continue;
- }
- let ty = type_kind.ty(bits, vector_scalar);
- let traits = trait_sets.get(type_kind, bits, vector_scalar);
- types.push(quote! {
- type #ty: #(#traits)+*;
- });
- }
- }
- }
- Ok(quote! {#(#types)*})
- }
-}
-
-#[proc_macro]
-pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
- let input = parse_macro_input!(input as Input);
- match input.to_tokens() {
- Ok(retval) => retval,
- Err(err) => err.to_compile_error(),
- }
- .into()
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test() -> syn::Result<()> {
- Input {}.to_tokens()?;
- Ok(())
- }
-}