+use std::{
+ cmp::Ordering,
+ collections::{BTreeSet, HashMap},
+ hash::Hash,
+};
+
+use proc_macro2::{Ident, Span, TokenStream};
+use quote::{quote, ToTokens};
+use syn::{
+ parse::{Parse, ParseStream},
+ parse_macro_input,
+};
+
+struct Input {}
+
+impl Parse for Input {
+ fn parse(_input: ParseStream) -> syn::Result<Self> {
+ Ok(Input {})
+ }
+}
+
+macro_rules! make_enum {
+ (
+ $vis:vis enum $ty:ident {
+ $(
+ $field:ident $(= $value:expr)?,
+ )*
+ }
+ ) => {
+ #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
+ #[repr(u8)]
+ $vis enum $ty {
+ $(
+ $field $(= $value)?,
+ )*
+ }
+
+ impl $ty {
+ #[allow(dead_code)]
+ $vis const VALUES: &'static [Self] = &[
+ $(
+ Self::$field,
+ )*
+ ];
+ }
+ };
+}
+
+make_enum! {
+ enum TypeKind {
+ Bool,
+ UInt,
+ SInt,
+ Float,
+ }
+}
+
+make_enum! {
+ enum VectorScalar {
+ Scalar,
+ Vector,
+ }
+}
+
+make_enum! {
+ enum TypeBits {
+ Bits8 = 8,
+ Bits16 = 16,
+ Bits32 = 32,
+ Bits64 = 64,
+ }
+}
+
+impl TypeBits {
+ const fn bits(self) -> u32 {
+ self as u8 as u32
+ }
+}
+
+make_enum! {
+ enum Convertibility {
+ Impossible,
+ Lossy,
+ Lossless,
+ }
+}
+
+impl Convertibility {
+ const fn make_possible(lossless: bool) -> Self {
+ if lossless {
+ Self::Lossless
+ } else {
+ Self::Lossy
+ }
+ }
+ const fn make_non_lossy(possible: bool) -> Self {
+ if possible {
+ Self::Lossless
+ } else {
+ Self::Impossible
+ }
+ }
+ const fn possible(self) -> bool {
+ match self {
+ Convertibility::Impossible => false,
+ Convertibility::Lossy | Convertibility::Lossless => true,
+ }
+ }
+}
+
+impl TypeKind {
+ fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool {
+ match self {
+ TypeKind::Float => bits >= TypeBits::Bits16,
+ TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector,
+ TypeKind::UInt | TypeKind::SInt => true,
+ }
+ }
+ fn prim_ty(self, bits: TypeBits) -> Ident {
+ Ident::new(
+ &match self {
+ TypeKind::Bool => "bool".into(),
+ TypeKind::UInt => format!("u{}", bits.bits()),
+ TypeKind::SInt => format!("i{}", bits.bits()),
+ TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(),
+ TypeKind::Float => format!("f{}", bits.bits()),
+ },
+ Span::call_site(),
+ )
+ }
+ fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident {
+ let vec_prefix = match vector_scalar {
+ VectorScalar::Scalar => "",
+ VectorScalar::Vector => "Vec",
+ };
+ Ident::new(
+ &match self {
+ TypeKind::Bool => match vector_scalar {
+ VectorScalar::Scalar => "Bool".into(),
+ VectorScalar::Vector => format!("VecBool{}", bits.bits()),
+ },
+ TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()),
+ TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()),
+ TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()),
+ },
+ Span::call_site(),
+ )
+ }
+ fn convertibility_to(
+ self,
+ src_bits: TypeBits,
+ dest_type_kind: TypeKind,
+ dest_bits: TypeBits,
+ ) -> Convertibility {
+ Convertibility::make_possible(match (self, dest_type_kind) {
+ (TypeKind::Bool, _) | (_, TypeKind::Bool) => {
+ return Convertibility::make_non_lossy(self == dest_type_kind);
+ }
+ (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits,
+ (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits,
+ (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits,
+ (TypeKind::SInt, TypeKind::UInt) => false,
+ (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits,
+ (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits,
+ (TypeKind::Float, TypeKind::UInt) => false,
+ (TypeKind::Float, TypeKind::SInt) => false,
+ (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits,
+ })
+ }
+}
+
+#[derive(Default, Debug)]
+struct TokenStreamSetElement {
+ token_stream: TokenStream,
+ text: String,
+}
+
+impl Ord for TokenStreamSetElement {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.text.cmp(&other.text)
+ }
+}
+
+impl PartialOrd for TokenStreamSetElement {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl PartialEq for TokenStreamSetElement {
+ fn eq(&self, other: &Self) -> bool {
+ self.text == other.text
+ }
+}
+
+impl Eq for TokenStreamSetElement {}
+
+impl From<TokenStream> for TokenStreamSetElement {
+ fn from(token_stream: TokenStream) -> Self {
+ let text = token_stream.to_string();
+ Self { token_stream, text }
+ }
+}
+
+impl ToTokens for TokenStreamSetElement {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ self.token_stream.to_tokens(tokens)
+ }
+
+ fn to_token_stream(&self) -> TokenStream {
+ self.token_stream.to_token_stream()
+ }
+
+ fn into_token_stream(self) -> TokenStream {
+ self.token_stream
+ }
+}
+
+type TokenStreamSet = BTreeSet<TokenStreamSetElement>;
+
+#[derive(Debug, Default)]
+struct TraitSets {
+ trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>,
+}
+
+impl TraitSets {
+ fn get(
+ &mut self,
+ type_kind: TypeKind,
+ mut bits: TypeBits,
+ vector_scalar: VectorScalar,
+ ) -> &mut TokenStreamSet {
+ if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar {
+ bits = TypeBits::Bits8;
+ }
+ self.trait_sets_map
+ .entry((type_kind, bits, vector_scalar))
+ .or_default()
+ }
+ fn add_trait(
+ &mut self,
+ type_kind: TypeKind,
+ bits: TypeBits,
+ vector_scalar: VectorScalar,
+ v: impl Into<TokenStreamSetElement>,
+ ) {
+ self.get(type_kind, bits, vector_scalar).insert(v.into());
+ }
+ fn fill(&mut self) {
+ for &bits in TypeBits::VALUES {
+ for &type_kind in TypeKind::VALUES {
+ for &vector_scalar in VectorScalar::VALUES {
+ if !type_kind.is_valid(bits, vector_scalar) {
+ continue;
+ }
+ let prim_ty = type_kind.prim_ty(bits);
+ let ty = type_kind.ty(bits, vector_scalar);
+ if vector_scalar == VectorScalar::Vector {
+ let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar);
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { From<Self::#scalar_ty> },
+ );
+ }
+ let bool_ty = TypeKind::Bool.ty(bits, vector_scalar);
+ let uint_ty = TypeKind::UInt.ty(bits, vector_scalar);
+ let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
+ let type_trait = match type_kind {
+ TypeKind::Bool => quote! { Bool },
+ TypeKind::UInt => quote! { UInt },
+ TypeKind::SInt => quote! { SInt },
+ TypeKind::Float => quote! { Float<
+ BitsType = Self::#uint_ty,
+ SignedBitsType = Self::#sint_ty,
+ FloatEncoding = #prim_ty,
+ > },
+ };
+ self.add_trait(type_kind, bits, vector_scalar, type_trait);
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { Compare<Bool = Self::#bool_ty> },
+ );
+ self.add_trait(
+ TypeKind::Bool,
+ bits,
+ vector_scalar,
+ quote! { Select<Self::#ty> },
+ );
+ self.add_trait(
+ TypeKind::Bool,
+ TypeBits::Bits8,
+ VectorScalar::Scalar,
+ quote! { Select<Self::#ty> },
+ );
+ for &other_bits in TypeBits::VALUES {
+ for &other_type_kind in TypeKind::VALUES {
+ if !other_type_kind.is_valid(other_bits, vector_scalar) {
+ continue;
+ }
+ if other_bits == bits && other_type_kind == type_kind {
+ continue;
+ }
+ let other_ty = other_type_kind.ty(other_bits, vector_scalar);
+ let convertibility =
+ other_type_kind.convertibility_to(other_bits, type_kind, bits);
+ if convertibility == Convertibility::Lossless {
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { From<Self::#other_ty> },
+ );
+ }
+ if convertibility.possible() {
+ self.add_trait(
+ other_type_kind,
+ other_bits,
+ vector_scalar,
+ quote! { ConvertTo<Self::#ty> },
+ );
+ }
+ }
+ }
+ self.add_trait(
+ type_kind,
+ bits,
+ vector_scalar,
+ quote! { Make<Context = Self, Prim = #prim_ty> },
+ );
+ }
+ }
+ }
+ }
+}
+
+impl Input {
+ fn to_tokens(&self) -> syn::Result<TokenStream> {
+ let mut types = Vec::new();
+ let mut trait_sets = TraitSets::default();
+ trait_sets.fill();
+ for &bits in TypeBits::VALUES {
+ for &type_kind in TypeKind::VALUES {
+ for &vector_scalar in VectorScalar::VALUES {
+ if !type_kind.is_valid(bits, vector_scalar) {
+ continue;
+ }
+ let ty = type_kind.ty(bits, vector_scalar);
+ let traits = trait_sets.get(type_kind, bits, vector_scalar);
+ types.push(quote! {
+ type #ty: #(#traits)+*;
+ });
+ }
+ }
+ }
+ Ok(quote! {#(#types)*})
+ }
+}
+
+#[proc_macro]
+pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let input = parse_macro_input!(input as Input);
+ match input.to_tokens() {
+ Ok(retval) => retval,
+ Err(err) => err.to_compile_error(),
+ }
+ .into()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test() -> syn::Result<()> {
+ Input {}.to_tokens()?;
+ Ok(())
+ }
+}