4 Bool, Compare, Context, ConvertFrom, ConvertTo, Float, Int, Make, SInt, Select, UInt,
11 fmt::{self, Write as _},
14 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div,
15 DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub,
21 use typed_arena::Arena;
23 macro_rules! make_enum {
25 $vis:vis enum $enum:ident {
32 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
41 $vis const fn as_str(self) -> &'static str {
44 Self::$name => stringify!($name),
50 impl fmt::Display for $enum {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.write_str(self.as_str())
75 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
76 pub struct VectorType {
77 pub element: ScalarType,
80 impl fmt::Display for VectorType {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 write!(f, "vec<{}>", self.element)
86 impl From<ScalarType> for Type {
87 fn from(v: ScalarType) -> Self {
92 impl From<VectorType> for Type {
93 fn from(v: VectorType) -> Self {
98 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
104 impl fmt::Display for Type {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 Type::Scalar(v) => v.fmt(f),
108 Type::Vector(v) => v.fmt(f),
113 #[derive(Clone, Copy, PartialEq, Eq, Hash)]
114 pub enum ScalarConstant {
129 macro_rules! make_scalar_constant_get {
130 ($ty:ident, $enumerant:ident) => {
131 pub fn $ty(self) -> Option<$ty> {
132 if let Self::$enumerant(v) = self {
141 macro_rules! make_scalar_constant_from {
142 ($ty:ident, $enumerant:ident) => {
143 impl From<$ty> for ScalarConstant {
144 fn from(v: $ty) -> Self {
148 impl From<$ty> for Constant {
149 fn from(v: $ty) -> Self {
150 Self::Scalar(v.into())
153 impl From<$ty> for Value<'_> {
154 fn from(v: $ty) -> Self {
155 Self::Constant(v.into())
161 make_scalar_constant_from!(bool, Bool);
162 make_scalar_constant_from!(u8, U8);
163 make_scalar_constant_from!(u16, U16);
164 make_scalar_constant_from!(u32, U32);
165 make_scalar_constant_from!(u64, U64);
166 make_scalar_constant_from!(i8, I8);
167 make_scalar_constant_from!(i16, I16);
168 make_scalar_constant_from!(i32, I32);
169 make_scalar_constant_from!(i64, I64);
171 impl ScalarConstant {
172 pub const fn ty(self) -> ScalarType {
174 ScalarConstant::Bool(_) => ScalarType::Bool,
175 ScalarConstant::U8(_) => ScalarType::U8,
176 ScalarConstant::U16(_) => ScalarType::U16,
177 ScalarConstant::U32(_) => ScalarType::U32,
178 ScalarConstant::U64(_) => ScalarType::U64,
179 ScalarConstant::I8(_) => ScalarType::I8,
180 ScalarConstant::I16(_) => ScalarType::I16,
181 ScalarConstant::I32(_) => ScalarType::I32,
182 ScalarConstant::I64(_) => ScalarType::I64,
183 ScalarConstant::F16 { .. } => ScalarType::F16,
184 ScalarConstant::F32 { .. } => ScalarType::F32,
185 ScalarConstant::F64 { .. } => ScalarType::F64,
188 pub const fn from_f16_bits(bits: u16) -> Self {
191 pub const fn from_f32_bits(bits: u32) -> Self {
194 pub const fn from_f64_bits(bits: u64) -> Self {
197 pub const fn f16_bits(self) -> Option<u16> {
198 if let Self::F16 { bits } = self {
204 pub const fn f32_bits(self) -> Option<u32> {
205 if let Self::F32 { bits } = self {
211 pub const fn f64_bits(self) -> Option<u64> {
212 if let Self::F64 { bits } = self {
218 make_scalar_constant_get!(bool, Bool);
219 make_scalar_constant_get!(u8, U8);
220 make_scalar_constant_get!(u16, U16);
221 make_scalar_constant_get!(u32, U32);
222 make_scalar_constant_get!(u64, U64);
223 make_scalar_constant_get!(i8, I8);
224 make_scalar_constant_get!(i16, I16);
225 make_scalar_constant_get!(i32, I32);
226 make_scalar_constant_get!(i64, I64);
229 impl fmt::Display for ScalarConstant {
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 ScalarConstant::Bool(false) => write!(f, "false"),
233 ScalarConstant::Bool(true) => write!(f, "true"),
234 ScalarConstant::U8(v) => write!(f, "{:#X}_u8", v),
235 ScalarConstant::U16(v) => write!(f, "{:#X}_u16", v),
236 ScalarConstant::U32(v) => write!(f, "{:#X}_u32", v),
237 ScalarConstant::U64(v) => write!(f, "{:#X}_u64", v),
238 ScalarConstant::I8(v) => write!(f, "{:#X}_i8", v),
239 ScalarConstant::I16(v) => write!(f, "{:#X}_i16", v),
240 ScalarConstant::I32(v) => write!(f, "{:#X}_i32", v),
241 ScalarConstant::I64(v) => write!(f, "{:#X}_i64", v),
242 ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits),
243 ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits),
244 ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits),
249 impl fmt::Debug for ScalarConstant {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 fmt::Display::fmt(self, f)
255 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
256 pub struct VectorSplatConstant {
257 pub element: ScalarConstant,
260 impl VectorSplatConstant {
261 pub const fn ty(self) -> VectorType {
263 element: self.element.ty(),
268 impl fmt::Display for VectorSplatConstant {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 write!(f, "splat({})", self.element)
274 impl From<ScalarConstant> for Constant {
275 fn from(v: ScalarConstant) -> Self {
280 impl From<VectorSplatConstant> for Constant {
281 fn from(v: VectorSplatConstant) -> Self {
282 Constant::VectorSplat(v)
286 impl From<ScalarConstant> for Value<'_> {
287 fn from(v: ScalarConstant) -> Self {
288 Value::Constant(v.into())
292 impl From<VectorSplatConstant> for Value<'_> {
293 fn from(v: VectorSplatConstant) -> Self {
294 Value::Constant(v.into())
298 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
300 Scalar(ScalarConstant),
301 VectorSplat(VectorSplatConstant),
305 pub const fn ty(self) -> Type {
307 Constant::Scalar(v) => Type::Scalar(v.ty()),
308 Constant::VectorSplat(v) => Type::Vector(v.ty()),
313 impl fmt::Display for Constant {
314 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 Constant::Scalar(v) => v.fmt(f),
317 Constant::VectorSplat(v) => v.fmt(f),
323 pub struct Input<'ctx> {
328 impl fmt::Display for Input<'_> {
329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
330 write!(f, "in<{}>", self.name)
334 #[derive(Copy, Clone)]
335 pub enum Value<'ctx> {
336 Input(&'ctx Input<'ctx>),
338 OpResult(&'ctx Operation<'ctx>),
341 impl<'ctx> Value<'ctx> {
342 pub const fn ty(self) -> Type {
344 Value::Input(v) => v.ty,
345 Value::Constant(v) => v.ty(),
346 Value::OpResult(v) => v.result_type,
351 impl fmt::Debug for Value<'_> {
352 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 Value::Input(v) => v.fmt(f),
355 Value::Constant(v) => v.fmt(f),
356 Value::OpResult(v) => v.result_id.fmt(f),
361 impl fmt::Display for Value<'_> {
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
364 Value::Input(v) => v.fmt(f),
365 Value::Constant(v) => v.fmt(f),
366 Value::OpResult(v) => v.result_id.fmt(f),
371 impl<'ctx> From<&'ctx Input<'ctx>> for Value<'ctx> {
372 fn from(v: &'ctx Input<'ctx>) -> Self {
377 impl<'ctx> From<&'ctx Operation<'ctx>> for Value<'ctx> {
378 fn from(v: &'ctx Operation<'ctx>) -> Self {
383 impl<'ctx> From<Constant> for Value<'ctx> {
384 fn from(v: Constant) -> Self {
429 pub struct Operation<'ctx> {
431 pub arguments: Vec<Value<'ctx>>,
432 pub result_type: Type,
433 pub result_id: OperationId,
436 impl fmt::Display for Operation<'_> {
437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441 self.result_id, self.result_type, self.opcode
443 let mut separator = " ";
444 for i in &self.arguments {
445 write!(f, "{}{}", separator, i)?;
452 #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
453 pub struct OperationId(pub u64);
455 impl fmt::Display for OperationId {
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 write!(f, "op_{}", self.0)
462 pub struct IrContext<'ctx> {
463 bytes_arena: Arena<u8>,
464 inputs_arena: Arena<Input<'ctx>>,
465 inputs: RefCell<HashMap<&'ctx str, &'ctx Input<'ctx>>>,
466 operations_arena: Arena<Operation<'ctx>>,
467 operations: RefCell<Vec<&'ctx Operation<'ctx>>>,
468 next_operation_result_id: Cell<u64>,
471 impl fmt::Debug for IrContext<'_> {
472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473 f.write_str("IrContext { .. }")
477 impl<'ctx> IrContext<'ctx> {
478 pub fn new() -> Self {
481 pub fn make_input<N: Borrow<str> + Into<String>, T: Into<Type>>(
485 ) -> &'ctx Input<'ctx> {
486 let mut inputs = self.inputs.borrow_mut();
487 let name_str = name.borrow();
489 if !name_str.is_empty() && !inputs.contains_key(name_str) {
490 let name = self.bytes_arena.alloc_str(name_str);
491 let input = self.inputs_arena.alloc(Input { name, ty });
492 inputs.insert(name, input);
495 let mut name: String = name.into();
499 let name_len = name.len();
500 let mut tag = 2usize;
502 name.truncate(name_len);
503 write!(name, "_{}", tag).unwrap();
504 if !inputs.contains_key(&*name) {
505 let name = self.bytes_arena.alloc_str(&name);
506 let input = self.inputs_arena.alloc(Input { name, ty });
507 inputs.insert(name, input);
513 pub fn make_operation<A: Into<Vec<Value<'ctx>>>, T: Into<Type>>(
518 ) -> &'ctx Operation<'ctx> {
519 let arguments = arguments.into();
520 let result_type = result_type.into();
521 let result_id = OperationId(self.next_operation_result_id.get());
522 self.next_operation_result_id.set(result_id.0 + 1);
523 let operation = self.operations_arena.alloc(Operation {
529 self.operations.borrow_mut().push(operation);
532 pub fn replace_operations(
534 new_operations: Vec<&'ctx Operation<'ctx>>,
535 ) -> Vec<&'ctx Operation<'ctx>> {
536 self.operations.replace(new_operations)
541 pub struct IrFunction<'ctx> {
542 pub inputs: Vec<&'ctx Input<'ctx>>,
543 pub operations: Vec<&'ctx Operation<'ctx>>,
544 pub outputs: Vec<Value<'ctx>>,
547 impl fmt::Display for IrFunction<'_> {
548 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
549 write!(f, "function(")?;
550 let mut first = true;
551 for input in &self.inputs {
557 write!(f, "{}: {}", input, input.ty)?;
559 match self.outputs.len() {
560 0 => writeln!(f, ") {{")?,
561 1 => writeln!(f, ") -> {} {{", self.outputs[0].ty())?,
563 write!(f, ") -> ({}", self.outputs[0].ty())?;
564 for output in self.outputs.iter().skip(1) {
565 write!(f, ", {}", output.ty())?;
567 writeln!(f, ") {{")?;
570 for operation in &self.operations {
571 writeln!(f, " {}", operation)?;
573 match self.outputs.len() {
574 0 => writeln!(f, "}}")?,
575 1 => writeln!(f, " Return {}\n}}", self.outputs[0])?,
577 write!(f, " Return {}", self.outputs[0])?;
578 for output in self.outputs.iter().skip(1) {
579 write!(f, ", {}", output)?;
581 writeln!(f, "\n}}")?;
588 impl<'ctx> IrFunction<'ctx> {
589 pub fn make<F: IrFunctionMaker<'ctx>>(ctx: &'ctx IrContext<'ctx>, f: F) -> Self {
590 let old_operations = ctx.replace_operations(Vec::new());
591 let (v, inputs) = F::make_inputs(ctx);
592 let outputs = f.call(ctx, v).outputs_to_vec();
593 let operations = ctx.replace_operations(old_operations);
602 pub trait IrFunctionMaker<'ctx>: Sized {
604 type Outputs: IrFunctionMakerOutputs<'ctx>;
605 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs;
606 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>);
609 pub trait IrFunctionMakerOutputs<'ctx> {
610 fn outputs_to_vec(self) -> Vec<Value<'ctx>>;
613 impl<'ctx, T: IrValue<'ctx>> IrFunctionMakerOutputs<'ctx> for T {
614 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
615 [self.value()].into()
619 impl<'ctx> IrFunctionMakerOutputs<'ctx> for () {
620 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
625 impl<'ctx, R: IrFunctionMakerOutputs<'ctx>> IrFunctionMaker<'ctx>
626 for fn(&'ctx IrContext<'ctx>) -> R
630 fn call(self, ctx: &'ctx IrContext<'ctx>, _inputs: Self::Inputs) -> Self::Outputs {
633 fn make_inputs(_ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
638 macro_rules! impl_ir_function_maker_io {
640 ($first_arg:ident: $first_arg_ty:ident, $($arg:ident: $arg_ty:ident,)*) => {
641 impl<'ctx, $first_arg_ty, $($arg_ty,)* R> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>, $first_arg_ty $(, $arg_ty)*) -> R
643 $first_arg_ty: IrValue<'ctx>,
644 $($arg_ty: IrValue<'ctx>,)*
645 R: IrFunctionMakerOutputs<'ctx>,
647 type Inputs = ($first_arg_ty, $($arg_ty,)*);
649 fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs {
650 let ($first_arg, $($arg,)*) = inputs;
651 self(ctx, $first_arg$(, $arg)*)
653 fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) {
654 let mut $first_arg = String::new();
655 $(let mut $arg = String::new();)*
656 for (index, arg) in [&mut $first_arg $(, &mut $arg)*].iter_mut().enumerate() {
657 **arg = format!("arg_{}", index);
659 let $first_arg = $first_arg_ty::make_input(ctx, $first_arg);
660 $(let $arg = $arg_ty::make_input(ctx, $arg);)*
661 (($first_arg.0, $($arg.0,)*), [$first_arg.1 $(, $arg.1)*].into())
664 impl<'ctx, $first_arg_ty, $($arg_ty),*> IrFunctionMakerOutputs<'ctx> for ($first_arg_ty, $($arg_ty,)*)
666 $first_arg_ty: IrValue<'ctx>,
667 $($arg_ty: IrValue<'ctx>,)*
669 fn outputs_to_vec(self) -> Vec<Value<'ctx>> {
670 let ($first_arg, $($arg,)*) = self;
671 [$first_arg.value() $(, $arg.value())*].into()
674 impl_ir_function_maker_io!($($arg: $arg_ty,)*);
678 impl_ir_function_maker_io!(
693 pub trait IrValue<'ctx>: Copy + Make<Context = &'ctx IrContext<'ctx>> {
695 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self;
696 fn make_input<N: Borrow<str> + Into<String>>(
697 ctx: &'ctx IrContext<'ctx>,
699 ) -> (Self, &'ctx Input<'ctx>) {
700 let input = ctx.make_input(name, Self::TYPE);
701 (Self::new(ctx, input.into()), input)
703 fn value(self) -> Value<'ctx>;
706 macro_rules! ir_value {
707 ($name:ident, $vec_name:ident, TYPE = $scalar_type:ident, fn make($make_var:ident: $prim:ident) {$make:expr}) => {
708 #[derive(Clone, Copy, Debug)]
709 pub struct $name<'ctx> {
710 pub value: Value<'ctx>,
711 pub ctx: &'ctx IrContext<'ctx>,
714 impl<'ctx> IrValue<'ctx> for $name<'ctx> {
715 const TYPE: Type = Type::Scalar(Self::SCALAR_TYPE);
716 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
717 assert_eq!(value.ty(), Self::TYPE);
720 fn value(self) -> Value<'ctx> {
725 impl<'ctx> $name<'ctx> {
726 pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type;
729 impl<'ctx> Make for $name<'ctx> {
731 type Context = &'ctx IrContext<'ctx>;
732 fn ctx(self) -> Self::Context {
735 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
736 let value: ScalarConstant = $make;
737 let value = value.into();
742 #[derive(Clone, Copy, Debug)]
743 pub struct $vec_name<'ctx> {
744 pub value: Value<'ctx>,
745 pub ctx: &'ctx IrContext<'ctx>,
748 impl<'ctx> IrValue<'ctx> for $vec_name<'ctx> {
749 const TYPE: Type = Type::Vector(Self::VECTOR_TYPE);
750 fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self {
751 assert_eq!(value.ty(), Self::TYPE);
754 fn value(self) -> Value<'ctx> {
759 impl<'ctx> $vec_name<'ctx> {
760 pub const VECTOR_TYPE: VectorType = VectorType {
761 element: ScalarType::$scalar_type,
765 impl<'ctx> Make for $vec_name<'ctx> {
767 type Context = &'ctx IrContext<'ctx>;
768 fn ctx(self) -> Self::Context {
771 fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self {
774 value: VectorSplatConstant { element }.into(),
780 impl<'ctx> Select<$name<'ctx>> for IrBool<'ctx> {
781 fn select(self, true_v: $name<'ctx>, false_v: $name<'ctx>) -> $name<'ctx> {
786 [self.value, true_v.value, false_v.value],
797 impl<'ctx> Select<$vec_name<'ctx>> for IrVecBool<'ctx> {
798 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
803 [self.value, true_v.value, false_v.value],
814 impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> {
815 fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> {
820 [self.value, true_v.value, false_v.value],
831 impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> {
832 fn from(v: $name<'ctx>) -> Self {
835 .make_operation(Opcode::Splat, [v.value], $vec_name::TYPE)
837 Self { value, ctx: v.ctx }
843 macro_rules! impl_bit_ops {
845 impl<'ctx> BitAnd for $ty<'ctx> {
848 fn bitand(self, rhs: Self) -> Self::Output {
851 .make_operation(Opcode::And, [self.value, rhs.value], Self::TYPE)
859 impl<'ctx> BitOr for $ty<'ctx> {
862 fn bitor(self, rhs: Self) -> Self::Output {
865 .make_operation(Opcode::Or, [self.value, rhs.value], Self::TYPE)
873 impl<'ctx> BitXor for $ty<'ctx> {
876 fn bitxor(self, rhs: Self) -> Self::Output {
879 .make_operation(Opcode::Xor, [self.value, rhs.value], Self::TYPE)
887 impl<'ctx> Not for $ty<'ctx> {
890 fn not(self) -> Self::Output {
893 .make_operation(Opcode::Not, [self.value], Self::TYPE)
901 impl<'ctx> BitAndAssign for $ty<'ctx> {
902 fn bitand_assign(&mut self, rhs: Self) {
906 impl<'ctx> BitOrAssign for $ty<'ctx> {
907 fn bitor_assign(&mut self, rhs: Self) {
911 impl<'ctx> BitXorAssign for $ty<'ctx> {
912 fn bitxor_assign(&mut self, rhs: Self) {
919 macro_rules! impl_number_ops {
920 ($ty:ident, $bool:ident) => {
921 impl<'ctx> Add for $ty<'ctx> {
924 fn add(self, rhs: Self) -> Self::Output {
927 .make_operation(Opcode::Add, [self.value, rhs.value], Self::TYPE)
935 impl<'ctx> Sub for $ty<'ctx> {
938 fn sub(self, rhs: Self) -> Self::Output {
941 .make_operation(Opcode::Sub, [self.value, rhs.value], Self::TYPE)
949 impl<'ctx> Mul for $ty<'ctx> {
952 fn mul(self, rhs: Self) -> Self::Output {
955 .make_operation(Opcode::Mul, [self.value, rhs.value], Self::TYPE)
963 impl<'ctx> Div for $ty<'ctx> {
966 fn div(self, rhs: Self) -> Self::Output {
969 .make_operation(Opcode::Div, [self.value, rhs.value], Self::TYPE)
977 impl<'ctx> Rem for $ty<'ctx> {
980 fn rem(self, rhs: Self) -> Self::Output {
983 .make_operation(Opcode::Rem, [self.value, rhs.value], Self::TYPE)
991 impl<'ctx> AddAssign for $ty<'ctx> {
992 fn add_assign(&mut self, rhs: Self) {
996 impl<'ctx> SubAssign for $ty<'ctx> {
997 fn sub_assign(&mut self, rhs: Self) {
1001 impl<'ctx> MulAssign for $ty<'ctx> {
1002 fn mul_assign(&mut self, rhs: Self) {
1003 *self = *self * rhs;
1006 impl<'ctx> DivAssign for $ty<'ctx> {
1007 fn div_assign(&mut self, rhs: Self) {
1008 *self = *self / rhs;
1011 impl<'ctx> RemAssign for $ty<'ctx> {
1012 fn rem_assign(&mut self, rhs: Self) {
1013 *self = *self % rhs;
1016 impl<'ctx> Compare for $ty<'ctx> {
1017 type Bool = $bool<'ctx>;
1018 fn eq(self, rhs: Self) -> Self::Bool {
1021 .make_operation(Opcode::CompareEq, [self.value, rhs.value], $bool::TYPE)
1028 fn ne(self, rhs: Self) -> Self::Bool {
1031 .make_operation(Opcode::CompareNe, [self.value, rhs.value], $bool::TYPE)
1038 fn lt(self, rhs: Self) -> Self::Bool {
1041 .make_operation(Opcode::CompareLt, [self.value, rhs.value], $bool::TYPE)
1048 fn gt(self, rhs: Self) -> Self::Bool {
1051 .make_operation(Opcode::CompareGt, [self.value, rhs.value], $bool::TYPE)
1058 fn le(self, rhs: Self) -> Self::Bool {
1061 .make_operation(Opcode::CompareLe, [self.value, rhs.value], $bool::TYPE)
1068 fn ge(self, rhs: Self) -> Self::Bool {
1071 .make_operation(Opcode::CompareGe, [self.value, rhs.value], $bool::TYPE)
1082 macro_rules! impl_bool_compare {
1084 impl<'ctx> Compare for $ty<'ctx> {
1086 fn eq(self, rhs: Self) -> Self::Bool {
1089 fn ne(self, rhs: Self) -> Self::Bool {
1092 fn lt(self, rhs: Self) -> Self::Bool {
1095 fn gt(self, rhs: Self) -> Self::Bool {
1098 fn le(self, rhs: Self) -> Self::Bool {
1101 fn ge(self, rhs: Self) -> Self::Bool {
1108 impl_bool_compare!(IrBool);
1109 impl_bool_compare!(IrVecBool);
1111 macro_rules! impl_shift_ops {
1113 impl<'ctx> Shl for $ty<'ctx> {
1116 fn shl(self, rhs: Self) -> Self::Output {
1119 .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE)
1127 impl<'ctx> Shr for $ty<'ctx> {
1130 fn shr(self, rhs: Self) -> Self::Output {
1133 .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE)
1141 impl<'ctx> ShlAssign for $ty<'ctx> {
1142 fn shl_assign(&mut self, rhs: Self) {
1143 *self = *self << rhs;
1146 impl<'ctx> ShrAssign for $ty<'ctx> {
1147 fn shr_assign(&mut self, rhs: Self) {
1148 *self = *self >> rhs;
1154 macro_rules! impl_neg {
1156 impl<'ctx> Neg for $ty<'ctx> {
1159 fn neg(self) -> Self::Output {
1162 .make_operation(Opcode::Neg, [self.value], Self::TYPE)
1173 macro_rules! impl_int_trait {
1175 impl<'ctx> Int for $ty<'ctx> {
1176 fn leading_zeros(self) -> Self {
1179 .make_operation(Opcode::CountLeadingZeros, [self.value], Self::TYPE)
1186 fn trailing_zeros(self) -> Self {
1189 .make_operation(Opcode::CountTrailingZeros, [self.value], Self::TYPE)
1196 fn count_ones(self) -> Self {
1199 .make_operation(Opcode::CountSetBits, [self.value], Self::TYPE)
1210 macro_rules! impl_integer_ops {
1211 ($scalar:ident, $vec:ident) => {
1212 impl_bit_ops!($scalar);
1213 impl_number_ops!($scalar, IrBool);
1214 impl_shift_ops!($scalar);
1215 impl_bit_ops!($vec);
1216 impl_number_ops!($vec, IrVecBool);
1217 impl_shift_ops!($vec);
1218 impl_int_trait!($scalar);
1219 impl_int_trait!($vec);
1223 macro_rules! impl_uint_sint_ops {
1224 ($uint_scalar:ident, $uint_vec:ident, $sint_scalar:ident, $sint_vec:ident) => {
1225 impl_integer_ops!($uint_scalar, $uint_vec);
1226 impl_integer_ops!($sint_scalar, $sint_vec);
1227 impl_neg!($sint_scalar);
1228 impl_neg!($sint_vec);
1230 impl<'ctx> UInt for $uint_scalar<'ctx> {
1231 type PrimUInt = Self::Prim;
1232 type SignedType = $sint_scalar<'ctx>;
1234 impl<'ctx> UInt for $uint_vec<'ctx> {
1235 type PrimUInt = Self::Prim;
1236 type SignedType = $sint_vec<'ctx>;
1238 impl<'ctx> SInt for $sint_scalar<'ctx> {
1239 type PrimSInt = Self::Prim;
1240 type UnsignedType = $uint_scalar<'ctx>;
1242 impl<'ctx> SInt for $sint_vec<'ctx> {
1243 type PrimSInt = Self::Prim;
1244 type UnsignedType = $uint_vec<'ctx>;
1249 impl_uint_sint_ops!(IrU8, IrVecU8, IrI8, IrVecI8);
1250 impl_uint_sint_ops!(IrU16, IrVecU16, IrI16, IrVecI16);
1251 impl_uint_sint_ops!(IrU32, IrVecU32, IrI32, IrVecI32);
1252 impl_uint_sint_ops!(IrU64, IrVecU64, IrI64, IrVecI64);
1254 macro_rules! impl_float {
1255 ($float:ident, $bits:ident, $signed_bits:ident) => {
1256 impl<'ctx> Float for $float<'ctx> {
1257 type PrimFloat = <$float<'ctx> as Make>::Prim;
1258 type BitsType = $bits<'ctx>;
1259 type SignedBitsType = $signed_bits<'ctx>;
1260 fn abs(self) -> Self {
1263 .make_operation(Opcode::Abs, [self.value], Self::TYPE)
1270 fn trunc(self) -> Self {
1273 .make_operation(Opcode::Trunc, [self.value], Self::TYPE)
1280 fn ceil(self) -> Self {
1283 .make_operation(Opcode::Ceil, [self.value], Self::TYPE)
1290 fn floor(self) -> Self {
1293 .make_operation(Opcode::Floor, [self.value], Self::TYPE)
1300 fn round(self) -> Self {
1303 .make_operation(Opcode::Round, [self.value], Self::TYPE)
1310 #[cfg(feature = "fma")]
1311 fn fma(self, a: Self, b: Self) -> Self {
1314 .make_operation(Opcode::Fma, [self.value, a.value, b.value], Self::TYPE)
1321 fn is_nan(self) -> Self::Bool {
1326 [self.value, self.value],
1335 fn is_infinite(self) -> Self::Bool {
1338 .make_operation(Opcode::IsInfinite, [self.value], Self::Bool::TYPE)
1345 fn is_finite(self) -> Self::Bool {
1348 .make_operation(Opcode::IsFinite, [self.value], Self::Bool::TYPE)
1355 fn from_bits(v: Self::BitsType) -> Self {
1358 .make_operation(Opcode::FromBits, [v.value], Self::TYPE)
1360 Self { value, ctx: v.ctx }
1362 fn to_bits(self) -> Self::BitsType {
1365 .make_operation(Opcode::ToBits, [self.value], Self::BitsType::TYPE)
1376 macro_rules! impl_float_ops {
1377 ($scalar:ident, $scalar_bits:ident, $scalar_signed_bits:ident, $vec:ident, $vec_bits:ident, $vec_signed_bits:ident) => {
1378 impl_number_ops!($scalar, IrBool);
1379 impl_number_ops!($vec, IrVecBool);
1382 impl_float!($scalar, $scalar_bits, $scalar_signed_bits);
1383 impl_float!($vec, $vec_bits, $vec_signed_bits);
1387 impl_float_ops!(IrF16, IrU16, IrI16, IrVecF16, IrVecU16, IrVecI16);
1388 impl_float_ops!(IrF32, IrU32, IrI32, IrVecF32, IrVecU32, IrVecI32);
1389 impl_float_ops!(IrF64, IrU64, IrI64, IrVecF64, IrVecU64, IrVecI64);
1400 impl<'ctx> Bool for IrBool<'ctx> {}
1401 impl<'ctx> Bool for IrVecBool<'ctx> {}
1403 impl_bit_ops!(IrBool);
1404 impl_bit_ops!(IrVecBool);
1475 ScalarConstant::from_f16_bits(v.to_bits())
1483 ScalarConstant::from_f32_bits(v.to_bits())
1491 ScalarConstant::from_f64_bits(v.to_bits())
1495 macro_rules! impl_convert_from {
1496 ($src:ident -> $dest:ident) => {
1497 impl<'ctx> ConvertFrom<$src<'ctx>> for $dest<'ctx> {
1498 fn cvt_from(v: $src<'ctx>) -> Self {
1499 let value = if $src::TYPE == $dest::TYPE {
1504 .make_operation(Opcode::Cast, [v.value], $dest::TYPE)
1514 ($first:ident $(, $ty:ident)*) => {
1516 impl_convert_from!($first -> $ty);
1517 impl_convert_from!($ty -> $first);
1519 impl_convert_from![$($ty),*];
1524 impl_convert_from![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64];
1527 IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64,
1531 macro_rules! impl_from {
1532 ($src:ident => [$($dest:ident),*]) => {
1534 impl<'ctx> From<$src<'ctx>> for $dest<'ctx> {
1535 fn from(v: $src<'ctx>) -> Self {
1543 macro_rules! impl_froms {
1557 impl_from!($u8 => [$u16, $i16, $f16, $u32, $i32, $f32, $u64, $i64, $f64]);
1558 impl_from!($u16 => [$u32, $i32, $f32, $u64, $i64, $f64]);
1559 impl_from!($u32 => [$u64, $i64, $f64]);
1560 impl_from!($i8 => [$i16, $f16, $i32, $f32, $i64, $f64]);
1561 impl_from!($i16 => [$i32, $f32, $i64, $f64]);
1562 impl_from!($i32 => [$i64, $f64]);
1563 impl_from!($f16 => [$f32, $f64]);
1564 impl_from!($f32 => [$f64]);
1596 impl<'ctx> Context for &'ctx IrContext<'ctx> {
1597 type Bool = IrBool<'ctx>;
1598 type U8 = IrU8<'ctx>;
1599 type I8 = IrI8<'ctx>;
1600 type U16 = IrU16<'ctx>;
1601 type I16 = IrI16<'ctx>;
1602 type F16 = IrF16<'ctx>;
1603 type U32 = IrU32<'ctx>;
1604 type I32 = IrI32<'ctx>;
1605 type F32 = IrF32<'ctx>;
1606 type U64 = IrU64<'ctx>;
1607 type I64 = IrI64<'ctx>;
1608 type F64 = IrF64<'ctx>;
1609 type VecBool8 = IrVecBool<'ctx>;
1610 type VecU8 = IrVecU8<'ctx>;
1611 type VecI8 = IrVecI8<'ctx>;
1612 type VecBool16 = IrVecBool<'ctx>;
1613 type VecU16 = IrVecU16<'ctx>;
1614 type VecI16 = IrVecI16<'ctx>;
1615 type VecF16 = IrVecF16<'ctx>;
1616 type VecBool32 = IrVecBool<'ctx>;
1617 type VecU32 = IrVecU32<'ctx>;
1618 type VecI32 = IrVecI32<'ctx>;
1619 type VecF32 = IrVecF32<'ctx>;
1620 type VecBool64 = IrVecBool<'ctx>;
1621 type VecU64 = IrVecU64<'ctx>;
1622 type VecI64 = IrVecI64<'ctx>;
1623 type VecF64 = IrVecF64<'ctx>;
1628 use crate::algorithms;
1635 fn f<Ctx: Context>(ctx: Ctx, a: Ctx::VecU8, b: Ctx::VecF32) -> Ctx::VecF64 {
1636 let a: Ctx::VecF32 = a.into();
1637 (a - (a + b - ctx.make(5f32)).floor()).to()
1639 let ctx = IrContext::new();
1640 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1641 let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>, IrVecF32<'ctx>) -> IrVecF64<'ctx> = f;
1642 IrFunction::make(ctx, f)
1644 let text = format!("\n{}", make_it(&ctx));
1645 println!("{}", text);
1649 function(in<arg_0>: vec<U8>, in<arg_1>: vec<F32>) -> vec<F64> {
1650 op_0: vec<F32> = Cast in<arg_0>
1651 op_1: vec<F32> = Add op_0, in<arg_1>
1652 op_2: vec<F32> = Sub op_1, splat(0x40A00000_f32)
1653 op_3: vec<F32> = Floor op_2
1654 op_4: vec<F32> = Sub op_0, op_3
1655 op_5: vec<F64> = Cast op_4
1663 fn test_display_ilogb_f32() {
1664 let ctx = IrContext::new();
1665 fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> {
1666 let f: fn(&'ctx IrContext<'ctx>, IrVecF32<'ctx>) -> IrVecI32<'ctx> =
1667 algorithms::ilogb::ilogb_f32;
1668 IrFunction::make(ctx, f)
1670 let text = format!("\n{}", make_it(&ctx));
1671 println!("{}", text);
1675 function(in<arg_0>: vec<F32>) -> vec<I32> {
1676 op_0: vec<Bool> = IsFinite in<arg_0>
1677 op_1: vec<U32> = ToBits in<arg_0>
1678 op_2: vec<U32> = And op_1, splat(0x7F800000_u32)
1679 op_3: vec<U32> = Shr op_2, splat(0x17_u32)
1680 op_4: vec<Bool> = CompareEq op_3, splat(0x0_u32)
1681 op_5: vec<Bool> = CompareNe in<arg_0>, in<arg_0>
1682 op_6: vec<I32> = Splat 0x80000001_i32
1683 op_7: vec<I32> = Splat 0x7FFFFFFF_i32
1684 op_8: vec<I32> = Select op_5, op_6, op_7
1685 op_9: vec<F32> = Mul in<arg_0>, splat(0x4B000000_f32)
1686 op_10: vec<U32> = ToBits op_9
1687 op_11: vec<U32> = And op_10, splat(0x7F800000_u32)
1688 op_12: vec<U32> = Shr op_11, splat(0x17_u32)
1689 op_13: vec<I32> = Cast op_12
1690 op_14: vec<I32> = Sub op_13, splat(0x7F_i32)
1691 op_15: vec<U32> = ToBits in<arg_0>
1692 op_16: vec<U32> = And op_15, splat(0x7F800000_u32)
1693 op_17: vec<U32> = Shr op_16, splat(0x17_u32)
1694 op_18: vec<I32> = Cast op_17
1695 op_19: vec<I32> = Sub op_18, splat(0x7F_i32)
1696 op_20: vec<I32> = Select op_0, op_19, op_8
1697 op_21: vec<Bool> = CompareEq in<arg_0>, splat(0x0_f32)
1698 op_22: vec<I32> = Splat 0x80000000_i32
1699 op_23: vec<I32> = Sub op_14, splat(0x17_i32)
1700 op_24: vec<I32> = Select op_21, op_22, op_23
1701 op_25: vec<I32> = Select op_4, op_24, op_20