add sin_pi_f16, cos_pi_f16, and sin_cos_pi_f16
[vector-math.git] / vector-math-proc-macro / src / lib.rs
1 use std::{
2 cmp::Ordering,
3 collections::{BTreeSet, HashMap},
4 hash::Hash,
5 };
6
7 use proc_macro2::{Ident, Span, TokenStream};
8 use quote::{quote, ToTokens};
9 use syn::{
10 parse::{Parse, ParseStream},
11 parse_macro_input,
12 };
13
14 struct Input {}
15
16 impl Parse for Input {
17 fn parse(_input: ParseStream) -> syn::Result<Self> {
18 Ok(Input {})
19 }
20 }
21
22 macro_rules! make_enum {
23 (
24 $vis:vis enum $ty:ident {
25 $(
26 $field:ident $(= $value:expr)?,
27 )*
28 }
29 ) => {
30 #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
31 #[repr(u8)]
32 $vis enum $ty {
33 $(
34 $field $(= $value)?,
35 )*
36 }
37
38 impl $ty {
39 #[allow(dead_code)]
40 $vis const VALUES: &'static [Self] = &[
41 $(
42 Self::$field,
43 )*
44 ];
45 }
46 };
47 }
48
49 make_enum! {
50 enum TypeKind {
51 Bool,
52 UInt,
53 SInt,
54 Float,
55 }
56 }
57
58 make_enum! {
59 enum VectorScalar {
60 Scalar,
61 Vector,
62 }
63 }
64
65 make_enum! {
66 enum TypeBits {
67 Bits8 = 8,
68 Bits16 = 16,
69 Bits32 = 32,
70 Bits64 = 64,
71 }
72 }
73
74 impl TypeBits {
75 const fn bits(self) -> u32 {
76 self as u8 as u32
77 }
78 }
79
80 make_enum! {
81 enum Convertibility {
82 Impossible,
83 Lossy,
84 Lossless,
85 }
86 }
87
88 impl Convertibility {
89 const fn make_possible(lossless: bool) -> Self {
90 if lossless {
91 Self::Lossless
92 } else {
93 Self::Lossy
94 }
95 }
96 const fn make_non_lossy(possible: bool) -> Self {
97 if possible {
98 Self::Lossless
99 } else {
100 Self::Impossible
101 }
102 }
103 const fn possible(self) -> bool {
104 match self {
105 Convertibility::Impossible => false,
106 Convertibility::Lossy | Convertibility::Lossless => true,
107 }
108 }
109 }
110
111 impl TypeKind {
112 fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool {
113 match self {
114 TypeKind::Float => bits >= TypeBits::Bits16,
115 TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector,
116 TypeKind::UInt | TypeKind::SInt => true,
117 }
118 }
119 fn prim_ty(self, bits: TypeBits) -> Ident {
120 Ident::new(
121 &match self {
122 TypeKind::Bool => "bool".into(),
123 TypeKind::UInt => format!("u{}", bits.bits()),
124 TypeKind::SInt => format!("i{}", bits.bits()),
125 TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(),
126 TypeKind::Float => format!("f{}", bits.bits()),
127 },
128 Span::call_site(),
129 )
130 }
131 fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident {
132 let vec_prefix = match vector_scalar {
133 VectorScalar::Scalar => "",
134 VectorScalar::Vector => "Vec",
135 };
136 Ident::new(
137 &match self {
138 TypeKind::Bool => match vector_scalar {
139 VectorScalar::Scalar => "Bool".into(),
140 VectorScalar::Vector => format!("VecBool{}", bits.bits()),
141 },
142 TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()),
143 TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()),
144 TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()),
145 },
146 Span::call_site(),
147 )
148 }
149 fn convertibility_to(
150 self,
151 src_bits: TypeBits,
152 dest_type_kind: TypeKind,
153 dest_bits: TypeBits,
154 ) -> Convertibility {
155 Convertibility::make_possible(match (self, dest_type_kind) {
156 (TypeKind::Bool, _) | (_, TypeKind::Bool) => {
157 return Convertibility::make_non_lossy(self == dest_type_kind);
158 }
159 (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits,
160 (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits,
161 (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits,
162 (TypeKind::SInt, TypeKind::UInt) => false,
163 (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits,
164 (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits,
165 (TypeKind::Float, TypeKind::UInt) => false,
166 (TypeKind::Float, TypeKind::SInt) => false,
167 (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits,
168 })
169 }
170 }
171
172 #[derive(Default, Debug)]
173 struct TokenStreamSetElement {
174 token_stream: TokenStream,
175 text: String,
176 }
177
178 impl Ord for TokenStreamSetElement {
179 fn cmp(&self, other: &Self) -> Ordering {
180 self.text.cmp(&other.text)
181 }
182 }
183
184 impl PartialOrd for TokenStreamSetElement {
185 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
186 Some(self.cmp(other))
187 }
188 }
189
190 impl PartialEq for TokenStreamSetElement {
191 fn eq(&self, other: &Self) -> bool {
192 self.text == other.text
193 }
194 }
195
196 impl Eq for TokenStreamSetElement {}
197
198 impl From<TokenStream> for TokenStreamSetElement {
199 fn from(token_stream: TokenStream) -> Self {
200 let text = token_stream.to_string();
201 Self { token_stream, text }
202 }
203 }
204
205 impl ToTokens for TokenStreamSetElement {
206 fn to_tokens(&self, tokens: &mut TokenStream) {
207 self.token_stream.to_tokens(tokens)
208 }
209
210 fn to_token_stream(&self) -> TokenStream {
211 self.token_stream.to_token_stream()
212 }
213
214 fn into_token_stream(self) -> TokenStream {
215 self.token_stream
216 }
217 }
218
219 type TokenStreamSet = BTreeSet<TokenStreamSetElement>;
220
221 #[derive(Debug, Default)]
222 struct TraitSets {
223 trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>,
224 }
225
226 impl TraitSets {
227 fn get(
228 &mut self,
229 type_kind: TypeKind,
230 mut bits: TypeBits,
231 vector_scalar: VectorScalar,
232 ) -> &mut TokenStreamSet {
233 if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar {
234 bits = TypeBits::Bits8;
235 }
236 self.trait_sets_map
237 .entry((type_kind, bits, vector_scalar))
238 .or_default()
239 }
240 fn add_trait(
241 &mut self,
242 type_kind: TypeKind,
243 bits: TypeBits,
244 vector_scalar: VectorScalar,
245 v: impl Into<TokenStreamSetElement>,
246 ) {
247 self.get(type_kind, bits, vector_scalar).insert(v.into());
248 }
249 fn fill(&mut self) {
250 for &bits in TypeBits::VALUES {
251 for &type_kind in TypeKind::VALUES {
252 for &vector_scalar in VectorScalar::VALUES {
253 if !type_kind.is_valid(bits, vector_scalar) {
254 continue;
255 }
256 let prim_ty = type_kind.prim_ty(bits);
257 let ty = type_kind.ty(bits, vector_scalar);
258 if vector_scalar == VectorScalar::Vector {
259 let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar);
260 self.add_trait(
261 type_kind,
262 bits,
263 vector_scalar,
264 quote! { From<Self::#scalar_ty> },
265 );
266 }
267 let bool_ty = TypeKind::Bool.ty(bits, vector_scalar);
268 let uint_ty = TypeKind::UInt.ty(bits, vector_scalar);
269 let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
270 let type_trait = match type_kind {
271 TypeKind::Bool => quote! { Bool },
272 TypeKind::UInt => quote! { UInt },
273 TypeKind::SInt => quote! { SInt },
274 TypeKind::Float => quote! { Float<
275 BitsType = Self::#uint_ty,
276 SignedBitsType = Self::#sint_ty,
277 FloatEncoding = #prim_ty,
278 > },
279 };
280 self.add_trait(type_kind, bits, vector_scalar, type_trait);
281 self.add_trait(
282 type_kind,
283 bits,
284 vector_scalar,
285 quote! { Compare<Bool = Self::#bool_ty> },
286 );
287 self.add_trait(
288 TypeKind::Bool,
289 bits,
290 vector_scalar,
291 quote! { Select<Self::#ty> },
292 );
293 self.add_trait(
294 TypeKind::Bool,
295 TypeBits::Bits8,
296 VectorScalar::Scalar,
297 quote! { Select<Self::#ty> },
298 );
299 for &other_bits in TypeBits::VALUES {
300 for &other_type_kind in TypeKind::VALUES {
301 if !other_type_kind.is_valid(other_bits, vector_scalar) {
302 continue;
303 }
304 if other_bits == bits && other_type_kind == type_kind {
305 continue;
306 }
307 let other_ty = other_type_kind.ty(other_bits, vector_scalar);
308 let convertibility =
309 other_type_kind.convertibility_to(other_bits, type_kind, bits);
310 if convertibility == Convertibility::Lossless {
311 self.add_trait(
312 type_kind,
313 bits,
314 vector_scalar,
315 quote! { From<Self::#other_ty> },
316 );
317 }
318 if convertibility.possible() {
319 self.add_trait(
320 type_kind,
321 bits,
322 vector_scalar,
323 quote! { ConvertFrom<Self::#other_ty> },
324 );
325 }
326 }
327 }
328 self.add_trait(
329 type_kind,
330 bits,
331 vector_scalar,
332 quote! { Make<Context = Self, Prim = #prim_ty> },
333 );
334 }
335 }
336 }
337 }
338 }
339
340 impl Input {
341 fn to_tokens(&self) -> syn::Result<TokenStream> {
342 let mut types = Vec::new();
343 let mut trait_sets = TraitSets::default();
344 trait_sets.fill();
345 for &bits in TypeBits::VALUES {
346 for &type_kind in TypeKind::VALUES {
347 for &vector_scalar in VectorScalar::VALUES {
348 if !type_kind.is_valid(bits, vector_scalar) {
349 continue;
350 }
351 let ty = type_kind.ty(bits, vector_scalar);
352 let traits = trait_sets.get(type_kind, bits, vector_scalar);
353 types.push(quote! {
354 type #ty: #(#traits)+*;
355 });
356 }
357 }
358 }
359 Ok(quote! {#(#types)*})
360 }
361 }
362
363 #[proc_macro]
364 pub fn make_context_types(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
365 let input = parse_macro_input!(input as Input);
366 match input.to_tokens() {
367 Ok(retval) => retval,
368 Err(err) => err.to_compile_error(),
369 }
370 .into()
371 }
372
373 #[cfg(test)]
374 mod tests {
375 use super::*;
376
377 #[test]
378 fn test() -> syn::Result<()> {
379 Input {}.to_tokens()?;
380 Ok(())
381 }
382 }