remove note to convert to sin_cos_tau instead sin_cos_pi
[vector-math.git] / vector-math-build-helpers / src / lib.rs
1 use proc_macro2::{Ident, Span, TokenStream};
2 use quote::{quote, ToTokens};
3 use std::{
4 cmp::Ordering,
5 collections::{BTreeSet, HashMap},
6 hash::Hash,
7 };
8
9 macro_rules! make_enum {
10 (
11 $vis:vis enum $ty:ident {
12 $(
13 $field:ident $(= $value:expr)?,
14 )*
15 }
16 ) => {
17 #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
18 #[repr(u8)]
19 $vis enum $ty {
20 $(
21 $field $(= $value)?,
22 )*
23 }
24
25 impl $ty {
26 #[allow(dead_code)]
27 $vis const VALUES: &'static [Self] = &[
28 $(
29 Self::$field,
30 )*
31 ];
32 }
33 };
34 }
35
36 make_enum! {
37 enum TypeKind {
38 Bool,
39 UInt,
40 SInt,
41 Float,
42 }
43 }
44
45 make_enum! {
46 enum VectorScalar {
47 Scalar,
48 Vector,
49 }
50 }
51
52 make_enum! {
53 enum TypeBits {
54 Bits8 = 8,
55 Bits16 = 16,
56 Bits32 = 32,
57 Bits64 = 64,
58 }
59 }
60
61 impl TypeBits {
62 const fn bits(self) -> u32 {
63 self as u8 as u32
64 }
65 }
66
67 make_enum! {
68 enum Convertibility {
69 Impossible,
70 Lossy,
71 Lossless,
72 }
73 }
74
75 impl Convertibility {
76 const fn make_possible(lossless: bool) -> Self {
77 if lossless {
78 Self::Lossless
79 } else {
80 Self::Lossy
81 }
82 }
83 const fn make_non_lossy(possible: bool) -> Self {
84 if possible {
85 Self::Lossless
86 } else {
87 Self::Impossible
88 }
89 }
90 const fn possible(self) -> bool {
91 match self {
92 Convertibility::Impossible => false,
93 Convertibility::Lossy | Convertibility::Lossless => true,
94 }
95 }
96 }
97
98 impl TypeKind {
99 fn is_valid(self, bits: TypeBits, vector_scalar: VectorScalar) -> bool {
100 match self {
101 TypeKind::Float => bits >= TypeBits::Bits16,
102 TypeKind::Bool => bits == TypeBits::Bits8 || vector_scalar == VectorScalar::Vector,
103 TypeKind::UInt | TypeKind::SInt => true,
104 }
105 }
106 fn prim_ty(self, bits: TypeBits) -> Ident {
107 Ident::new(
108 &match self {
109 TypeKind::Bool => "bool".into(),
110 TypeKind::UInt => format!("u{}", bits.bits()),
111 TypeKind::SInt => format!("i{}", bits.bits()),
112 TypeKind::Float if bits == TypeBits::Bits16 => "F16".into(),
113 TypeKind::Float => format!("f{}", bits.bits()),
114 },
115 Span::call_site(),
116 )
117 }
118 fn ty(self, bits: TypeBits, vector_scalar: VectorScalar) -> Ident {
119 let vec_prefix = match vector_scalar {
120 VectorScalar::Scalar => "",
121 VectorScalar::Vector => "Vec",
122 };
123 Ident::new(
124 &match self {
125 TypeKind::Bool => match vector_scalar {
126 VectorScalar::Scalar => "Bool".into(),
127 VectorScalar::Vector => format!("VecBool{}", bits.bits()),
128 },
129 TypeKind::UInt => format!("{}U{}", vec_prefix, bits.bits()),
130 TypeKind::SInt => format!("{}I{}", vec_prefix, bits.bits()),
131 TypeKind::Float => format!("{}F{}", vec_prefix, bits.bits()),
132 },
133 Span::call_site(),
134 )
135 }
136 fn convertibility_to(
137 self,
138 src_bits: TypeBits,
139 dest_type_kind: TypeKind,
140 dest_bits: TypeBits,
141 ) -> Convertibility {
142 Convertibility::make_possible(match (self, dest_type_kind) {
143 (TypeKind::Bool, _) | (_, TypeKind::Bool) => {
144 return Convertibility::make_non_lossy(self == dest_type_kind);
145 }
146 (TypeKind::UInt, TypeKind::UInt) => dest_bits >= src_bits,
147 (TypeKind::UInt, TypeKind::SInt) => dest_bits > src_bits,
148 (TypeKind::UInt, TypeKind::Float) => dest_bits > src_bits,
149 (TypeKind::SInt, TypeKind::UInt) => false,
150 (TypeKind::SInt, TypeKind::SInt) => dest_bits >= src_bits,
151 (TypeKind::SInt, TypeKind::Float) => dest_bits > src_bits,
152 (TypeKind::Float, TypeKind::UInt) => false,
153 (TypeKind::Float, TypeKind::SInt) => false,
154 (TypeKind::Float, TypeKind::Float) => dest_bits >= src_bits,
155 })
156 }
157 }
158
159 #[derive(Default, Debug)]
160 struct TokenStreamSetElement {
161 token_stream: TokenStream,
162 text: String,
163 }
164
165 impl Ord for TokenStreamSetElement {
166 fn cmp(&self, other: &Self) -> Ordering {
167 self.text.cmp(&other.text)
168 }
169 }
170
171 impl PartialOrd for TokenStreamSetElement {
172 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
173 Some(self.cmp(other))
174 }
175 }
176
177 impl PartialEq for TokenStreamSetElement {
178 fn eq(&self, other: &Self) -> bool {
179 self.text == other.text
180 }
181 }
182
183 impl Eq for TokenStreamSetElement {}
184
185 impl From<TokenStream> for TokenStreamSetElement {
186 fn from(token_stream: TokenStream) -> Self {
187 let text = token_stream.to_string();
188 Self { token_stream, text }
189 }
190 }
191
192 impl ToTokens for TokenStreamSetElement {
193 fn to_tokens(&self, tokens: &mut TokenStream) {
194 self.token_stream.to_tokens(tokens)
195 }
196
197 fn to_token_stream(&self) -> TokenStream {
198 self.token_stream.to_token_stream()
199 }
200
201 fn into_token_stream(self) -> TokenStream {
202 self.token_stream
203 }
204 }
205
206 type TokenStreamSet = BTreeSet<TokenStreamSetElement>;
207
208 #[derive(Debug, Default)]
209 struct TraitSets {
210 trait_sets_map: HashMap<(TypeKind, TypeBits, VectorScalar), TokenStreamSet>,
211 }
212
213 impl TraitSets {
214 fn get(
215 &mut self,
216 type_kind: TypeKind,
217 mut bits: TypeBits,
218 vector_scalar: VectorScalar,
219 ) -> &mut TokenStreamSet {
220 if type_kind == TypeKind::Bool && vector_scalar == VectorScalar::Scalar {
221 bits = TypeBits::Bits8;
222 }
223 self.trait_sets_map
224 .entry((type_kind, bits, vector_scalar))
225 .or_default()
226 }
227 fn add_trait(
228 &mut self,
229 type_kind: TypeKind,
230 bits: TypeBits,
231 vector_scalar: VectorScalar,
232 v: impl Into<TokenStreamSetElement>,
233 ) {
234 self.get(type_kind, bits, vector_scalar).insert(v.into());
235 }
236 fn fill(&mut self) {
237 for &bits in TypeBits::VALUES {
238 for &type_kind in TypeKind::VALUES {
239 for &vector_scalar in VectorScalar::VALUES {
240 if !type_kind.is_valid(bits, vector_scalar) {
241 continue;
242 }
243 let prim_ty = type_kind.prim_ty(bits);
244 let ty = type_kind.ty(bits, vector_scalar);
245 if vector_scalar == VectorScalar::Vector {
246 let scalar_ty = type_kind.ty(bits, VectorScalar::Scalar);
247 self.add_trait(
248 type_kind,
249 bits,
250 vector_scalar,
251 quote! { From<Self::#scalar_ty> },
252 );
253 }
254 let bool_ty = TypeKind::Bool.ty(bits, vector_scalar);
255 let uint_ty = TypeKind::UInt.ty(bits, vector_scalar);
256 let sint_ty = TypeKind::SInt.ty(bits, vector_scalar);
257 let type_trait = match type_kind {
258 TypeKind::Bool => quote! { Bool },
259 TypeKind::UInt => {
260 quote! { UInt<PrimUInt = #prim_ty, SignedType = Self::#sint_ty> }
261 }
262 TypeKind::SInt => {
263 quote! { SInt<PrimSInt = #prim_ty, UnsignedType = Self::#uint_ty> }
264 }
265 TypeKind::Float => quote! { Float<
266 BitsType = Self::#uint_ty,
267 SignedBitsType = Self::#sint_ty,
268 PrimFloat = #prim_ty,
269 > },
270 };
271 self.add_trait(type_kind, bits, vector_scalar, type_trait);
272 self.add_trait(
273 type_kind,
274 bits,
275 vector_scalar,
276 quote! { Compare<Bool = Self::#bool_ty> },
277 );
278 self.add_trait(
279 TypeKind::Bool,
280 bits,
281 vector_scalar,
282 quote! { Select<Self::#ty> },
283 );
284 self.add_trait(
285 TypeKind::Bool,
286 TypeBits::Bits8,
287 VectorScalar::Scalar,
288 quote! { Select<Self::#ty> },
289 );
290 for &other_bits in TypeBits::VALUES {
291 for &other_type_kind in TypeKind::VALUES {
292 if !other_type_kind.is_valid(other_bits, vector_scalar) {
293 continue;
294 }
295 if other_bits == bits && other_type_kind == type_kind {
296 continue;
297 }
298 let other_ty = other_type_kind.ty(other_bits, vector_scalar);
299 let convertibility =
300 other_type_kind.convertibility_to(other_bits, type_kind, bits);
301 if convertibility == Convertibility::Lossless {
302 self.add_trait(
303 type_kind,
304 bits,
305 vector_scalar,
306 quote! { From<Self::#other_ty> },
307 );
308 }
309 if convertibility.possible() {
310 self.add_trait(
311 type_kind,
312 bits,
313 vector_scalar,
314 quote! { ConvertFrom<Self::#other_ty> },
315 );
316 }
317 }
318 }
319 self.add_trait(
320 type_kind,
321 bits,
322 vector_scalar,
323 quote! { Make<Context = Self, Prim = #prim_ty> },
324 );
325 }
326 }
327 }
328 }
329 }
330
331 pub fn make_context_types() -> TokenStream {
332 let mut types = Vec::new();
333 let mut trait_sets = TraitSets::default();
334 trait_sets.fill();
335 for &bits in TypeBits::VALUES {
336 for &type_kind in TypeKind::VALUES {
337 for &vector_scalar in VectorScalar::VALUES {
338 if !type_kind.is_valid(bits, vector_scalar) {
339 continue;
340 }
341 let ty = type_kind.ty(bits, vector_scalar);
342 let traits = trait_sets.get(type_kind, bits, vector_scalar);
343 types.push(quote! {
344 type #ty: #(#traits)+*;
345 });
346 }
347 }
348 }
349 quote! {
350 /// reference used to build IR for Kazan; an empty type for `core::simd`
351 pub trait Context: Copy {
352 #(#types)*
353 fn make<T: Make<Context = Self>>(self, v: T::Prim) -> T {
354 T::make(self, v)
355 }
356 }
357 }
358 }