IR works!
[vector-math.git] / src / ir.rs
1 use crate::{
2 f16::F16,
3 traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt},
4 };
5 use std::{
6 borrow::Borrow,
7 cell::{Cell, RefCell},
8 collections::HashMap,
9 fmt::{self, Write as _},
10 format,
11 ops::{
12 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
13 DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
14 SubAssign,
15 },
16 string::String,
17 vec::Vec,
18 };
19 use typed_arena::Arena;
20
21 macro_rules! make_enum {
22 (
23 $vis:vis enum $enum:ident {
24 $(
25 $(#[$meta:meta])*
26 $name:ident,
27 )*
28 }
29 ) => {
30 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
31 $vis enum $enum {
32 $(
33 $(#[$meta])*
34 $name,
35 )*
36 }
37
38 impl $enum {
39 $vis const fn as_str(self) -> &'static str {
40 match self {
41 $(
42 Self::$name => stringify!($name),
43 )*
44 }
45 }
46 }
47
48 impl fmt::Display for $enum {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.write_str(self.as_str())
51 }
52 }
53 };
54 }
55
56 make_enum! {
57 pub enum ScalarType {
58 Bool,
59 U8,
60 I8,
61 U16,
62 I16,
63 F16,
64 U32,
65 I32,
66 F32,
67 U64,
68 I64,
69 F64,
70 }
71 }
72
73 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
74 pub struct VectorType {
75 pub element: ScalarType,
76 }
77
78 impl fmt::Display for VectorType {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "vec<{}>", self.element)
81 }
82 }
83
84 impl From<ScalarType> for Type {
85 fn from(v: ScalarType) -> Self {
86 Type::Scalar(v)
87 }
88 }
89
90 impl From<VectorType> for Type {
91 fn from(v: VectorType) -> Self {
92 Type::Vector(v)
93 }
94 }
95
96 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
97 pub enum Type {
98 Scalar(ScalarType),
99 Vector(VectorType),
100 }
101
102 impl fmt::Display for Type {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 match self {
105 Type::Scalar(v) => v.fmt(f),
106 Type::Vector(v) => v.fmt(f),
107 }
108 }
109 }
110
111 #[derive(Clone, Copy, PartialEq, Eq, Hash)]
112 pub enum ScalarConstant {
113 Bool(bool),
114 U8(u8),
115 U16(u16),
116 U32(u32),
117 U64(u64),
118 I8(i8),
119 I16(i16),
120 I32(i32),
121 I64(i64),
122 F16 { bits: u16 },
123 F32 { bits: u32 },
124 F64 { bits: u64 },
125 }
126
127 macro_rules! make_scalar_constant_get {
128 ($ty:ident, $enumerant:ident) => {
129 pub fn $ty(self) -> Option<$ty> {
130 if let Self::$enumerant(v) = self {
131 Some(v)
132 } else {
133 None
134 }
135 }
136 };
137 }
138
139 macro_rules! make_scalar_constant_from {
140 ($ty:ident, $enumerant:ident) => {
141 impl From<$ty> for ScalarConstant {
142 fn from(v: $ty) -> Self {
143 Self::$enumerant(v)
144 }
145 }
146 impl From<$ty> for Constant {
147 fn from(v: $ty) -> Self {
148 Self::Scalar(v.into())
149 }
150 }
151 impl From<$ty> for Value<'_> {
152 fn from(v: $ty) -> Self {
153 Self::Constant(v.into())
154 }
155 }
156 };
157 }
158
159 make_scalar_constant_from!(bool, Bool);
160 make_scalar_constant_from!(u8, U8);
161 make_scalar_constant_from!(u16, U16);
162 make_scalar_constant_from!(u32, U32);
163 make_scalar_constant_from!(u64, U64);
164 make_scalar_constant_from!(i8, I8);
165 make_scalar_constant_from!(i16, I16);
166 make_scalar_constant_from!(i32, I32);
167 make_scalar_constant_from!(i64, I64);
168
169 impl ScalarConstant {
170 pub const fn ty(self) -> ScalarType {
171 match self {
172 ScalarConstant::Bool(_) => ScalarType::Bool,
173 ScalarConstant::U8(_) => ScalarType::U8,
174 ScalarConstant::U16(_) => ScalarType::U16,
175 ScalarConstant::U32(_) => ScalarType::U32,
176 ScalarConstant::U64(_) => ScalarType::U64,
177 ScalarConstant::I8(_) => ScalarType::I8,
178 ScalarConstant::I16(_) => ScalarType::I16,
179 ScalarConstant::I32(_) => ScalarType::I32,
180 ScalarConstant::I64(_) => ScalarType::I64,
181 ScalarConstant::F16 { .. } => ScalarType::F16,
182 ScalarConstant::F32 { .. } => ScalarType::F32,
183 ScalarConstant::F64 { .. } => ScalarType::F64,
184 }
185 }
186 pub const fn from_f16_bits(bits: u16) -> Self {
187 Self::F16 { bits }
188 }
189 pub const fn from_f32_bits(bits: u32) -> Self {
190 Self::F32 { bits }
191 }
192 pub const fn from_f64_bits(bits: u64) -> Self {
193 Self::F64 { bits }
194 }
195 pub const fn f16_bits(self) -> Option<u16> {
196 if let Self::F16 { bits } = self {
197 Some(bits)
198 } else {
199 None
200 }
201 }
202 pub const fn f32_bits(self) -> Option<u32> {
203 if let Self::F32 { bits } = self {
204 Some(bits)
205 } else {
206 None
207 }
208 }
209 pub const fn f64_bits(self) -> Option<u64> {
210 if let Self::F64 { bits } = self {
211 Some(bits)
212 } else {
213 None
214 }
215 }
216 make_scalar_constant_get!(bool, Bool);
217 make_scalar_constant_get!(u8, U8);
218 make_scalar_constant_get!(u16, U16);
219 make_scalar_constant_get!(u32, U32);
220 make_scalar_constant_get!(u64, U64);
221 make_scalar_constant_get!(i8, I8);
222 make_scalar_constant_get!(i16, I16);
223 make_scalar_constant_get!(i32, I32);
224 make_scalar_constant_get!(i64, I64);
225 }
226
227 impl fmt::Display for ScalarConstant {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 match self {
230 ScalarConstant::Bool(false) => write!(f, "false"),
231 ScalarConstant::Bool(true) => write!(f, "true"),
232 ScalarConstant::U8(v) => write!(f, "{}_u8", v),
233 ScalarConstant::U16(v) => write!(f, "{}_u16", v),
234 ScalarConstant::U32(v) => write!(f, "{}_u32", v),
235 ScalarConstant::U64(v) => write!(f, "{}_u64", v),
236 ScalarConstant::I8(v) => write!(f, "{}_i8", v),
237 ScalarConstant::I16(v) => write!(f, "{}_i16", v),
238 ScalarConstant::I32(v) => write!(f, "{}_i32", v),
239 ScalarConstant::I64(v) => write!(f, "{}_i64", v),
240 ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits),
241 ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits),
242 ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits),
243 }
244 }
245 }
246
247 impl fmt::Debug for ScalarConstant {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 fmt::Display::fmt(self, f)
250 }
251 }
252
253 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
254 pub struct VectorSplatConstant {
255 pub element: ScalarConstant,
256 }
257
258 impl VectorSplatConstant {
259 pub const fn ty(self) -> VectorType {
260 VectorType {
261 element: self.element.ty(),
262 }
263 }
264 }
265
266 impl fmt::Display for VectorSplatConstant {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 write!(f, "splat({})", self.element)
269 }
270 }
271
272 impl From<ScalarConstant> for Constant {
273 fn from(v: ScalarConstant) -> Self {
274 Constant::Scalar(v)
275 }
276 }
277
278 impl From<VectorSplatConstant> for Constant {
279 fn from(v: VectorSplatConstant) -> Self {
280 Constant::VectorSplat(v)
281 }
282 }
283
284 impl From<ScalarConstant> for Value<'_> {
285 fn from(v: ScalarConstant) -> Self {
286 Value::Constant(v.into())
287 }
288 }
289
290 impl From<VectorSplatConstant> for Value<'_> {
291 fn from(v: VectorSplatConstant) -> Self {
292 Value::Constant(v.into())
293 }
294 }
295
296 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
297 pub enum Constant {
298 Scalar(ScalarConstant),
299 VectorSplat(VectorSplatConstant),
300 }
301
302 impl Constant {
303 pub const fn ty(self) -> Type {
304 match self {
305 Constant::Scalar(v) => Type::Scalar(v.ty()),
306 Constant::VectorSplat(v) => Type::Vector(v.ty()),
307 }
308 }
309 }
310
311 impl fmt::Display for Constant {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 match self {
314 Constant::Scalar(v) => v.fmt(f),
315 Constant::VectorSplat(v) => v.fmt(f),
316 }
317 }
318 }
319
320 #[derive(Debug)]
321 pub struct Input<'ctx> {
322 pub name: &'ctx str,
323 pub ty: Type,
324 }
325
326 impl fmt::Display for Input<'_> {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 write!(f, "in<{}>", self.name)
329 }
330 }
331
332 #[derive(Copy, Clone)]
333 pub enum Value<'ctx> {
334 Input(&'ctx Input<'ctx>),
335 Constant(Constant),
336 OpResult(&'ctx Operation<'ctx>),
337 }
338
339 impl<'ctx> Value<'ctx> {
340 pub const fn ty(self) -> Type {
341 match self {
342 Value::Input(v) => v.ty,
343 Value::Constant(v) => v.ty(),
344 Value::OpResult(v) => v.result_type,
345 }
346 }
347 }
348
349 impl fmt::Debug for Value<'_> {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 match self {
352 Value::Input(v) => v.fmt(f),
353 Value::Constant(v) => v.fmt(f),
354 Value::OpResult(v) => v.result_id.fmt(f),
355 }
356 }
357 }
358
359 impl fmt::Display for Value<'_> {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 match self {
362 Value::Input(v) => v.fmt(f),
363 Value::Constant(v) => v.fmt(f),
364 Value::OpResult(v) => v.result_id.fmt(f),
365 }
366 }
367 }
368
369 impl<'ctx> From<&'ctx Input<'ctx>> for Value<'ctx> {
370 fn from(v: &'ctx Input<'ctx>) -> Self {
371 Value::Input(v)
372 }
373 }
374
375 impl<'ctx> From<&'ctx Operation<'ctx>> for Value<'ctx> {
376 fn from(v: &'ctx Operation<'ctx>) -> Self {
377 Value::OpResult(v)
378 }
379 }
380
381 impl<'ctx> From<Constant> for Value<'ctx> {
382 fn from(v: Constant) -> Self {
383 Value::Constant(v)
384 }
385 }
386
387 make_enum! {
388 pub enum Opcode {
389 Add,
390 Sub,
391 Mul,
392 Div,
393 Rem,
394 Fma,
395 Cast,
396 And,
397 Or,
398 Xor,
399 Not,
400 Shl,
401 Shr,
402 Neg,
403 Abs,
404 Trunc,
405 Ceil,
406 Floor,
407 Round,
408 IsInfinite,
409 IsFinite,
410 ToBits,
411 FromBits,
412 Splat,
413 CompareEq,
414 CompareNe,
415 CompareLt,
416 CompareLe,
417 CompareGt,
418 CompareGe,
419 Select,
420 }
421 }
422
423 #[derive(Debug)]
424 pub struct Operation<'ctx> {
425 pub opcode: Opcode,
426 pub arguments: Vec<Value<'ctx>>,
427 pub result_type: Type,
428 pub result_id: OperationId,
429 }
430
431 impl fmt::Display for Operation<'_> {
432 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
433 write!(
434 f,
435 "{}: {} = {}",
436 self.result_id, self.result_type, self.opcode
437 )?;
438 let mut separator = " ";
439 for i in &self.arguments {
440 write!(f, "{}{}", separator, i)?;
441 separator = ", ";
442 }
443 Ok(())
444 }
445 }
446
447 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
448 pub struct OperationId(pub u64);
449
450 impl fmt::Display for OperationId {
451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 write!(f, "op_{}", self.0)
453 }
454 }
455
456 #[derive(Default)]
457 pub struct IrContext<'ctx> {
458 bytes_arena: Arena<u8>,
459 inputs_arena: Arena<Input<'ctx>>,
460 inputs: RefCell<HashMap<&'ctx str, &'ctx Input<'ctx>>>,
461 operations_arena: Arena<Operation<'ctx>>,
462 operations: RefCell<Vec<&'ctx Operation<'ctx>>>,
463 next_operation_result_id: Cell<u64>,
464 }
465
466 impl fmt::Debug for IrContext<'_> {
467 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468 f.write_str("IrContext { .. }")
469 }
470 }
471
472 impl<'ctx> IrContext<'ctx> {
473 pub fn new() -> Self {
474 Self::default()
475 }
476 pub fn make_input<N: Borrow<str> + Into<String>, T: Into<Type>>(
477 &'ctx self,
478 name: N,
479 ty: T,
480 ) -> &'ctx Input<'ctx> {
481 let mut inputs = self.inputs.borrow_mut();
482 let name_str = name.borrow();
483 let ty = ty.into();
484 if !name_str.is_empty() && !inputs.contains_key(name_str) {
485 let name = self.bytes_arena.alloc_str(name_str);
486 let input = self.inputs_arena.alloc(Input { name, ty });
487 inputs.insert(name, input);
488 return input;
489 }
490 let mut name: String = name.into();
491 if name.is_empty() {
492 name = "in".into();
493 }
494 let name_len = name.len();
495 let mut tag = 2usize;
496 loop {
497 name.truncate(name_len);
498 write!(name, "_{}", tag).unwrap();
499 if !inputs.contains_key(&*name) {
500 let name = self.bytes_arena.alloc_str(&name);
501 let input = self.inputs_arena.alloc(Input { name, ty });
502 inputs.insert(name, input);
503 return input;
504 }
505 tag += 1;
506 }
507 }
508 pub fn make_operation<A: Into<Vec<Value<'ctx>>>, T: Into<Type>>(
509 &'ctx self,
510 opcode: Opcode,
511 arguments: A,
512 result_type: T,
513 ) -> &'ctx Operation<'ctx> {
514 let arguments = arguments.into();
515 let result_type = result_type.into();
516 let result_id = OperationId(self.next_operation_result_id.get());
517 self.next_operation_result_id.set(result_id.0 + 1);
518 let operation = self.operations_arena.alloc(Operation {
519 opcode,
520 arguments,
521 result_type,
522 result_id,
523 });
524 self.operations.borrow_mut().push(operation);
525 operation
526 }
527 pub fn replace_operations(
528 &'ctx self,
529 new_operations: Vec<&'ctx Operation<'ctx>>,
530 ) -> Vec<&'ctx Operation<'ctx>> {
531 self.operations.replace(new_operations)
532 }
533 }
534
535 #[derive(Debug)]
536 pub struct IrFunction<'ctx> {
537 pub inputs: Vec<&'ctx Input<'ctx>>,
538 pub operations: Vec<&'ctx Operation<'ctx>>,
539 pub outputs: Vec<Value<'ctx>>,
540 }
541
542 impl fmt::Display for IrFunction<'_> {
543 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
544 write!(f, "function(")?;
545 let mut first = true;
546 for input in &self.inputs {
547 if first {
548 first = false
549 } else {
550 write!(f, ", ")?;
551 }
552 write!(f, "{}: {}", input, input.ty)?;
553 }
554 match self.outputs.len() {
555 0 => writeln!(f, ") {{")?,
556 1 => writeln!(f, ") -> {} {{", self.outputs[0].ty())?,
557 _ => {
558 write!(f, ") -> ({}", self.outputs[0].ty())?;
559 for output in self.outputs.iter().skip(1) {
560 write!(f, ", {}", output.ty())?;
561 }
562 writeln!(f, ") {{")?;
563 }
564 }
565 for operation in &self.operations {
566 writeln!(f, " {}", operation)?;
567 }
568 match self.outputs.len() {
569 0 => writeln!(f, "}}")?,
570 1 => writeln!(f, " Return {}\n}}", self.outputs[0])?,
571 _ => {
572 write!(f, " Return {}", self.outputs[0])?;
573 for output in self.outputs.iter().skip(1) {
574 write!(f, ", {}", output)?;
575 }
576 writeln!(f, "\n}}")?;
577 }
578 }
579 Ok(())
580 }
581 }
582
583 impl<'ctx> IrFunction<'ctx> {
584 pub fn make<F: IrFunctionMaker<'ctx>>(ctx: &'ctx IrContext<'ctx>, f: F) -> Self {
585 let old_operations = ctx.replace_operations(Vec::new());
586 let (v, inputs) = F::make_inputs(ctx);
587 let outputs = f.call(ctx, v).outputs_to_vec();
588 let operations = ctx.replace_operations(old_operations);
589 Self {
590 inputs,
591 operations,
592 outputs,
593 }
594 }
595 }
596
597 pub trait IrFunctionMaker<'ctx>: Sized {
598 type Inputs;
599 type Outputs: IrFunctionMakerOutputs<'ctx>;
600 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs;
601 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>);
602 }
603
604 pub trait IrFunctionMakerOutputs<'ctx> {
605 fn outputs_to_vec(self) -> Vec<Value<'ctx>>;
606 }
607
608 impl<'ctx, T: IrValue<'ctx>> IrFunctionMakerOutputs<'ctx> for T {
609 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
610 [self.value()].into()
611 }
612 }
613
614 impl<'ctx> IrFunctionMakerOutputs<'ctx> for () {
615 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
616 Vec::new()
617 }
618 }
619
620 impl<'ctx, R: IrFunctionMakerOutputs<'ctx>> IrFunctionMaker<'ctx>
621 for fn(&'ctx IrContext<'ctx>) -> R
622 {
623 type Inputs = ();
624 type Outputs = R;
625 fn call(self, ctx: &'ctx IrContext<'ctx>, _inputs: Self::Inputs) -> Self::Outputs {
626 self(ctx)
627 }
628 fn make_inputs(_ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
629 ((), Vec::new())
630 }
631 }
632
633 macro_rules! impl_ir_function_maker_io {
634 () => {};
635 ($first_arg:ident: $first_arg_ty:ident, $($arg:ident: $arg_ty:ident,)*) => {
636 impl<'ctx, $first_arg_ty, $($arg_ty,)* R> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>, $first_arg_ty $(, $arg_ty)*) -> R
637 where
638 $first_arg_ty: IrValue<'ctx>,
639 $($arg_ty: IrValue<'ctx>,)*
640 R: IrFunctionMakerOutputs<'ctx>,
641 {
642 type Inputs = ($first_arg_ty, $($arg_ty,)*);
643 type Outputs = R;
644 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs {
645 let ($first_arg, $($arg,)*) = inputs;
646 self(ctx, $first_arg$(, $arg)*)
647 }
648 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
649 let mut $first_arg = String::new();
650 $(let mut $arg = String::new();)*
651 for (index, arg) in [&mut $first_arg $(, &mut $arg)*].iter_mut().enumerate() {
652 **arg = format!("arg_{}", index);
653 }
654 let $first_arg = $first_arg_ty::make_input(ctx, $first_arg);
655 $(let $arg = $arg_ty::make_input(ctx, $arg);)*
656 (($first_arg.0, $($arg.0,)*), [$first_arg.1 $(, $arg.1)*].into())
657 }
658 }
659 impl<'ctx, $first_arg_ty, $($arg_ty),*> IrFunctionMakerOutputs<'ctx> for ($first_arg_ty, $($arg_ty,)*)
660 where
661 $first_arg_ty: IrValue<'ctx>,
662 $($arg_ty: IrValue<'ctx>,)*
663 {
664 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
665 let ($first_arg, $($arg,)*) = self;
666 [$first_arg.value() $(, $arg.value())*].into()
667 }
668 }
669 impl_ir_function_maker_io!($($arg: $arg_ty,)*);
670 };
671 }
672
673 impl_ir_function_maker_io!(
674 in0: In0,
675 in1: In1,
676 in2: In2,
677 in3: In3,
678 in4: In4,
679 in5: In5,
680 in6: In6,
681 in7: In7,
682 in8: In8,
683 in9: In9,
684 in10: In10,
685 in11: In11,
686 );
687
688 pub trait IrValue<'ctx>: Copy {
689 const TYPE: Type;
690 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self;
691 fn make_input<N: Borrow<str> + Into<String>>(
692 ctx: &'ctx IrContext<'ctx>,
693 name: N,
694 ) -> (Self, &'ctx Input<'ctx>) {
695 let input = ctx.make_input(name, Self::TYPE);
696 (Self::new(ctx, input.into()), input)
697 }
698 fn ctx(self) -> &'ctx IrContext<'ctx>;
699 fn value(self) -> Value<'ctx>;
700 }
701
702 macro_rules! ir_value {
703 ($name:ident, $vec_name:ident, TYPE = $scalar_type:ident, fn make($make_var:ident: $prim:ident) {$make:expr}) => {
704 #[derive(Clone, Copy, Debug)]
705 pub struct $name<'ctx> {
706 pub value: Value<'ctx>,
707 pub ctx: &'ctx IrContext<'ctx>,
708 }
709
710 impl<'ctx> IrValue<'ctx> for $name<'ctx> {
711 const TYPE: Type = Type::Scalar(Self::SCALAR_TYPE);
712 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
713 assert_eq!(value.ty(), Self::TYPE);
714 Self { ctx, value }
715 }
716 fn ctx(self) -> &'ctx IrContext<'ctx> {
717 self.ctx
718 }
719 fn value(self) -> Value<'ctx> {
720 self.value
721 }
722 }
723
724 impl<'ctx> $name<'ctx> {
725 pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type;
726 }
727
728 impl<'ctx> Make<&'ctx IrContext<'ctx>> for $name<'ctx> {
729 type Prim = $prim;
730
731 fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self {
732 let value: ScalarConstant = $make;
733 let value = value.into();
734 Self { value, ctx }
735 }
736 }
737
738 #[derive(Clone, Copy, Debug)]
739 pub struct $vec_name<'ctx> {
740 pub value: Value<'ctx>,
741 pub ctx: &'ctx IrContext<'ctx>,
742 }
743
744 impl<'ctx> IrValue<'ctx> for $vec_name<'ctx> {
745 const TYPE: Type = Type::Vector(Self::VECTOR_TYPE);
746 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
747 assert_eq!(value.ty(), Self::TYPE);
748 Self { ctx, value }
749 }
750 fn ctx(self) -> &'ctx IrContext<'ctx> {
751 self.ctx
752 }
753 fn value(self) -> Value<'ctx> {
754 self.value
755 }
756 }
757
758 impl<'ctx> $vec_name<'ctx> {
759 pub const VECTOR_TYPE: VectorType = VectorType {
760 element: ScalarType::$scalar_type,
761 };
762 }
763
764 impl<'ctx> Make<&'ctx IrContext<'ctx>> for $vec_name<'ctx> {
765 type Prim = $prim;
766
767 fn make(ctx: &'ctx IrContext<'ctx>, $make_var: Self::Prim) -> Self {
768 let element = $make;
769 Self {
770 value: VectorSplatConstant { element }.into(),
771 ctx,
772 }
773 }
774 }
775
776 impl<'ctx> Select<$name<'ctx>> for IrBool<'ctx> {
777 fn select(self, true_v: $name<'ctx>, false_v: $name<'ctx>) -> $name<'ctx> {
778 let value = self
779 .ctx
780 .make_operation(
781 Opcode::Select,
782 [self.value, true_v.value, false_v.value],
783 $name::TYPE,
784 )
785 .into();
786 $name {
787 value,
788 ctx: self.ctx,
789 }
790 }
791 }
792
793 impl<'ctx> Select<$vec_name<'ctx>> for IrVecBool<'ctx> {
794 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
795 let value = self
796 .ctx
797 .make_operation(
798 Opcode::Select,
799 [self.value, true_v.value, false_v.value],
800 $vec_name::TYPE,
801 )
802 .into();
803 $vec_name {
804 value,
805 ctx: self.ctx,
806 }
807 }
808 }
809
810 impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
811 fn from(v: $name<'ctx>) -> Self {
812 let value = v
813 .ctx
814 .make_operation(Opcode::Splat, [v.value], $vec_name::TYPE)
815 .into();
816 Self { value, ctx: v.ctx }
817 }
818 }
819 };
820 }
821
822 macro_rules! impl_bit_ops {
823 ($ty:ident) => {
824 impl<'ctx> BitAnd for $ty<'ctx> {
825 type Output = Self;
826
827 fn bitand(self, rhs: Self) -> Self::Output {
828 let value = self
829 .ctx
830 .make_operation(Opcode::And, [self.value, rhs.value], Self::TYPE)
831 .into();
832 Self {
833 value,
834 ctx: self.ctx,
835 }
836 }
837 }
838 impl<'ctx> BitOr for $ty<'ctx> {
839 type Output = Self;
840
841 fn bitor(self, rhs: Self) -> Self::Output {
842 let value = self
843 .ctx
844 .make_operation(Opcode::Or, [self.value, rhs.value], Self::TYPE)
845 .into();
846 Self {
847 value,
848 ctx: self.ctx,
849 }
850 }
851 }
852 impl<'ctx> BitXor for $ty<'ctx> {
853 type Output = Self;
854
855 fn bitxor(self, rhs: Self) -> Self::Output {
856 let value = self
857 .ctx
858 .make_operation(Opcode::Xor, [self.value, rhs.value], Self::TYPE)
859 .into();
860 Self {
861 value,
862 ctx: self.ctx,
863 }
864 }
865 }
866 impl<'ctx> Not for $ty<'ctx> {
867 type Output = Self;
868
869 fn not(self) -> Self::Output {
870 let value = self
871 .ctx
872 .make_operation(Opcode::Not, [self.value], Self::TYPE)
873 .into();
874 Self {
875 value,
876 ctx: self.ctx,
877 }
878 }
879 }
880 impl<'ctx> BitAndAssign for $ty<'ctx> {
881 fn bitand_assign(&mut self, rhs: Self) {
882 *self = *self & rhs;
883 }
884 }
885 impl<'ctx> BitOrAssign for $ty<'ctx> {
886 fn bitor_assign(&mut self, rhs: Self) {
887 *self = *self | rhs;
888 }
889 }
890 impl<'ctx> BitXorAssign for $ty<'ctx> {
891 fn bitxor_assign(&mut self, rhs: Self) {
892 *self = *self ^ rhs;
893 }
894 }
895 };
896 }
897
898 macro_rules! impl_number_ops {
899 ($ty:ident, $bool:ident) => {
900 impl<'ctx> Add for $ty<'ctx> {
901 type Output = Self;
902
903 fn add(self, rhs: Self) -> Self::Output {
904 let value = self
905 .ctx
906 .make_operation(Opcode::Add, [self.value, rhs.value], Self::TYPE)
907 .into();
908 Self {
909 value,
910 ctx: self.ctx,
911 }
912 }
913 }
914 impl<'ctx> Sub for $ty<'ctx> {
915 type Output = Self;
916
917 fn sub(self, rhs: Self) -> Self::Output {
918 let value = self
919 .ctx
920 .make_operation(Opcode::Sub, [self.value, rhs.value], Self::TYPE)
921 .into();
922 Self {
923 value,
924 ctx: self.ctx,
925 }
926 }
927 }
928 impl<'ctx> Mul for $ty<'ctx> {
929 type Output = Self;
930
931 fn mul(self, rhs: Self) -> Self::Output {
932 let value = self
933 .ctx
934 .make_operation(Opcode::Mul, [self.value, rhs.value], Self::TYPE)
935 .into();
936 Self {
937 value,
938 ctx: self.ctx,
939 }
940 }
941 }
942 impl<'ctx> Div for $ty<'ctx> {
943 type Output = Self;
944
945 fn div(self, rhs: Self) -> Self::Output {
946 let value = self
947 .ctx
948 .make_operation(Opcode::Div, [self.value, rhs.value], Self::TYPE)
949 .into();
950 Self {
951 value,
952 ctx: self.ctx,
953 }
954 }
955 }
956 impl<'ctx> Rem for $ty<'ctx> {
957 type Output = Self;
958
959 fn rem(self, rhs: Self) -> Self::Output {
960 let value = self
961 .ctx
962 .make_operation(Opcode::Rem, [self.value, rhs.value], Self::TYPE)
963 .into();
964 Self {
965 value,
966 ctx: self.ctx,
967 }
968 }
969 }
970 impl<'ctx> AddAssign for $ty<'ctx> {
971 fn add_assign(&mut self, rhs: Self) {
972 *self = *self + rhs;
973 }
974 }
975 impl<'ctx> SubAssign for $ty<'ctx> {
976 fn sub_assign(&mut self, rhs: Self) {
977 *self = *self - rhs;
978 }
979 }
980 impl<'ctx> MulAssign for $ty<'ctx> {
981 fn mul_assign(&mut self, rhs: Self) {
982 *self = *self * rhs;
983 }
984 }
985 impl<'ctx> DivAssign for $ty<'ctx> {
986 fn div_assign(&mut self, rhs: Self) {
987 *self = *self / rhs;
988 }
989 }
990 impl<'ctx> RemAssign for $ty<'ctx> {
991 fn rem_assign(&mut self, rhs: Self) {
992 *self = *self % rhs;
993 }
994 }
995 impl<'ctx> Compare for $ty<'ctx> {
996 type Bool = $bool<'ctx>;
997 fn eq(self, rhs: Self) -> Self::Bool {
998 let value = self
999 .ctx
1000 .make_operation(Opcode::CompareEq, [self.value, rhs.value], $bool::TYPE)
1001 .into();
1002 $bool {
1003 value,
1004 ctx: self.ctx,
1005 }
1006 }
1007 fn ne(self, rhs: Self) -> Self::Bool {
1008 let value = self
1009 .ctx
1010 .make_operation(Opcode::CompareNe, [self.value, rhs.value], $bool::TYPE)
1011 .into();
1012 $bool {
1013 value,
1014 ctx: self.ctx,
1015 }
1016 }
1017 fn lt(self, rhs: Self) -> Self::Bool {
1018 let value = self
1019 .ctx
1020 .make_operation(Opcode::CompareLt, [self.value, rhs.value], $bool::TYPE)
1021 .into();
1022 $bool {
1023 value,
1024 ctx: self.ctx,
1025 }
1026 }
1027 fn gt(self, rhs: Self) -> Self::Bool {
1028 let value = self
1029 .ctx
1030 .make_operation(Opcode::CompareGt, [self.value, rhs.value], $bool::TYPE)
1031 .into();
1032 $bool {
1033 value,
1034 ctx: self.ctx,
1035 }
1036 }
1037 fn le(self, rhs: Self) -> Self::Bool {
1038 let value = self
1039 .ctx
1040 .make_operation(Opcode::CompareLe, [self.value, rhs.value], $bool::TYPE)
1041 .into();
1042 $bool {
1043 value,
1044 ctx: self.ctx,
1045 }
1046 }
1047 fn ge(self, rhs: Self) -> Self::Bool {
1048 let value = self
1049 .ctx
1050 .make_operation(Opcode::CompareGe, [self.value, rhs.value], $bool::TYPE)
1051 .into();
1052 $bool {
1053 value,
1054 ctx: self.ctx,
1055 }
1056 }
1057 }
1058 };
1059 }
1060
1061 macro_rules! impl_shift_ops {
1062 ($ty:ident, $rhs:ident) => {
1063 impl<'ctx> Shl<$rhs<'ctx>> for $ty<'ctx> {
1064 type Output = Self;
1065
1066 fn shl(self, rhs: $rhs<'ctx>) -> Self::Output {
1067 let value = self
1068 .ctx
1069 .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
1070 .into();
1071 Self {
1072 value,
1073 ctx: self.ctx,
1074 }
1075 }
1076 }
1077 impl<'ctx> Shr<$rhs<'ctx>> for $ty<'ctx> {
1078 type Output = Self;
1079
1080 fn shr(self, rhs: $rhs<'ctx>) -> Self::Output {
1081 let value = self
1082 .ctx
1083 .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
1084 .into();
1085 Self {
1086 value,
1087 ctx: self.ctx,
1088 }
1089 }
1090 }
1091 impl<'ctx> ShlAssign<$rhs<'ctx>> for $ty<'ctx> {
1092 fn shl_assign(&mut self, rhs: $rhs<'ctx>) {
1093 *self = *self << rhs;
1094 }
1095 }
1096 impl<'ctx> ShrAssign<$rhs<'ctx>> for $ty<'ctx> {
1097 fn shr_assign(&mut self, rhs: $rhs<'ctx>) {
1098 *self = *self >> rhs;
1099 }
1100 }
1101 };
1102 }
1103
1104 macro_rules! impl_neg {
1105 ($ty:ident) => {
1106 impl<'ctx> Neg for $ty<'ctx> {
1107 type Output = Self;
1108
1109 fn neg(self) -> Self::Output {
1110 let value = self
1111 .ctx
1112 .make_operation(Opcode::Neg, [self.value], Self::TYPE)
1113 .into();
1114 Self {
1115 value,
1116 ctx: self.ctx,
1117 }
1118 }
1119 }
1120 };
1121 }
1122
1123 macro_rules! impl_integer_ops {
1124 ($scalar:ident, $vec:ident) => {
1125 impl_bit_ops!($scalar);
1126 impl_number_ops!($scalar, IrBool);
1127 impl_shift_ops!($scalar, IrU32);
1128 impl_bit_ops!($vec);
1129 impl_number_ops!($vec, IrVecBool);
1130 impl_shift_ops!($vec, IrVecU32);
1131
1132 impl<'ctx> Int<IrU32<'ctx>> for $scalar<'ctx> {}
1133 impl<'ctx> Int<IrVecU32<'ctx>> for $vec<'ctx> {}
1134 };
1135 }
1136
1137 macro_rules! impl_uint_ops {
1138 ($scalar:ident, $vec:ident) => {
1139 impl_integer_ops!($scalar, $vec);
1140
1141 impl<'ctx> UInt<IrU32<'ctx>> for $scalar<'ctx> {}
1142 impl<'ctx> UInt<IrVecU32<'ctx>> for $vec<'ctx> {}
1143 };
1144 }
1145
1146 impl_uint_ops!(IrU8, IrVecU8);
1147 impl_uint_ops!(IrU16, IrVecU16);
1148 impl_uint_ops!(IrU32, IrVecU32);
1149 impl_uint_ops!(IrU64, IrVecU64);
1150
1151 macro_rules! impl_sint_ops {
1152 ($scalar:ident, $vec:ident) => {
1153 impl_integer_ops!($scalar, $vec);
1154 impl_neg!($scalar);
1155 impl_neg!($vec);
1156
1157 impl<'ctx> SInt<IrU32<'ctx>> for $scalar<'ctx> {}
1158 impl<'ctx> SInt<IrVecU32<'ctx>> for $vec<'ctx> {}
1159 };
1160 }
1161
1162 impl_sint_ops!(IrI8, IrVecI8);
1163 impl_sint_ops!(IrI16, IrVecI16);
1164 impl_sint_ops!(IrI32, IrVecI32);
1165 impl_sint_ops!(IrI64, IrVecI64);
1166
1167 macro_rules! impl_float {
1168 ($float:ident, $bits:ident, $u32:ident) => {
1169 impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> {
1170 type BitsType = $bits<'ctx>;
1171 fn abs(self) -> Self {
1172 let value = self
1173 .ctx
1174 .make_operation(Opcode::Abs, [self.value], Self::TYPE)
1175 .into();
1176 Self {
1177 value,
1178 ctx: self.ctx,
1179 }
1180 }
1181 fn trunc(self) -> Self {
1182 let value = self
1183 .ctx
1184 .make_operation(Opcode::Trunc, [self.value], Self::TYPE)
1185 .into();
1186 Self {
1187 value,
1188 ctx: self.ctx,
1189 }
1190 }
1191 fn ceil(self) -> Self {
1192 let value = self
1193 .ctx
1194 .make_operation(Opcode::Ceil, [self.value], Self::TYPE)
1195 .into();
1196 Self {
1197 value,
1198 ctx: self.ctx,
1199 }
1200 }
1201 fn floor(self) -> Self {
1202 let value = self
1203 .ctx
1204 .make_operation(Opcode::Floor, [self.value], Self::TYPE)
1205 .into();
1206 Self {
1207 value,
1208 ctx: self.ctx,
1209 }
1210 }
1211 fn round(self) -> Self {
1212 let value = self
1213 .ctx
1214 .make_operation(Opcode::Round, [self.value], Self::TYPE)
1215 .into();
1216 Self {
1217 value,
1218 ctx: self.ctx,
1219 }
1220 }
1221 #[cfg(feature = "fma")]
1222 fn fma(self, a: Self, b: Self) -> Self {
1223 let value = self
1224 .ctx
1225 .make_operation(Opcode::Fma, [self.value, a.value, b.value], Self::TYPE)
1226 .into();
1227 Self {
1228 value,
1229 ctx: self.ctx,
1230 }
1231 }
1232 fn is_nan(self) -> Self::Bool {
1233 let value = self
1234 .ctx
1235 .make_operation(
1236 Opcode::CompareNe,
1237 [self.value, self.value],
1238 Self::Bool::TYPE,
1239 )
1240 .into();
1241 Self::Bool {
1242 value,
1243 ctx: self.ctx,
1244 }
1245 }
1246 fn is_infinite(self) -> Self::Bool {
1247 let value = self
1248 .ctx
1249 .make_operation(Opcode::IsInfinite, [self.value], Self::Bool::TYPE)
1250 .into();
1251 Self::Bool {
1252 value,
1253 ctx: self.ctx,
1254 }
1255 }
1256 fn is_finite(self) -> Self::Bool {
1257 let value = self
1258 .ctx
1259 .make_operation(Opcode::IsFinite, [self.value], Self::Bool::TYPE)
1260 .into();
1261 Self::Bool {
1262 value,
1263 ctx: self.ctx,
1264 }
1265 }
1266 fn from_bits(v: Self::BitsType) -> Self {
1267 let value = v
1268 .ctx
1269 .make_operation(Opcode::FromBits, [v.value], Self::TYPE)
1270 .into();
1271 Self { value, ctx: v.ctx }
1272 }
1273 fn to_bits(self) -> Self::BitsType {
1274 let value = self
1275 .ctx
1276 .make_operation(Opcode::ToBits, [self.value], Self::BitsType::TYPE)
1277 .into();
1278 Self::BitsType {
1279 value,
1280 ctx: self.ctx,
1281 }
1282 }
1283 }
1284 };
1285 }
1286
1287 macro_rules! impl_float_ops {
1288 ($scalar:ident, $scalar_bits:ident, $vec:ident, $vec_bits:ident) => {
1289 impl_number_ops!($scalar, IrBool);
1290 impl_number_ops!($vec, IrVecBool);
1291 impl_neg!($scalar);
1292 impl_neg!($vec);
1293 impl_float!($scalar, $scalar_bits, IrU32);
1294 impl_float!($vec, $vec_bits, IrVecU32);
1295 };
1296 }
1297
1298 impl_float_ops!(IrF16, IrU16, IrVecF16, IrVecU16);
1299 impl_float_ops!(IrF32, IrU32, IrVecF32, IrVecU32);
1300 impl_float_ops!(IrF64, IrU64, IrVecF64, IrVecU64);
1301
1302 ir_value!(
1303 IrBool,
1304 IrVecBool,
1305 TYPE = Bool,
1306 fn make(v: bool) {
1307 v.into()
1308 }
1309 );
1310
1311 impl<'ctx> Bool for IrBool<'ctx> {}
1312 impl<'ctx> Bool for IrVecBool<'ctx> {}
1313
1314 impl_bit_ops!(IrBool);
1315 impl_bit_ops!(IrVecBool);
1316
1317 ir_value!(
1318 IrU8,
1319 IrVecU8,
1320 TYPE = U8,
1321 fn make(v: u8) {
1322 v.into()
1323 }
1324 );
1325 ir_value!(
1326 IrU16,
1327 IrVecU16,
1328 TYPE = U16,
1329 fn make(v: u16) {
1330 v.into()
1331 }
1332 );
1333 ir_value!(
1334 IrU32,
1335 IrVecU32,
1336 TYPE = U32,
1337 fn make(v: u32) {
1338 v.into()
1339 }
1340 );
1341 ir_value!(
1342 IrU64,
1343 IrVecU64,
1344 TYPE = U64,
1345 fn make(v: u64) {
1346 v.into()
1347 }
1348 );
1349 ir_value!(
1350 IrI8,
1351 IrVecI8,
1352 TYPE = I8,
1353 fn make(v: i8) {
1354 v.into()
1355 }
1356 );
1357 ir_value!(
1358 IrI16,
1359 IrVecI16,
1360 TYPE = I16,
1361 fn make(v: i16) {
1362 v.into()
1363 }
1364 );
1365 ir_value!(
1366 IrI32,
1367 IrVecI32,
1368 TYPE = I32,
1369 fn make(v: i32) {
1370 v.into()
1371 }
1372 );
1373 ir_value!(
1374 IrI64,
1375 IrVecI64,
1376 TYPE = I64,
1377 fn make(v: i64) {
1378 v.into()
1379 }
1380 );
1381 ir_value!(
1382 IrF16,
1383 IrVecF16,
1384 TYPE = F16,
1385 fn make(v: F16) {
1386 ScalarConstant::from_f16_bits(v.to_bits())
1387 }
1388 );
1389 ir_value!(
1390 IrF32,
1391 IrVecF32,
1392 TYPE = F32,
1393 fn make(v: f32) {
1394 ScalarConstant::from_f32_bits(v.to_bits())
1395 }
1396 );
1397 ir_value!(
1398 IrF64,
1399 IrVecF64,
1400 TYPE = F64,
1401 fn make(v: f64) {
1402 ScalarConstant::from_f64_bits(v.to_bits())
1403 }
1404 );
1405
1406 macro_rules! impl_convert_to {
1407 ($($src:ident -> [$($dest:ident),*];)*) => {
1408 $($(
1409 impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
1410 fn to(self) -> $dest<'ctx> {
1411 let value = if $src::TYPE == $dest::TYPE {
1412 self.value
1413 } else {
1414 self
1415 .ctx
1416 .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
1417 .into()
1418 };
1419 $dest {
1420 value,
1421 ctx: self.ctx,
1422 }
1423 }
1424 }
1425 )*)*
1426 };
1427 ([$($src:ident),*] -> $dest:tt;) => {
1428 impl_convert_to! {
1429 $(
1430 $src -> $dest;
1431 )*
1432 }
1433 };
1434 ([$($src:ident),*];) => {
1435 impl_convert_to! {
1436 [$($src),*] -> [$($src),*];
1437 }
1438 };
1439 }
1440
1441 impl_convert_to! {
1442 [IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
1443 }
1444
1445 impl_convert_to! {
1446 [IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64];
1447 }
1448
1449 macro_rules! impl_from {
1450 ($src:ident => [$($dest:ident),*]) => {
1451 $(
1452 impl<'ctx> From<$src<'ctx>> for $dest<'ctx> {
1453 fn from(v: $src<'ctx>) -> Self {
1454 v.to()
1455 }
1456 }
1457 )*
1458 };
1459 }
1460
1461 macro_rules! impl_froms {
1462 (
1463 #[u8] $u8:ident;
1464 #[i8] $i8:ident;
1465 #[u16] $u16:ident;
1466 #[i16] $i16:ident;
1467 #[f16] $f16:ident;
1468 #[u32] $u32:ident;
1469 #[i32] $i32:ident;
1470 #[f32] $f32:ident;
1471 #[u64] $u64:ident;
1472 #[i64] $i64:ident;
1473 #[f64] $f64:ident;
1474 ) => {
1475 impl_from!($u8 => [$u16, $i16, $f16, $u32, $i32, $f32, $u64, $i64, $f64]);
1476 impl_from!($u16 => [$u32, $i32, $f32, $u64, $i64, $f64]);
1477 impl_from!($u32 => [$u64, $i64, $f64]);
1478 impl_from!($i8 => [$i16, $f16, $i32, $f32, $i64, $f64]);
1479 impl_from!($i16 => [$i32, $f32, $i64, $f64]);
1480 impl_from!($i32 => [$i64, $f64]);
1481 impl_from!($f16 => [$f32, $f64]);
1482 impl_from!($f32 => [$f64]);
1483 };
1484 }
1485
1486 impl_froms! {
1487 #[u8] IrU8;
1488 #[i8] IrI8;
1489 #[u16] IrU16;
1490 #[i16] IrI16;
1491 #[f16] IrF16;
1492 #[u32] IrU32;
1493 #[i32] IrI32;
1494 #[f32] IrF32;
1495 #[u64] IrU64;
1496 #[i64] IrI64;
1497 #[f64] IrF64;
1498 }
1499
1500 impl_froms! {
1501 #[u8] IrVecU8;
1502 #[i8] IrVecI8;
1503 #[u16] IrVecU16;
1504 #[i16] IrVecI16;
1505 #[f16] IrVecF16;
1506 #[u32] IrVecU32;
1507 #[i32] IrVecI32;
1508 #[f32] IrVecF32;
1509 #[u64] IrVecU64;
1510 #[i64] IrVecI64;
1511 #[f64] IrVecF64;
1512 }
1513
1514 impl<'ctx> Context for &'ctx IrContext<'ctx> {
1515 type Bool = IrBool<'ctx>;
1516 type U8 = IrU8<'ctx>;
1517 type I8 = IrI8<'ctx>;
1518 type U16 = IrU16<'ctx>;
1519 type I16 = IrI16<'ctx>;
1520 type F16 = IrF16<'ctx>;
1521 type U32 = IrU32<'ctx>;
1522 type I32 = IrI32<'ctx>;
1523 type F32 = IrF32<'ctx>;
1524 type U64 = IrU64<'ctx>;
1525 type I64 = IrI64<'ctx>;
1526 type F64 = IrF64<'ctx>;
1527 type VecBool = IrVecBool<'ctx>;
1528 type VecU8 = IrVecU8<'ctx>;
1529 type VecI8 = IrVecI8<'ctx>;
1530 type VecU16 = IrVecU16<'ctx>;
1531 type VecI16 = IrVecI16<'ctx>;
1532 type VecF16 = IrVecF16<'ctx>;
1533 type VecU32 = IrVecU32<'ctx>;
1534 type VecI32 = IrVecI32<'ctx>;
1535 type VecF32 = IrVecF32<'ctx>;
1536 type VecU64 = IrVecU64<'ctx>;
1537 type VecI64 = IrVecI64<'ctx>;
1538 type VecF64 = IrVecF64<'ctx>;
1539 }
1540
1541 #[cfg(test)]
1542 mod tests {
1543 use super::*;
1544 use std::println;
1545
1546 #[test]
1547 fn test_display() {
1548 fn f<Ctx: Context>(ctx: Ctx, a: Ctx::VecU8, b: Ctx::VecF32) -> Ctx::VecF64 {
1549 let a: Ctx::VecF32 = a.into();
1550 (a - (a + b - ctx.make(5f32)).floor()).to()
1551 }
1552 let ctx = IrContext::new();
1553 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1554 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>, IrVecF32<'ctx>) -> IrVecF64<'ctx> = f;
1555 IrFunction::make(ctx, f)
1556 }
1557 let text = format!("\n{}", make_it(&ctx));
1558 println!("{}", text);
1559 assert_eq!(
1560 text,
1561 r"
1562 function(in<arg_0>: vec<U8>, in<arg_1>: vec<F32>) -> vec<F64> {
1563 op_0: vec<F32> = Cast in<arg_0>
1564 op_1: vec<F32> = Add op_0, in<arg_1>
1565 op_2: vec<F32> = Sub op_1, splat(0x40A00000_f32)
1566 op_3: vec<F32> = Floor op_2
1567 op_4: vec<F32> = Sub op_0, op_3
1568 op_5: vec<F64> = Cast op_4
1569 Return op_5
1570 }
1571 "
1572 );
1573 }
1574 }