use crate::{ f16::F16, traits::{ Bool, Compare, Context, ConvertFrom, ConvertTo, Float, Int, Make, SInt, Select, UInt, }, }; use std::{ borrow::Borrow, cell::{Cell, RefCell}, collections::HashMap, fmt::{self, Write as _}, format, ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }, string::String, vec::Vec, }; use typed_arena::Arena; macro_rules! make_enum { ( $vis:vis enum $enum:ident { $( $(#[$meta:meta])* $name:ident, )* } ) => { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] $vis enum $enum { $( $(#[$meta])* $name, )* } impl $enum { $vis const fn as_str(self) -> &'static str { match self { $( Self::$name => stringify!($name), )* } } } impl fmt::Display for $enum { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.as_str()) } } }; } make_enum! { pub enum ScalarType { Bool, U8, I8, U16, I16, F16, U32, I32, F32, U64, I64, F64, } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct VectorType { pub element: ScalarType, } impl fmt::Display for VectorType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "vec<{}>", self.element) } } impl From for Type { fn from(v: ScalarType) -> Self { Type::Scalar(v) } } impl From for Type { fn from(v: VectorType) -> Self { Type::Vector(v) } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Type { Scalar(ScalarType), Vector(VectorType), } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Scalar(v) => v.fmt(f), Type::Vector(v) => v.fmt(f), } } } #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub enum ScalarConstant { Bool(bool), U8(u8), U16(u16), U32(u32), U64(u64), I8(i8), I16(i16), I32(i32), I64(i64), F16 { bits: u16 }, F32 { bits: u32 }, F64 { bits: u64 }, } macro_rules! make_scalar_constant_get { ($ty:ident, $enumerant:ident) => { pub fn $ty(self) -> Option<$ty> { if let Self::$enumerant(v) = self { Some(v) } else { None } } }; } macro_rules! make_scalar_constant_from { ($ty:ident, $enumerant:ident) => { impl From<$ty> for ScalarConstant { fn from(v: $ty) -> Self { Self::$enumerant(v) } } impl From<$ty> for Constant { fn from(v: $ty) -> Self { Self::Scalar(v.into()) } } impl From<$ty> for Value<'_> { fn from(v: $ty) -> Self { Self::Constant(v.into()) } } }; } make_scalar_constant_from!(bool, Bool); make_scalar_constant_from!(u8, U8); make_scalar_constant_from!(u16, U16); make_scalar_constant_from!(u32, U32); make_scalar_constant_from!(u64, U64); make_scalar_constant_from!(i8, I8); make_scalar_constant_from!(i16, I16); make_scalar_constant_from!(i32, I32); make_scalar_constant_from!(i64, I64); impl ScalarConstant { pub const fn ty(self) -> ScalarType { match self { ScalarConstant::Bool(_) => ScalarType::Bool, ScalarConstant::U8(_) => ScalarType::U8, ScalarConstant::U16(_) => ScalarType::U16, ScalarConstant::U32(_) => ScalarType::U32, ScalarConstant::U64(_) => ScalarType::U64, ScalarConstant::I8(_) => ScalarType::I8, ScalarConstant::I16(_) => ScalarType::I16, ScalarConstant::I32(_) => ScalarType::I32, ScalarConstant::I64(_) => ScalarType::I64, ScalarConstant::F16 { .. } => ScalarType::F16, ScalarConstant::F32 { .. } => ScalarType::F32, ScalarConstant::F64 { .. } => ScalarType::F64, } } pub const fn from_f16_bits(bits: u16) -> Self { Self::F16 { bits } } pub const fn from_f32_bits(bits: u32) -> Self { Self::F32 { bits } } pub const fn from_f64_bits(bits: u64) -> Self { Self::F64 { bits } } pub const fn f16_bits(self) -> Option { if let Self::F16 { bits } = self { Some(bits) } else { None } } pub const fn f32_bits(self) -> Option { if let Self::F32 { bits } = self { Some(bits) } else { None } } pub const fn f64_bits(self) -> Option { if let Self::F64 { bits } = self { Some(bits) } else { None } } make_scalar_constant_get!(bool, Bool); make_scalar_constant_get!(u8, U8); make_scalar_constant_get!(u16, U16); make_scalar_constant_get!(u32, U32); make_scalar_constant_get!(u64, U64); make_scalar_constant_get!(i8, I8); make_scalar_constant_get!(i16, I16); make_scalar_constant_get!(i32, I32); make_scalar_constant_get!(i64, I64); } impl fmt::Display for ScalarConstant { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ScalarConstant::Bool(false) => write!(f, "false"), ScalarConstant::Bool(true) => write!(f, "true"), ScalarConstant::U8(v) => write!(f, "{:#X}_u8", v), ScalarConstant::U16(v) => write!(f, "{:#X}_u16", v), ScalarConstant::U32(v) => write!(f, "{:#X}_u32", v), ScalarConstant::U64(v) => write!(f, "{:#X}_u64", v), ScalarConstant::I8(v) => write!(f, "{:#X}_i8", v), ScalarConstant::I16(v) => write!(f, "{:#X}_i16", v), ScalarConstant::I32(v) => write!(f, "{:#X}_i32", v), ScalarConstant::I64(v) => write!(f, "{:#X}_i64", v), ScalarConstant::F16 { bits } => write!(f, "{:#X}_f16", bits), ScalarConstant::F32 { bits } => write!(f, "{:#X}_f32", bits), ScalarConstant::F64 { bits } => write!(f, "{:#X}_f64", bits), } } } impl fmt::Debug for ScalarConstant { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(self, f) } } #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct VectorSplatConstant { pub element: ScalarConstant, } impl VectorSplatConstant { pub const fn ty(self) -> VectorType { VectorType { element: self.element.ty(), } } } impl fmt::Display for VectorSplatConstant { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "splat({})", self.element) } } impl From for Constant { fn from(v: ScalarConstant) -> Self { Constant::Scalar(v) } } impl From for Constant { fn from(v: VectorSplatConstant) -> Self { Constant::VectorSplat(v) } } impl From for Value<'_> { fn from(v: ScalarConstant) -> Self { Value::Constant(v.into()) } } impl From for Value<'_> { fn from(v: VectorSplatConstant) -> Self { Value::Constant(v.into()) } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Constant { Scalar(ScalarConstant), VectorSplat(VectorSplatConstant), } impl Constant { pub const fn ty(self) -> Type { match self { Constant::Scalar(v) => Type::Scalar(v.ty()), Constant::VectorSplat(v) => Type::Vector(v.ty()), } } } impl fmt::Display for Constant { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Constant::Scalar(v) => v.fmt(f), Constant::VectorSplat(v) => v.fmt(f), } } } #[derive(Debug)] pub struct Input<'ctx> { pub name: &'ctx str, pub ty: Type, } impl fmt::Display for Input<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "in<{}>", self.name) } } #[derive(Copy, Clone)] pub enum Value<'ctx> { Input(&'ctx Input<'ctx>), Constant(Constant), OpResult(&'ctx Operation<'ctx>), } impl<'ctx> Value<'ctx> { pub const fn ty(self) -> Type { match self { Value::Input(v) => v.ty, Value::Constant(v) => v.ty(), Value::OpResult(v) => v.result_type, } } } impl fmt::Debug for Value<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Value::Input(v) => v.fmt(f), Value::Constant(v) => v.fmt(f), Value::OpResult(v) => v.result_id.fmt(f), } } } impl fmt::Display for Value<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Value::Input(v) => v.fmt(f), Value::Constant(v) => v.fmt(f), Value::OpResult(v) => v.result_id.fmt(f), } } } impl<'ctx> From<&'ctx Input<'ctx>> for Value<'ctx> { fn from(v: &'ctx Input<'ctx>) -> Self { Value::Input(v) } } impl<'ctx> From<&'ctx Operation<'ctx>> for Value<'ctx> { fn from(v: &'ctx Operation<'ctx>) -> Self { Value::OpResult(v) } } impl<'ctx> From for Value<'ctx> { fn from(v: Constant) -> Self { Value::Constant(v) } } make_enum! { pub enum Opcode { Add, Sub, Mul, Div, Rem, Fma, Cast, And, Or, Xor, Not, Shl, Shr, CountSetBits, CountLeadingZeros, CountTrailingZeros, Neg, Abs, Trunc, Ceil, Floor, Round, IsInfinite, IsFinite, ToBits, FromBits, Splat, CompareEq, CompareNe, CompareLt, CompareLe, CompareGt, CompareGe, Select, } } #[derive(Debug)] pub struct Operation<'ctx> { pub opcode: Opcode, pub arguments: Vec>, pub result_type: Type, pub result_id: OperationId, } impl fmt::Display for Operation<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}: {} = {}", self.result_id, self.result_type, self.opcode )?; let mut separator = " "; for i in &self.arguments { write!(f, "{}{}", separator, i)?; separator = ", "; } Ok(()) } } #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub struct OperationId(pub u64); impl fmt::Display for OperationId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "op_{}", self.0) } } #[derive(Default)] pub struct IrContext<'ctx> { bytes_arena: Arena, inputs_arena: Arena>, inputs: RefCell>>, operations_arena: Arena>, operations: RefCell>>, next_operation_result_id: Cell, } impl fmt::Debug for IrContext<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("IrContext { .. }") } } impl<'ctx> IrContext<'ctx> { pub fn new() -> Self { Self::default() } pub fn make_input + Into, T: Into>( &'ctx self, name: N, ty: T, ) -> &'ctx Input<'ctx> { let mut inputs = self.inputs.borrow_mut(); let name_str = name.borrow(); let ty = ty.into(); if !name_str.is_empty() && !inputs.contains_key(name_str) { let name = self.bytes_arena.alloc_str(name_str); let input = self.inputs_arena.alloc(Input { name, ty }); inputs.insert(name, input); return input; } let mut name: String = name.into(); if name.is_empty() { name = "in".into(); } let name_len = name.len(); let mut tag = 2usize; loop { name.truncate(name_len); write!(name, "_{}", tag).unwrap(); if !inputs.contains_key(&*name) { let name = self.bytes_arena.alloc_str(&name); let input = self.inputs_arena.alloc(Input { name, ty }); inputs.insert(name, input); return input; } tag += 1; } } pub fn make_operation>>, T: Into>( &'ctx self, opcode: Opcode, arguments: A, result_type: T, ) -> &'ctx Operation<'ctx> { let arguments = arguments.into(); let result_type = result_type.into(); let result_id = OperationId(self.next_operation_result_id.get()); self.next_operation_result_id.set(result_id.0 + 1); let operation = self.operations_arena.alloc(Operation { opcode, arguments, result_type, result_id, }); self.operations.borrow_mut().push(operation); operation } pub fn replace_operations( &'ctx self, new_operations: Vec<&'ctx Operation<'ctx>>, ) -> Vec<&'ctx Operation<'ctx>> { self.operations.replace(new_operations) } } #[derive(Debug)] pub struct IrFunction<'ctx> { pub inputs: Vec<&'ctx Input<'ctx>>, pub operations: Vec<&'ctx Operation<'ctx>>, pub outputs: Vec>, } impl fmt::Display for IrFunction<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "function(")?; let mut first = true; for input in &self.inputs { if first { first = false } else { write!(f, ", ")?; } write!(f, "{}: {}", input, input.ty)?; } match self.outputs.len() { 0 => writeln!(f, ") {{")?, 1 => writeln!(f, ") -> {} {{", self.outputs[0].ty())?, _ => { write!(f, ") -> ({}", self.outputs[0].ty())?; for output in self.outputs.iter().skip(1) { write!(f, ", {}", output.ty())?; } writeln!(f, ") {{")?; } } for operation in &self.operations { writeln!(f, " {}", operation)?; } match self.outputs.len() { 0 => writeln!(f, "}}")?, 1 => writeln!(f, " Return {}\n}}", self.outputs[0])?, _ => { write!(f, " Return {}", self.outputs[0])?; for output in self.outputs.iter().skip(1) { write!(f, ", {}", output)?; } writeln!(f, "\n}}")?; } } Ok(()) } } impl<'ctx> IrFunction<'ctx> { pub fn make>(ctx: &'ctx IrContext<'ctx>, f: F) -> Self { let old_operations = ctx.replace_operations(Vec::new()); let (v, inputs) = F::make_inputs(ctx); let outputs = f.call(ctx, v).outputs_to_vec(); let operations = ctx.replace_operations(old_operations); Self { inputs, operations, outputs, } } } pub trait IrFunctionMaker<'ctx>: Sized { type Inputs; type Outputs: IrFunctionMakerOutputs<'ctx>; fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs; fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>); } pub trait IrFunctionMakerOutputs<'ctx> { fn outputs_to_vec(self) -> Vec>; } impl<'ctx, T: IrValue<'ctx>> IrFunctionMakerOutputs<'ctx> for T { fn outputs_to_vec(self) -> Vec> { [self.value()].into() } } impl<'ctx> IrFunctionMakerOutputs<'ctx> for () { fn outputs_to_vec(self) -> Vec> { Vec::new() } } impl<'ctx, R: IrFunctionMakerOutputs<'ctx>> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>) -> R { type Inputs = (); type Outputs = R; fn call(self, ctx: &'ctx IrContext<'ctx>, _inputs: Self::Inputs) -> Self::Outputs { self(ctx) } fn make_inputs(_ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) { ((), Vec::new()) } } macro_rules! impl_ir_function_maker_io { () => {}; ($first_arg:ident: $first_arg_ty:ident, $($arg:ident: $arg_ty:ident,)*) => { impl<'ctx, $first_arg_ty, $($arg_ty,)* R> IrFunctionMaker<'ctx> for fn(&'ctx IrContext<'ctx>, $first_arg_ty $(, $arg_ty)*) -> R where $first_arg_ty: IrValue<'ctx>, $($arg_ty: IrValue<'ctx>,)* R: IrFunctionMakerOutputs<'ctx>, { type Inputs = ($first_arg_ty, $($arg_ty,)*); type Outputs = R; fn call(self, ctx: &'ctx IrContext<'ctx>, inputs: Self::Inputs) -> Self::Outputs { let ($first_arg, $($arg,)*) = inputs; self(ctx, $first_arg$(, $arg)*) } fn make_inputs(ctx: &'ctx IrContext<'ctx>) -> (Self::Inputs, Vec<&'ctx Input<'ctx>>) { let mut $first_arg = String::new(); $(let mut $arg = String::new();)* for (index, arg) in [&mut $first_arg $(, &mut $arg)*].iter_mut().enumerate() { **arg = format!("arg_{}", index); } let $first_arg = $first_arg_ty::make_input(ctx, $first_arg); $(let $arg = $arg_ty::make_input(ctx, $arg);)* (($first_arg.0, $($arg.0,)*), [$first_arg.1 $(, $arg.1)*].into()) } } impl<'ctx, $first_arg_ty, $($arg_ty),*> IrFunctionMakerOutputs<'ctx> for ($first_arg_ty, $($arg_ty,)*) where $first_arg_ty: IrValue<'ctx>, $($arg_ty: IrValue<'ctx>,)* { fn outputs_to_vec(self) -> Vec> { let ($first_arg, $($arg,)*) = self; [$first_arg.value() $(, $arg.value())*].into() } } impl_ir_function_maker_io!($($arg: $arg_ty,)*); }; } impl_ir_function_maker_io!( in0: In0, in1: In1, in2: In2, in3: In3, in4: In4, in5: In5, in6: In6, in7: In7, in8: In8, in9: In9, in10: In10, in11: In11, ); pub trait IrValue<'ctx>: Copy + Make> { const TYPE: Type; fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self; fn make_input + Into>( ctx: &'ctx IrContext<'ctx>, name: N, ) -> (Self, &'ctx Input<'ctx>) { let input = ctx.make_input(name, Self::TYPE); (Self::new(ctx, input.into()), input) } fn value(self) -> Value<'ctx>; } macro_rules! ir_value { ($name:ident, $vec_name:ident, TYPE = $scalar_type:ident, fn make($make_var:ident: $prim:ident) {$make:expr}) => { #[derive(Clone, Copy, Debug)] pub struct $name<'ctx> { pub value: Value<'ctx>, pub ctx: &'ctx IrContext<'ctx>, } impl<'ctx> IrValue<'ctx> for $name<'ctx> { const TYPE: Type = Type::Scalar(Self::SCALAR_TYPE); fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self { assert_eq!(value.ty(), Self::TYPE); Self { ctx, value } } fn value(self) -> Value<'ctx> { self.value } } impl<'ctx> $name<'ctx> { pub const SCALAR_TYPE: ScalarType = ScalarType::$scalar_type; } impl<'ctx> Make for $name<'ctx> { type Prim = $prim; type Context = &'ctx IrContext<'ctx>; fn ctx(self) -> Self::Context { self.ctx } fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self { let value: ScalarConstant = $make; let value = value.into(); Self { value, ctx } } } #[derive(Clone, Copy, Debug)] pub struct $vec_name<'ctx> { pub value: Value<'ctx>, pub ctx: &'ctx IrContext<'ctx>, } impl<'ctx> IrValue<'ctx> for $vec_name<'ctx> { const TYPE: Type = Type::Vector(Self::VECTOR_TYPE); fn new(ctx: &'ctx IrContext<'ctx>, value: Value<'ctx>) -> Self { assert_eq!(value.ty(), Self::TYPE); Self { ctx, value } } fn value(self) -> Value<'ctx> { self.value } } impl<'ctx> $vec_name<'ctx> { pub const VECTOR_TYPE: VectorType = VectorType { element: ScalarType::$scalar_type, }; } impl<'ctx> Make for $vec_name<'ctx> { type Prim = $prim; type Context = &'ctx IrContext<'ctx>; fn ctx(self) -> Self::Context { self.ctx } fn make(ctx: Self::Context, $make_var: Self::Prim) -> Self { let element = $make; Self { value: VectorSplatConstant { element }.into(), ctx, } } } impl<'ctx> Select<$name<'ctx>> for IrBool<'ctx> { fn select(self, true_v: $name<'ctx>, false_v: $name<'ctx>) -> $name<'ctx> { let value = self .ctx .make_operation( Opcode::Select, [self.value, true_v.value, false_v.value], $name::TYPE, ) .into(); $name { value, ctx: self.ctx, } } } impl<'ctx> Select<$vec_name<'ctx>> for IrVecBool<'ctx> { fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> { let value = self .ctx .make_operation( Opcode::Select, [self.value, true_v.value, false_v.value], $vec_name::TYPE, ) .into(); $vec_name { value, ctx: self.ctx, } } } impl<'ctx> Select<$vec_name<'ctx>> for IrBool<'ctx> { fn select(self, true_v: $vec_name<'ctx>, false_v: $vec_name<'ctx>) -> $vec_name<'ctx> { let value = self .ctx .make_operation( Opcode::Select, [self.value, true_v.value, false_v.value], $vec_name::TYPE, ) .into(); $vec_name { value, ctx: self.ctx, } } } impl<'ctx> From<$name<'ctx>> for $vec_name<'ctx> { fn from(v: $name<'ctx>) -> Self { let value = v .ctx .make_operation(Opcode::Splat, [v.value], $vec_name::TYPE) .into(); Self { value, ctx: v.ctx } } } }; } macro_rules! impl_bit_ops { ($ty:ident) => { impl<'ctx> BitAnd for $ty<'ctx> { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::And, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> BitOr for $ty<'ctx> { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Or, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> BitXor for $ty<'ctx> { type Output = Self; fn bitxor(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Xor, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> Not for $ty<'ctx> { type Output = Self; fn not(self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Not, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> BitAndAssign for $ty<'ctx> { fn bitand_assign(&mut self, rhs: Self) { *self = *self & rhs; } } impl<'ctx> BitOrAssign for $ty<'ctx> { fn bitor_assign(&mut self, rhs: Self) { *self = *self | rhs; } } impl<'ctx> BitXorAssign for $ty<'ctx> { fn bitxor_assign(&mut self, rhs: Self) { *self = *self ^ rhs; } } }; } macro_rules! impl_number_ops { ($ty:ident, $bool:ident) => { impl<'ctx> Add for $ty<'ctx> { type Output = Self; fn add(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Add, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> Sub for $ty<'ctx> { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Sub, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> Mul for $ty<'ctx> { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Mul, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> Div for $ty<'ctx> { type Output = Self; fn div(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Div, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> Rem for $ty<'ctx> { type Output = Self; fn rem(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Rem, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> AddAssign for $ty<'ctx> { fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } impl<'ctx> SubAssign for $ty<'ctx> { fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } impl<'ctx> MulAssign for $ty<'ctx> { fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } impl<'ctx> DivAssign for $ty<'ctx> { fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } impl<'ctx> RemAssign for $ty<'ctx> { fn rem_assign(&mut self, rhs: Self) { *self = *self % rhs; } } impl<'ctx> Compare for $ty<'ctx> { type Bool = $bool<'ctx>; fn eq(self, rhs: Self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::CompareEq, [self.value, rhs.value], $bool::TYPE) .into(); $bool { value, ctx: self.ctx, } } fn ne(self, rhs: Self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::CompareNe, [self.value, rhs.value], $bool::TYPE) .into(); $bool { value, ctx: self.ctx, } } fn lt(self, rhs: Self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::CompareLt, [self.value, rhs.value], $bool::TYPE) .into(); $bool { value, ctx: self.ctx, } } fn gt(self, rhs: Self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::CompareGt, [self.value, rhs.value], $bool::TYPE) .into(); $bool { value, ctx: self.ctx, } } fn le(self, rhs: Self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::CompareLe, [self.value, rhs.value], $bool::TYPE) .into(); $bool { value, ctx: self.ctx, } } fn ge(self, rhs: Self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::CompareGe, [self.value, rhs.value], $bool::TYPE) .into(); $bool { value, ctx: self.ctx, } } } }; } macro_rules! impl_bool_compare { ($ty:ident) => { impl<'ctx> Compare for $ty<'ctx> { type Bool = Self; fn eq(self, rhs: Self) -> Self::Bool { !(self ^ rhs) } fn ne(self, rhs: Self) -> Self::Bool { self ^ rhs } fn lt(self, rhs: Self) -> Self::Bool { !self & rhs } fn gt(self, rhs: Self) -> Self::Bool { self & !rhs } fn le(self, rhs: Self) -> Self::Bool { !self | rhs } fn ge(self, rhs: Self) -> Self::Bool { self | !rhs } } }; } impl_bool_compare!(IrBool); impl_bool_compare!(IrVecBool); macro_rules! impl_shift_ops { ($ty:ident) => { impl<'ctx> Shl for $ty<'ctx> { type Output = Self; fn shl(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Shl, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> Shr for $ty<'ctx> { type Output = Self; fn shr(self, rhs: Self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Shr, [self.value, rhs.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } impl<'ctx> ShlAssign for $ty<'ctx> { fn shl_assign(&mut self, rhs: Self) { *self = *self << rhs; } } impl<'ctx> ShrAssign for $ty<'ctx> { fn shr_assign(&mut self, rhs: Self) { *self = *self >> rhs; } } }; } macro_rules! impl_neg { ($ty:ident) => { impl<'ctx> Neg for $ty<'ctx> { type Output = Self; fn neg(self) -> Self::Output { let value = self .ctx .make_operation(Opcode::Neg, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } }; } macro_rules! impl_int_trait { ($ty:ident) => { impl<'ctx> Int for $ty<'ctx> { fn leading_zeros(self) -> Self { let value = self .ctx .make_operation(Opcode::CountLeadingZeros, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } fn trailing_zeros(self) -> Self { let value = self .ctx .make_operation(Opcode::CountTrailingZeros, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } fn count_ones(self) -> Self { let value = self .ctx .make_operation(Opcode::CountSetBits, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } } }; } macro_rules! impl_integer_ops { ($scalar:ident, $vec:ident) => { impl_bit_ops!($scalar); impl_number_ops!($scalar, IrBool); impl_shift_ops!($scalar); impl_bit_ops!($vec); impl_number_ops!($vec, IrVecBool); impl_shift_ops!($vec); impl_int_trait!($scalar); impl_int_trait!($vec); }; } macro_rules! impl_uint_sint_ops { ($uint_scalar:ident, $uint_vec:ident, $sint_scalar:ident, $sint_vec:ident) => { impl_integer_ops!($uint_scalar, $uint_vec); impl_integer_ops!($sint_scalar, $sint_vec); impl_neg!($sint_scalar); impl_neg!($sint_vec); impl<'ctx> UInt for $uint_scalar<'ctx> { type PrimUInt = Self::Prim; type SignedType = $sint_scalar<'ctx>; } impl<'ctx> UInt for $uint_vec<'ctx> { type PrimUInt = Self::Prim; type SignedType = $sint_vec<'ctx>; } impl<'ctx> SInt for $sint_scalar<'ctx> { type PrimSInt = Self::Prim; type UnsignedType = $uint_scalar<'ctx>; } impl<'ctx> SInt for $sint_vec<'ctx> { type PrimSInt = Self::Prim; type UnsignedType = $uint_vec<'ctx>; } }; } impl_uint_sint_ops!(IrU8, IrVecU8, IrI8, IrVecI8); impl_uint_sint_ops!(IrU16, IrVecU16, IrI16, IrVecI16); impl_uint_sint_ops!(IrU32, IrVecU32, IrI32, IrVecI32); impl_uint_sint_ops!(IrU64, IrVecU64, IrI64, IrVecI64); macro_rules! impl_float { ($float:ident, $bits:ident, $signed_bits:ident) => { impl<'ctx> Float for $float<'ctx> { type PrimFloat = <$float<'ctx> as Make>::Prim; type BitsType = $bits<'ctx>; type SignedBitsType = $signed_bits<'ctx>; fn abs(self) -> Self { let value = self .ctx .make_operation(Opcode::Abs, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } fn trunc(self) -> Self { let value = self .ctx .make_operation(Opcode::Trunc, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } fn ceil(self) -> Self { let value = self .ctx .make_operation(Opcode::Ceil, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } fn floor(self) -> Self { let value = self .ctx .make_operation(Opcode::Floor, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } fn round(self) -> Self { let value = self .ctx .make_operation(Opcode::Round, [self.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self { let value = self .ctx .make_operation(Opcode::Fma, [self.value, a.value, b.value], Self::TYPE) .into(); Self { value, ctx: self.ctx, } } fn is_nan(self) -> Self::Bool { let value = self .ctx .make_operation( Opcode::CompareNe, [self.value, self.value], Self::Bool::TYPE, ) .into(); Self::Bool { value, ctx: self.ctx, } } fn is_infinite(self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::IsInfinite, [self.value], Self::Bool::TYPE) .into(); Self::Bool { value, ctx: self.ctx, } } fn is_finite(self) -> Self::Bool { let value = self .ctx .make_operation(Opcode::IsFinite, [self.value], Self::Bool::TYPE) .into(); Self::Bool { value, ctx: self.ctx, } } fn from_bits(v: Self::BitsType) -> Self { let value = v .ctx .make_operation(Opcode::FromBits, [v.value], Self::TYPE) .into(); Self { value, ctx: v.ctx } } fn to_bits(self) -> Self::BitsType { let value = self .ctx .make_operation(Opcode::ToBits, [self.value], Self::BitsType::TYPE) .into(); Self::BitsType { value, ctx: self.ctx, } } } }; } macro_rules! impl_float_ops { ($scalar:ident, $scalar_bits:ident, $scalar_signed_bits:ident, $vec:ident, $vec_bits:ident, $vec_signed_bits:ident) => { impl_number_ops!($scalar, IrBool); impl_number_ops!($vec, IrVecBool); impl_neg!($scalar); impl_neg!($vec); impl_float!($scalar, $scalar_bits, $scalar_signed_bits); impl_float!($vec, $vec_bits, $vec_signed_bits); }; } impl_float_ops!(IrF16, IrU16, IrI16, IrVecF16, IrVecU16, IrVecI16); impl_float_ops!(IrF32, IrU32, IrI32, IrVecF32, IrVecU32, IrVecI32); impl_float_ops!(IrF64, IrU64, IrI64, IrVecF64, IrVecU64, IrVecI64); ir_value!( IrBool, IrVecBool, TYPE = Bool, fn make(v: bool) { v.into() } ); impl<'ctx> Bool for IrBool<'ctx> {} impl<'ctx> Bool for IrVecBool<'ctx> {} impl_bit_ops!(IrBool); impl_bit_ops!(IrVecBool); ir_value!( IrU8, IrVecU8, TYPE = U8, fn make(v: u8) { v.into() } ); ir_value!( IrU16, IrVecU16, TYPE = U16, fn make(v: u16) { v.into() } ); ir_value!( IrU32, IrVecU32, TYPE = U32, fn make(v: u32) { v.into() } ); ir_value!( IrU64, IrVecU64, TYPE = U64, fn make(v: u64) { v.into() } ); ir_value!( IrI8, IrVecI8, TYPE = I8, fn make(v: i8) { v.into() } ); ir_value!( IrI16, IrVecI16, TYPE = I16, fn make(v: i16) { v.into() } ); ir_value!( IrI32, IrVecI32, TYPE = I32, fn make(v: i32) { v.into() } ); ir_value!( IrI64, IrVecI64, TYPE = I64, fn make(v: i64) { v.into() } ); ir_value!( IrF16, IrVecF16, TYPE = F16, fn make(v: F16) { ScalarConstant::from_f16_bits(v.to_bits()) } ); ir_value!( IrF32, IrVecF32, TYPE = F32, fn make(v: f32) { ScalarConstant::from_f32_bits(v.to_bits()) } ); ir_value!( IrF64, IrVecF64, TYPE = F64, fn make(v: f64) { ScalarConstant::from_f64_bits(v.to_bits()) } ); macro_rules! impl_convert_from { ($src:ident -> $dest:ident) => { impl<'ctx> ConvertFrom<$src<'ctx>> for $dest<'ctx> { fn cvt_from(v: $src<'ctx>) -> Self { let value = if $src::TYPE == $dest::TYPE { v.value } else { v .ctx .make_operation(Opcode::Cast, [v.value], $dest::TYPE) .into() }; $dest { value, ctx: v.ctx, } } } }; ($first:ident $(, $ty:ident)*) => { $( impl_convert_from!($first -> $ty); impl_convert_from!($ty -> $first); )* impl_convert_from![$($ty),*]; }; () => { }; } impl_convert_from![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64]; impl_convert_from![ IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64 ]; macro_rules! impl_from { ($src:ident => [$($dest:ident),*]) => { $( impl<'ctx> From<$src<'ctx>> for $dest<'ctx> { fn from(v: $src<'ctx>) -> Self { v.to() } } )* }; } macro_rules! impl_froms { ( #[u8] $u8:ident; #[i8] $i8:ident; #[u16] $u16:ident; #[i16] $i16:ident; #[f16] $f16:ident; #[u32] $u32:ident; #[i32] $i32:ident; #[f32] $f32:ident; #[u64] $u64:ident; #[i64] $i64:ident; #[f64] $f64:ident; ) => { impl_from!($u8 => [$u16, $i16, $f16, $u32, $i32, $f32, $u64, $i64, $f64]); impl_from!($u16 => [$u32, $i32, $f32, $u64, $i64, $f64]); impl_from!($u32 => [$u64, $i64, $f64]); impl_from!($i8 => [$i16, $f16, $i32, $f32, $i64, $f64]); impl_from!($i16 => [$i32, $f32, $i64, $f64]); impl_from!($i32 => [$i64, $f64]); impl_from!($f16 => [$f32, $f64]); impl_from!($f32 => [$f64]); }; } impl_froms! { #[u8] IrU8; #[i8] IrI8; #[u16] IrU16; #[i16] IrI16; #[f16] IrF16; #[u32] IrU32; #[i32] IrI32; #[f32] IrF32; #[u64] IrU64; #[i64] IrI64; #[f64] IrF64; } impl_froms! { #[u8] IrVecU8; #[i8] IrVecI8; #[u16] IrVecU16; #[i16] IrVecI16; #[f16] IrVecF16; #[u32] IrVecU32; #[i32] IrVecI32; #[f32] IrVecF32; #[u64] IrVecU64; #[i64] IrVecI64; #[f64] IrVecF64; } impl<'ctx> Context for &'ctx IrContext<'ctx> { type Bool = IrBool<'ctx>; type U8 = IrU8<'ctx>; type I8 = IrI8<'ctx>; type U16 = IrU16<'ctx>; type I16 = IrI16<'ctx>; type F16 = IrF16<'ctx>; type U32 = IrU32<'ctx>; type I32 = IrI32<'ctx>; type F32 = IrF32<'ctx>; type U64 = IrU64<'ctx>; type I64 = IrI64<'ctx>; type F64 = IrF64<'ctx>; type VecBool8 = IrVecBool<'ctx>; type VecU8 = IrVecU8<'ctx>; type VecI8 = IrVecI8<'ctx>; type VecBool16 = IrVecBool<'ctx>; type VecU16 = IrVecU16<'ctx>; type VecI16 = IrVecI16<'ctx>; type VecF16 = IrVecF16<'ctx>; type VecBool32 = IrVecBool<'ctx>; type VecU32 = IrVecU32<'ctx>; type VecI32 = IrVecI32<'ctx>; type VecF32 = IrVecF32<'ctx>; type VecBool64 = IrVecBool<'ctx>; type VecU64 = IrVecU64<'ctx>; type VecI64 = IrVecI64<'ctx>; type VecF64 = IrVecF64<'ctx>; } #[cfg(test)] mod tests { use crate::algorithms; use super::*; use std::println; #[test] fn test_display() { fn f(ctx: Ctx, a: Ctx::VecU8, b: Ctx::VecF32) -> Ctx::VecF64 { let a: Ctx::VecF32 = a.into(); (a - (a + b - ctx.make(5f32)).floor()).to() } let ctx = IrContext::new(); fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { let f: fn(&'ctx IrContext<'ctx>, IrVecU8<'ctx>, IrVecF32<'ctx>) -> IrVecF64<'ctx> = f; IrFunction::make(ctx, f) } let text = format!("\n{}", make_it(&ctx)); println!("{}", text); assert_eq!( text, r" function(in: vec, in: vec) -> vec { op_0: vec = Cast in op_1: vec = Add op_0, in op_2: vec = Sub op_1, splat(0x40A00000_f32) op_3: vec = Floor op_2 op_4: vec = Sub op_0, op_3 op_5: vec = Cast op_4 Return op_5 } " ); } #[test] fn test_display_ilogb_f32() { let ctx = IrContext::new(); fn make_it<'ctx>(ctx: &'ctx IrContext<'ctx>) -> IrFunction<'ctx> { let f: fn(&'ctx IrContext<'ctx>, IrVecF32<'ctx>) -> IrVecI32<'ctx> = algorithms::ilogb::ilogb_f32; IrFunction::make(ctx, f) } let text = format!("\n{}", make_it(&ctx)); println!("{}", text); assert_eq!( text, r" function(in: vec) -> vec { op_0: vec = IsFinite in op_1: vec = ToBits in op_2: vec = And op_1, splat(0x7F800000_u32) op_3: vec = Shr op_2, splat(0x17_u32) op_4: vec = CompareEq op_3, splat(0x0_u32) op_5: vec = CompareNe in, in op_6: vec = Splat 0x80000001_i32 op_7: vec = Splat 0x7FFFFFFF_i32 op_8: vec = Select op_5, op_6, op_7 op_9: vec = Mul in, splat(0x4B000000_f32) op_10: vec = ToBits op_9 op_11: vec = And op_10, splat(0x7F800000_u32) op_12: vec = Shr op_11, splat(0x17_u32) op_13: vec = Cast op_12 op_14: vec = Sub op_13, splat(0x7F_i32) op_15: vec = ToBits in op_16: vec = And op_15, splat(0x7F800000_u32) op_17: vec = Shr op_16, splat(0x17_u32) op_18: vec = Cast op_17 op_19: vec = Sub op_18, splat(0x7F_i32) op_20: vec = Select op_0, op_19, op_8 op_21: vec = CompareEq in, splat(0x0_f32) op_22: vec = Splat 0x80000000_i32 op_23: vec = Sub op_14, splat(0x17_i32) op_24: vec = Select op_21, op_22, op_23 op_25: vec = Select op_4, op_24, op_20 Return op_25 } " ); } }