switch to using separate VecBool8/16/32/64
[vector-math.git] / src / ir.rs
1 use crate::{
2 f16::F16,
3 traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt},
4 };
5 use std::{
6 borrow::Borrow,
7 cell::{Cell, RefCell},
8 collections::HashMap,
9 fmt::{self, Write as _},
10 format,
11 ops::{
12 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
13 DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
14 SubAssign,
15 },
16 string::String,
17 vec::Vec,
18 };
19 use typed_arena::Arena;
20
21 macro_rules! make_enum {
22 (
23 $vis:vis enum $enum:ident {
24 $(
25 $(#[$meta:meta])*
26 $name:ident,
27 )*
28 }
29 ) => {
30 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
31 $vis enum $enum {
32 $(
33 $(#[$meta])*
34 $name,
35 )*
36 }
37
38 impl $enum {
39 $vis const fn as_str(self) -> &'static str {
40 match self {
41 $(
42 Self::$name => stringify!($name),
43 )*
44 }
45 }
46 }
47
48 impl fmt::Display for $enum {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.write_str(self.as_str())
51 }
52 }
53 };
54 }
55
56 make_enum! {
57 pub enum ScalarType {
58 Bool,
59 U8,
60 I8,
61 U16,
62 I16,
63 F16,
64 U32,
65 I32,
66 F32,
67 U64,
68 I64,
69 F64,
70 }
71 }
72
73 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
74 pub struct VectorType {
75 pub element: ScalarType,
76 }
77
78 impl fmt::Display for VectorType {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(f, "vec<{}>", self.element)
81 }
82 }
83
84 impl From<ScalarType> for Type {
85 fn from(v: ScalarType) -> Self {
86 Type::Scalar(v)
87 }
88 }
89
90 impl From<VectorType> for Type {
91 fn from(v: VectorType) -> Self {
92 Type::Vector(v)
93 }
94 }
95
96 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
97 pub enum Type {
98 Scalar(ScalarType),
99 Vector(VectorType),
100 }
101
102 impl fmt::Display for Type {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 match self {
105 Type::Scalar(v) => v.fmt(f),
106 Type::Vector(v) => v.fmt(f),
107 }
108 }
109 }
110
111 #[derive(Clone, Copy, PartialEq, Eq, Hash)]
112 pub enum ScalarConstant {
113 Bool(bool),
114 U8(u8),
115 U16(u16),
116 U32(u32),
117 U64(u64),
118 I8(i8),
119 I16(i16),
120 I32(i32),
121 I64(i64),
122 F16 { bits: u16 },
123 F32 { bits: u32 },
124 F64 { bits: u64 },
125 }
126
127 macro_rules! make_scalar_constant_get {
128 ($ty:ident, $enumerant:ident) => {
129 pub fn $ty(self) -> Option<$ty> {
130 if let Self::$enumerant(v) = self {
131 Some(v)
132 } else {
133 None
134 }
135 }
136 };
137 }
138
139 macro_rules! make_scalar_constant_from {
140 ($ty:ident, $enumerant:ident) => {
141 impl From<$ty> for ScalarConstant {
142 fn from(v: $ty) -> Self {
143 Self::$enumerant(v)
144 }
145 }
146 impl From<$ty> for Constant {
147 fn from(v: $ty) -> Self {
148 Self::Scalar(v.into())
149 }
150 }
151 impl From<$ty> for Value<'_> {
152 fn from(v: $ty) -> Self {
153 Self::Constant(v.into())
154 }
155 }
156 };
157 }
158
159 make_scalar_constant_from!(bool, Bool);
160 make_scalar_constant_from!(u8, U8);
161 make_scalar_constant_from!(u16, U16);
162 make_scalar_constant_from!(u32, U32);
163 make_scalar_constant_from!(u64, U64);
164 make_scalar_constant_from!(i8, I8);
165 make_scalar_constant_from!(i16, I16);
166 make_scalar_constant_from!(i32, I32);
167 make_scalar_constant_from!(i64, I64);
168
169 impl ScalarConstant {
170 pub const fn ty(self) -> ScalarType {
171 match self {
172 ScalarConstant::Bool(_) => ScalarType::Bool,
173 ScalarConstant::U8(_) => ScalarType::U8,
174 ScalarConstant::U16(_) => ScalarType::U16,
175 ScalarConstant::U32(_) => ScalarType::U32,
176 ScalarConstant::U64(_) => ScalarType::U64,
177 ScalarConstant::I8(_) => ScalarType::I8,
178 ScalarConstant::I16(_) => ScalarType::I16,
179 ScalarConstant::I32(_) => ScalarType::I32,
180 ScalarConstant::I64(_) => ScalarType::I64,
181 ScalarConstant::F16 { .. } => ScalarType::F16,
182 ScalarConstant::F32 { .. } => ScalarType::F32,
183 ScalarConstant::F64 { .. } => ScalarType::F64,
184 }
185 }
186 pub const fn from_f16_bits(bits: u16) -> Self {
187 Self::F16 { bits }
188 }
189 pub const fn from_f32_bits(bits: u32) -> Self {
190 Self::F32 { bits }
191 }
192 pub const fn from_f64_bits(bits: u64) -> Self {
193 Self::F64 { bits }
194 }
195 pub const fn f16_bits(self) -> Option<u16> {
196 if let Self::F16 { bits } = self {
197 Some(bits)
198 } else {
199 None
200 }
201 }
202 pub const fn f32_bits(self) -> Option<u32> {
203 if let Self::F32 { bits } = self {
204 Some(bits)
205 } else {
206 None
207 }
208 }
209 pub const fn f64_bits(self) -> Option<u64> {
210 if let Self::F64 { bits } = self {
211 Some(bits)
212 } else {
213 None
214 }
215 }
216 make_scalar_constant_get!(bool, Bool);
217 make_scalar_constant_get!(u8, U8);
218 make_scalar_constant_get!(u16, U16);
219 make_scalar_constant_get!(u32, U32);
220 make_scalar_constant_get!(u64, U64);
221 make_scalar_constant_get!(i8, I8);
222 make_scalar_constant_get!(i16, I16);
223 make_scalar_constant_get!(i32, I32);
224 make_scalar_constant_get!(i64, I64);
225 }
226
227 impl fmt::Display for ScalarConstant {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 match self {
230 ScalarConstant::Bool(false) => write!(f, "false"),
231 ScalarConstant::Bool(true) => write!(f, "true"),
232 ScalarConstant::U8(v) => write!(f, "{:#X}_u8", v),
233 ScalarConstant::U16(v) => write!(f, "{:#X}_u16", v),
234 ScalarConstant::U32(v) => write!(f, "{:#X}_u32", v),
235 ScalarConstant::U64(v) => write!(f, "{:#X}_u64", v),
236 ScalarConstant::I8(v) => write!(f, "{:#X}_i8", v),
237 ScalarConstant::I16(v) => write!(f, "{:#X}_i16", v),
238 ScalarConstant::I32(v) => write!(f, "{:#X}_i32", v),
239 ScalarConstant::I64(v) => write!(f, "{:#X}_i64", v),
240 ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits),
241 ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits),
242 ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits),
243 }
244 }
245 }
246
247 impl fmt::Debug for ScalarConstant {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 fmt::Display::fmt(self, f)
250 }
251 }
252
253 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
254 pub struct VectorSplatConstant {
255 pub element: ScalarConstant,
256 }
257
258 impl VectorSplatConstant {
259 pub const fn ty(self) -> VectorType {
260 VectorType {
261 element: self.element.ty(),
262 }
263 }
264 }
265
266 impl fmt::Display for VectorSplatConstant {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 write!(f, "splat({})", self.element)
269 }
270 }
271
272 impl From<ScalarConstant> for Constant {
273 fn from(v: ScalarConstant) -> Self {
274 Constant::Scalar(v)
275 }
276 }
277
278 impl From<VectorSplatConstant> for Constant {
279 fn from(v: VectorSplatConstant) -> Self {
280 Constant::VectorSplat(v)
281 }
282 }
283
284 impl From<ScalarConstant> for Value<'_> {
285 fn from(v: ScalarConstant) -> Self {
286 Value::Constant(v.into())
287 }
288 }
289
290 impl From<VectorSplatConstant> for Value<'_> {
291 fn from(v: VectorSplatConstant) -> Self {
292 Value::Constant(v.into())
293 }
294 }
295
296 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
297 pub enum Constant {
298 Scalar(ScalarConstant),
299 VectorSplat(VectorSplatConstant),
300 }
301
302 impl Constant {
303 pub const fn ty(self) -> Type {
304 match self {
305 Constant::Scalar(v) => Type::Scalar(v.ty()),
306 Constant::VectorSplat(v) => Type::Vector(v.ty()),
307 }
308 }
309 }
310
311 impl fmt::Display for Constant {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 match self {
314 Constant::Scalar(v) => v.fmt(f),
315 Constant::VectorSplat(v) => v.fmt(f),
316 }
317 }
318 }
319
320 #[derive(Debug)]
321 pub struct Input<'ctx> {
322 pub name: &'ctx str,
323 pub ty: Type,
324 }
325
326 impl fmt::Display for Input<'_> {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 write!(f, "in<{}>", self.name)
329 }
330 }
331
332 #[derive(Copy, Clone)]
333 pub enum Value<'ctx> {
334 Input(&'ctx Input<'ctx>),
335 Constant(Constant),
336 OpResult(&'ctx Operation<'ctx>),
337 }
338
339 impl<'ctx> Value<'ctx> {
340 pub const fn ty(self) -> Type {
341 match self {
342 Value::Input(v) => v.ty,
343 Value::Constant(v) => v.ty(),
344 Value::OpResult(v) => v.result_type,
345 }
346 }
347 }
348
349 impl fmt::Debug for Value<'_> {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 match self {
352 Value::Input(v) => v.fmt(f),
353 Value::Constant(v) => v.fmt(f),
354 Value::OpResult(v) => v.result_id.fmt(f),
355 }
356 }
357 }
358
359 impl fmt::Display for Value<'_> {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 match self {
362 Value::Input(v) => v.fmt(f),
363 Value::Constant(v) => v.fmt(f),
364 Value::OpResult(v) => v.result_id.fmt(f),
365 }
366 }
367 }
368
369 impl<'ctx> From<&'ctx Input<'ctx>> for Value<'ctx> {
370 fn from(v: &'ctx Input<'ctx>) -> Self {
371 Value::Input(v)
372 }
373 }
374
375 impl<'ctx> From<&'ctx Operation<'ctx>> for Value<'ctx> {
376 fn from(v: &'ctx Operation<'ctx>) -> Self {
377 Value::OpResult(v)
378 }
379 }
380
381 impl<'ctx> From<Constant> for Value<'ctx> {
382 fn from(v: Constant) -> Self {
383 Value::Constant(v)
384 }
385 }
386
387 make_enum! {
388 pub enum Opcode {
389 Add,
390 Sub,
391 Mul,
392 Div,
393 Rem,
394 Fma,
395 Cast,
396 And,
397 Or,
398 Xor,
399 Not,
400 Shl,
401 Shr,
402 CountSetBits,
403 CountLeadingZeros,
404 CountTrailingZeros,
405 Neg,
406 Abs,
407 Trunc,
408 Ceil,
409 Floor,
410 Round,
411 IsInfinite,
412 IsFinite,
413 ToBits,
414 FromBits,
415 Splat,
416 CompareEq,
417 CompareNe,
418 CompareLt,
419 CompareLe,
420 CompareGt,
421 CompareGe,
422 Select,
423 }
424 }
425
426 #[derive(Debug)]
427 pub struct Operation<'ctx> {
428 pub opcode: Opcode,
429 pub arguments: Vec<Value<'ctx>>,
430 pub result_type: Type,
431 pub result_id: OperationId,
432 }
433
434 impl fmt::Display for Operation<'_> {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 write!(
437 f,
438 "{}: {} = {}",
439 self.result_id, self.result_type, self.opcode
440 )?;
441 let mut separator = " ";
442 for i in &self.arguments {
443 write!(f, "{}{}", separator, i)?;
444 separator = ", ";
445 }
446 Ok(())
447 }
448 }
449
450 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
451 pub struct OperationId(pub u64);
452
453 impl fmt::Display for OperationId {
454 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 write!(f, "op_{}", self.0)
456 }
457 }
458
459 #[derive(Default)]
460 pub struct IrContext<'ctx> {
461 bytes_arena: Arena<u8>,
462 inputs_arena: Arena<Input<'ctx>>,
463 inputs: RefCell<HashMap<&'ctx str, &'ctx Input<'ctx>>>,
464 operations_arena: Arena<Operation<'ctx>>,
465 operations: RefCell<Vec<&'ctx Operation<'ctx>>>,
466 next_operation_result_id: Cell<u64>,
467 }
468
469 impl fmt::Debug for IrContext<'_> {
470 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 f.write_str("IrContext { .. }")
472 }
473 }
474
475 impl<'ctx> IrContext<'ctx> {
476 pub fn new() -> Self {
477 Self::default()
478 }
479 pub fn make_input<N: Borrow<str> + Into<String>, T: Into<Type>>(
480 &'ctx self,
481 name: N,
482 ty: T,
483 ) -> &'ctx Input<'ctx> {
484 let mut inputs = self.inputs.borrow_mut();
485 let name_str = name.borrow();
486 let ty = ty.into();
487 if !name_str.is_empty() && !inputs.contains_key(name_str) {
488 let name = self.bytes_arena.alloc_str(name_str);
489 let input = self.inputs_arena.alloc(Input { name, ty });
490 inputs.insert(name, input);
491 return input;
492 }
493 let mut name: String = name.into();
494 if name.is_empty() {
495 name = "in".into();
496 }
497 let name_len = name.len();
498 let mut tag = 2usize;
499 loop {
500 name.truncate(name_len);
501 write!(name, "_{}", tag).unwrap();
502 if !inputs.contains_key(&*name) {
503 let name = self.bytes_arena.alloc_str(&name);
504 let input = self.inputs_arena.alloc(Input { name, ty });
505 inputs.insert(name, input);
506 return input;
507 }
508 tag += 1;
509 }
510 }
511 pub fn make_operation<A: Into<Vec<Value<'ctx>>>, T: Into<Type>>(
512 &'ctx self,
513 opcode: Opcode,
514 arguments: A,
515 result_type: T,
516 ) -> &'ctx Operation<'ctx> {
517 let arguments = arguments.into();
518 let result_type = result_type.into();
519 let result_id = OperationId(self.next_operation_result_id.get());
520 self.next_operation_result_id.set(result_id.0 + 1);
521 let operation = self.operations_arena.alloc(Operation {
522 opcode,
523 arguments,
524 result_type,
525 result_id,
526 });
527 self.operations.borrow_mut().push(operation);
528 operation
529 }
530 pub fn replace_operations(
531 &'ctx self,
532 new_operations: Vec<&'ctx Operation<'ctx>>,
533 ) -> Vec<&'ctx Operation<'ctx>> {
534 self.operations.replace(new_operations)
535 }
536 }
537
538 #[derive(Debug)]
539 pub struct IrFunction<'ctx> {
540 pub inputs: Vec<&'ctx Input<'ctx>>,
541 pub operations: Vec<&'ctx Operation<'ctx>>,
542 pub outputs: Vec<Value<'ctx>>,
543 }
544
545 impl fmt::Display for IrFunction<'_> {
546 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
547 write!(f, "function(")?;
548 let mut first = true;
549 for input in &self.inputs {
550 if first {
551 first = false
552 } else {
553 write!(f, ", ")?;
554 }
555 write!(f, "{}: {}", input, input.ty)?;
556 }
557 match self.outputs.len() {
558 0 => writeln!(f, ") {{")?,
559 1 => writeln!(f, ") -> {} {{", self.outputs[0].ty())?,
560 _ => {
561 write!(f, ") -> ({}", self.outputs[0].ty())?;
562 for output in self.outputs.iter().skip(1) {
563 write!(f, ", {}", output.ty())?;
564 }
565 writeln!(f, ") {{")?;
566 }
567 }
568 for operation in &self.operations {
569 writeln!(f, " {}", operation)?;
570 }
571 match self.outputs.len() {
572 0 => writeln!(f, "}}")?,
573 1 => writeln!(f, " Return {}\n}}", self.outputs[0])?,
574 _ => {
575 write!(f, " Return {}", self.outputs[0])?;
576 for output in self.outputs.iter().skip(1) {
577 write!(f, ", {}", output)?;
578 }
579 writeln!(f, "\n}}")?;
580 }
581 }
582 Ok(())
583 }
584 }
585
586 impl<'ctx> IrFunction<'ctx> {
587 pub fn make<F: IrFunctionMaker<'ctx>>(ctx: &'ctx IrContext<'ctx>, f: F) -> Self {
588 let old_operations = ctx.replace_operations(Vec::new());
589 let (v, inputs) = F::make_inputs(ctx);
590 let outputs = f.call(ctx, v).outputs_to_vec();
591 let operations = ctx.replace_operations(old_operations);
592 Self {
593 inputs,
594 operations,
595 outputs,
596 }
597 }
598 }
599
600 pub trait IrFunctionMaker<'ctx>: Sized {
601 type Inputs;
602 type Outputs: IrFunctionMakerOutputs<'ctx>;
603 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs;
604 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>);
605 }
606
607 pub trait IrFunctionMakerOutputs<'ctx> {
608 fn outputs_to_vec(self) -> Vec<Value<'ctx>>;
609 }
610
611 impl<'ctx, T: IrValue<'ctx>> IrFunctionMakerOutputs<'ctx> for T {
612 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
613 [self.value()].into()
614 }
615 }
616
617 impl<'ctx> IrFunctionMakerOutputs<'ctx> for () {
618 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
619 Vec::new()
620 }
621 }
622
623 impl<'ctx, R: IrFunctionMakerOutputs<'ctx>> IrFunctionMaker<'ctx>
624 for fn(&'ctx IrContext<'ctx>) -> R
625 {
626 type Inputs = ();
627 type Outputs = R;
628 fn call(self, ctx: &'ctx IrContext<'ctx>, _inputs: Self::Inputs) -> Self::Outputs {
629 self(ctx)
630 }
631 fn make_inputs(_ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
632 ((), Vec::new())
633 }
634 }
635
636 macro_rules! impl_ir_function_maker_io {
637 () => {};
638 ($first_arg:ident: $first_arg_ty:ident, $($arg:ident: $arg_ty:ident,)*) => {
639 impl<'ctx, $first_arg_ty, $($arg_ty,)* R> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>, $first_arg_ty $(, $arg_ty)*) -> R
640 where
641 $first_arg_ty: IrValue<'ctx>,
642 $($arg_ty: IrValue<'ctx>,)*
643 R: IrFunctionMakerOutputs<'ctx>,
644 {
645 type Inputs = ($first_arg_ty, $($arg_ty,)*);
646 type Outputs = R;
647 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs {
648 let ($first_arg, $($arg,)*) = inputs;
649 self(ctx, $first_arg$(, $arg)*)
650 }
651 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
652 let mut $first_arg = String::new();
653 $(let mut $arg = String::new();)*
654 for (index, arg) in [&mut $first_arg $(, &mut $arg)*].iter_mut().enumerate() {
655 **arg = format!("arg_{}", index);
656 }
657 let $first_arg = $first_arg_ty::make_input(ctx, $first_arg);
658 $(let $arg = $arg_ty::make_input(ctx, $arg);)*
659 (($first_arg.0, $($arg.0,)*), [$first_arg.1 $(, $arg.1)*].into())
660 }
661 }
662 impl<'ctx, $first_arg_ty, $($arg_ty),*> IrFunctionMakerOutputs<'ctx> for ($first_arg_ty, $($arg_ty,)*)
663 where
664 $first_arg_ty: IrValue<'ctx>,
665 $($arg_ty: IrValue<'ctx>,)*
666 {
667 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
668 let ($first_arg, $($arg,)*) = self;
669 [$first_arg.value() $(, $arg.value())*].into()
670 }
671 }
672 impl_ir_function_maker_io!($($arg: $arg_ty,)*);
673 };
674 }
675
676 impl_ir_function_maker_io!(
677 in0: In0,
678 in1: In1,
679 in2: In2,
680 in3: In3,
681 in4: In4,
682 in5: In5,
683 in6: In6,
684 in7: In7,
685 in8: In8,
686 in9: In9,
687 in10: In10,
688 in11: In11,
689 );
690
691 pub trait IrValue<'ctx>: Copy + Make<Context = &'ctx IrContext<'ctx>> {
692 const TYPE: Type;
693 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self;
694 fn make_input<N: Borrow<str> + Into<String>>(
695 ctx: &'ctx IrContext<'ctx>,
696 name: N,
697 ) -> (Self, &'ctx Input<'ctx>) {
698 let input = ctx.make_input(name, Self::TYPE);
699 (Self::new(ctx, input.into()), input)
700 }
701 fn value(self) -> Value<'ctx>;
702 }
703
704 macro_rules! ir_value {
705 ($name:ident, $vec_name:ident, TYPE = $scalar_type:ident, fn make($make_var:ident: $prim:ident) {$make:expr}) => {
706 #[derive(Clone, Copy, Debug)]
707 pub struct $name<'ctx> {
708 pub value: Value<'ctx>,
709 pub ctx: &'ctx IrContext<'ctx>,
710 }
711
712 impl<'ctx> IrValue<'ctx> for $name<'ctx> {
713 const TYPE: Type = Type::Scalar(Self::SCALAR_TYPE);
714 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
715 assert_eq!(value.ty(), Self::TYPE);
716 Self { ctx, value }
717 }
718 fn value(self) -> Value<'ctx> {
719 self.value
720 }
721 }
722
723 impl<'ctx> $name<'ctx> {
724 pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type;
725 }
726
727 impl<'ctx> Make for $name<'ctx> {
728 type Prim = $prim;
729 type Context = &'ctx IrContext<'ctx>;
730 fn ctx(self) -> Self::Context {
731 self.ctx
732 }
733 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
734 let value: ScalarConstant = $make;
735 let value = value.into();
736 Self { value, ctx }
737 }
738 }
739
740 #[derive(Clone, Copy, Debug)]
741 pub struct $vec_name<'ctx> {
742 pub value: Value<'ctx>,
743 pub ctx: &'ctx IrContext<'ctx>,
744 }
745
746 impl<'ctx> IrValue<'ctx> for $vec_name<'ctx> {
747 const TYPE: Type = Type::Vector(Self::VECTOR_TYPE);
748 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
749 assert_eq!(value.ty(), Self::TYPE);
750 Self { ctx, value }
751 }
752 fn value(self) -> Value<'ctx> {
753 self.value
754 }
755 }
756
757 impl<'ctx> $vec_name<'ctx> {
758 pub const VECTOR_TYPE: VectorType = VectorType {
759 element: ScalarType::$scalar_type,
760 };
761 }
762
763 impl<'ctx> Make for $vec_name<'ctx> {
764 type Prim = $prim;
765 type Context = &'ctx IrContext<'ctx>;
766 fn ctx(self) -> Self::Context {
767 self.ctx
768 }
769 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
770 let element = $make;
771 Self {
772 value: VectorSplatConstant { element }.into(),
773 ctx,
774 }
775 }
776 }
777
778 impl<'ctx> Select<$name<'ctx>> for IrBool<'ctx> {
779 fn select(self, true_v: $name<'ctx>, false_v: $name<'ctx>) -> $name<'ctx> {
780 let value = self
781 .ctx
782 .make_operation(
783 Opcode::Select,
784 [self.value, true_v.value, false_v.value],
785 $name::TYPE,
786 )
787 .into();
788 $name {
789 value,
790 ctx: self.ctx,
791 }
792 }
793 }
794
795 impl<'ctx> Select<$vec_name<'ctx>> for IrVecBool<'ctx> {
796 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
797 let value = self
798 .ctx
799 .make_operation(
800 Opcode::Select,
801 [self.value, true_v.value, false_v.value],
802 $vec_name::TYPE,
803 )
804 .into();
805 $vec_name {
806 value,
807 ctx: self.ctx,
808 }
809 }
810 }
811
812 impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> {
813 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
814 let value = self
815 .ctx
816 .make_operation(
817 Opcode::Select,
818 [self.value, true_v.value, false_v.value],
819 $vec_name::TYPE,
820 )
821 .into();
822 $vec_name {
823 value,
824 ctx: self.ctx,
825 }
826 }
827 }
828
829 impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
830 fn from(v: $name<'ctx>) -> Self {
831 let value = v
832 .ctx
833 .make_operation(Opcode::Splat, [v.value], $vec_name::TYPE)
834 .into();
835 Self { value, ctx: v.ctx }
836 }
837 }
838 };
839 }
840
841 macro_rules! impl_bit_ops {
842 ($ty:ident) => {
843 impl<'ctx> BitAnd for $ty<'ctx> {
844 type Output = Self;
845
846 fn bitand(self, rhs: Self) -> Self::Output {
847 let value = self
848 .ctx
849 .make_operation(Opcode::And, [self.value, rhs.value], Self::TYPE)
850 .into();
851 Self {
852 value,
853 ctx: self.ctx,
854 }
855 }
856 }
857 impl<'ctx> BitOr for $ty<'ctx> {
858 type Output = Self;
859
860 fn bitor(self, rhs: Self) -> Self::Output {
861 let value = self
862 .ctx
863 .make_operation(Opcode::Or, [self.value, rhs.value], Self::TYPE)
864 .into();
865 Self {
866 value,
867 ctx: self.ctx,
868 }
869 }
870 }
871 impl<'ctx> BitXor for $ty<'ctx> {
872 type Output = Self;
873
874 fn bitxor(self, rhs: Self) -> Self::Output {
875 let value = self
876 .ctx
877 .make_operation(Opcode::Xor, [self.value, rhs.value], Self::TYPE)
878 .into();
879 Self {
880 value,
881 ctx: self.ctx,
882 }
883 }
884 }
885 impl<'ctx> Not for $ty<'ctx> {
886 type Output = Self;
887
888 fn not(self) -> Self::Output {
889 let value = self
890 .ctx
891 .make_operation(Opcode::Not, [self.value], Self::TYPE)
892 .into();
893 Self {
894 value,
895 ctx: self.ctx,
896 }
897 }
898 }
899 impl<'ctx> BitAndAssign for $ty<'ctx> {
900 fn bitand_assign(&mut self, rhs: Self) {
901 *self = *self & rhs;
902 }
903 }
904 impl<'ctx> BitOrAssign for $ty<'ctx> {
905 fn bitor_assign(&mut self, rhs: Self) {
906 *self = *self | rhs;
907 }
908 }
909 impl<'ctx> BitXorAssign for $ty<'ctx> {
910 fn bitxor_assign(&mut self, rhs: Self) {
911 *self = *self ^ rhs;
912 }
913 }
914 };
915 }
916
917 macro_rules! impl_number_ops {
918 ($ty:ident, $bool:ident) => {
919 impl<'ctx> Add for $ty<'ctx> {
920 type Output = Self;
921
922 fn add(self, rhs: Self) -> Self::Output {
923 let value = self
924 .ctx
925 .make_operation(Opcode::Add, [self.value, rhs.value], Self::TYPE)
926 .into();
927 Self {
928 value,
929 ctx: self.ctx,
930 }
931 }
932 }
933 impl<'ctx> Sub for $ty<'ctx> {
934 type Output = Self;
935
936 fn sub(self, rhs: Self) -> Self::Output {
937 let value = self
938 .ctx
939 .make_operation(Opcode::Sub, [self.value, rhs.value], Self::TYPE)
940 .into();
941 Self {
942 value,
943 ctx: self.ctx,
944 }
945 }
946 }
947 impl<'ctx> Mul for $ty<'ctx> {
948 type Output = Self;
949
950 fn mul(self, rhs: Self) -> Self::Output {
951 let value = self
952 .ctx
953 .make_operation(Opcode::Mul, [self.value, rhs.value], Self::TYPE)
954 .into();
955 Self {
956 value,
957 ctx: self.ctx,
958 }
959 }
960 }
961 impl<'ctx> Div for $ty<'ctx> {
962 type Output = Self;
963
964 fn div(self, rhs: Self) -> Self::Output {
965 let value = self
966 .ctx
967 .make_operation(Opcode::Div, [self.value, rhs.value], Self::TYPE)
968 .into();
969 Self {
970 value,
971 ctx: self.ctx,
972 }
973 }
974 }
975 impl<'ctx> Rem for $ty<'ctx> {
976 type Output = Self;
977
978 fn rem(self, rhs: Self) -> Self::Output {
979 let value = self
980 .ctx
981 .make_operation(Opcode::Rem, [self.value, rhs.value], Self::TYPE)
982 .into();
983 Self {
984 value,
985 ctx: self.ctx,
986 }
987 }
988 }
989 impl<'ctx> AddAssign for $ty<'ctx> {
990 fn add_assign(&mut self, rhs: Self) {
991 *self = *self + rhs;
992 }
993 }
994 impl<'ctx> SubAssign for $ty<'ctx> {
995 fn sub_assign(&mut self, rhs: Self) {
996 *self = *self - rhs;
997 }
998 }
999 impl<'ctx> MulAssign for $ty<'ctx> {
1000 fn mul_assign(&mut self, rhs: Self) {
1001 *self = *self * rhs;
1002 }
1003 }
1004 impl<'ctx> DivAssign for $ty<'ctx> {
1005 fn div_assign(&mut self, rhs: Self) {
1006 *self = *self / rhs;
1007 }
1008 }
1009 impl<'ctx> RemAssign for $ty<'ctx> {
1010 fn rem_assign(&mut self, rhs: Self) {
1011 *self = *self % rhs;
1012 }
1013 }
1014 impl<'ctx> Compare for $ty<'ctx> {
1015 type Bool = $bool<'ctx>;
1016 fn eq(self, rhs: Self) -> Self::Bool {
1017 let value = self
1018 .ctx
1019 .make_operation(Opcode::CompareEq, [self.value, rhs.value], $bool::TYPE)
1020 .into();
1021 $bool {
1022 value,
1023 ctx: self.ctx,
1024 }
1025 }
1026 fn ne(self, rhs: Self) -> Self::Bool {
1027 let value = self
1028 .ctx
1029 .make_operation(Opcode::CompareNe, [self.value, rhs.value], $bool::TYPE)
1030 .into();
1031 $bool {
1032 value,
1033 ctx: self.ctx,
1034 }
1035 }
1036 fn lt(self, rhs: Self) -> Self::Bool {
1037 let value = self
1038 .ctx
1039 .make_operation(Opcode::CompareLt, [self.value, rhs.value], $bool::TYPE)
1040 .into();
1041 $bool {
1042 value,
1043 ctx: self.ctx,
1044 }
1045 }
1046 fn gt(self, rhs: Self) -> Self::Bool {
1047 let value = self
1048 .ctx
1049 .make_operation(Opcode::CompareGt, [self.value, rhs.value], $bool::TYPE)
1050 .into();
1051 $bool {
1052 value,
1053 ctx: self.ctx,
1054 }
1055 }
1056 fn le(self, rhs: Self) -> Self::Bool {
1057 let value = self
1058 .ctx
1059 .make_operation(Opcode::CompareLe, [self.value, rhs.value], $bool::TYPE)
1060 .into();
1061 $bool {
1062 value,
1063 ctx: self.ctx,
1064 }
1065 }
1066 fn ge(self, rhs: Self) -> Self::Bool {
1067 let value = self
1068 .ctx
1069 .make_operation(Opcode::CompareGe, [self.value, rhs.value], $bool::TYPE)
1070 .into();
1071 $bool {
1072 value,
1073 ctx: self.ctx,
1074 }
1075 }
1076 }
1077 };
1078 }
1079
1080 macro_rules! impl_bool_compare {
1081 ($ty:ident) => {
1082 impl<'ctx> Compare for $ty<'ctx> {
1083 type Bool = Self;
1084 fn eq(self, rhs: Self) -> Self::Bool {
1085 !(self ^ rhs)
1086 }
1087 fn ne(self, rhs: Self) -> Self::Bool {
1088 self ^ rhs
1089 }
1090 fn lt(self, rhs: Self) -> Self::Bool {
1091 !self & rhs
1092 }
1093 fn gt(self, rhs: Self) -> Self::Bool {
1094 self & !rhs
1095 }
1096 fn le(self, rhs: Self) -> Self::Bool {
1097 !self | rhs
1098 }
1099 fn ge(self, rhs: Self) -> Self::Bool {
1100 self | !rhs
1101 }
1102 }
1103 };
1104 }
1105
1106 impl_bool_compare!(IrBool);
1107 impl_bool_compare!(IrVecBool);
1108
1109 macro_rules! impl_shift_ops {
1110 ($ty:ident) => {
1111 impl<'ctx> Shl for $ty<'ctx> {
1112 type Output = Self;
1113
1114 fn shl(self, rhs: Self) -> Self::Output {
1115 let value = self
1116 .ctx
1117 .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
1118 .into();
1119 Self {
1120 value,
1121 ctx: self.ctx,
1122 }
1123 }
1124 }
1125 impl<'ctx> Shr for $ty<'ctx> {
1126 type Output = Self;
1127
1128 fn shr(self, rhs: Self) -> Self::Output {
1129 let value = self
1130 .ctx
1131 .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
1132 .into();
1133 Self {
1134 value,
1135 ctx: self.ctx,
1136 }
1137 }
1138 }
1139 impl<'ctx> ShlAssign for $ty<'ctx> {
1140 fn shl_assign(&mut self, rhs: Self) {
1141 *self = *self << rhs;
1142 }
1143 }
1144 impl<'ctx> ShrAssign for $ty<'ctx> {
1145 fn shr_assign(&mut self, rhs: Self) {
1146 *self = *self >> rhs;
1147 }
1148 }
1149 };
1150 }
1151
1152 macro_rules! impl_neg {
1153 ($ty:ident) => {
1154 impl<'ctx> Neg for $ty<'ctx> {
1155 type Output = Self;
1156
1157 fn neg(self) -> Self::Output {
1158 let value = self
1159 .ctx
1160 .make_operation(Opcode::Neg, [self.value], Self::TYPE)
1161 .into();
1162 Self {
1163 value,
1164 ctx: self.ctx,
1165 }
1166 }
1167 }
1168 };
1169 }
1170
1171 macro_rules! impl_int_trait {
1172 ($ty:ident) => {
1173 impl<'ctx> Int for $ty<'ctx> {
1174 fn leading_zeros(self) -> Self {
1175 let value = self
1176 .ctx
1177 .make_operation(Opcode::CountLeadingZeros, [self.value], Self::TYPE)
1178 .into();
1179 Self {
1180 value,
1181 ctx: self.ctx,
1182 }
1183 }
1184 fn trailing_zeros(self) -> Self {
1185 let value = self
1186 .ctx
1187 .make_operation(Opcode::CountTrailingZeros, [self.value], Self::TYPE)
1188 .into();
1189 Self {
1190 value,
1191 ctx: self.ctx,
1192 }
1193 }
1194 fn count_ones(self) -> Self {
1195 let value = self
1196 .ctx
1197 .make_operation(Opcode::CountSetBits, [self.value], Self::TYPE)
1198 .into();
1199 Self {
1200 value,
1201 ctx: self.ctx,
1202 }
1203 }
1204 }
1205 };
1206 }
1207
1208 macro_rules! impl_integer_ops {
1209 ($scalar:ident, $vec:ident) => {
1210 impl_bit_ops!($scalar);
1211 impl_number_ops!($scalar, IrBool);
1212 impl_shift_ops!($scalar);
1213 impl_bit_ops!($vec);
1214 impl_number_ops!($vec, IrVecBool);
1215 impl_shift_ops!($vec);
1216 impl_int_trait!($scalar);
1217 impl_int_trait!($vec);
1218 };
1219 }
1220
1221 macro_rules! impl_uint_ops {
1222 ($scalar:ident, $vec:ident) => {
1223 impl_integer_ops!($scalar, $vec);
1224
1225 impl<'ctx> UInt for $scalar<'ctx> {}
1226 impl<'ctx> UInt for $vec<'ctx> {}
1227 };
1228 }
1229
1230 impl_uint_ops!(IrU8, IrVecU8);
1231 impl_uint_ops!(IrU16, IrVecU16);
1232 impl_uint_ops!(IrU32, IrVecU32);
1233 impl_uint_ops!(IrU64, IrVecU64);
1234
1235 macro_rules! impl_sint_ops {
1236 ($scalar:ident, $vec:ident) => {
1237 impl_integer_ops!($scalar, $vec);
1238 impl_neg!($scalar);
1239 impl_neg!($vec);
1240
1241 impl<'ctx> SInt for $scalar<'ctx> {}
1242 impl<'ctx> SInt for $vec<'ctx> {}
1243 };
1244 }
1245
1246 impl_sint_ops!(IrI8, IrVecI8);
1247 impl_sint_ops!(IrI16, IrVecI16);
1248 impl_sint_ops!(IrI32, IrVecI32);
1249 impl_sint_ops!(IrI64, IrVecI64);
1250
1251 macro_rules! impl_float {
1252 ($float:ident, $bits:ident, $signed_bits:ident) => {
1253 impl<'ctx> Float for $float<'ctx> {
1254 type FloatEncoding = <$float<'ctx> as Make>::Prim;
1255 type BitsType = $bits<'ctx>;
1256 type SignedBitsType = $signed_bits<'ctx>;
1257 fn abs(self) -> Self {
1258 let value = self
1259 .ctx
1260 .make_operation(Opcode::Abs, [self.value], Self::TYPE)
1261 .into();
1262 Self {
1263 value,
1264 ctx: self.ctx,
1265 }
1266 }
1267 fn trunc(self) -> Self {
1268 let value = self
1269 .ctx
1270 .make_operation(Opcode::Trunc, [self.value], Self::TYPE)
1271 .into();
1272 Self {
1273 value,
1274 ctx: self.ctx,
1275 }
1276 }
1277 fn ceil(self) -> Self {
1278 let value = self
1279 .ctx
1280 .make_operation(Opcode::Ceil, [self.value], Self::TYPE)
1281 .into();
1282 Self {
1283 value,
1284 ctx: self.ctx,
1285 }
1286 }
1287 fn floor(self) -> Self {
1288 let value = self
1289 .ctx
1290 .make_operation(Opcode::Floor, [self.value], Self::TYPE)
1291 .into();
1292 Self {
1293 value,
1294 ctx: self.ctx,
1295 }
1296 }
1297 fn round(self) -> Self {
1298 let value = self
1299 .ctx
1300 .make_operation(Opcode::Round, [self.value], Self::TYPE)
1301 .into();
1302 Self {
1303 value,
1304 ctx: self.ctx,
1305 }
1306 }
1307 #[cfg(feature = "fma")]
1308 fn fma(self, a: Self, b: Self) -> Self {
1309 let value = self
1310 .ctx
1311 .make_operation(Opcode::Fma, [self.value, a.value, b.value], Self::TYPE)
1312 .into();
1313 Self {
1314 value,
1315 ctx: self.ctx,
1316 }
1317 }
1318 fn is_nan(self) -> Self::Bool {
1319 let value = self
1320 .ctx
1321 .make_operation(
1322 Opcode::CompareNe,
1323 [self.value, self.value],
1324 Self::Bool::TYPE,
1325 )
1326 .into();
1327 Self::Bool {
1328 value,
1329 ctx: self.ctx,
1330 }
1331 }
1332 fn is_infinite(self) -> Self::Bool {
1333 let value = self
1334 .ctx
1335 .make_operation(Opcode::IsInfinite, [self.value], Self::Bool::TYPE)
1336 .into();
1337 Self::Bool {
1338 value,
1339 ctx: self.ctx,
1340 }
1341 }
1342 fn is_finite(self) -> Self::Bool {
1343 let value = self
1344 .ctx
1345 .make_operation(Opcode::IsFinite, [self.value], Self::Bool::TYPE)
1346 .into();
1347 Self::Bool {
1348 value,
1349 ctx: self.ctx,
1350 }
1351 }
1352 fn from_bits(v: Self::BitsType) -> Self {
1353 let value = v
1354 .ctx
1355 .make_operation(Opcode::FromBits, [v.value], Self::TYPE)
1356 .into();
1357 Self { value, ctx: v.ctx }
1358 }
1359 fn to_bits(self) -> Self::BitsType {
1360 let value = self
1361 .ctx
1362 .make_operation(Opcode::ToBits, [self.value], Self::BitsType::TYPE)
1363 .into();
1364 Self::BitsType {
1365 value,
1366 ctx: self.ctx,
1367 }
1368 }
1369 }
1370 };
1371 }
1372
1373 macro_rules! impl_float_ops {
1374 ($scalar:ident, $scalar_bits:ident, $scalar_signed_bits:ident, $vec:ident, $vec_bits:ident, $vec_signed_bits:ident) => {
1375 impl_number_ops!($scalar, IrBool);
1376 impl_number_ops!($vec, IrVecBool);
1377 impl_neg!($scalar);
1378 impl_neg!($vec);
1379 impl_float!($scalar, $scalar_bits, $scalar_signed_bits);
1380 impl_float!($vec, $vec_bits, $vec_signed_bits);
1381 };
1382 }
1383
1384 impl_float_ops!(IrF16, IrU16, IrI16, IrVecF16, IrVecU16, IrVecI16);
1385 impl_float_ops!(IrF32, IrU32, IrI32, IrVecF32, IrVecU32, IrVecI32);
1386 impl_float_ops!(IrF64, IrU64, IrI64, IrVecF64, IrVecU64, IrVecI64);
1387
1388 ir_value!(
1389 IrBool,
1390 IrVecBool,
1391 TYPE = Bool,
1392 fn make(v: bool) {
1393 v.into()
1394 }
1395 );
1396
1397 impl<'ctx> Bool for IrBool<'ctx> {}
1398 impl<'ctx> Bool for IrVecBool<'ctx> {}
1399
1400 impl_bit_ops!(IrBool);
1401 impl_bit_ops!(IrVecBool);
1402
1403 ir_value!(
1404 IrU8,
1405 IrVecU8,
1406 TYPE = U8,
1407 fn make(v: u8) {
1408 v.into()
1409 }
1410 );
1411 ir_value!(
1412 IrU16,
1413 IrVecU16,
1414 TYPE = U16,
1415 fn make(v: u16) {
1416 v.into()
1417 }
1418 );
1419 ir_value!(
1420 IrU32,
1421 IrVecU32,
1422 TYPE = U32,
1423 fn make(v: u32) {
1424 v.into()
1425 }
1426 );
1427 ir_value!(
1428 IrU64,
1429 IrVecU64,
1430 TYPE = U64,
1431 fn make(v: u64) {
1432 v.into()
1433 }
1434 );
1435 ir_value!(
1436 IrI8,
1437 IrVecI8,
1438 TYPE = I8,
1439 fn make(v: i8) {
1440 v.into()
1441 }
1442 );
1443 ir_value!(
1444 IrI16,
1445 IrVecI16,
1446 TYPE = I16,
1447 fn make(v: i16) {
1448 v.into()
1449 }
1450 );
1451 ir_value!(
1452 IrI32,
1453 IrVecI32,
1454 TYPE = I32,
1455 fn make(v: i32) {
1456 v.into()
1457 }
1458 );
1459 ir_value!(
1460 IrI64,
1461 IrVecI64,
1462 TYPE = I64,
1463 fn make(v: i64) {
1464 v.into()
1465 }
1466 );
1467 ir_value!(
1468 IrF16,
1469 IrVecF16,
1470 TYPE = F16,
1471 fn make(v: F16) {
1472 ScalarConstant::from_f16_bits(v.to_bits())
1473 }
1474 );
1475 ir_value!(
1476 IrF32,
1477 IrVecF32,
1478 TYPE = F32,
1479 fn make(v: f32) {
1480 ScalarConstant::from_f32_bits(v.to_bits())
1481 }
1482 );
1483 ir_value!(
1484 IrF64,
1485 IrVecF64,
1486 TYPE = F64,
1487 fn make(v: f64) {
1488 ScalarConstant::from_f64_bits(v.to_bits())
1489 }
1490 );
1491
1492 macro_rules! impl_convert_to {
1493 ($src:ident -> $dest:ident) => {
1494 impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> {
1495 fn to(self) -> $dest<'ctx> {
1496 let value = if $src::TYPE == $dest::TYPE {
1497 self.value
1498 } else {
1499 self
1500 .ctx
1501 .make_operation(Opcode::Cast, [self.value], $dest::TYPE)
1502 .into()
1503 };
1504 $dest {
1505 value,
1506 ctx: self.ctx,
1507 }
1508 }
1509 }
1510 };
1511 ($first:ident $(, $ty:ident)*) => {
1512 $(
1513 impl_convert_to!($first -> $ty);
1514 impl_convert_to!($ty -> $first);
1515 )*
1516 impl_convert_to![$($ty),*];
1517 };
1518 () => {
1519 };
1520 }
1521 impl_convert_to![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
1522
1523 impl_convert_to![
1524 IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64,
1525 IrVecF32, IrVecF64
1526 ];
1527
1528 macro_rules! impl_from {
1529 ($src:ident => [$($dest:ident),*]) => {
1530 $(
1531 impl<'ctx> From<$src<'ctx>> for $dest<'ctx> {
1532 fn from(v: $src<'ctx>) -> Self {
1533 v.to()
1534 }
1535 }
1536 )*
1537 };
1538 }
1539
1540 macro_rules! impl_froms {
1541 (
1542 #[u8] $u8:ident;
1543 #[i8] $i8:ident;
1544 #[u16] $u16:ident;
1545 #[i16] $i16:ident;
1546 #[f16] $f16:ident;
1547 #[u32] $u32:ident;
1548 #[i32] $i32:ident;
1549 #[f32] $f32:ident;
1550 #[u64] $u64:ident;
1551 #[i64] $i64:ident;
1552 #[f64] $f64:ident;
1553 ) => {
1554 impl_from!($u8 => [$u16, $i16, $f16, $u32, $i32, $f32, $u64, $i64, $f64]);
1555 impl_from!($u16 => [$u32, $i32, $f32, $u64, $i64, $f64]);
1556 impl_from!($u32 => [$u64, $i64, $f64]);
1557 impl_from!($i8 => [$i16, $f16, $i32, $f32, $i64, $f64]);
1558 impl_from!($i16 => [$i32, $f32, $i64, $f64]);
1559 impl_from!($i32 => [$i64, $f64]);
1560 impl_from!($f16 => [$f32, $f64]);
1561 impl_from!($f32 => [$f64]);
1562 };
1563 }
1564
1565 impl_froms! {
1566 #[u8] IrU8;
1567 #[i8] IrI8;
1568 #[u16] IrU16;
1569 #[i16] IrI16;
1570 #[f16] IrF16;
1571 #[u32] IrU32;
1572 #[i32] IrI32;
1573 #[f32] IrF32;
1574 #[u64] IrU64;
1575 #[i64] IrI64;
1576 #[f64] IrF64;
1577 }
1578
1579 impl_froms! {
1580 #[u8] IrVecU8;
1581 #[i8] IrVecI8;
1582 #[u16] IrVecU16;
1583 #[i16] IrVecI16;
1584 #[f16] IrVecF16;
1585 #[u32] IrVecU32;
1586 #[i32] IrVecI32;
1587 #[f32] IrVecF32;
1588 #[u64] IrVecU64;
1589 #[i64] IrVecI64;
1590 #[f64] IrVecF64;
1591 }
1592
1593 impl<'ctx> Context for &'ctx IrContext<'ctx> {
1594 type Bool = IrBool<'ctx>;
1595 type U8 = IrU8<'ctx>;
1596 type I8 = IrI8<'ctx>;
1597 type U16 = IrU16<'ctx>;
1598 type I16 = IrI16<'ctx>;
1599 type F16 = IrF16<'ctx>;
1600 type U32 = IrU32<'ctx>;
1601 type I32 = IrI32<'ctx>;
1602 type F32 = IrF32<'ctx>;
1603 type U64 = IrU64<'ctx>;
1604 type I64 = IrI64<'ctx>;
1605 type F64 = IrF64<'ctx>;
1606 type VecBool8 = IrVecBool<'ctx>;
1607 type VecU8 = IrVecU8<'ctx>;
1608 type VecI8 = IrVecI8<'ctx>;
1609 type VecBool16 = IrVecBool<'ctx>;
1610 type VecU16 = IrVecU16<'ctx>;
1611 type VecI16 = IrVecI16<'ctx>;
1612 type VecF16 = IrVecF16<'ctx>;
1613 type VecBool32 = IrVecBool<'ctx>;
1614 type VecU32 = IrVecU32<'ctx>;
1615 type VecI32 = IrVecI32<'ctx>;
1616 type VecF32 = IrVecF32<'ctx>;
1617 type VecBool64 = IrVecBool<'ctx>;
1618 type VecU64 = IrVecU64<'ctx>;
1619 type VecI64 = IrVecI64<'ctx>;
1620 type VecF64 = IrVecF64<'ctx>;
1621 }
1622
1623 #[cfg(test)]
1624 mod tests {
1625 use crate::algorithms;
1626
1627 use super::*;
1628 use std::println;
1629
1630 #[test]
1631 fn test_display() {
1632 fn f<Ctx: Context>(ctx: Ctx, a: Ctx::VecU8, b: Ctx::VecF32) -> Ctx::VecF64 {
1633 let a: Ctx::VecF32 = a.into();
1634 (a - (a + b - ctx.make(5f32)).floor()).to()
1635 }
1636 let ctx = IrContext::new();
1637 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1638 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>, IrVecF32<'ctx>) -> IrVecF64<'ctx> = f;
1639 IrFunction::make(ctx, f)
1640 }
1641 let text = format!("\n{}", make_it(&ctx));
1642 println!("{}", text);
1643 assert_eq!(
1644 text,
1645 r"
1646 function(in<arg_0>: vec<U8>, in<arg_1>: vec<F32>) -> vec<F64> {
1647 op_0: vec<F32> = Cast in<arg_0>
1648 op_1: vec<F32> = Add op_0, in<arg_1>
1649 op_2: vec<F32> = Sub op_1, splat(0x40A00000_f32)
1650 op_3: vec<F32> = Floor op_2
1651 op_4: vec<F32> = Sub op_0, op_3
1652 op_5: vec<F64> = Cast op_4
1653 Return op_5
1654 }
1655 "
1656 );
1657 }
1658
1659 #[test]
1660 fn test_display_ilogb_f32() {
1661 let ctx = IrContext::new();
1662 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1663 let f: fn(&'ctx IrContext<'ctx>, IrVecF32<'ctx>) -> IrVecI32<'ctx> =
1664 algorithms::ilogb::ilogb_f32;
1665 IrFunction::make(ctx, f)
1666 }
1667 let text = format!("\n{}", make_it(&ctx));
1668 println!("{}", text);
1669 assert_eq!(
1670 text,
1671 r"
1672 function(in<arg_0>: vec<F32>) -> vec<I32> {
1673 op_0: vec<Bool> = IsFinite in<arg_0>
1674 op_1: vec<U32> = ToBits in<arg_0>
1675 op_2: vec<U32> = And op_1, splat(0x7F800000_u32)
1676 op_3: vec<U32> = Shr op_2, splat(0x17_u32)
1677 op_4: vec<Bool> = CompareEq op_3, splat(0x0_u32)
1678 op_5: vec<Bool> = CompareNe in<arg_0>, in<arg_0>
1679 op_6: vec<I32> = Splat 0x80000001_i32
1680 op_7: vec<I32> = Splat 0x7FFFFFFF_i32
1681 op_8: vec<I32> = Select op_5, op_6, op_7
1682 op_9: vec<F32> = Mul in<arg_0>, splat(0x4B000000_f32)
1683 op_10: vec<U32> = ToBits op_9
1684 op_11: vec<U32> = And op_10, splat(0x7F800000_u32)
1685 op_12: vec<U32> = Shr op_11, splat(0x17_u32)
1686 op_13: vec<I32> = Cast op_12
1687 op_14: vec<I32> = Sub op_13, splat(0x7F_i32)
1688 op_15: vec<U32> = ToBits in<arg_0>
1689 op_16: vec<U32> = And op_15, splat(0x7F800000_u32)
1690 op_17: vec<U32> = Shr op_16, splat(0x17_u32)
1691 op_18: vec<I32> = Cast op_17
1692 op_19: vec<I32> = Sub op_18, splat(0x7F_i32)
1693 op_20: vec<I32> = Select op_0, op_19, op_8
1694 op_21: vec<Bool> = CompareEq in<arg_0>, splat(0x0_f32)
1695 op_22: vec<I32> = Splat 0x80000000_i32
1696 op_23: vec<I32> = Sub op_14, splat(0x17_i32)
1697 op_24: vec<I32> = Select op_21, op_22, op_23
1698 op_25: vec<I32> = Select op_4, op_24, op_20
1699 Return op_25
1700 }
1701 "
1702 );
1703 }
1704 }