f1de05aa7d6bad699471d3afcbd4a5bd86724ac2
[vector-math.git] / src / ir.rs
1 use crate::{
2 f16::F16,
3 traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt},
4 };
5 use std::{
6 borrow::Borrow,
7 cell::{Cell, RefCell},
8 collections::HashMap,
9 fmt::{self, Write as _},
10 format,
11 ops::{
12 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
13 DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
14 SubAssign,
15 },
16 string::String,
17 vec::Vec,
18 };
19 use typed_arena::Arena;
20
21 macro_rules! make_enum {
22 (
23 $vis:vis enum $enum:ident {
24 $(
25 $(#[$meta:meta])*
26 $name:ident,
27 )*
28 }
29 ) => {
30 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
31 $vis enum $enum {
32 $(
33 $(#[$meta])*
34 $name,
35 )*
36 }
37
38 impl $enum {
39 $vis const fn as_str(self) -> &'static str {
40 match self {
41 $(
42 Self::$name => stringify!($name),
43 )*
44 }
45 }
46 }
47
48 impl fmt::Display for $enum {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.write_str(self.as_str())
51 }
52 }
53 };
54 }
55
56 make_enum! {
57 pub enum ScalarType {
58 Bool,
59 U8,
60 I8,
61 U16,
62 I16,
63 F16,
64 U32,
65 I32,
66 F32,
67 U64,
68 I64,
69 F64,
70 }
71 }
72
73 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
74 pub struct VectorType {
75 pub element: ScalarType,
76 }
77
78 impl fmt::Display for VectorType {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "vec<{}>", self.element)
81 }
82 }
83
84 impl From<ScalarType> for Type {
85 fn from(v: ScalarType) -> Self {
86 Type::Scalar(v)
87 }
88 }
89
90 impl From<VectorType> for Type {
91 fn from(v: VectorType) -> Self {
92 Type::Vector(v)
93 }
94 }
95
96 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
97 pub enum Type {
98 Scalar(ScalarType),
99 Vector(VectorType),
100 }
101
102 impl fmt::Display for Type {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 match self {
105 Type::Scalar(v) => v.fmt(f),
106 Type::Vector(v) => v.fmt(f),
107 }
108 }
109 }
110
111 #[derive(Clone, Copy, PartialEq, Eq, Hash)]
112 pub enum ScalarConstant {
113 Bool(bool),
114 U8(u8),
115 U16(u16),
116 U32(u32),
117 U64(u64),
118 I8(i8),
119 I16(i16),
120 I32(i32),
121 I64(i64),
122 F16 { bits: u16 },
123 F32 { bits: u32 },
124 F64 { bits: u64 },
125 }
126
127 macro_rules! make_scalar_constant_get {
128 ($ty:ident, $enumerant:ident) => {
129 pub fn $ty(self) -> Option<$ty> {
130 if let Self::$enumerant(v) = self {
131 Some(v)
132 } else {
133 None
134 }
135 }
136 };
137 }
138
139 macro_rules! make_scalar_constant_from {
140 ($ty:ident, $enumerant:ident) => {
141 impl From<$ty> for ScalarConstant {
142 fn from(v: $ty) -> Self {
143 Self::$enumerant(v)
144 }
145 }
146 impl From<$ty> for Constant {
147 fn from(v: $ty) -> Self {
148 Self::Scalar(v.into())
149 }
150 }
151 impl From<$ty> for Value<'_> {
152 fn from(v: $ty) -> Self {
153 Self::Constant(v.into())
154 }
155 }
156 };
157 }
158
159 make_scalar_constant_from!(bool, Bool);
160 make_scalar_constant_from!(u8, U8);
161 make_scalar_constant_from!(u16, U16);
162 make_scalar_constant_from!(u32, U32);
163 make_scalar_constant_from!(u64, U64);
164 make_scalar_constant_from!(i8, I8);
165 make_scalar_constant_from!(i16, I16);
166 make_scalar_constant_from!(i32, I32);
167 make_scalar_constant_from!(i64, I64);
168
169 impl ScalarConstant {
170 pub const fn ty(self) -> ScalarType {
171 match self {
172 ScalarConstant::Bool(_) => ScalarType::Bool,
173 ScalarConstant::U8(_) => ScalarType::U8,
174 ScalarConstant::U16(_) => ScalarType::U16,
175 ScalarConstant::U32(_) => ScalarType::U32,
176 ScalarConstant::U64(_) => ScalarType::U64,
177 ScalarConstant::I8(_) => ScalarType::I8,
178 ScalarConstant::I16(_) => ScalarType::I16,
179 ScalarConstant::I32(_) => ScalarType::I32,
180 ScalarConstant::I64(_) => ScalarType::I64,
181 ScalarConstant::F16 { .. } => ScalarType::F16,
182 ScalarConstant::F32 { .. } => ScalarType::F32,
183 ScalarConstant::F64 { .. } => ScalarType::F64,
184 }
185 }
186 pub const fn from_f16_bits(bits: u16) -> Self {
187 Self::F16 { bits }
188 }
189 pub const fn from_f32_bits(bits: u32) -> Self {
190 Self::F32 { bits }
191 }
192 pub const fn from_f64_bits(bits: u64) -> Self {
193 Self::F64 { bits }
194 }
195 pub const fn f16_bits(self) -> Option<u16> {
196 if let Self::F16 { bits } = self {
197 Some(bits)
198 } else {
199 None
200 }
201 }
202 pub const fn f32_bits(self) -> Option<u32> {
203 if let Self::F32 { bits } = self {
204 Some(bits)
205 } else {
206 None
207 }
208 }
209 pub const fn f64_bits(self) -> Option<u64> {
210 if let Self::F64 { bits } = self {
211 Some(bits)
212 } else {
213 None
214 }
215 }
216 make_scalar_constant_get!(bool, Bool);
217 make_scalar_constant_get!(u8, U8);
218 make_scalar_constant_get!(u16, U16);
219 make_scalar_constant_get!(u32, U32);
220 make_scalar_constant_get!(u64, U64);
221 make_scalar_constant_get!(i8, I8);
222 make_scalar_constant_get!(i16, I16);
223 make_scalar_constant_get!(i32, I32);
224 make_scalar_constant_get!(i64, I64);
225 }
226
227 impl fmt::Display for ScalarConstant {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 match self {
230 ScalarConstant::Bool(false) => write!(f, "false"),
231 ScalarConstant::Bool(true) => write!(f, "true"),
232 ScalarConstant::U8(v) => write!(f, "{:#X}_u8", v),
233 ScalarConstant::U16(v) => write!(f, "{:#X}_u16", v),
234 ScalarConstant::U32(v) => write!(f, "{:#X}_u32", v),
235 ScalarConstant::U64(v) => write!(f, "{:#X}_u64", v),
236 ScalarConstant::I8(v) => write!(f, "{:#X}_i8", v),
237 ScalarConstant::I16(v) => write!(f, "{:#X}_i16", v),
238 ScalarConstant::I32(v) => write!(f, "{:#X}_i32", v),
239 ScalarConstant::I64(v) => write!(f, "{:#X}_i64", v),
240 ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits),
241 ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits),
242 ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits),
243 }
244 }
245 }
246
247 impl fmt::Debug for ScalarConstant {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 fmt::Display::fmt(self, f)
250 }
251 }
252
253 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
254 pub struct VectorSplatConstant {
255 pub element: ScalarConstant,
256 }
257
258 impl VectorSplatConstant {
259 pub const fn ty(self) -> VectorType {
260 VectorType {
261 element: self.element.ty(),
262 }
263 }
264 }
265
266 impl fmt::Display for VectorSplatConstant {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 write!(f, "splat({})", self.element)
269 }
270 }
271
272 impl From<ScalarConstant> for Constant {
273 fn from(v: ScalarConstant) -> Self {
274 Constant::Scalar(v)
275 }
276 }
277
278 impl From<VectorSplatConstant> for Constant {
279 fn from(v: VectorSplatConstant) -> Self {
280 Constant::VectorSplat(v)
281 }
282 }
283
284 impl From<ScalarConstant> for Value<'_> {
285 fn from(v: ScalarConstant) -> Self {
286 Value::Constant(v.into())
287 }
288 }
289
290 impl From<VectorSplatConstant> for Value<'_> {
291 fn from(v: VectorSplatConstant) -> Self {
292 Value::Constant(v.into())
293 }
294 }
295
296 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
297 pub enum Constant {
298 Scalar(ScalarConstant),
299 VectorSplat(VectorSplatConstant),
300 }
301
302 impl Constant {
303 pub const fn ty(self) -> Type {
304 match self {
305 Constant::Scalar(v) => Type::Scalar(v.ty()),
306 Constant::VectorSplat(v) => Type::Vector(v.ty()),
307 }
308 }
309 }
310
311 impl fmt::Display for Constant {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 match self {
314 Constant::Scalar(v) => v.fmt(f),
315 Constant::VectorSplat(v) => v.fmt(f),
316 }
317 }
318 }
319
320 #[derive(Debug)]
321 pub struct Input<'ctx> {
322 pub name: &'ctx str,
323 pub ty: Type,
324 }
325
326 impl fmt::Display for Input<'_> {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 write!(f, "in<{}>", self.name)
329 }
330 }
331
332 #[derive(Copy, Clone)]
333 pub enum Value<'ctx> {
334 Input(&'ctx Input<'ctx>),
335 Constant(Constant),
336 OpResult(&'ctx Operation<'ctx>),
337 }
338
339 impl<'ctx> Value<'ctx> {
340 pub const fn ty(self) -> Type {
341 match self {
342 Value::Input(v) => v.ty,
343 Value::Constant(v) => v.ty(),
344 Value::OpResult(v) => v.result_type,
345 }
346 }
347 }
348
349 impl fmt::Debug for Value<'_> {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 match self {
352 Value::Input(v) => v.fmt(f),
353 Value::Constant(v) => v.fmt(f),
354 Value::OpResult(v) => v.result_id.fmt(f),
355 }
356 }
357 }
358
359 impl fmt::Display for Value<'_> {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 match self {
362 Value::Input(v) => v.fmt(f),
363 Value::Constant(v) => v.fmt(f),
364 Value::OpResult(v) => v.result_id.fmt(f),
365 }
366 }
367 }
368
369 impl<'ctx> From<&'ctx Input<'ctx>> for Value<'ctx> {
370 fn from(v: &'ctx Input<'ctx>) -> Self {
371 Value::Input(v)
372 }
373 }
374
375 impl<'ctx> From<&'ctx Operation<'ctx>> for Value<'ctx> {
376 fn from(v: &'ctx Operation<'ctx>) -> Self {
377 Value::OpResult(v)
378 }
379 }
380
381 impl<'ctx> From<Constant> for Value<'ctx> {
382 fn from(v: Constant) -> Self {
383 Value::Constant(v)
384 }
385 }
386
387 make_enum! {
388 pub enum Opcode {
389 Add,
390 Sub,
391 Mul,
392 Div,
393 Rem,
394 Fma,
395 Cast,
396 And,
397 Or,
398 Xor,
399 Not,
400 Shl,
401 Shr,
402 CountSetBits,
403 CountLeadingZeros,
404 CountTrailingZeros,
405 Neg,
406 Abs,
407 Trunc,
408 Ceil,
409 Floor,
410 Round,
411 IsInfinite,
412 IsFinite,
413 ToBits,
414 FromBits,
415 Splat,
416 CompareEq,
417 CompareNe,
418 CompareLt,
419 CompareLe,
420 CompareGt,
421 CompareGe,
422 Select,
423 }
424 }
425
426 #[derive(Debug)]
427 pub struct Operation<'ctx> {
428 pub opcode: Opcode,
429 pub arguments: Vec<Value<'ctx>>,
430 pub result_type: Type,
431 pub result_id: OperationId,
432 }
433
434 impl fmt::Display for Operation<'_> {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 write!(
437 f,
438 "{}: {} = {}",
439 self.result_id, self.result_type, self.opcode
440 )?;
441 let mut separator = " ";
442 for i in &self.arguments {
443 write!(f, "{}{}", separator, i)?;
444 separator = ", ";
445 }
446 Ok(())
447 }
448 }
449
450 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
451 pub struct OperationId(pub u64);
452
453 impl fmt::Display for OperationId {
454 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 write!(f, "op_{}", self.0)
456 }
457 }
458
459 #[derive(Default)]
460 pub struct IrContext<'ctx> {
461 bytes_arena: Arena<u8>,
462 inputs_arena: Arena<Input<'ctx>>,
463 inputs: RefCell<HashMap<&'ctx str, &'ctx Input<'ctx>>>,
464 operations_arena: Arena<Operation<'ctx>>,
465 operations: RefCell<Vec<&'ctx Operation<'ctx>>>,
466 next_operation_result_id: Cell<u64>,
467 }
468
469 impl fmt::Debug for IrContext<'_> {
470 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 f.write_str("IrContext { .. }")
472 }
473 }
474
475 impl<'ctx> IrContext<'ctx> {
476 pub fn new() -> Self {
477 Self::default()
478 }
479 pub fn make_input<N: Borrow<str> + Into<String>, T: Into<Type>>(
480 &'ctx self,
481 name: N,
482 ty: T,
483 ) -> &'ctx Input<'ctx> {
484 let mut inputs = self.inputs.borrow_mut();
485 let name_str = name.borrow();
486 let ty = ty.into();
487 if !name_str.is_empty() && !inputs.contains_key(name_str) {
488 let name = self.bytes_arena.alloc_str(name_str);
489 let input = self.inputs_arena.alloc(Input { name, ty });
490 inputs.insert(name, input);
491 return input;
492 }
493 let mut name: String = name.into();
494 if name.is_empty() {
495 name = "in".into();
496 }
497 let name_len = name.len();
498 let mut tag = 2usize;
499 loop {
500 name.truncate(name_len);
501 write!(name, "_{}", tag).unwrap();
502 if !inputs.contains_key(&*name) {
503 let name = self.bytes_arena.alloc_str(&name);
504 let input = self.inputs_arena.alloc(Input { name, ty });
505 inputs.insert(name, input);
506 return input;
507 }
508 tag += 1;
509 }
510 }
511 pub fn make_operation<A: Into<Vec<Value<'ctx>>>, T: Into<Type>>(
512 &'ctx self,
513 opcode: Opcode,
514 arguments: A,
515 result_type: T,
516 ) -> &'ctx Operation<'ctx> {
517 let arguments = arguments.into();
518 let result_type = result_type.into();
519 let result_id = OperationId(self.next_operation_result_id.get());
520 self.next_operation_result_id.set(result_id.0 + 1);
521 let operation = self.operations_arena.alloc(Operation {
522 opcode,
523 arguments,
524 result_type,
525 result_id,
526 });
527 self.operations.borrow_mut().push(operation);
528 operation
529 }
530 pub fn replace_operations(
531 &'ctx self,
532 new_operations: Vec<&'ctx Operation<'ctx>>,
533 ) -> Vec<&'ctx Operation<'ctx>> {
534 self.operations.replace(new_operations)
535 }
536 }
537
538 #[derive(Debug)]
539 pub struct IrFunction<'ctx> {
540 pub inputs: Vec<&'ctx Input<'ctx>>,
541 pub operations: Vec<&'ctx Operation<'ctx>>,
542 pub outputs: Vec<Value<'ctx>>,
543 }
544
545 impl fmt::Display for IrFunction<'_> {
546 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
547 write!(f, "function(")?;
548 let mut first = true;
549 for input in &self.inputs {
550 if first {
551 first = false
552 } else {
553 write!(f, ", ")?;
554 }
555 write!(f, "{}: {}", input, input.ty)?;
556 }
557 match self.outputs.len() {
558 0 => writeln!(f, ") {{")?,
559 1 => writeln!(f, ") -> {} {{", self.outputs[0].ty())?,
560 _ => {
561 write!(f, ") -> ({}", self.outputs[0].ty())?;
562 for output in self.outputs.iter().skip(1) {
563 write!(f, ", {}", output.ty())?;
564 }
565 writeln!(f, ") {{")?;
566 }
567 }
568 for operation in &self.operations {
569 writeln!(f, " {}", operation)?;
570 }
571 match self.outputs.len() {
572 0 => writeln!(f, "}}")?,
573 1 => writeln!(f, " Return {}\n}}", self.outputs[0])?,
574 _ => {
575 write!(f, " Return {}", self.outputs[0])?;
576 for output in self.outputs.iter().skip(1) {
577 write!(f, ", {}", output)?;
578 }
579 writeln!(f, "\n}}")?;
580 }
581 }
582 Ok(())
583 }
584 }
585
586 impl<'ctx> IrFunction<'ctx> {
587 pub fn make<F: IrFunctionMaker<'ctx>>(ctx: &'ctx IrContext<'ctx>, f: F) -> Self {
588 let old_operations = ctx.replace_operations(Vec::new());
589 let (v, inputs) = F::make_inputs(ctx);
590 let outputs = f.call(ctx, v).outputs_to_vec();
591 let operations = ctx.replace_operations(old_operations);
592 Self {
593 inputs,
594 operations,
595 outputs,
596 }
597 }
598 }
599
600 pub trait IrFunctionMaker<'ctx>: Sized {
601 type Inputs;
602 type Outputs: IrFunctionMakerOutputs<'ctx>;
603 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs;
604 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>);
605 }
606
607 pub trait IrFunctionMakerOutputs<'ctx> {
608 fn outputs_to_vec(self) -> Vec<Value<'ctx>>;
609 }
610
611 impl<'ctx, T: IrValue<'ctx>> IrFunctionMakerOutputs<'ctx> for T {
612 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
613 [self.value()].into()
614 }
615 }
616
617 impl<'ctx> IrFunctionMakerOutputs<'ctx> for () {
618 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
619 Vec::new()
620 }
621 }
622
623 impl<'ctx, R: IrFunctionMakerOutputs<'ctx>> IrFunctionMaker<'ctx>
624 for fn(&'ctx IrContext<'ctx>) -> R
625 {
626 type Inputs = ();
627 type Outputs = R;
628 fn call(self, ctx: &'ctx IrContext<'ctx>, _inputs: Self::Inputs) -> Self::Outputs {
629 self(ctx)
630 }
631 fn make_inputs(_ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
632 ((), Vec::new())
633 }
634 }
635
636 macro_rules! impl_ir_function_maker_io {
637 () => {};
638 ($first_arg:ident: $first_arg_ty:ident, $($arg:ident: $arg_ty:ident,)*) => {
639 impl<'ctx, $first_arg_ty, $($arg_ty,)* R> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>, $first_arg_ty $(, $arg_ty)*) -> R
640 where
641 $first_arg_ty: IrValue<'ctx>,
642 $($arg_ty: IrValue<'ctx>,)*
643 R: IrFunctionMakerOutputs<'ctx>,
644 {
645 type Inputs = ($first_arg_ty, $($arg_ty,)*);
646 type Outputs = R;
647 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs {
648 let ($first_arg, $($arg,)*) = inputs;
649 self(ctx, $first_arg$(, $arg)*)
650 }
651 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
652 let mut $first_arg = String::new();
653 $(let mut $arg = String::new();)*
654 for (index, arg) in [&mut $first_arg $(, &mut $arg)*].iter_mut().enumerate() {
655 **arg = format!("arg_{}", index);
656 }
657 let $first_arg = $first_arg_ty::make_input(ctx, $first_arg);
658 $(let $arg = $arg_ty::make_input(ctx, $arg);)*
659 (($first_arg.0, $($arg.0,)*), [$first_arg.1 $(, $arg.1)*].into())
660 }
661 }
662 impl<'ctx, $first_arg_ty, $($arg_ty),*> IrFunctionMakerOutputs<'ctx> for ($first_arg_ty, $($arg_ty,)*)
663 where
664 $first_arg_ty: IrValue<'ctx>,
665 $($arg_ty: IrValue<'ctx>,)*
666 {
667 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
668 let ($first_arg, $($arg,)*) = self;
669 [$first_arg.value() $(, $arg.value())*].into()
670 }
671 }
672 impl_ir_function_maker_io!($($arg: $arg_ty,)*);
673 };
674 }
675
676 impl_ir_function_maker_io!(
677 in0: In0,
678 in1: In1,
679 in2: In2,
680 in3: In3,
681 in4: In4,
682 in5: In5,
683 in6: In6,
684 in7: In7,
685 in8: In8,
686 in9: In9,
687 in10: In10,
688 in11: In11,
689 );
690
691 pub trait IrValue<'ctx>: Copy + Make<Context = &'ctx IrContext<'ctx>> {
692 const TYPE: Type;
693 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self;
694 fn make_input<N: Borrow<str> + Into<String>>(
695 ctx: &'ctx IrContext<'ctx>,
696 name: N,
697 ) -> (Self, &'ctx Input<'ctx>) {
698 let input = ctx.make_input(name, Self::TYPE);
699 (Self::new(ctx, input.into()), input)
700 }
701 fn value(self) -> Value<'ctx>;
702 }
703
704 macro_rules! ir_value {
705 ($name:ident, $vec_name:ident, TYPE = $scalar_type:ident, fn make($make_var:ident: $prim:ident) {$make:expr}) => {
706 #[derive(Clone, Copy, Debug)]
707 pub struct $name<'ctx> {
708 pub value: Value<'ctx>,
709 pub ctx: &'ctx IrContext<'ctx>,
710 }
711
712 impl<'ctx> IrValue<'ctx> for $name<'ctx> {
713 const TYPE: Type = Type::Scalar(Self::SCALAR_TYPE);
714 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
715 assert_eq!(value.ty(), Self::TYPE);
716 Self { ctx, value }
717 }
718 fn value(self) -> Value<'ctx> {
719 self.value
720 }
721 }
722
723 impl<'ctx> $name<'ctx> {
724 pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type;
725 }
726
727 impl<'ctx> Make for $name<'ctx> {
728 type Prim = $prim;
729 type Context = &'ctx IrContext<'ctx>;
730 fn ctx(self) -> Self::Context {
731 self.ctx
732 }
733 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
734 let value: ScalarConstant = $make;
735 let value = value.into();
736 Self { value, ctx }
737 }
738 }
739
740 #[derive(Clone, Copy, Debug)]
741 pub struct $vec_name<'ctx> {
742 pub value: Value<'ctx>,
743 pub ctx: &'ctx IrContext<'ctx>,
744 }
745
746 impl<'ctx> IrValue<'ctx> for $vec_name<'ctx> {
747 const TYPE: Type = Type::Vector(Self::VECTOR_TYPE);
748 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
749 assert_eq!(value.ty(), Self::TYPE);
750 Self { ctx, value }
751 }
752 fn value(self) -> Value<'ctx> {
753 self.value
754 }
755 }
756
757 impl<'ctx> $vec_name<'ctx> {
758 pub const VECTOR_TYPE: VectorType = VectorType {
759 element: ScalarType::$scalar_type,
760 };
761 }
762
763 impl<'ctx> Make for $vec_name<'ctx> {
764 type Prim = $prim;
765 type Context = &'ctx IrContext<'ctx>;
766 fn ctx(self) -> Self::Context {
767 self.ctx
768 }
769 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
770 let element = $make;
771 Self {
772 value: VectorSplatConstant { element }.into(),
773 ctx,
774 }
775 }
776 }
777
778 impl<'ctx> Select<$name<'ctx>> for IrBool<'ctx> {
779 fn select(self, true_v: $name<'ctx>, false_v: $name<'ctx>) -> $name<'ctx> {
780 let value = self
781 .ctx
782 .make_operation(
783 Opcode::Select,
784 [self.value, true_v.value, false_v.value],
785 $name::TYPE,
786 )
787 .into();
788 $name {
789 value,
790 ctx: self.ctx,
791 }
792 }
793 }
794
795 impl<'ctx> Select<$vec_name<'ctx>> for IrVecBool<'ctx> {
796 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
797 let value = self
798 .ctx
799 .make_operation(
800 Opcode::Select,
801 [self.value, true_v.value, false_v.value],
802 $vec_name::TYPE,
803 )
804 .into();
805 $vec_name {
806 value,
807 ctx: self.ctx,
808 }
809 }
810 }
811
812 impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
813 fn from(v: $name<'ctx>) -> Self {
814 let value = v
815 .ctx
816 .make_operation(Opcode::Splat, [v.value], $vec_name::TYPE)
817 .into();
818 Self { value, ctx: v.ctx }
819 }
820 }
821 };
822 }
823
824 macro_rules! impl_bit_ops {
825 ($ty:ident) => {
826 impl<'ctx> BitAnd for $ty<'ctx> {
827 type Output = Self;
828
829 fn bitand(self, rhs: Self) -> Self::Output {
830 let value = self
831 .ctx
832 .make_operation(Opcode::And, [self.value, rhs.value], Self::TYPE)
833 .into();
834 Self {
835 value,
836 ctx: self.ctx,
837 }
838 }
839 }
840 impl<'ctx> BitOr for $ty<'ctx> {
841 type Output = Self;
842
843 fn bitor(self, rhs: Self) -> Self::Output {
844 let value = self
845 .ctx
846 .make_operation(Opcode::Or, [self.value, rhs.value], Self::TYPE)
847 .into();
848 Self {
849 value,
850 ctx: self.ctx,
851 }
852 }
853 }
854 impl<'ctx> BitXor for $ty<'ctx> {
855 type Output = Self;
856
857 fn bitxor(self, rhs: Self) -> Self::Output {
858 let value = self
859 .ctx
860 .make_operation(Opcode::Xor, [self.value, rhs.value], Self::TYPE)
861 .into();
862 Self {
863 value,
864 ctx: self.ctx,
865 }
866 }
867 }
868 impl<'ctx> Not for $ty<'ctx> {
869 type Output = Self;
870
871 fn not(self) -> Self::Output {
872 let value = self
873 .ctx
874 .make_operation(Opcode::Not, [self.value], Self::TYPE)
875 .into();
876 Self {
877 value,
878 ctx: self.ctx,
879 }
880 }
881 }
882 impl<'ctx> BitAndAssign for $ty<'ctx> {
883 fn bitand_assign(&mut self, rhs: Self) {
884 *self = *self & rhs;
885 }
886 }
887 impl<'ctx> BitOrAssign for $ty<'ctx> {
888 fn bitor_assign(&mut self, rhs: Self) {
889 *self = *self | rhs;
890 }
891 }
892 impl<'ctx> BitXorAssign for $ty<'ctx> {
893 fn bitxor_assign(&mut self, rhs: Self) {
894 *self = *self ^ rhs;
895 }
896 }
897 };
898 }
899
900 macro_rules! impl_number_ops {
901 ($ty:ident, $bool:ident) => {
902 impl<'ctx> Add for $ty<'ctx> {
903 type Output = Self;
904
905 fn add(self, rhs: Self) -> Self::Output {
906 let value = self
907 .ctx
908 .make_operation(Opcode::Add, [self.value, rhs.value], Self::TYPE)
909 .into();
910 Self {
911 value,
912 ctx: self.ctx,
913 }
914 }
915 }
916 impl<'ctx> Sub for $ty<'ctx> {
917 type Output = Self;
918
919 fn sub(self, rhs: Self) -> Self::Output {
920 let value = self
921 .ctx
922 .make_operation(Opcode::Sub, [self.value, rhs.value], Self::TYPE)
923 .into();
924 Self {
925 value,
926 ctx: self.ctx,
927 }
928 }
929 }
930 impl<'ctx> Mul for $ty<'ctx> {
931 type Output = Self;
932
933 fn mul(self, rhs: Self) -> Self::Output {
934 let value = self
935 .ctx
936 .make_operation(Opcode::Mul, [self.value, rhs.value], Self::TYPE)
937 .into();
938 Self {
939 value,
940 ctx: self.ctx,
941 }
942 }
943 }
944 impl<'ctx> Div for $ty<'ctx> {
945 type Output = Self;
946
947 fn div(self, rhs: Self) -> Self::Output {
948 let value = self
949 .ctx
950 .make_operation(Opcode::Div, [self.value, rhs.value], Self::TYPE)
951 .into();
952 Self {
953 value,
954 ctx: self.ctx,
955 }
956 }
957 }
958 impl<'ctx> Rem for $ty<'ctx> {
959 type Output = Self;
960
961 fn rem(self, rhs: Self) -> Self::Output {
962 let value = self
963 .ctx
964 .make_operation(Opcode::Rem, [self.value, rhs.value], Self::TYPE)
965 .into();
966 Self {
967 value,
968 ctx: self.ctx,
969 }
970 }
971 }
972 impl<'ctx> AddAssign for $ty<'ctx> {
973 fn add_assign(&mut self, rhs: Self) {
974 *self = *self + rhs;
975 }
976 }
977 impl<'ctx> SubAssign for $ty<'ctx> {
978 fn sub_assign(&mut self, rhs: Self) {
979 *self = *self - rhs;
980 }
981 }
982 impl<'ctx> MulAssign for $ty<'ctx> {
983 fn mul_assign(&mut self, rhs: Self) {
984 *self = *self * rhs;
985 }
986 }
987 impl<'ctx> DivAssign for $ty<'ctx> {
988 fn div_assign(&mut self, rhs: Self) {
989 *self = *self / rhs;
990 }
991 }
992 impl<'ctx> RemAssign for $ty<'ctx> {
993 fn rem_assign(&mut self, rhs: Self) {
994 *self = *self % rhs;
995 }
996 }
997 impl<'ctx> Compare for $ty<'ctx> {
998 type Bool = $bool<'ctx>;
999 fn eq(self, rhs: Self) -> Self::Bool {
1000 let value = self
1001 .ctx
1002 .make_operation(Opcode::CompareEq, [self.value, rhs.value], $bool::TYPE)
1003 .into();
1004 $bool {
1005 value,
1006 ctx: self.ctx,
1007 }
1008 }
1009 fn ne(self, rhs: Self) -> Self::Bool {
1010 let value = self
1011 .ctx
1012 .make_operation(Opcode::CompareNe, [self.value, rhs.value], $bool::TYPE)
1013 .into();
1014 $bool {
1015 value,
1016 ctx: self.ctx,
1017 }
1018 }
1019 fn lt(self, rhs: Self) -> Self::Bool {
1020 let value = self
1021 .ctx
1022 .make_operation(Opcode::CompareLt, [self.value, rhs.value], $bool::TYPE)
1023 .into();
1024 $bool {
1025 value,
1026 ctx: self.ctx,
1027 }
1028 }
1029 fn gt(self, rhs: Self) -> Self::Bool {
1030 let value = self
1031 .ctx
1032 .make_operation(Opcode::CompareGt, [self.value, rhs.value], $bool::TYPE)
1033 .into();
1034 $bool {
1035 value,
1036 ctx: self.ctx,
1037 }
1038 }
1039 fn le(self, rhs: Self) -> Self::Bool {
1040 let value = self
1041 .ctx
1042 .make_operation(Opcode::CompareLe, [self.value, rhs.value], $bool::TYPE)
1043 .into();
1044 $bool {
1045 value,
1046 ctx: self.ctx,
1047 }
1048 }
1049 fn ge(self, rhs: Self) -> Self::Bool {
1050 let value = self
1051 .ctx
1052 .make_operation(Opcode::CompareGe, [self.value, rhs.value], $bool::TYPE)
1053 .into();
1054 $bool {
1055 value,
1056 ctx: self.ctx,
1057 }
1058 }
1059 }
1060 };
1061 }
1062
1063 macro_rules! impl_shift_ops {
1064 ($ty:ident, $rhs:ident) => {
1065 impl<'ctx> Shl<$rhs<'ctx>> for $ty<'ctx> {
1066 type Output = Self;
1067
1068 fn shl(self, rhs: $rhs<'ctx>) -> Self::Output {
1069 let value = self
1070 .ctx
1071 .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
1072 .into();
1073 Self {
1074 value,
1075 ctx: self.ctx,
1076 }
1077 }
1078 }
1079 impl<'ctx> Shr<$rhs<'ctx>> for $ty<'ctx> {
1080 type Output = Self;
1081
1082 fn shr(self, rhs: $rhs<'ctx>) -> Self::Output {
1083 let value = self
1084 .ctx
1085 .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
1086 .into();
1087 Self {
1088 value,
1089 ctx: self.ctx,
1090 }
1091 }
1092 }
1093 impl<'ctx> ShlAssign<$rhs<'ctx>> for $ty<'ctx> {
1094 fn shl_assign(&mut self, rhs: $rhs<'ctx>) {
1095 *self = *self << rhs;
1096 }
1097 }
1098 impl<'ctx> ShrAssign<$rhs<'ctx>> for $ty<'ctx> {
1099 fn shr_assign(&mut self, rhs: $rhs<'ctx>) {
1100 *self = *self >> rhs;
1101 }
1102 }
1103 };
1104 }
1105
1106 macro_rules! impl_neg {
1107 ($ty:ident) => {
1108 impl<'ctx> Neg for $ty<'ctx> {
1109 type Output = Self;
1110
1111 fn neg(self) -> Self::Output {
1112 let value = self
1113 .ctx
1114 .make_operation(Opcode::Neg, [self.value], Self::TYPE)
1115 .into();
1116 Self {
1117 value,
1118 ctx: self.ctx,
1119 }
1120 }
1121 }
1122 };
1123 }
1124
1125 macro_rules! impl_int_trait {
1126 ($ty:ident, $u32:ident) => {
1127 impl<'ctx> Int<$u32<'ctx>> for $ty<'ctx> {
1128 fn leading_zeros(self) -> Self {
1129 let value = self
1130 .ctx
1131 .make_operation(Opcode::CountLeadingZeros, [self.value], Self::TYPE)
1132 .into();
1133 Self {
1134 value,
1135 ctx: self.ctx,
1136 }
1137 }
1138 fn trailing_zeros(self) -> Self {
1139 let value = self
1140 .ctx
1141 .make_operation(Opcode::CountTrailingZeros, [self.value], Self::TYPE)
1142 .into();
1143 Self {
1144 value,
1145 ctx: self.ctx,
1146 }
1147 }
1148 fn count_ones(self) -> Self {
1149 let value = self
1150 .ctx
1151 .make_operation(Opcode::CountSetBits, [self.value], Self::TYPE)
1152 .into();
1153 Self {
1154 value,
1155 ctx: self.ctx,
1156 }
1157 }
1158 }
1159 };
1160 }
1161
1162 macro_rules! impl_integer_ops {
1163 ($scalar:ident, $vec:ident) => {
1164 impl_bit_ops!($scalar);
1165 impl_number_ops!($scalar, IrBool);
1166 impl_shift_ops!($scalar, IrU32);
1167 impl_bit_ops!($vec);
1168 impl_number_ops!($vec, IrVecBool);
1169 impl_shift_ops!($vec, IrVecU32);
1170 impl_int_trait!($scalar, IrU32);
1171 impl_int_trait!($vec, IrVecU32);
1172 };
1173 }
1174
1175 macro_rules! impl_uint_ops {
1176 ($scalar:ident, $vec:ident) => {
1177 impl_integer_ops!($scalar, $vec);
1178
1179 impl<'ctx> UInt<IrU32<'ctx>> for $scalar<'ctx> {}
1180 impl<'ctx> UInt<IrVecU32<'ctx>> for $vec<'ctx> {}
1181 };
1182 }
1183
1184 impl_uint_ops!(IrU8, IrVecU8);
1185 impl_uint_ops!(IrU16, IrVecU16);
1186 impl_uint_ops!(IrU32, IrVecU32);
1187 impl_uint_ops!(IrU64, IrVecU64);
1188
1189 macro_rules! impl_sint_ops {
1190 ($scalar:ident, $vec:ident) => {
1191 impl_integer_ops!($scalar, $vec);
1192 impl_neg!($scalar);
1193 impl_neg!($vec);
1194
1195 impl<'ctx> SInt<IrU32<'ctx>> for $scalar<'ctx> {}
1196 impl<'ctx> SInt<IrVecU32<'ctx>> for $vec<'ctx> {}
1197 };
1198 }
1199
1200 impl_sint_ops!(IrI8, IrVecI8);
1201 impl_sint_ops!(IrI16, IrVecI16);
1202 impl_sint_ops!(IrI32, IrVecI32);
1203 impl_sint_ops!(IrI64, IrVecI64);
1204
1205 macro_rules! impl_float {
1206 ($float:ident, $bits:ident, $signed_bits:ident, $u32:ident) => {
1207 impl<'ctx> Float<$u32<'ctx>> for $float<'ctx> {
1208 type FloatEncoding = <$float<'ctx> as Make>::Prim;
1209 type BitsType = $bits<'ctx>;
1210 type SignedBitsType = $signed_bits<'ctx>;
1211 fn abs(self) -> Self {
1212 let value = self
1213 .ctx
1214 .make_operation(Opcode::Abs, [self.value], Self::TYPE)
1215 .into();
1216 Self {
1217 value,
1218 ctx: self.ctx,
1219 }
1220 }
1221 fn trunc(self) -> Self {
1222 let value = self
1223 .ctx
1224 .make_operation(Opcode::Trunc, [self.value], Self::TYPE)
1225 .into();
1226 Self {
1227 value,
1228 ctx: self.ctx,
1229 }
1230 }
1231 fn ceil(self) -> Self {
1232 let value = self
1233 .ctx
1234 .make_operation(Opcode::Ceil, [self.value], Self::TYPE)
1235 .into();
1236 Self {
1237 value,
1238 ctx: self.ctx,
1239 }
1240 }
1241 fn floor(self) -> Self {
1242 let value = self
1243 .ctx
1244 .make_operation(Opcode::Floor, [self.value], Self::TYPE)
1245 .into();
1246 Self {
1247 value,
1248 ctx: self.ctx,
1249 }
1250 }
1251 fn round(self) -> Self {
1252 let value = self
1253 .ctx
1254 .make_operation(Opcode::Round, [self.value], Self::TYPE)
1255 .into();
1256 Self {
1257 value,
1258 ctx: self.ctx,
1259 }
1260 }
1261 #[cfg(feature = "fma")]
1262 fn fma(self, a: Self, b: Self) -> Self {
1263 let value = self
1264 .ctx
1265 .make_operation(Opcode::Fma, [self.value, a.value, b.value], Self::TYPE)
1266 .into();
1267 Self {
1268 value,
1269 ctx: self.ctx,
1270 }
1271 }
1272 fn is_nan(self) -> Self::Bool {
1273 let value = self
1274 .ctx
1275 .make_operation(
1276 Opcode::CompareNe,
1277 [self.value, self.value],
1278 Self::Bool::TYPE,
1279 )
1280 .into();
1281 Self::Bool {
1282 value,
1283 ctx: self.ctx,
1284 }
1285 }
1286 fn is_infinite(self) -> Self::Bool {
1287 let value = self
1288 .ctx
1289 .make_operation(Opcode::IsInfinite, [self.value], Self::Bool::TYPE)
1290 .into();
1291 Self::Bool {
1292 value,
1293 ctx: self.ctx,
1294 }
1295 }
1296 fn is_finite(self) -> Self::Bool {
1297 let value = self
1298 .ctx
1299 .make_operation(Opcode::IsFinite, [self.value], Self::Bool::TYPE)
1300 .into();
1301 Self::Bool {
1302 value,
1303 ctx: self.ctx,
1304 }
1305 }
1306 fn from_bits(v: Self::BitsType) -> Self {
1307 let value = v
1308 .ctx
1309 .make_operation(Opcode::FromBits, [v.value], Self::TYPE)
1310 .into();
1311 Self { value, ctx: v.ctx }
1312 }
1313 fn to_bits(self) -> Self::BitsType {
1314 let value = self
1315 .ctx
1316 .make_operation(Opcode::ToBits, [self.value], Self::BitsType::TYPE)
1317 .into();
1318 Self::BitsType {
1319 value,
1320 ctx: self.ctx,
1321 }
1322 }
1323 }
1324 };
1325 }
1326
1327 macro_rules! impl_float_ops {
1328 ($scalar:ident, $scalar_bits:ident, $scalar_signed_bits:ident, $vec:ident, $vec_bits:ident, $vec_signed_bits:ident) => {
1329 impl_number_ops!($scalar, IrBool);
1330 impl_number_ops!($vec, IrVecBool);
1331 impl_neg!($scalar);
1332 impl_neg!($vec);
1333 impl_float!($scalar, $scalar_bits, $scalar_signed_bits, IrU32);
1334 impl_float!($vec, $vec_bits, $vec_signed_bits, IrVecU32);
1335 };
1336 }
1337
1338 impl_float_ops!(IrF16, IrU16, IrI16, IrVecF16, IrVecU16, IrVecI16);
1339 impl_float_ops!(IrF32, IrU32, IrI32, IrVecF32, IrVecU32, IrVecI32);
1340 impl_float_ops!(IrF64, IrU64, IrI64, IrVecF64, IrVecU64, IrVecI64);
1341
1342 ir_value!(
1343 IrBool,
1344 IrVecBool,
1345 TYPE = Bool,
1346 fn make(v: bool) {
1347 v.into()
1348 }
1349 );
1350
1351 impl<'ctx> Bool for IrBool<'ctx> {}
1352 impl<'ctx> Bool for IrVecBool<'ctx> {}
1353
1354 impl_bit_ops!(IrBool);
1355 impl_bit_ops!(IrVecBool);
1356
1357 ir_value!(
1358 IrU8,
1359 IrVecU8,
1360 TYPE = U8,
1361 fn make(v: u8) {
1362 v.into()
1363 }
1364 );
1365 ir_value!(
1366 IrU16,
1367 IrVecU16,
1368 TYPE = U16,
1369 fn make(v: u16) {
1370 v.into()
1371 }
1372 );
1373 ir_value!(
1374 IrU32,
1375 IrVecU32,
1376 TYPE = U32,
1377 fn make(v: u32) {
1378 v.into()
1379 }
1380 );
1381 ir_value!(
1382 IrU64,
1383 IrVecU64,
1384 TYPE = U64,
1385 fn make(v: u64) {
1386 v.into()
1387 }
1388 );
1389 ir_value!(
1390 IrI8,
1391 IrVecI8,
1392 TYPE = I8,
1393 fn make(v: i8) {
1394 v.into()
1395 }
1396 );
1397 ir_value!(
1398 IrI16,
1399 IrVecI16,
1400 TYPE = I16,
1401 fn make(v: i16) {
1402 v.into()
1403 }
1404 );
1405 ir_value!(
1406 IrI32,
1407 IrVecI32,
1408 TYPE = I32,
1409 fn make(v: i32) {
1410 v.into()
1411 }
1412 );
1413 ir_value!(
1414 IrI64,
1415 IrVecI64,
1416 TYPE = I64,
1417 fn make(v: i64) {
1418 v.into()
1419 }
1420 );
1421 ir_value!(
1422 IrF16,
1423 IrVecF16,
1424 TYPE = F16,
1425 fn make(v: F16) {
1426 ScalarConstant::from_f16_bits(v.to_bits())
1427 }
1428 );
1429 ir_value!(
1430 IrF32,
1431 IrVecF32,
1432 TYPE = F32,
1433 fn make(v: f32) {
1434 ScalarConstant::from_f32_bits(v.to_bits())
1435 }
1436 );
1437 ir_value!(
1438 IrF64,
1439 IrVecF64,
1440 TYPE = F64,
1441 fn make(v: f64) {
1442 ScalarConstant::from_f64_bits(v.to_bits())
1443 }
1444 );
1445
1446 macro_rules! impl_convert_to {
1447 ($($src:ident -> [$($dest:ident),*];)*) => {
1448 $($(
1449 impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
1450 fn to(self) -> $dest<'ctx> {
1451 let value = if $src::TYPE == $dest::TYPE {
1452 self.value
1453 } else {
1454 self
1455 .ctx
1456 .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
1457 .into()
1458 };
1459 $dest {
1460 value,
1461 ctx: self.ctx,
1462 }
1463 }
1464 }
1465 )*)*
1466 };
1467 ([$($src:ident),*] -> $dest:tt;) => {
1468 impl_convert_to! {
1469 $(
1470 $src -> $dest;
1471 )*
1472 }
1473 };
1474 ([$($src:ident),*];) => {
1475 impl_convert_to! {
1476 [$($src),*] -> [$($src),*];
1477 }
1478 };
1479 }
1480
1481 impl_convert_to! {
1482 [IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
1483 }
1484
1485 impl_convert_to! {
1486 [IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64];
1487 }
1488
1489 macro_rules! impl_from {
1490 ($src:ident => [$($dest:ident),*]) => {
1491 $(
1492 impl<'ctx> From<$src<'ctx>> for $dest<'ctx> {
1493 fn from(v: $src<'ctx>) -> Self {
1494 v.to()
1495 }
1496 }
1497 )*
1498 };
1499 }
1500
1501 macro_rules! impl_froms {
1502 (
1503 #[u8] $u8:ident;
1504 #[i8] $i8:ident;
1505 #[u16] $u16:ident;
1506 #[i16] $i16:ident;
1507 #[f16] $f16:ident;
1508 #[u32] $u32:ident;
1509 #[i32] $i32:ident;
1510 #[f32] $f32:ident;
1511 #[u64] $u64:ident;
1512 #[i64] $i64:ident;
1513 #[f64] $f64:ident;
1514 ) => {
1515 impl_from!($u8 => [$u16, $i16, $f16, $u32, $i32, $f32, $u64, $i64, $f64]);
1516 impl_from!($u16 => [$u32, $i32, $f32, $u64, $i64, $f64]);
1517 impl_from!($u32 => [$u64, $i64, $f64]);
1518 impl_from!($i8 => [$i16, $f16, $i32, $f32, $i64, $f64]);
1519 impl_from!($i16 => [$i32, $f32, $i64, $f64]);
1520 impl_from!($i32 => [$i64, $f64]);
1521 impl_from!($f16 => [$f32, $f64]);
1522 impl_from!($f32 => [$f64]);
1523 };
1524 }
1525
1526 impl_froms! {
1527 #[u8] IrU8;
1528 #[i8] IrI8;
1529 #[u16] IrU16;
1530 #[i16] IrI16;
1531 #[f16] IrF16;
1532 #[u32] IrU32;
1533 #[i32] IrI32;
1534 #[f32] IrF32;
1535 #[u64] IrU64;
1536 #[i64] IrI64;
1537 #[f64] IrF64;
1538 }
1539
1540 impl_froms! {
1541 #[u8] IrVecU8;
1542 #[i8] IrVecI8;
1543 #[u16] IrVecU16;
1544 #[i16] IrVecI16;
1545 #[f16] IrVecF16;
1546 #[u32] IrVecU32;
1547 #[i32] IrVecI32;
1548 #[f32] IrVecF32;
1549 #[u64] IrVecU64;
1550 #[i64] IrVecI64;
1551 #[f64] IrVecF64;
1552 }
1553
1554 impl<'ctx> Context for &'ctx IrContext<'ctx> {
1555 type Bool = IrBool<'ctx>;
1556 type U8 = IrU8<'ctx>;
1557 type I8 = IrI8<'ctx>;
1558 type U16 = IrU16<'ctx>;
1559 type I16 = IrI16<'ctx>;
1560 type F16 = IrF16<'ctx>;
1561 type U32 = IrU32<'ctx>;
1562 type I32 = IrI32<'ctx>;
1563 type F32 = IrF32<'ctx>;
1564 type U64 = IrU64<'ctx>;
1565 type I64 = IrI64<'ctx>;
1566 type F64 = IrF64<'ctx>;
1567 type VecBool = IrVecBool<'ctx>;
1568 type VecU8 = IrVecU8<'ctx>;
1569 type VecI8 = IrVecI8<'ctx>;
1570 type VecU16 = IrVecU16<'ctx>;
1571 type VecI16 = IrVecI16<'ctx>;
1572 type VecF16 = IrVecF16<'ctx>;
1573 type VecU32 = IrVecU32<'ctx>;
1574 type VecI32 = IrVecI32<'ctx>;
1575 type VecF32 = IrVecF32<'ctx>;
1576 type VecU64 = IrVecU64<'ctx>;
1577 type VecI64 = IrVecI64<'ctx>;
1578 type VecF64 = IrVecF64<'ctx>;
1579 }
1580
1581 #[cfg(test)]
1582 mod tests {
1583 use crate::algorithms;
1584
1585 use super::*;
1586 use std::println;
1587
1588 #[test]
1589 fn test_display() {
1590 fn f<Ctx: Context>(ctx: Ctx, a: Ctx::VecU8, b: Ctx::VecF32) -> Ctx::VecF64 {
1591 let a: Ctx::VecF32 = a.into();
1592 (a - (a + b - ctx.make(5f32)).floor()).to()
1593 }
1594 let ctx = IrContext::new();
1595 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1596 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>, IrVecF32<'ctx>) -> IrVecF64<'ctx> = f;
1597 IrFunction::make(ctx, f)
1598 }
1599 let text = format!("\n{}", make_it(&ctx));
1600 println!("{}", text);
1601 assert_eq!(
1602 text,
1603 r"
1604 function(in<arg_0>: vec<U8>, in<arg_1>: vec<F32>) -> vec<F64> {
1605 op_0: vec<F32> = Cast in<arg_0>
1606 op_1: vec<F32> = Add op_0, in<arg_1>
1607 op_2: vec<F32> = Sub op_1, splat(0x40A00000_f32)
1608 op_3: vec<F32> = Floor op_2
1609 op_4: vec<F32> = Sub op_0, op_3
1610 op_5: vec<F64> = Cast op_4
1611 Return op_5
1612 }
1613 "
1614 );
1615 }
1616
1617 #[test]
1618 fn test_display_ilogb_f32() {
1619 let ctx = IrContext::new();
1620 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1621 let f: fn(&'ctx IrContext<'ctx>, IrVecF32<'ctx>) -> IrVecI32<'ctx> =
1622 algorithms::ilogb::ilogb_f32;
1623 IrFunction::make(ctx, f)
1624 }
1625 let text = format!("\n{}", make_it(&ctx));
1626 println!("{}", text);
1627 assert_eq!(
1628 text,
1629 r"
1630 function(in<arg_0>: vec<F32>) -> vec<I32> {
1631 op_0: vec<Bool> = IsFinite in<arg_0>
1632 op_1: vec<U32> = ToBits in<arg_0>
1633 op_2: vec<U32> = And op_1, splat(0x7F800000_u32)
1634 op_3: vec<U32> = Shr op_2, splat(0x17_u32)
1635 op_4: vec<Bool> = CompareEq op_3, splat(0x0_u32)
1636 op_5: vec<Bool> = CompareNe in<arg_0>, in<arg_0>
1637 op_6: vec<I32> = Splat 0x80000001_i32
1638 op_7: vec<I32> = Splat 0x7FFFFFFF_i32
1639 op_8: vec<I32> = Select op_5, op_6, op_7
1640 op_9: vec<F32> = Mul in<arg_0>, splat(0x4B000000_f32)
1641 op_10: vec<U32> = ToBits op_9
1642 op_11: vec<U32> = And op_10, splat(0x7F800000_u32)
1643 op_12: vec<U32> = Shr op_11, splat(0x17_u32)
1644 op_13: vec<I32> = Cast op_12
1645 op_14: vec<I32> = Sub op_13, splat(0x7F_i32)
1646 op_15: vec<U32> = ToBits in<arg_0>
1647 op_16: vec<U32> = And op_15, splat(0x7F800000_u32)
1648 op_17: vec<U32> = Shr op_16, splat(0x17_u32)
1649 op_18: vec<I32> = Cast op_17
1650 op_19: vec<I32> = Sub op_18, splat(0x7F_i32)
1651 op_20: vec<I32> = Select op_0, op_19, op_8
1652 op_21: vec<Bool> = CompareEq in<arg_0>, splat(0x0_f32)
1653 op_22: vec<I32> = Splat 0x80000000_i32
1654 op_23: vec<I32> = Sub op_14, splat(0x17_i32)
1655 op_24: vec<I32> = Select op_21, op_22, op_23
1656 op_25: vec<I32> = Select op_4, op_24, op_20
1657 Return op_25
1658 }
1659 "
1660 );
1661 }
1662 }