remove note to convert to sin_cos_tau instead sin_cos_pi
[vector-math.git] / src / ir.rs
1 use crate::{
2 f16::F16,
3 traits::{
4 Bool, Compare, Context, ConvertFrom, ConvertTo, Float, Int, Make, SInt, Select, UInt,
5 },
6 };
7 use std::{
8 borrow::Borrow,
9 cell::{Cell, RefCell},
10 collections::HashMap,
11 fmt::{self, Write as _},
12 format,
13 ops::{
14 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
15 DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
16 SubAssign,
17 },
18 string::String,
19 vec::Vec,
20 };
21 use typed_arena::Arena;
22
23 macro_rules! make_enum {
24 (
25 $vis:vis enum $enum:ident {
26 $(
27 $(#[$meta:meta])*
28 $name:ident,
29 )*
30 }
31 ) => {
32 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
33 $vis enum $enum {
34 $(
35 $(#[$meta])*
36 $name,
37 )*
38 }
39
40 impl $enum {
41 $vis const fn as_str(self) -> &'static str {
42 match self {
43 $(
44 Self::$name => stringify!($name),
45 )*
46 }
47 }
48 }
49
50 impl fmt::Display for $enum {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.write_str(self.as_str())
53 }
54 }
55 };
56 }
57
58 make_enum! {
59 pub enum ScalarType {
60 Bool,
61 U8,
62 I8,
63 U16,
64 I16,
65 F16,
66 U32,
67 I32,
68 F32,
69 U64,
70 I64,
71 F64,
72 }
73 }
74
75 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
76 pub struct VectorType {
77 pub element: ScalarType,
78 }
79
80 impl fmt::Display for VectorType {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 write!(f, "vec<{}>", self.element)
83 }
84 }
85
86 impl From<ScalarType> for Type {
87 fn from(v: ScalarType) -> Self {
88 Type::Scalar(v)
89 }
90 }
91
92 impl From<VectorType> for Type {
93 fn from(v: VectorType) -> Self {
94 Type::Vector(v)
95 }
96 }
97
98 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
99 pub enum Type {
100 Scalar(ScalarType),
101 Vector(VectorType),
102 }
103
104 impl fmt::Display for Type {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Type::Scalar(v) => v.fmt(f),
108 Type::Vector(v) => v.fmt(f),
109 }
110 }
111 }
112
113 #[derive(Clone, Copy, PartialEq, Eq, Hash)]
114 pub enum ScalarConstant {
115 Bool(bool),
116 U8(u8),
117 U16(u16),
118 U32(u32),
119 U64(u64),
120 I8(i8),
121 I16(i16),
122 I32(i32),
123 I64(i64),
124 F16 { bits: u16 },
125 F32 { bits: u32 },
126 F64 { bits: u64 },
127 }
128
129 macro_rules! make_scalar_constant_get {
130 ($ty:ident, $enumerant:ident) => {
131 pub fn $ty(self) -> Option<$ty> {
132 if let Self::$enumerant(v) = self {
133 Some(v)
134 } else {
135 None
136 }
137 }
138 };
139 }
140
141 macro_rules! make_scalar_constant_from {
142 ($ty:ident, $enumerant:ident) => {
143 impl From<$ty> for ScalarConstant {
144 fn from(v: $ty) -> Self {
145 Self::$enumerant(v)
146 }
147 }
148 impl From<$ty> for Constant {
149 fn from(v: $ty) -> Self {
150 Self::Scalar(v.into())
151 }
152 }
153 impl From<$ty> for Value<'_> {
154 fn from(v: $ty) -> Self {
155 Self::Constant(v.into())
156 }
157 }
158 };
159 }
160
161 make_scalar_constant_from!(bool, Bool);
162 make_scalar_constant_from!(u8, U8);
163 make_scalar_constant_from!(u16, U16);
164 make_scalar_constant_from!(u32, U32);
165 make_scalar_constant_from!(u64, U64);
166 make_scalar_constant_from!(i8, I8);
167 make_scalar_constant_from!(i16, I16);
168 make_scalar_constant_from!(i32, I32);
169 make_scalar_constant_from!(i64, I64);
170
171 impl ScalarConstant {
172 pub const fn ty(self) -> ScalarType {
173 match self {
174 ScalarConstant::Bool(_) => ScalarType::Bool,
175 ScalarConstant::U8(_) => ScalarType::U8,
176 ScalarConstant::U16(_) => ScalarType::U16,
177 ScalarConstant::U32(_) => ScalarType::U32,
178 ScalarConstant::U64(_) => ScalarType::U64,
179 ScalarConstant::I8(_) => ScalarType::I8,
180 ScalarConstant::I16(_) => ScalarType::I16,
181 ScalarConstant::I32(_) => ScalarType::I32,
182 ScalarConstant::I64(_) => ScalarType::I64,
183 ScalarConstant::F16 { .. } => ScalarType::F16,
184 ScalarConstant::F32 { .. } => ScalarType::F32,
185 ScalarConstant::F64 { .. } => ScalarType::F64,
186 }
187 }
188 pub const fn from_f16_bits(bits: u16) -> Self {
189 Self::F16 { bits }
190 }
191 pub const fn from_f32_bits(bits: u32) -> Self {
192 Self::F32 { bits }
193 }
194 pub const fn from_f64_bits(bits: u64) -> Self {
195 Self::F64 { bits }
196 }
197 pub const fn f16_bits(self) -> Option<u16> {
198 if let Self::F16 { bits } = self {
199 Some(bits)
200 } else {
201 None
202 }
203 }
204 pub const fn f32_bits(self) -> Option<u32> {
205 if let Self::F32 { bits } = self {
206 Some(bits)
207 } else {
208 None
209 }
210 }
211 pub const fn f64_bits(self) -> Option<u64> {
212 if let Self::F64 { bits } = self {
213 Some(bits)
214 } else {
215 None
216 }
217 }
218 make_scalar_constant_get!(bool, Bool);
219 make_scalar_constant_get!(u8, U8);
220 make_scalar_constant_get!(u16, U16);
221 make_scalar_constant_get!(u32, U32);
222 make_scalar_constant_get!(u64, U64);
223 make_scalar_constant_get!(i8, I8);
224 make_scalar_constant_get!(i16, I16);
225 make_scalar_constant_get!(i32, I32);
226 make_scalar_constant_get!(i64, I64);
227 }
228
229 impl fmt::Display for ScalarConstant {
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231 match self {
232 ScalarConstant::Bool(false) => write!(f, "false"),
233 ScalarConstant::Bool(true) => write!(f, "true"),
234 ScalarConstant::U8(v) => write!(f, "{:#X}_u8", v),
235 ScalarConstant::U16(v) => write!(f, "{:#X}_u16", v),
236 ScalarConstant::U32(v) => write!(f, "{:#X}_u32", v),
237 ScalarConstant::U64(v) => write!(f, "{:#X}_u64", v),
238 ScalarConstant::I8(v) => write!(f, "{:#X}_i8", v),
239 ScalarConstant::I16(v) => write!(f, "{:#X}_i16", v),
240 ScalarConstant::I32(v) => write!(f, "{:#X}_i32", v),
241 ScalarConstant::I64(v) => write!(f, "{:#X}_i64", v),
242 ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits),
243 ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits),
244 ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits),
245 }
246 }
247 }
248
249 impl fmt::Debug for ScalarConstant {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 fmt::Display::fmt(self, f)
252 }
253 }
254
255 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
256 pub struct VectorSplatConstant {
257 pub element: ScalarConstant,
258 }
259
260 impl VectorSplatConstant {
261 pub const fn ty(self) -> VectorType {
262 VectorType {
263 element: self.element.ty(),
264 }
265 }
266 }
267
268 impl fmt::Display for VectorSplatConstant {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 write!(f, "splat({})", self.element)
271 }
272 }
273
274 impl From<ScalarConstant> for Constant {
275 fn from(v: ScalarConstant) -> Self {
276 Constant::Scalar(v)
277 }
278 }
279
280 impl From<VectorSplatConstant> for Constant {
281 fn from(v: VectorSplatConstant) -> Self {
282 Constant::VectorSplat(v)
283 }
284 }
285
286 impl From<ScalarConstant> for Value<'_> {
287 fn from(v: ScalarConstant) -> Self {
288 Value::Constant(v.into())
289 }
290 }
291
292 impl From<VectorSplatConstant> for Value<'_> {
293 fn from(v: VectorSplatConstant) -> Self {
294 Value::Constant(v.into())
295 }
296 }
297
298 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
299 pub enum Constant {
300 Scalar(ScalarConstant),
301 VectorSplat(VectorSplatConstant),
302 }
303
304 impl Constant {
305 pub const fn ty(self) -> Type {
306 match self {
307 Constant::Scalar(v) => Type::Scalar(v.ty()),
308 Constant::VectorSplat(v) => Type::Vector(v.ty()),
309 }
310 }
311 }
312
313 impl fmt::Display for Constant {
314 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315 match self {
316 Constant::Scalar(v) => v.fmt(f),
317 Constant::VectorSplat(v) => v.fmt(f),
318 }
319 }
320 }
321
322 #[derive(Debug)]
323 pub struct Input<'ctx> {
324 pub name: &'ctx str,
325 pub ty: Type,
326 }
327
328 impl fmt::Display for Input<'_> {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 write!(f, "in<{}>", self.name)
331 }
332 }
333
334 #[derive(Copy, Clone)]
335 pub enum Value<'ctx> {
336 Input(&'ctx Input<'ctx>),
337 Constant(Constant),
338 OpResult(&'ctx Operation<'ctx>),
339 }
340
341 impl<'ctx> Value<'ctx> {
342 pub const fn ty(self) -> Type {
343 match self {
344 Value::Input(v) => v.ty,
345 Value::Constant(v) => v.ty(),
346 Value::OpResult(v) => v.result_type,
347 }
348 }
349 }
350
351 impl fmt::Debug for Value<'_> {
352 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353 match self {
354 Value::Input(v) => v.fmt(f),
355 Value::Constant(v) => v.fmt(f),
356 Value::OpResult(v) => v.result_id.fmt(f),
357 }
358 }
359 }
360
361 impl fmt::Display for Value<'_> {
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 match self {
364 Value::Input(v) => v.fmt(f),
365 Value::Constant(v) => v.fmt(f),
366 Value::OpResult(v) => v.result_id.fmt(f),
367 }
368 }
369 }
370
371 impl<'ctx> From<&'ctx Input<'ctx>> for Value<'ctx> {
372 fn from(v: &'ctx Input<'ctx>) -> Self {
373 Value::Input(v)
374 }
375 }
376
377 impl<'ctx> From<&'ctx Operation<'ctx>> for Value<'ctx> {
378 fn from(v: &'ctx Operation<'ctx>) -> Self {
379 Value::OpResult(v)
380 }
381 }
382
383 impl<'ctx> From<Constant> for Value<'ctx> {
384 fn from(v: Constant) -> Self {
385 Value::Constant(v)
386 }
387 }
388
389 make_enum! {
390 pub enum Opcode {
391 Add,
392 Sub,
393 Mul,
394 Div,
395 Rem,
396 Fma,
397 Cast,
398 And,
399 Or,
400 Xor,
401 Not,
402 Shl,
403 Shr,
404 CountSetBits,
405 CountLeadingZeros,
406 CountTrailingZeros,
407 Neg,
408 Abs,
409 Trunc,
410 Ceil,
411 Floor,
412 Round,
413 IsInfinite,
414 IsFinite,
415 ToBits,
416 FromBits,
417 Splat,
418 CompareEq,
419 CompareNe,
420 CompareLt,
421 CompareLe,
422 CompareGt,
423 CompareGe,
424 Select,
425 }
426 }
427
428 #[derive(Debug)]
429 pub struct Operation<'ctx> {
430 pub opcode: Opcode,
431 pub arguments: Vec<Value<'ctx>>,
432 pub result_type: Type,
433 pub result_id: OperationId,
434 }
435
436 impl fmt::Display for Operation<'_> {
437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438 write!(
439 f,
440 "{}: {} = {}",
441 self.result_id, self.result_type, self.opcode
442 )?;
443 let mut separator = " ";
444 for i in &self.arguments {
445 write!(f, "{}{}", separator, i)?;
446 separator = ", ";
447 }
448 Ok(())
449 }
450 }
451
452 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
453 pub struct OperationId(pub u64);
454
455 impl fmt::Display for OperationId {
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 write!(f, "op_{}", self.0)
458 }
459 }
460
461 #[derive(Default)]
462 pub struct IrContext<'ctx> {
463 bytes_arena: Arena<u8>,
464 inputs_arena: Arena<Input<'ctx>>,
465 inputs: RefCell<HashMap<&'ctx str, &'ctx Input<'ctx>>>,
466 operations_arena: Arena<Operation<'ctx>>,
467 operations: RefCell<Vec<&'ctx Operation<'ctx>>>,
468 next_operation_result_id: Cell<u64>,
469 }
470
471 impl fmt::Debug for IrContext<'_> {
472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473 f.write_str("IrContext { .. }")
474 }
475 }
476
477 impl<'ctx> IrContext<'ctx> {
478 pub fn new() -> Self {
479 Self::default()
480 }
481 pub fn make_input<N: Borrow<str> + Into<String>, T: Into<Type>>(
482 &'ctx self,
483 name: N,
484 ty: T,
485 ) -> &'ctx Input<'ctx> {
486 let mut inputs = self.inputs.borrow_mut();
487 let name_str = name.borrow();
488 let ty = ty.into();
489 if !name_str.is_empty() && !inputs.contains_key(name_str) {
490 let name = self.bytes_arena.alloc_str(name_str);
491 let input = self.inputs_arena.alloc(Input { name, ty });
492 inputs.insert(name, input);
493 return input;
494 }
495 let mut name: String = name.into();
496 if name.is_empty() {
497 name = "in".into();
498 }
499 let name_len = name.len();
500 let mut tag = 2usize;
501 loop {
502 name.truncate(name_len);
503 write!(name, "_{}", tag).unwrap();
504 if !inputs.contains_key(&*name) {
505 let name = self.bytes_arena.alloc_str(&name);
506 let input = self.inputs_arena.alloc(Input { name, ty });
507 inputs.insert(name, input);
508 return input;
509 }
510 tag += 1;
511 }
512 }
513 pub fn make_operation<A: Into<Vec<Value<'ctx>>>, T: Into<Type>>(
514 &'ctx self,
515 opcode: Opcode,
516 arguments: A,
517 result_type: T,
518 ) -> &'ctx Operation<'ctx> {
519 let arguments = arguments.into();
520 let result_type = result_type.into();
521 let result_id = OperationId(self.next_operation_result_id.get());
522 self.next_operation_result_id.set(result_id.0 + 1);
523 let operation = self.operations_arena.alloc(Operation {
524 opcode,
525 arguments,
526 result_type,
527 result_id,
528 });
529 self.operations.borrow_mut().push(operation);
530 operation
531 }
532 pub fn replace_operations(
533 &'ctx self,
534 new_operations: Vec<&'ctx Operation<'ctx>>,
535 ) -> Vec<&'ctx Operation<'ctx>> {
536 self.operations.replace(new_operations)
537 }
538 }
539
540 #[derive(Debug)]
541 pub struct IrFunction<'ctx> {
542 pub inputs: Vec<&'ctx Input<'ctx>>,
543 pub operations: Vec<&'ctx Operation<'ctx>>,
544 pub outputs: Vec<Value<'ctx>>,
545 }
546
547 impl fmt::Display for IrFunction<'_> {
548 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
549 write!(f, "function(")?;
550 let mut first = true;
551 for input in &self.inputs {
552 if first {
553 first = false
554 } else {
555 write!(f, ", ")?;
556 }
557 write!(f, "{}: {}", input, input.ty)?;
558 }
559 match self.outputs.len() {
560 0 => writeln!(f, ") {{")?,
561 1 => writeln!(f, ") -> {} {{", self.outputs[0].ty())?,
562 _ => {
563 write!(f, ") -> ({}", self.outputs[0].ty())?;
564 for output in self.outputs.iter().skip(1) {
565 write!(f, ", {}", output.ty())?;
566 }
567 writeln!(f, ") {{")?;
568 }
569 }
570 for operation in &self.operations {
571 writeln!(f, " {}", operation)?;
572 }
573 match self.outputs.len() {
574 0 => writeln!(f, "}}")?,
575 1 => writeln!(f, " Return {}\n}}", self.outputs[0])?,
576 _ => {
577 write!(f, " Return {}", self.outputs[0])?;
578 for output in self.outputs.iter().skip(1) {
579 write!(f, ", {}", output)?;
580 }
581 writeln!(f, "\n}}")?;
582 }
583 }
584 Ok(())
585 }
586 }
587
588 impl<'ctx> IrFunction<'ctx> {
589 pub fn make<F: IrFunctionMaker<'ctx>>(ctx: &'ctx IrContext<'ctx>, f: F) -> Self {
590 let old_operations = ctx.replace_operations(Vec::new());
591 let (v, inputs) = F::make_inputs(ctx);
592 let outputs = f.call(ctx, v).outputs_to_vec();
593 let operations = ctx.replace_operations(old_operations);
594 Self {
595 inputs,
596 operations,
597 outputs,
598 }
599 }
600 }
601
602 pub trait IrFunctionMaker<'ctx>: Sized {
603 type Inputs;
604 type Outputs: IrFunctionMakerOutputs<'ctx>;
605 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs;
606 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>);
607 }
608
609 pub trait IrFunctionMakerOutputs<'ctx> {
610 fn outputs_to_vec(self) -> Vec<Value<'ctx>>;
611 }
612
613 impl<'ctx, T: IrValue<'ctx>> IrFunctionMakerOutputs<'ctx> for T {
614 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
615 [self.value()].into()
616 }
617 }
618
619 impl<'ctx> IrFunctionMakerOutputs<'ctx> for () {
620 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
621 Vec::new()
622 }
623 }
624
625 impl<'ctx, R: IrFunctionMakerOutputs<'ctx>> IrFunctionMaker<'ctx>
626 for fn(&'ctx IrContext<'ctx>) -> R
627 {
628 type Inputs = ();
629 type Outputs = R;
630 fn call(self, ctx: &'ctx IrContext<'ctx>, _inputs: Self::Inputs) -> Self::Outputs {
631 self(ctx)
632 }
633 fn make_inputs(_ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
634 ((), Vec::new())
635 }
636 }
637
638 macro_rules! impl_ir_function_maker_io {
639 () => {};
640 ($first_arg:ident: $first_arg_ty:ident, $($arg:ident: $arg_ty:ident,)*) => {
641 impl<'ctx, $first_arg_ty, $($arg_ty,)* R> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>, $first_arg_ty $(, $arg_ty)*) -> R
642 where
643 $first_arg_ty: IrValue<'ctx>,
644 $($arg_ty: IrValue<'ctx>,)*
645 R: IrFunctionMakerOutputs<'ctx>,
646 {
647 type Inputs = ($first_arg_ty, $($arg_ty,)*);
648 type Outputs = R;
649 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs {
650 let ($first_arg, $($arg,)*) = inputs;
651 self(ctx, $first_arg$(, $arg)*)
652 }
653 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
654 let mut $first_arg = String::new();
655 $(let mut $arg = String::new();)*
656 for (index, arg) in [&mut $first_arg $(, &mut $arg)*].iter_mut().enumerate() {
657 **arg = format!("arg_{}", index);
658 }
659 let $first_arg = $first_arg_ty::make_input(ctx, $first_arg);
660 $(let $arg = $arg_ty::make_input(ctx, $arg);)*
661 (($first_arg.0, $($arg.0,)*), [$first_arg.1 $(, $arg.1)*].into())
662 }
663 }
664 impl<'ctx, $first_arg_ty, $($arg_ty),*> IrFunctionMakerOutputs<'ctx> for ($first_arg_ty, $($arg_ty,)*)
665 where
666 $first_arg_ty: IrValue<'ctx>,
667 $($arg_ty: IrValue<'ctx>,)*
668 {
669 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
670 let ($first_arg, $($arg,)*) = self;
671 [$first_arg.value() $(, $arg.value())*].into()
672 }
673 }
674 impl_ir_function_maker_io!($($arg: $arg_ty,)*);
675 };
676 }
677
678 impl_ir_function_maker_io!(
679 in0: In0,
680 in1: In1,
681 in2: In2,
682 in3: In3,
683 in4: In4,
684 in5: In5,
685 in6: In6,
686 in7: In7,
687 in8: In8,
688 in9: In9,
689 in10: In10,
690 in11: In11,
691 );
692
693 pub trait IrValue<'ctx>: Copy + Make<Context = &'ctx IrContext<'ctx>> {
694 const TYPE: Type;
695 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self;
696 fn make_input<N: Borrow<str> + Into<String>>(
697 ctx: &'ctx IrContext<'ctx>,
698 name: N,
699 ) -> (Self, &'ctx Input<'ctx>) {
700 let input = ctx.make_input(name, Self::TYPE);
701 (Self::new(ctx, input.into()), input)
702 }
703 fn value(self) -> Value<'ctx>;
704 }
705
706 macro_rules! ir_value {
707 ($name:ident, $vec_name:ident, TYPE = $scalar_type:ident, fn make($make_var:ident: $prim:ident) {$make:expr}) => {
708 #[derive(Clone, Copy, Debug)]
709 pub struct $name<'ctx> {
710 pub value: Value<'ctx>,
711 pub ctx: &'ctx IrContext<'ctx>,
712 }
713
714 impl<'ctx> IrValue<'ctx> for $name<'ctx> {
715 const TYPE: Type = Type::Scalar(Self::SCALAR_TYPE);
716 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
717 assert_eq!(value.ty(), Self::TYPE);
718 Self { ctx, value }
719 }
720 fn value(self) -> Value<'ctx> {
721 self.value
722 }
723 }
724
725 impl<'ctx> $name<'ctx> {
726 pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type;
727 }
728
729 impl<'ctx> Make for $name<'ctx> {
730 type Prim = $prim;
731 type Context = &'ctx IrContext<'ctx>;
732 fn ctx(self) -> Self::Context {
733 self.ctx
734 }
735 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
736 let value: ScalarConstant = $make;
737 let value = value.into();
738 Self { value, ctx }
739 }
740 }
741
742 #[derive(Clone, Copy, Debug)]
743 pub struct $vec_name<'ctx> {
744 pub value: Value<'ctx>,
745 pub ctx: &'ctx IrContext<'ctx>,
746 }
747
748 impl<'ctx> IrValue<'ctx> for $vec_name<'ctx> {
749 const TYPE: Type = Type::Vector(Self::VECTOR_TYPE);
750 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
751 assert_eq!(value.ty(), Self::TYPE);
752 Self { ctx, value }
753 }
754 fn value(self) -> Value<'ctx> {
755 self.value
756 }
757 }
758
759 impl<'ctx> $vec_name<'ctx> {
760 pub const VECTOR_TYPE: VectorType = VectorType {
761 element: ScalarType::$scalar_type,
762 };
763 }
764
765 impl<'ctx> Make for $vec_name<'ctx> {
766 type Prim = $prim;
767 type Context = &'ctx IrContext<'ctx>;
768 fn ctx(self) -> Self::Context {
769 self.ctx
770 }
771 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
772 let element = $make;
773 Self {
774 value: VectorSplatConstant { element }.into(),
775 ctx,
776 }
777 }
778 }
779
780 impl<'ctx> Select<$name<'ctx>> for IrBool<'ctx> {
781 fn select(self, true_v: $name<'ctx>, false_v: $name<'ctx>) -> $name<'ctx> {
782 let value = self
783 .ctx
784 .make_operation(
785 Opcode::Select,
786 [self.value, true_v.value, false_v.value],
787 $name::TYPE,
788 )
789 .into();
790 $name {
791 value,
792 ctx: self.ctx,
793 }
794 }
795 }
796
797 impl<'ctx> Select<$vec_name<'ctx>> for IrVecBool<'ctx> {
798 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
799 let value = self
800 .ctx
801 .make_operation(
802 Opcode::Select,
803 [self.value, true_v.value, false_v.value],
804 $vec_name::TYPE,
805 )
806 .into();
807 $vec_name {
808 value,
809 ctx: self.ctx,
810 }
811 }
812 }
813
814 impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> {
815 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
816 let value = self
817 .ctx
818 .make_operation(
819 Opcode::Select,
820 [self.value, true_v.value, false_v.value],
821 $vec_name::TYPE,
822 )
823 .into();
824 $vec_name {
825 value,
826 ctx: self.ctx,
827 }
828 }
829 }
830
831 impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
832 fn from(v: $name<'ctx>) -> Self {
833 let value = v
834 .ctx
835 .make_operation(Opcode::Splat, [v.value], $vec_name::TYPE)
836 .into();
837 Self { value, ctx: v.ctx }
838 }
839 }
840 };
841 }
842
843 macro_rules! impl_bit_ops {
844 ($ty:ident) => {
845 impl<'ctx> BitAnd for $ty<'ctx> {
846 type Output = Self;
847
848 fn bitand(self, rhs: Self) -> Self::Output {
849 let value = self
850 .ctx
851 .make_operation(Opcode::And, [self.value, rhs.value], Self::TYPE)
852 .into();
853 Self {
854 value,
855 ctx: self.ctx,
856 }
857 }
858 }
859 impl<'ctx> BitOr for $ty<'ctx> {
860 type Output = Self;
861
862 fn bitor(self, rhs: Self) -> Self::Output {
863 let value = self
864 .ctx
865 .make_operation(Opcode::Or, [self.value, rhs.value], Self::TYPE)
866 .into();
867 Self {
868 value,
869 ctx: self.ctx,
870 }
871 }
872 }
873 impl<'ctx> BitXor for $ty<'ctx> {
874 type Output = Self;
875
876 fn bitxor(self, rhs: Self) -> Self::Output {
877 let value = self
878 .ctx
879 .make_operation(Opcode::Xor, [self.value, rhs.value], Self::TYPE)
880 .into();
881 Self {
882 value,
883 ctx: self.ctx,
884 }
885 }
886 }
887 impl<'ctx> Not for $ty<'ctx> {
888 type Output = Self;
889
890 fn not(self) -> Self::Output {
891 let value = self
892 .ctx
893 .make_operation(Opcode::Not, [self.value], Self::TYPE)
894 .into();
895 Self {
896 value,
897 ctx: self.ctx,
898 }
899 }
900 }
901 impl<'ctx> BitAndAssign for $ty<'ctx> {
902 fn bitand_assign(&mut self, rhs: Self) {
903 *self = *self & rhs;
904 }
905 }
906 impl<'ctx> BitOrAssign for $ty<'ctx> {
907 fn bitor_assign(&mut self, rhs: Self) {
908 *self = *self | rhs;
909 }
910 }
911 impl<'ctx> BitXorAssign for $ty<'ctx> {
912 fn bitxor_assign(&mut self, rhs: Self) {
913 *self = *self ^ rhs;
914 }
915 }
916 };
917 }
918
919 macro_rules! impl_number_ops {
920 ($ty:ident, $bool:ident) => {
921 impl<'ctx> Add for $ty<'ctx> {
922 type Output = Self;
923
924 fn add(self, rhs: Self) -> Self::Output {
925 let value = self
926 .ctx
927 .make_operation(Opcode::Add, [self.value, rhs.value], Self::TYPE)
928 .into();
929 Self {
930 value,
931 ctx: self.ctx,
932 }
933 }
934 }
935 impl<'ctx> Sub for $ty<'ctx> {
936 type Output = Self;
937
938 fn sub(self, rhs: Self) -> Self::Output {
939 let value = self
940 .ctx
941 .make_operation(Opcode::Sub, [self.value, rhs.value], Self::TYPE)
942 .into();
943 Self {
944 value,
945 ctx: self.ctx,
946 }
947 }
948 }
949 impl<'ctx> Mul for $ty<'ctx> {
950 type Output = Self;
951
952 fn mul(self, rhs: Self) -> Self::Output {
953 let value = self
954 .ctx
955 .make_operation(Opcode::Mul, [self.value, rhs.value], Self::TYPE)
956 .into();
957 Self {
958 value,
959 ctx: self.ctx,
960 }
961 }
962 }
963 impl<'ctx> Div for $ty<'ctx> {
964 type Output = Self;
965
966 fn div(self, rhs: Self) -> Self::Output {
967 let value = self
968 .ctx
969 .make_operation(Opcode::Div, [self.value, rhs.value], Self::TYPE)
970 .into();
971 Self {
972 value,
973 ctx: self.ctx,
974 }
975 }
976 }
977 impl<'ctx> Rem for $ty<'ctx> {
978 type Output = Self;
979
980 fn rem(self, rhs: Self) -> Self::Output {
981 let value = self
982 .ctx
983 .make_operation(Opcode::Rem, [self.value, rhs.value], Self::TYPE)
984 .into();
985 Self {
986 value,
987 ctx: self.ctx,
988 }
989 }
990 }
991 impl<'ctx> AddAssign for $ty<'ctx> {
992 fn add_assign(&mut self, rhs: Self) {
993 *self = *self + rhs;
994 }
995 }
996 impl<'ctx> SubAssign for $ty<'ctx> {
997 fn sub_assign(&mut self, rhs: Self) {
998 *self = *self - rhs;
999 }
1000 }
1001 impl<'ctx> MulAssign for $ty<'ctx> {
1002 fn mul_assign(&mut self, rhs: Self) {
1003 *self = *self * rhs;
1004 }
1005 }
1006 impl<'ctx> DivAssign for $ty<'ctx> {
1007 fn div_assign(&mut self, rhs: Self) {
1008 *self = *self / rhs;
1009 }
1010 }
1011 impl<'ctx> RemAssign for $ty<'ctx> {
1012 fn rem_assign(&mut self, rhs: Self) {
1013 *self = *self % rhs;
1014 }
1015 }
1016 impl<'ctx> Compare for $ty<'ctx> {
1017 type Bool = $bool<'ctx>;
1018 fn eq(self, rhs: Self) -> Self::Bool {
1019 let value = self
1020 .ctx
1021 .make_operation(Opcode::CompareEq, [self.value, rhs.value], $bool::TYPE)
1022 .into();
1023 $bool {
1024 value,
1025 ctx: self.ctx,
1026 }
1027 }
1028 fn ne(self, rhs: Self) -> Self::Bool {
1029 let value = self
1030 .ctx
1031 .make_operation(Opcode::CompareNe, [self.value, rhs.value], $bool::TYPE)
1032 .into();
1033 $bool {
1034 value,
1035 ctx: self.ctx,
1036 }
1037 }
1038 fn lt(self, rhs: Self) -> Self::Bool {
1039 let value = self
1040 .ctx
1041 .make_operation(Opcode::CompareLt, [self.value, rhs.value], $bool::TYPE)
1042 .into();
1043 $bool {
1044 value,
1045 ctx: self.ctx,
1046 }
1047 }
1048 fn gt(self, rhs: Self) -> Self::Bool {
1049 let value = self
1050 .ctx
1051 .make_operation(Opcode::CompareGt, [self.value, rhs.value], $bool::TYPE)
1052 .into();
1053 $bool {
1054 value,
1055 ctx: self.ctx,
1056 }
1057 }
1058 fn le(self, rhs: Self) -> Self::Bool {
1059 let value = self
1060 .ctx
1061 .make_operation(Opcode::CompareLe, [self.value, rhs.value], $bool::TYPE)
1062 .into();
1063 $bool {
1064 value,
1065 ctx: self.ctx,
1066 }
1067 }
1068 fn ge(self, rhs: Self) -> Self::Bool {
1069 let value = self
1070 .ctx
1071 .make_operation(Opcode::CompareGe, [self.value, rhs.value], $bool::TYPE)
1072 .into();
1073 $bool {
1074 value,
1075 ctx: self.ctx,
1076 }
1077 }
1078 }
1079 };
1080 }
1081
1082 macro_rules! impl_bool_compare {
1083 ($ty:ident) => {
1084 impl<'ctx> Compare for $ty<'ctx> {
1085 type Bool = Self;
1086 fn eq(self, rhs: Self) -> Self::Bool {
1087 !(self ^ rhs)
1088 }
1089 fn ne(self, rhs: Self) -> Self::Bool {
1090 self ^ rhs
1091 }
1092 fn lt(self, rhs: Self) -> Self::Bool {
1093 !self & rhs
1094 }
1095 fn gt(self, rhs: Self) -> Self::Bool {
1096 self & !rhs
1097 }
1098 fn le(self, rhs: Self) -> Self::Bool {
1099 !self | rhs
1100 }
1101 fn ge(self, rhs: Self) -> Self::Bool {
1102 self | !rhs
1103 }
1104 }
1105 };
1106 }
1107
1108 impl_bool_compare!(IrBool);
1109 impl_bool_compare!(IrVecBool);
1110
1111 macro_rules! impl_shift_ops {
1112 ($ty:ident) => {
1113 impl<'ctx> Shl for $ty<'ctx> {
1114 type Output = Self;
1115
1116 fn shl(self, rhs: Self) -> Self::Output {
1117 let value = self
1118 .ctx
1119 .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
1120 .into();
1121 Self {
1122 value,
1123 ctx: self.ctx,
1124 }
1125 }
1126 }
1127 impl<'ctx> Shr for $ty<'ctx> {
1128 type Output = Self;
1129
1130 fn shr(self, rhs: Self) -> Self::Output {
1131 let value = self
1132 .ctx
1133 .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
1134 .into();
1135 Self {
1136 value,
1137 ctx: self.ctx,
1138 }
1139 }
1140 }
1141 impl<'ctx> ShlAssign for $ty<'ctx> {
1142 fn shl_assign(&mut self, rhs: Self) {
1143 *self = *self << rhs;
1144 }
1145 }
1146 impl<'ctx> ShrAssign for $ty<'ctx> {
1147 fn shr_assign(&mut self, rhs: Self) {
1148 *self = *self >> rhs;
1149 }
1150 }
1151 };
1152 }
1153
1154 macro_rules! impl_neg {
1155 ($ty:ident) => {
1156 impl<'ctx> Neg for $ty<'ctx> {
1157 type Output = Self;
1158
1159 fn neg(self) -> Self::Output {
1160 let value = self
1161 .ctx
1162 .make_operation(Opcode::Neg, [self.value], Self::TYPE)
1163 .into();
1164 Self {
1165 value,
1166 ctx: self.ctx,
1167 }
1168 }
1169 }
1170 };
1171 }
1172
1173 macro_rules! impl_int_trait {
1174 ($ty:ident) => {
1175 impl<'ctx> Int for $ty<'ctx> {
1176 fn leading_zeros(self) -> Self {
1177 let value = self
1178 .ctx
1179 .make_operation(Opcode::CountLeadingZeros, [self.value], Self::TYPE)
1180 .into();
1181 Self {
1182 value,
1183 ctx: self.ctx,
1184 }
1185 }
1186 fn trailing_zeros(self) -> Self {
1187 let value = self
1188 .ctx
1189 .make_operation(Opcode::CountTrailingZeros, [self.value], Self::TYPE)
1190 .into();
1191 Self {
1192 value,
1193 ctx: self.ctx,
1194 }
1195 }
1196 fn count_ones(self) -> Self {
1197 let value = self
1198 .ctx
1199 .make_operation(Opcode::CountSetBits, [self.value], Self::TYPE)
1200 .into();
1201 Self {
1202 value,
1203 ctx: self.ctx,
1204 }
1205 }
1206 }
1207 };
1208 }
1209
1210 macro_rules! impl_integer_ops {
1211 ($scalar:ident, $vec:ident) => {
1212 impl_bit_ops!($scalar);
1213 impl_number_ops!($scalar, IrBool);
1214 impl_shift_ops!($scalar);
1215 impl_bit_ops!($vec);
1216 impl_number_ops!($vec, IrVecBool);
1217 impl_shift_ops!($vec);
1218 impl_int_trait!($scalar);
1219 impl_int_trait!($vec);
1220 };
1221 }
1222
1223 macro_rules! impl_uint_sint_ops {
1224 ($uint_scalar:ident, $uint_vec:ident, $sint_scalar:ident, $sint_vec:ident) => {
1225 impl_integer_ops!($uint_scalar, $uint_vec);
1226 impl_integer_ops!($sint_scalar, $sint_vec);
1227 impl_neg!($sint_scalar);
1228 impl_neg!($sint_vec);
1229
1230 impl<'ctx> UInt for $uint_scalar<'ctx> {
1231 type PrimUInt = Self::Prim;
1232 type SignedType = $sint_scalar<'ctx>;
1233 }
1234 impl<'ctx> UInt for $uint_vec<'ctx> {
1235 type PrimUInt = Self::Prim;
1236 type SignedType = $sint_vec<'ctx>;
1237 }
1238 impl<'ctx> SInt for $sint_scalar<'ctx> {
1239 type PrimSInt = Self::Prim;
1240 type UnsignedType = $uint_scalar<'ctx>;
1241 }
1242 impl<'ctx> SInt for $sint_vec<'ctx> {
1243 type PrimSInt = Self::Prim;
1244 type UnsignedType = $uint_vec<'ctx>;
1245 }
1246 };
1247 }
1248
1249 impl_uint_sint_ops!(IrU8, IrVecU8, IrI8, IrVecI8);
1250 impl_uint_sint_ops!(IrU16, IrVecU16, IrI16, IrVecI16);
1251 impl_uint_sint_ops!(IrU32, IrVecU32, IrI32, IrVecI32);
1252 impl_uint_sint_ops!(IrU64, IrVecU64, IrI64, IrVecI64);
1253
1254 macro_rules! impl_float {
1255 ($float:ident, $bits:ident, $signed_bits:ident) => {
1256 impl<'ctx> Float for $float<'ctx> {
1257 type PrimFloat = <$float<'ctx> as Make>::Prim;
1258 type BitsType = $bits<'ctx>;
1259 type SignedBitsType = $signed_bits<'ctx>;
1260 fn abs(self) -> Self {
1261 let value = self
1262 .ctx
1263 .make_operation(Opcode::Abs, [self.value], Self::TYPE)
1264 .into();
1265 Self {
1266 value,
1267 ctx: self.ctx,
1268 }
1269 }
1270 fn trunc(self) -> Self {
1271 let value = self
1272 .ctx
1273 .make_operation(Opcode::Trunc, [self.value], Self::TYPE)
1274 .into();
1275 Self {
1276 value,
1277 ctx: self.ctx,
1278 }
1279 }
1280 fn ceil(self) -> Self {
1281 let value = self
1282 .ctx
1283 .make_operation(Opcode::Ceil, [self.value], Self::TYPE)
1284 .into();
1285 Self {
1286 value,
1287 ctx: self.ctx,
1288 }
1289 }
1290 fn floor(self) -> Self {
1291 let value = self
1292 .ctx
1293 .make_operation(Opcode::Floor, [self.value], Self::TYPE)
1294 .into();
1295 Self {
1296 value,
1297 ctx: self.ctx,
1298 }
1299 }
1300 fn round(self) -> Self {
1301 let value = self
1302 .ctx
1303 .make_operation(Opcode::Round, [self.value], Self::TYPE)
1304 .into();
1305 Self {
1306 value,
1307 ctx: self.ctx,
1308 }
1309 }
1310 #[cfg(feature = "fma")]
1311 fn fma(self, a: Self, b: Self) -> Self {
1312 let value = self
1313 .ctx
1314 .make_operation(Opcode::Fma, [self.value, a.value, b.value], Self::TYPE)
1315 .into();
1316 Self {
1317 value,
1318 ctx: self.ctx,
1319 }
1320 }
1321 fn is_nan(self) -> Self::Bool {
1322 let value = self
1323 .ctx
1324 .make_operation(
1325 Opcode::CompareNe,
1326 [self.value, self.value],
1327 Self::Bool::TYPE,
1328 )
1329 .into();
1330 Self::Bool {
1331 value,
1332 ctx: self.ctx,
1333 }
1334 }
1335 fn is_infinite(self) -> Self::Bool {
1336 let value = self
1337 .ctx
1338 .make_operation(Opcode::IsInfinite, [self.value], Self::Bool::TYPE)
1339 .into();
1340 Self::Bool {
1341 value,
1342 ctx: self.ctx,
1343 }
1344 }
1345 fn is_finite(self) -> Self::Bool {
1346 let value = self
1347 .ctx
1348 .make_operation(Opcode::IsFinite, [self.value], Self::Bool::TYPE)
1349 .into();
1350 Self::Bool {
1351 value,
1352 ctx: self.ctx,
1353 }
1354 }
1355 fn from_bits(v: Self::BitsType) -> Self {
1356 let value = v
1357 .ctx
1358 .make_operation(Opcode::FromBits, [v.value], Self::TYPE)
1359 .into();
1360 Self { value, ctx: v.ctx }
1361 }
1362 fn to_bits(self) -> Self::BitsType {
1363 let value = self
1364 .ctx
1365 .make_operation(Opcode::ToBits, [self.value], Self::BitsType::TYPE)
1366 .into();
1367 Self::BitsType {
1368 value,
1369 ctx: self.ctx,
1370 }
1371 }
1372 }
1373 };
1374 }
1375
1376 macro_rules! impl_float_ops {
1377 ($scalar:ident, $scalar_bits:ident, $scalar_signed_bits:ident, $vec:ident, $vec_bits:ident, $vec_signed_bits:ident) => {
1378 impl_number_ops!($scalar, IrBool);
1379 impl_number_ops!($vec, IrVecBool);
1380 impl_neg!($scalar);
1381 impl_neg!($vec);
1382 impl_float!($scalar, $scalar_bits, $scalar_signed_bits);
1383 impl_float!($vec, $vec_bits, $vec_signed_bits);
1384 };
1385 }
1386
1387 impl_float_ops!(IrF16, IrU16, IrI16, IrVecF16, IrVecU16, IrVecI16);
1388 impl_float_ops!(IrF32, IrU32, IrI32, IrVecF32, IrVecU32, IrVecI32);
1389 impl_float_ops!(IrF64, IrU64, IrI64, IrVecF64, IrVecU64, IrVecI64);
1390
1391 ir_value!(
1392 IrBool,
1393 IrVecBool,
1394 TYPE = Bool,
1395 fn make(v: bool) {
1396 v.into()
1397 }
1398 );
1399
1400 impl<'ctx> Bool for IrBool<'ctx> {}
1401 impl<'ctx> Bool for IrVecBool<'ctx> {}
1402
1403 impl_bit_ops!(IrBool);
1404 impl_bit_ops!(IrVecBool);
1405
1406 ir_value!(
1407 IrU8,
1408 IrVecU8,
1409 TYPE = U8,
1410 fn make(v: u8) {
1411 v.into()
1412 }
1413 );
1414 ir_value!(
1415 IrU16,
1416 IrVecU16,
1417 TYPE = U16,
1418 fn make(v: u16) {
1419 v.into()
1420 }
1421 );
1422 ir_value!(
1423 IrU32,
1424 IrVecU32,
1425 TYPE = U32,
1426 fn make(v: u32) {
1427 v.into()
1428 }
1429 );
1430 ir_value!(
1431 IrU64,
1432 IrVecU64,
1433 TYPE = U64,
1434 fn make(v: u64) {
1435 v.into()
1436 }
1437 );
1438 ir_value!(
1439 IrI8,
1440 IrVecI8,
1441 TYPE = I8,
1442 fn make(v: i8) {
1443 v.into()
1444 }
1445 );
1446 ir_value!(
1447 IrI16,
1448 IrVecI16,
1449 TYPE = I16,
1450 fn make(v: i16) {
1451 v.into()
1452 }
1453 );
1454 ir_value!(
1455 IrI32,
1456 IrVecI32,
1457 TYPE = I32,
1458 fn make(v: i32) {
1459 v.into()
1460 }
1461 );
1462 ir_value!(
1463 IrI64,
1464 IrVecI64,
1465 TYPE = I64,
1466 fn make(v: i64) {
1467 v.into()
1468 }
1469 );
1470 ir_value!(
1471 IrF16,
1472 IrVecF16,
1473 TYPE = F16,
1474 fn make(v: F16) {
1475 ScalarConstant::from_f16_bits(v.to_bits())
1476 }
1477 );
1478 ir_value!(
1479 IrF32,
1480 IrVecF32,
1481 TYPE = F32,
1482 fn make(v: f32) {
1483 ScalarConstant::from_f32_bits(v.to_bits())
1484 }
1485 );
1486 ir_value!(
1487 IrF64,
1488 IrVecF64,
1489 TYPE = F64,
1490 fn make(v: f64) {
1491 ScalarConstant::from_f64_bits(v.to_bits())
1492 }
1493 );
1494
1495 macro_rules! impl_convert_from {
1496 ($src:ident -> $dest:ident) => {
1497 impl<'ctx> ConvertFrom<$src<'ctx>> for $dest<'ctx> {
1498 fn cvt_from(v: $src<'ctx>) -> Self {
1499 let value = if $src::TYPE == $dest::TYPE {
1500 v.value
1501 } else {
1502 v
1503 .ctx
1504 .make_operation(Opcode::Cast, [v.value], $dest::TYPE)
1505 .into()
1506 };
1507 $dest {
1508 value,
1509 ctx: v.ctx,
1510 }
1511 }
1512 }
1513 };
1514 ($first:ident $(, $ty:ident)*) => {
1515 $(
1516 impl_convert_from!($first -> $ty);
1517 impl_convert_from!($ty -> $first);
1518 )*
1519 impl_convert_from![$($ty),*];
1520 };
1521 () => {
1522 };
1523 }
1524 impl_convert_from![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
1525
1526 impl_convert_from![
1527 IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64,
1528 IrVecF32, IrVecF64
1529 ];
1530
1531 macro_rules! impl_from {
1532 ($src:ident => [$($dest:ident),*]) => {
1533 $(
1534 impl<'ctx> From<$src<'ctx>> for $dest<'ctx> {
1535 fn from(v: $src<'ctx>) -> Self {
1536 v.to()
1537 }
1538 }
1539 )*
1540 };
1541 }
1542
1543 macro_rules! impl_froms {
1544 (
1545 #[u8] $u8:ident;
1546 #[i8] $i8:ident;
1547 #[u16] $u16:ident;
1548 #[i16] $i16:ident;
1549 #[f16] $f16:ident;
1550 #[u32] $u32:ident;
1551 #[i32] $i32:ident;
1552 #[f32] $f32:ident;
1553 #[u64] $u64:ident;
1554 #[i64] $i64:ident;
1555 #[f64] $f64:ident;
1556 ) => {
1557 impl_from!($u8 => [$u16, $i16, $f16, $u32, $i32, $f32, $u64, $i64, $f64]);
1558 impl_from!($u16 => [$u32, $i32, $f32, $u64, $i64, $f64]);
1559 impl_from!($u32 => [$u64, $i64, $f64]);
1560 impl_from!($i8 => [$i16, $f16, $i32, $f32, $i64, $f64]);
1561 impl_from!($i16 => [$i32, $f32, $i64, $f64]);
1562 impl_from!($i32 => [$i64, $f64]);
1563 impl_from!($f16 => [$f32, $f64]);
1564 impl_from!($f32 => [$f64]);
1565 };
1566 }
1567
1568 impl_froms! {
1569 #[u8] IrU8;
1570 #[i8] IrI8;
1571 #[u16] IrU16;
1572 #[i16] IrI16;
1573 #[f16] IrF16;
1574 #[u32] IrU32;
1575 #[i32] IrI32;
1576 #[f32] IrF32;
1577 #[u64] IrU64;
1578 #[i64] IrI64;
1579 #[f64] IrF64;
1580 }
1581
1582 impl_froms! {
1583 #[u8] IrVecU8;
1584 #[i8] IrVecI8;
1585 #[u16] IrVecU16;
1586 #[i16] IrVecI16;
1587 #[f16] IrVecF16;
1588 #[u32] IrVecU32;
1589 #[i32] IrVecI32;
1590 #[f32] IrVecF32;
1591 #[u64] IrVecU64;
1592 #[i64] IrVecI64;
1593 #[f64] IrVecF64;
1594 }
1595
1596 impl<'ctx> Context for &'ctx IrContext<'ctx> {
1597 type Bool = IrBool<'ctx>;
1598 type U8 = IrU8<'ctx>;
1599 type I8 = IrI8<'ctx>;
1600 type U16 = IrU16<'ctx>;
1601 type I16 = IrI16<'ctx>;
1602 type F16 = IrF16<'ctx>;
1603 type U32 = IrU32<'ctx>;
1604 type I32 = IrI32<'ctx>;
1605 type F32 = IrF32<'ctx>;
1606 type U64 = IrU64<'ctx>;
1607 type I64 = IrI64<'ctx>;
1608 type F64 = IrF64<'ctx>;
1609 type VecBool8 = IrVecBool<'ctx>;
1610 type VecU8 = IrVecU8<'ctx>;
1611 type VecI8 = IrVecI8<'ctx>;
1612 type VecBool16 = IrVecBool<'ctx>;
1613 type VecU16 = IrVecU16<'ctx>;
1614 type VecI16 = IrVecI16<'ctx>;
1615 type VecF16 = IrVecF16<'ctx>;
1616 type VecBool32 = IrVecBool<'ctx>;
1617 type VecU32 = IrVecU32<'ctx>;
1618 type VecI32 = IrVecI32<'ctx>;
1619 type VecF32 = IrVecF32<'ctx>;
1620 type VecBool64 = IrVecBool<'ctx>;
1621 type VecU64 = IrVecU64<'ctx>;
1622 type VecI64 = IrVecI64<'ctx>;
1623 type VecF64 = IrVecF64<'ctx>;
1624 }
1625
1626 #[cfg(test)]
1627 mod tests {
1628 use crate::algorithms;
1629
1630 use super::*;
1631 use std::println;
1632
1633 #[test]
1634 fn test_display() {
1635 fn f<Ctx: Context>(ctx: Ctx, a: Ctx::VecU8, b: Ctx::VecF32) -> Ctx::VecF64 {
1636 let a: Ctx::VecF32 = a.into();
1637 (a - (a + b - ctx.make(5f32)).floor()).to()
1638 }
1639 let ctx = IrContext::new();
1640 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1641 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>, IrVecF32<'ctx>) -> IrVecF64<'ctx> = f;
1642 IrFunction::make(ctx, f)
1643 }
1644 let text = format!("\n{}", make_it(&ctx));
1645 println!("{}", text);
1646 assert_eq!(
1647 text,
1648 r"
1649 function(in<arg_0>: vec<U8>, in<arg_1>: vec<F32>) -> vec<F64> {
1650 op_0: vec<F32> = Cast in<arg_0>
1651 op_1: vec<F32> = Add op_0, in<arg_1>
1652 op_2: vec<F32> = Sub op_1, splat(0x40A00000_f32)
1653 op_3: vec<F32> = Floor op_2
1654 op_4: vec<F32> = Sub op_0, op_3
1655 op_5: vec<F64> = Cast op_4
1656 Return op_5
1657 }
1658 "
1659 );
1660 }
1661
1662 #[test]
1663 fn test_display_ilogb_f32() {
1664 let ctx = IrContext::new();
1665 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1666 let f: fn(&'ctx IrContext<'ctx>, IrVecF32<'ctx>) -> IrVecI32<'ctx> =
1667 algorithms::ilogb::ilogb_f32;
1668 IrFunction::make(ctx, f)
1669 }
1670 let text = format!("\n{}", make_it(&ctx));
1671 println!("{}", text);
1672 assert_eq!(
1673 text,
1674 r"
1675 function(in<arg_0>: vec<F32>) -> vec<I32> {
1676 op_0: vec<Bool> = IsFinite in<arg_0>
1677 op_1: vec<U32> = ToBits in<arg_0>
1678 op_2: vec<U32> = And op_1, splat(0x7F800000_u32)
1679 op_3: vec<U32> = Shr op_2, splat(0x17_u32)
1680 op_4: vec<Bool> = CompareEq op_3, splat(0x0_u32)
1681 op_5: vec<Bool> = CompareNe in<arg_0>, in<arg_0>
1682 op_6: vec<I32> = Splat 0x80000001_i32
1683 op_7: vec<I32> = Splat 0x7FFFFFFF_i32
1684 op_8: vec<I32> = Select op_5, op_6, op_7
1685 op_9: vec<F32> = Mul in<arg_0>, splat(0x4B000000_f32)
1686 op_10: vec<U32> = ToBits op_9
1687 op_11: vec<U32> = And op_10, splat(0x7F800000_u32)
1688 op_12: vec<U32> = Shr op_11, splat(0x17_u32)
1689 op_13: vec<I32> = Cast op_12
1690 op_14: vec<I32> = Sub op_13, splat(0x7F_i32)
1691 op_15: vec<U32> = ToBits in<arg_0>
1692 op_16: vec<U32> = And op_15, splat(0x7F800000_u32)
1693 op_17: vec<U32> = Shr op_16, splat(0x17_u32)
1694 op_18: vec<I32> = Cast op_17
1695 op_19: vec<I32> = Sub op_18, splat(0x7F_i32)
1696 op_20: vec<I32> = Select op_0, op_19, op_8
1697 op_21: vec<Bool> = CompareEq in<arg_0>, splat(0x0_f32)
1698 op_22: vec<I32> = Splat 0x80000000_i32
1699 op_23: vec<I32> = Sub op_14, splat(0x17_i32)
1700 op_24: vec<I32> = Select op_21, op_22, op_23
1701 op_25: vec<I32> = Select op_4, op_24, op_20
1702 Return op_25
1703 }
1704 "
1705 );
1706 }
1707 }