refactor to easily allow algorithms generic over f16/32/64
[vector-math.git] / vector-math-proc-macro / src / lib.rs
1 use std::{
2 cmp::Ordering,
3 collections::{BTreeSet, HashMap},
4 hash::Hash,
5 };
6
7 use proc_macro2::{Ident, Span, TokenStream};
8 use quote::{quote, ToTokens};
9 use syn::{
10 parse::{Parse, ParseStream},
11 parse_macro_input,
12 };
13
14 struct Input {}
15
16 impl Parse for Input {
17 fn parse(_input: ParseStream) -> syn::Result<Self> {
18 Ok(Input {})
19 }
20 }
21
22 macro_rules! make_enum {
23 (
24 $vis:vis enum $ty:ident {
25 $(
26 $field:ident $(= $value:expr)?,
27 )*
28 }
29 ) => {
30 #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
31 #[repr(u8)]
32 $vis enum $ty {
33 $(
34 $field $(= $value)?,
35 )*
36 }
37
38 impl $ty {
39 #[allow(dead_code)]
40 $vis const VALUES: &'static [Self] = &[
41 $(
42 Self::$field,
43 )*
44 ];
45 }
46 };
47 }
48
49 make_enum! {
50 enum TypeKind {
51 Bool,
52 UInt,
53 SInt,
54 Float,
55 }
56 }
57
58 make_enum! {
59 enum VectorScalar {
60 Scalar,
61 Vector,
62 }
63 }
64
65 make_enum! {
66 enum TypeBits {
67 Bits8 = 8,
68 Bits16 = 16,
69 Bits32 = 32,
70 Bits64 = 64,
71 }
72 }
73
74 impl TypeBits {
75 const fn bits(self) -> u32 {
76 self as u8 as u32
77 }
78 }
79
80 make_enum! {
81 enum Convertibility {
82 Impossible,
83 Lossy,
84 Lossless,
85 }
86 }
87
88 impl Convertibility {
89 const fn make_possible(lossless: bool) -> Self {
90 if lossless {
91 Self::Lossless
92 } else {
93 Self::Lossy
94 }
95 }
96 const fn make_non_lossy(possible: bool) -> Self {
97 if possible {
98 Self::Lossless
99 } else {
100 Self::Impossible
101 }
102 }
103 const fn possible(self) -> bool {
104 match self {
105 Convertibility::Impossible => false,
106 Convertibility::Lossy | Convertibility::Lossless => true,
107 }
108 }
109 }
110
111 impl TypeKind {
112 fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool {
113 match self {
114 TypeKind::Float => bits >= TypeBits::Bits16,
115 TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector,
116 TypeKind::UInt | TypeKind::SInt => true,
117 }
118 }
119 fn prim_ty(self, bits: TypeBits) -> Ident {
120 Ident::new(
121 &match self {
122 TypeKind::Bool => "bool".into(),
123 TypeKind::UInt => format!("u{}", bits.bits()),
124 TypeKind::SInt => format!("i{}", bits.bits()),
125 TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(),
126 TypeKind::Float => format!("f{}", bits.bits()),
127 },
128 Span::call_site(),
129 )
130 }
131 fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident {
132 let vec_prefix = match vector_scalar {
133 VectorScalar::Scalar => "",
134 VectorScalar::Vector => "Vec",
135 };
136 Ident::new(
137 &match self {
138 TypeKind::Bool => match vector_scalar {
139 VectorScalar::Scalar => "Bool".into(),
140 VectorScalar::Vector => format!("VecBool{}", bits.bits()),
141 },
142 TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()),
143 TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()),
144 TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()),
145 },
146 Span::call_site(),
147 )
148 }
149 fn convertibility_to(
150 self,
151 src_bits: TypeBits,
152 dest_type_kind: TypeKind,
153 dest_bits: TypeBits,
154 ) -> Convertibility {
155 Convertibility::make_possible(match (self, dest_type_kind) {
156 (TypeKind::Bool, _) | (_, TypeKind::Bool) => {
157 return Convertibility::make_non_lossy(self == dest_type_kind);
158 }
159 (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits,
160 (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits,
161 (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits,
162 (TypeKind::SInt, TypeKind::UInt) => false,
163 (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits,
164 (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits,
165 (TypeKind::Float, TypeKind::UInt) => false,
166 (TypeKind::Float, TypeKind::SInt) => false,
167 (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits,
168 })
169 }
170 }
171
172 #[derive(Default, Debug)]
173 struct TokenStreamSetElement {
174 token_stream: TokenStream,
175 text: String,
176 }
177
178 impl Ord for TokenStreamSetElement {
179 fn cmp(&self, other: &Self) -> Ordering {
180 self.text.cmp(&other.text)
181 }
182 }
183
184 impl PartialOrd for TokenStreamSetElement {
185 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
186 Some(self.cmp(other))
187 }
188 }
189
190 impl PartialEq for TokenStreamSetElement {
191 fn eq(&self, other: &Self) -> bool {
192 self.text == other.text
193 }
194 }
195
196 impl Eq for TokenStreamSetElement {}
197
198 impl From<TokenStream> for TokenStreamSetElement {
199 fn from(token_stream: TokenStream) -> Self {
200 let text = token_stream.to_string();
201 Self { token_stream, text }
202 }
203 }
204
205 impl ToTokens for TokenStreamSetElement {
206 fn to_tokens(&self, tokens: &mut TokenStream) {
207 self.token_stream.to_tokens(tokens)
208 }
209
210 fn to_token_stream(&self) -> TokenStream {
211 self.token_stream.to_token_stream()
212 }
213
214 fn into_token_stream(self) -> TokenStream {
215 self.token_stream
216 }
217 }
218
219 type TokenStreamSet = BTreeSet<TokenStreamSetElement>;
220
221 #[derive(Debug, Default)]
222 struct TraitSets {
223 trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>,
224 }
225
226 impl TraitSets {
227 fn get(
228 &mut self,
229 type_kind: TypeKind,
230 mut bits: TypeBits,
231 vector_scalar: VectorScalar,
232 ) -> &mut TokenStreamSet {
233 if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar {
234 bits = TypeBits::Bits8;
235 }
236 self.trait_sets_map
237 .entry((type_kind, bits, vector_scalar))
238 .or_default()
239 }
240 fn add_trait(
241 &mut self,
242 type_kind: TypeKind,
243 bits: TypeBits,
244 vector_scalar: VectorScalar,
245 v: impl Into<TokenStreamSetElement>,
246 ) {
247 self.get(type_kind, bits, vector_scalar).insert(v.into());
248 }
249 fn fill(&mut self) {
250 for &bits in TypeBits::VALUES {
251 for &type_kind in TypeKind::VALUES {
252 for &vector_scalar in VectorScalar::VALUES {
253 if !type_kind.is_valid(bits, vector_scalar) {
254 continue;
255 }
256 let prim_ty = type_kind.prim_ty(bits);
257 let ty = type_kind.ty(bits, vector_scalar);
258 if vector_scalar == VectorScalar::Vector {
259 let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar);
260 self.add_trait(
261 type_kind,
262 bits,
263 vector_scalar,
264 quote! { From<Self::#scalar_ty> },
265 );
266 }
267 let bool_ty = TypeKind::Bool.ty(bits, vector_scalar);
268 let uint_ty = TypeKind::UInt.ty(bits, vector_scalar);
269 let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
270 let type_trait = match type_kind {
271 TypeKind::Bool => quote! { Bool },
272 TypeKind::UInt => {
273 quote! { UInt<PrimUInt = #prim_ty, SignedType = Self::#sint_ty> }
274 }
275 TypeKind::SInt => {
276 quote! { SInt<PrimSInt = #prim_ty, UnsignedType = Self::#uint_ty> }
277 }
278 TypeKind::Float => quote! { Float<
279 BitsType = Self::#uint_ty,
280 SignedBitsType = Self::#sint_ty,
281 PrimFloat = #prim_ty,
282 > },
283 };
284 self.add_trait(type_kind, bits, vector_scalar, type_trait);
285 self.add_trait(
286 type_kind,
287 bits,
288 vector_scalar,
289 quote! { Compare<Bool = Self::#bool_ty> },
290 );
291 self.add_trait(
292 TypeKind::Bool,
293 bits,
294 vector_scalar,
295 quote! { Select<Self::#ty> },
296 );
297 self.add_trait(
298 TypeKind::Bool,
299 TypeBits::Bits8,
300 VectorScalar::Scalar,
301 quote! { Select<Self::#ty> },
302 );
303 for &other_bits in TypeBits::VALUES {
304 for &other_type_kind in TypeKind::VALUES {
305 if !other_type_kind.is_valid(other_bits, vector_scalar) {
306 continue;
307 }
308 if other_bits == bits && other_type_kind == type_kind {
309 continue;
310 }
311 let other_ty = other_type_kind.ty(other_bits, vector_scalar);
312 let convertibility =
313 other_type_kind.convertibility_to(other_bits, type_kind, bits);
314 if convertibility == Convertibility::Lossless {
315 self.add_trait(
316 type_kind,
317 bits,
318 vector_scalar,
319 quote! { From<Self::#other_ty> },
320 );
321 }
322 if convertibility.possible() {
323 self.add_trait(
324 type_kind,
325 bits,
326 vector_scalar,
327 quote! { ConvertFrom<Self::#other_ty> },
328 );
329 }
330 }
331 }
332 self.add_trait(
333 type_kind,
334 bits,
335 vector_scalar,
336 quote! { Make<Context = Self, Prim = #prim_ty> },
337 );
338 }
339 }
340 }
341 }
342 }
343
344 impl Input {
345 fn to_tokens(&self) -> syn::Result<TokenStream> {
346 let mut types = Vec::new();
347 let mut trait_sets = TraitSets::default();
348 trait_sets.fill();
349 for &bits in TypeBits::VALUES {
350 for &type_kind in TypeKind::VALUES {
351 for &vector_scalar in VectorScalar::VALUES {
352 if !type_kind.is_valid(bits, vector_scalar) {
353 continue;
354 }
355 let ty = type_kind.ty(bits, vector_scalar);
356 let traits = trait_sets.get(type_kind, bits, vector_scalar);
357 types.push(quote! {
358 type #ty: #(#traits)+*;
359 });
360 }
361 }
362 }
363 Ok(quote! {#(#types)*})
364 }
365 }
366
367 #[proc_macro]
368 pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
369 let input = parse_macro_input!(input as Input);
370 match input.to_tokens() {
371 Ok(retval) => retval,
372 Err(err) => err.to_compile_error(),
373 }
374 .into()
375 }
376
377 #[cfg(test)]
378 mod tests {
379 use super::*;
380
381 #[test]
382 fn test() -> syn::Result<()> {
383 Input {}.to_tokens()?;
384 Ok(())
385 }
386 }