From ba67f8a36240e8ee2e95bb79aa7a58ae77577197 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 7 May 2021 20:14:19 -0700 Subject: [PATCH] add sin_pi_f16, cos_pi_f16, and sin_cos_pi_f16 --- maxima/sin_cos_pi.mac | 45 +++ src/algorithms/ilogb.rs | 46 +-- src/algorithms/trig_pi.rs | 127 +++++++- src/f16.rs | 125 ++++--- src/ieee754.rs | 22 +- src/ir.rs | 28 +- src/scalar.rs | 525 ++++++++++++++++++++++++++++-- src/stdsimd.rs | 74 +++-- src/traits.rs | 190 ++--------- vector-math-proc-macro/src/lib.rs | 6 +- 10 files changed, 860 insertions(+), 328 deletions(-) create mode 100644 maxima/sin_cos_pi.mac diff --git a/maxima/sin_cos_pi.mac b/maxima/sin_cos_pi.mac new file mode 100644 index 0000000..754e516 --- /dev/null +++ b/maxima/sin_cos_pi.mac @@ -0,0 +1,45 @@ +load("bitwise")$ + +sin_kernel(x) := if abs(x) <= 1 / 4 then sin(%pi * x) else 2$ +cos_kernel(x) := if abs(x) <= 1 / 4 then cos(%pi * x) else 2$ + +sc(arg) := block( + [x:arg, xi, xk, sk, ck, st, ct, s, c], + xi:round(x*2), + xk:x - xi / 2, + sk:sin_kernel(xk), + ck:cos_kernel(xk), + st:if bit_and(xi, 1) < 1 then sk else ck, + ct:if bit_and(xi, 1) < 1 then ck else sk, + s:if bit_and(xi, 2) < 1 then st else -st, + c:if bit_and(xi + 1, 2) < 1 then ct else -ct, + [ + xk, + s, + c, + sin(%pi * x) + 1/64, + cos(%pi * x) + 1/64 + ])$ + +sc(0); +sc(1/4); +sc(1/2); +sc(3/4); +sc(1); +sc(-1/4); +sc(-1/2); +sc(-3/4); +sc(-1); + +plot2d( + sc(x), + [x, -1.5, 1.5], + [ + legend, + "xk", + "s", + "c", + "sin(%pi * x) + 1/64", + "cos(%pi * x) + 1/64" + ], + [png_file, "./sin_cos_pi.png"]); diff --git a/src/algorithms/ilogb.rs b/src/algorithms/ilogb.rs index 6196b2d..7da2733 100644 --- a/src/algorithms/ilogb.rs +++ b/src/algorithms/ilogb.rs @@ -88,7 +88,7 @@ impl_ilogb! { #[cfg(test)] mod tests { use super::*; - use crate::scalar::Scalar; + use crate::scalar::{Scalar, Value}; #[test] #[cfg_attr( @@ -98,7 +98,7 @@ mod tests { fn test_ilogb_f16() { fn ilogb(arg: f32) -> i16 { let arg: F16 = arg.to(); - ilogb_f16(Scalar, arg) + ilogb_f16(Scalar, Value(arg)).0 } assert_eq!(ilogb(0.), ILOGB_UNDERFLOW_RESULT_F16); assert_eq!(ilogb(1.), 0); @@ -113,27 +113,33 @@ mod tests { #[test] fn test_ilogb_f32() { - assert_eq!(ilogb_f32(Scalar, 0f32), ILOGB_UNDERFLOW_RESULT_F32); - assert_eq!(ilogb_f32(Scalar, 1f32), 0); - assert_eq!(ilogb_f32(Scalar, 2f32), 1); - assert_eq!(ilogb_f32(Scalar, 3f32), 1); - assert_eq!(ilogb_f32(Scalar, 3.99999f32), 1); - assert_eq!(ilogb_f32(Scalar, 0.5f32), -1); - assert_eq!(ilogb_f32(Scalar, 0.5f32.powi(130)), -130); - assert_eq!(ilogb_f32(Scalar, f32::INFINITY), ILOGB_OVERFLOW_RESULT_F32); - assert_eq!(ilogb_f32(Scalar, f32::NAN), ILOGB_NAN_RESULT_F32); + fn ilogb(arg: f32) -> i32 { + ilogb_f32(Scalar, Value(arg)).0 + } + assert_eq!(ilogb(0f32), ILOGB_UNDERFLOW_RESULT_F32); + assert_eq!(ilogb(1f32), 0); + assert_eq!(ilogb(2f32), 1); + assert_eq!(ilogb(3f32), 1); + assert_eq!(ilogb(3.99999f32), 1); + assert_eq!(ilogb(0.5f32), -1); + assert_eq!(ilogb(0.5f32.powi(130)), -130); + assert_eq!(ilogb(f32::INFINITY), ILOGB_OVERFLOW_RESULT_F32); + assert_eq!(ilogb(f32::NAN), ILOGB_NAN_RESULT_F32); } #[test] fn test_ilogb_f64() { - assert_eq!(ilogb_f64(Scalar, 0f64), ILOGB_UNDERFLOW_RESULT_F64); - assert_eq!(ilogb_f64(Scalar, 1f64), 0); - assert_eq!(ilogb_f64(Scalar, 2f64), 1); - assert_eq!(ilogb_f64(Scalar, 3f64), 1); - assert_eq!(ilogb_f64(Scalar, 3.99999f64), 1); - assert_eq!(ilogb_f64(Scalar, 0.5f64), -1); - assert_eq!(ilogb_f64(Scalar, 0.5f64.powi(1030)), -1030); - assert_eq!(ilogb_f64(Scalar, f64::INFINITY), ILOGB_OVERFLOW_RESULT_F64); - assert_eq!(ilogb_f64(Scalar, f64::NAN), ILOGB_NAN_RESULT_F64); + fn ilogb(arg: f64) -> i64 { + ilogb_f64(Scalar, Value(arg)).0 + } + assert_eq!(ilogb(0f64), ILOGB_UNDERFLOW_RESULT_F64); + assert_eq!(ilogb(1f64), 0); + assert_eq!(ilogb(2f64), 1); + assert_eq!(ilogb(3f64), 1); + assert_eq!(ilogb(3.99999f64), 1); + assert_eq!(ilogb(0.5f64), -1); + assert_eq!(ilogb(0.5f64.powi(1030)), -1030); + assert_eq!(ilogb(f64::INFINITY), ILOGB_OVERFLOW_RESULT_F64); + assert_eq!(ilogb(f64::NAN), ILOGB_NAN_RESULT_F64); } } diff --git a/src/algorithms/trig_pi.rs b/src/algorithms/trig_pi.rs index 28e67ef..38104a6 100644 --- a/src/algorithms/trig_pi.rs +++ b/src/algorithms/trig_pi.rs @@ -1,4 +1,8 @@ -use crate::traits::{Context, ConvertTo, Float}; +use crate::{ + f16::F16, + ieee754::FloatEncoding, + traits::{Compare, Context, ConvertFrom, ConvertTo, Float, Select}, +}; mod consts { #![allow(clippy::excessive_precision)] @@ -84,18 +88,64 @@ pub fn cos_pi_kernel_f16(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 /// computes `(sin(pi * x), cos(pi * x))` /// not guaranteed to give correct sign for zero results /// has an error of up to 2ULP -pub fn sin_cos_pi_f16(_ctx: Ctx, _x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) { - todo!() +pub fn sin_cos_pi_f16(ctx: Ctx, x: Ctx::VecF16) -> (Ctx::VecF16, Ctx::VecF16) { + let two_f16: Ctx::VecF16 = ctx.make(2.0.to()); + let one_half: Ctx::VecF16 = ctx.make(0.5.to()); + let max_contiguous_integer: Ctx::VecF16 = + ctx.make((1u16 << (F16::MANTISSA_FIELD_WIDTH + 1)).to()); + // if `x` is finite and bigger than `max_contiguous_integer`, then x is an even integer + let in_range = x.abs().lt(max_contiguous_integer); // use `lt` so nans are counted as out-of-range + let is_finite = x.is_finite(); + let nan: Ctx::VecF16 = ctx.make(f32::NAN.to()); + let zero_f16: Ctx::VecF16 = ctx.make(0.to()); + let one_f16: Ctx::VecF16 = ctx.make(1.to()); + let zero_i16: Ctx::VecI16 = ctx.make(0.to()); + let one_i16: Ctx::VecI16 = ctx.make(1.to()); + let two_i16: Ctx::VecI16 = ctx.make(2.to()); + let out_of_range_sin = is_finite.select(zero_f16, nan); + let out_of_range_cos = is_finite.select(one_f16, nan); + let xi = (x * two_f16).round(); + let xk = x - xi * one_half; + let sk = sin_pi_kernel_f16(ctx, xk); + let ck = cos_pi_kernel_f16(ctx, xk); + let xi = Ctx::VecI16::cvt_from(xi); + let bit_0_clear = (xi & one_i16).eq(zero_i16); + let st = bit_0_clear.select(sk, ck); + let ct = bit_0_clear.select(ck, sk); + let s = (xi & two_i16).eq(zero_i16).select(st, -st); + let c = ((xi + one_i16) & two_i16).eq(zero_i16).select(ct, -ct); + ( + in_range.select(s, out_of_range_sin), + in_range.select(c, out_of_range_cos), + ) +} + +/// computes `sin(pi * x)` +/// not guaranteed to give correct sign for zero results +/// has an error of up to 2ULP +pub fn sin_pi_f16(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 { + sin_cos_pi_f16(ctx, x).0 +} + +/// computes `cos(pi * x)` +/// not guaranteed to give correct sign for zero results +/// has an error of up to 2ULP +pub fn cos_pi_f16(ctx: Ctx, x: Ctx::VecF16) -> Ctx::VecF16 { + sin_cos_pi_f16(ctx, x).1 } #[cfg(test)] mod tests { use super::*; - use crate::{f16::F16, scalar::Scalar}; + use crate::{ + f16::F16, + scalar::{Scalar, Value}, + }; use std::f64; struct CheckUlpCallbackArg { distance_in_ulp: I, + x: F, expected: F, result: F, } @@ -103,7 +153,7 @@ mod tests { #[track_caller] fn check_ulp_f16( x: F16, - is_ok: impl Fn(CheckUlpCallbackArg) -> bool, + is_ok: impl Fn(CheckUlpCallbackArg) -> bool, fn_f16: impl Fn(F16) -> F16, fn_f64: impl Fn(f64) -> f64, ) { @@ -114,25 +164,34 @@ mod tests { if result == expected { return; } - let distance_in_ulp = (expected.to_bits() as i16).wrapping_sub(result.to_bits() as i16); - if is_ok(CheckUlpCallbackArg { - distance_in_ulp, - expected, - result, - }) { + if result.is_nan() && expected.is_nan() { + return; + } + let distance_in_ulp = (expected.to_bits() as i32 - result.to_bits() as i32).unsigned_abs(); + if !result.is_nan() + && !expected.is_nan() + && is_ok(CheckUlpCallbackArg { + distance_in_ulp, + x, + expected, + result, + }) + { return; } panic!( "error is too big: \ x = {x:?} {x_bits:#X}, \ result = {result:?} {result_bits:#X}, \ - expected = {expected:?} {expected_bits:#X}", + expected = {expected:?} {expected_bits:#X}, \ + distance_in_ulp = {distance_in_ulp}", x = x, x_bits = x.to_bits(), result = result, result_bits = result.to_bits(), expected = expected, expected_bits = expected.to_bits(), + distance_in_ulp = distance_in_ulp, ); } @@ -146,7 +205,7 @@ mod tests { check_ulp_f16( x, |arg| arg.distance_in_ulp <= if arg.expected == 0.to() { 0 } else { 2 }, - |x| sin_pi_kernel_f16(Scalar, x), + |x| sin_pi_kernel_f16(Scalar, Value(x)).0, |x| (f64::consts::PI * x).sin(), ) }; @@ -167,7 +226,7 @@ mod tests { check_ulp_f16( x, |arg| arg.distance_in_ulp <= 2 && arg.result <= 1.to(), - |x| cos_pi_kernel_f16(Scalar, x), + |x| cos_pi_kernel_f16(Scalar, Value(x)).0, |x| (f64::consts::PI * x).cos(), ) }; @@ -177,4 +236,44 @@ mod tests { check(-F16::from_bits(bits)); } } + + fn sin_cos_pi_check_ulp_callback_f16(arg: CheckUlpCallbackArg) -> bool { + if f32::cvt_from(arg.x) % 0.5 == 0.0 { + arg.distance_in_ulp == 0 + } else { + arg.distance_in_ulp <= 2 && arg.result.abs() <= 1.to() + } + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_sin_pi_f16() { + for bits in 0..=u16::MAX { + check_ulp_f16( + F16::from_bits(bits), + sin_cos_pi_check_ulp_callback_f16, + |x| sin_pi_f16(Scalar, Value(x)).0, + |x| (f64::consts::PI * x).sin(), + ); + } + } + + #[test] + #[cfg_attr( + not(feature = "f16"), + should_panic(expected = "f16 feature is not enabled") + )] + fn test_cos_pi_f16() { + for bits in 0..=u16::MAX { + check_ulp_f16( + F16::from_bits(bits), + sin_cos_pi_check_ulp_callback_f16, + |x| cos_pi_f16(Scalar, Value(x)).0, + |x| (f64::consts::PI * x).cos(), + ); + } + } } diff --git a/src/f16.rs b/src/f16.rs index 5bc1119..e9541b4 100644 --- a/src/f16.rs +++ b/src/f16.rs @@ -2,7 +2,10 @@ use core::ops::{ Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, }; -use crate::traits::{ConvertTo, Float}; +use crate::{ + scalar::Value, + traits::{ConvertFrom, ConvertTo, Float}, +}; #[cfg(feature = "f16")] use half::f16 as F16Impl; @@ -64,9 +67,9 @@ macro_rules! impl_f16_from { } } - impl ConvertTo for $ty { - fn to(self) -> F16 { - self.into() + impl ConvertFrom<$ty> for F16 { + fn cvt_from(v: $ty) -> F16 { + v.into() } } )* @@ -82,9 +85,9 @@ macro_rules! impl_from_f16 { } } - impl ConvertTo<$ty> for F16 { - fn to(self) -> $ty { - self.into() + impl ConvertFrom for $ty { + fn cvt_from(v: F16) -> Self { + v.into() } } )* @@ -98,12 +101,12 @@ impl_from_f16![f32, f64,]; macro_rules! impl_int_to_f16 { ($($int:ident),*) => { $( - impl ConvertTo for $int { - fn to(self) -> F16 { + impl ConvertFrom<$int> for F16 { + fn cvt_from(v: $int) -> Self { // f32 has enough mantissa bits such that f16 overflows to // infinity before f32 stops being able to properly // represent integer values, making the below conversion correct. - (self as f32).to() + F16::cvt_from(v as f32) } } )* @@ -113,9 +116,9 @@ macro_rules! impl_int_to_f16 { macro_rules! impl_f16_to_int { ($($int:ident),*) => { $( - impl ConvertTo<$int> for F16 { - fn to(self) -> $int { - f32::from(self) as $int + impl ConvertFrom for $int { + fn cvt_from(v: F16) -> Self { + f32::from(v) as $int } } )* @@ -125,15 +128,15 @@ macro_rules! impl_f16_to_int { impl_int_to_f16![i16, u16, i32, u32, i64, u64, i128, u128]; impl_f16_to_int![i8, u8, i16, u16, i32, u32, i64, u64, i128, u128]; -impl ConvertTo for f32 { - fn to(self) -> F16 { - f16_impl!(F16(F16Impl::from_f32(self)), []) +impl ConvertFrom for F16 { + fn cvt_from(v: f32) -> Self { + f16_impl!(F16(F16Impl::from_f32(v)), [v]) } } -impl ConvertTo for f64 { - fn to(self) -> F16 { - f16_impl!(F16(F16Impl::from_f64(self)), []) +impl ConvertFrom for F16 { + fn cvt_from(v: f64) -> Self { + f16_impl!(F16(F16Impl::from_f64(v)), [v]) } } @@ -173,60 +176,104 @@ impl_bin_op_using_f32! { Rem, rem, RemAssign, rem_assign; } -impl Float for F16 { +impl F16 { + pub fn from_bits(v: u16) -> Self { + #[cfg(feature = "f16")] + return F16(F16Impl::from_bits(v)); + #[cfg(not(feature = "f16"))] + return F16(v); + } + pub fn to_bits(self) -> u16 { + #[cfg(feature = "f16")] + return self.0.to_bits(); + #[cfg(not(feature = "f16"))] + return self.0; + } + pub fn abs(self) -> Self { + f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), []) + } + pub fn trunc(self) -> Self { + f32::from(self).trunc().to() + } + + pub fn ceil(self) -> Self { + f32::from(self).ceil().to() + } + + pub fn floor(self) -> Self { + f32::from(self).floor().to() + } + + pub fn round(self) -> Self { + f32::from(self).round().to() + } + + #[cfg(feature = "fma")] + pub fn fma(self, a: Self, b: Self) -> Self { + (f64::from(self) * f64::from(a) + f64::from(b)).to() + } + + pub fn is_nan(self) -> bool { + f16_impl!(self.0.is_nan(), []) + } + + pub fn is_infinite(self) -> bool { + f16_impl!(self.0.is_infinite(), []) + } + + pub fn is_finite(self) -> bool { + f16_impl!(self.0.is_finite(), []) + } +} + +impl Float for Value { type FloatEncoding = F16; - type BitsType = u16; - type SignedBitsType = i16; + type BitsType = Value; + type SignedBitsType = Value; fn abs(self) -> Self { - f16_impl!(Self::from_bits(self.to_bits() & 0x7FFF), []) + Value(self.0.abs()) } fn trunc(self) -> Self { - f32::from(self).trunc().to() + Value(self.0.trunc()) } fn ceil(self) -> Self { - f32::from(self).ceil().to() + Value(self.0.ceil()) } fn floor(self) -> Self { - f32::from(self).floor().to() + Value(self.0.floor()) } fn round(self) -> Self { - f32::from(self).round().to() + Value(self.0.round()) } #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self { - (f64::from(self) * f64::from(a) + f64::from(b)).to() + Value(self.0.fma(a.0, b.0)) } fn is_nan(self) -> Self::Bool { - f16_impl!(self.0.is_nan(), []) + Value(self.0.is_nan()) } fn is_infinite(self) -> Self::Bool { - f16_impl!(self.0.is_infinite(), []) + Value(self.0.is_infinite()) } fn is_finite(self) -> Self::Bool { - f16_impl!(self.0.is_finite(), []) + Value(self.0.is_finite()) } fn from_bits(v: Self::BitsType) -> Self { - #[cfg(feature = "f16")] - return F16(F16Impl::from_bits(v)); - #[cfg(not(feature = "f16"))] - return F16(v); + Value(F16::from_bits(v.0)) } fn to_bits(self) -> Self::BitsType { - #[cfg(feature = "f16")] - return self.0.to_bits(); - #[cfg(not(feature = "f16"))] - return self.0; + Value(self.0.to_bits()) } } diff --git a/src/ieee754.rs b/src/ieee754.rs index 6f9fea7..3d70468 100644 --- a/src/ieee754.rs +++ b/src/ieee754.rs @@ -1,8 +1,4 @@ -use crate::{ - f16::F16, - scalar::Scalar, - traits::{Float, Make}, -}; +use crate::f16::F16; mod sealed { use crate::f16::F16; @@ -13,9 +9,9 @@ mod sealed { impl Sealed for f64 {} } -pub trait FloatEncoding: - sealed::Sealed + Copy + 'static + Send + Sync + Float + Make -{ +pub trait FloatEncoding: sealed::Sealed + Copy + 'static + Send + Sync { + type BitsType; + type SignedBitsType; const EXPONENT_BIAS_UNSIGNED: Self::BitsType; const EXPONENT_BIAS_SIGNED: Self::SignedBitsType; const SIGN_FIELD_WIDTH: Self::BitsType; @@ -37,11 +33,15 @@ pub trait FloatEncoding: macro_rules! impl_float_encoding { ( impl FloatEncoding for $float:ident { + type BitsType = $bits_type:ident; + type SignedBitsType = $signed_bits_type:ident; const EXPONENT_FIELD_WIDTH: u32 = $exponent_field_width:literal; const MANTISSA_FIELD_WIDTH: u32 = $mantissa_field_width:literal; } ) => { impl FloatEncoding for $float { + type BitsType = $bits_type; + type SignedBitsType = $signed_bits_type; const EXPONENT_BIAS_UNSIGNED: Self::BitsType = (1 << (Self::EXPONENT_FIELD_WIDTH - 1)) - 1; const EXPONENT_BIAS_SIGNED: Self::SignedBitsType = Self::EXPONENT_BIAS_UNSIGNED as _; @@ -69,6 +69,8 @@ macro_rules! impl_float_encoding { impl_float_encoding! { impl FloatEncoding for F16 { + type BitsType = u16; + type SignedBitsType = i16; const EXPONENT_FIELD_WIDTH: u32 = 5; const MANTISSA_FIELD_WIDTH: u32 = 10; } @@ -76,6 +78,8 @@ impl_float_encoding! { impl_float_encoding! { impl FloatEncoding for f32 { + type BitsType = u32; + type SignedBitsType = i32; const EXPONENT_FIELD_WIDTH: u32 = 8; const MANTISSA_FIELD_WIDTH: u32 = 23; } @@ -83,6 +87,8 @@ impl_float_encoding! { impl_float_encoding! { impl FloatEncoding for f64 { + type BitsType = u64; + type SignedBitsType = i64; const EXPONENT_FIELD_WIDTH: u32 = 11; const MANTISSA_FIELD_WIDTH: u32 = 52; } diff --git a/src/ir.rs b/src/ir.rs index a31bece..e2b4a0e 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1,6 +1,8 @@ use crate::{ f16::F16, - traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt}, + traits::{ + Bool, Compare, Context, ConvertFrom, ConvertTo, Float, Int, Make, SInt, Select, UInt, + }, }; use std::{ borrow::Borrow, @@ -1489,38 +1491,38 @@ ir_value!( } ); -macro_rules! impl_convert_to { +macro_rules! impl_convert_from { ($src:ident -> $dest:ident) => { - impl<'ctx> ConvertTo<$dest<'ctx>> for $src<'ctx> { - fn to(self) -> $dest<'ctx> { + impl<'ctx> ConvertFrom<$src<'ctx>> for $dest<'ctx> { + fn cvt_from(v: $src<'ctx>) -> Self { let value = if $src::TYPE == $dest::TYPE { - self.value + v.value } else { - self + v .ctx - .make_operation(Opcode::Cast, [self.value], $dest::TYPE) + .make_operation(Opcode::Cast, [v.value], $dest::TYPE) .into() }; $dest { value, - ctx: self.ctx, + ctx: v.ctx, } } } }; ($first:ident $(, $ty:ident)*) => { $( - impl_convert_to!($first -> $ty); - impl_convert_to!($ty -> $first); + impl_convert_from!($first -> $ty); + impl_convert_from!($ty -> $first); )* - impl_convert_to![$($ty),*]; + impl_convert_from![$($ty),*]; }; () => { }; } -impl_convert_to![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64]; +impl_convert_from![IrU8, IrI8, IrU16, IrI16, IrF16, IrU32, IrI32, IrU64, IrI64, IrF32, IrF64]; -impl_convert_to![ +impl_convert_from![ IrVecU8, IrVecI8, IrVecU16, IrVecI16, IrVecF16, IrVecU32, IrVecI32, IrVecU64, IrVecI64, IrVecF32, IrVecF64 ]; diff --git a/src/scalar.rs b/src/scalar.rs index fb83af6..4eb5b98 100644 --- a/src/scalar.rs +++ b/src/scalar.rs @@ -1,30 +1,487 @@ -use crate::traits::{Context, Make}; +use crate::{ + f16::F16, + traits::{Bool, Compare, Context, ConvertFrom, Float, Int, Make, SInt, Select, UInt}, +}; +use core::ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, + Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, +}; #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)] pub struct Scalar; +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Default)] +#[repr(transparent)] +pub struct Value(pub T); + +macro_rules! impl_convert_from { + ($first:ident $(, $ty:ident)*) => { + $( + impl ConvertFrom> for Value<$ty> { + fn cvt_from(v: Value<$first>) -> Self { + Value(ConvertFrom::cvt_from(v.0)) + } + } + impl ConvertFrom> for Value<$first> { + fn cvt_from(v: Value<$ty>) -> Self { + Value(ConvertFrom::cvt_from(v.0)) + } + } + )* + impl_convert_from![$($ty),*]; + }; + () => { + }; +} + +impl_convert_from![u8, i8, u16, i16, F16, u32, i32, u64, i64, f32, f64]; + +macro_rules! impl_bit_ops { + ($ty:ident) => { + impl BitAnd for Value<$ty> { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + Value(self.0 & rhs.0) + } + } + + impl BitOr for Value<$ty> { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self { + Value(self.0 | rhs.0) + } + } + + impl BitXor for Value<$ty> { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self { + Value(self.0 ^ rhs.0) + } + } + + impl Not for Value<$ty> { + type Output = Self; + + fn not(self) -> Self { + Value(!self.0) + } + } + + impl BitAndAssign for Value<$ty> { + fn bitand_assign(&mut self, rhs: Self) { + self.0 &= rhs.0; + } + } + + impl BitOrAssign for Value<$ty> { + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } + } + + impl BitXorAssign for Value<$ty> { + fn bitxor_assign(&mut self, rhs: Self) { + self.0 ^= rhs.0; + } + } + }; +} + +macro_rules! impl_wrapping_int_ops { + ($ty:ident) => { + impl Add for Value<$ty> { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Value(self.0.wrapping_add(rhs.0)) + } + } + + impl Sub for Value<$ty> { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Value(self.0.wrapping_sub(rhs.0)) + } + } + + impl Mul for Value<$ty> { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + Value(self.0.wrapping_mul(rhs.0)) + } + } + + impl Div for Value<$ty> { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + Value(self.0.wrapping_div(rhs.0)) + } + } + + impl Rem for Value<$ty> { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + Value(self.0.wrapping_rem(rhs.0)) + } + } + + impl Shl for Value<$ty> { + type Output = Self; + + fn shl(self, rhs: Self) -> Self { + Value(self.0.wrapping_shl(rhs.0 as u32)) + } + } + + impl Shr for Value<$ty> { + type Output = Self; + + fn shr(self, rhs: Self) -> Self { + Value(self.0.wrapping_shr(rhs.0 as u32)) + } + } + + impl Neg for Value<$ty> { + type Output = Self; + + fn neg(self) -> Self { + Value(self.0.wrapping_neg()) + } + } + + impl AddAssign for Value<$ty> { + fn add_assign(&mut self, rhs: Self) { + *self = self.add(rhs); + } + } + + impl SubAssign for Value<$ty> { + fn sub_assign(&mut self, rhs: Self) { + *self = self.sub(rhs); + } + } + + impl MulAssign for Value<$ty> { + fn mul_assign(&mut self, rhs: Self) { + *self = self.mul(rhs); + } + } + + impl DivAssign for Value<$ty> { + fn div_assign(&mut self, rhs: Self) { + *self = self.div(rhs); + } + } + + impl RemAssign for Value<$ty> { + fn rem_assign(&mut self, rhs: Self) { + *self = self.rem(rhs); + } + } + + impl ShlAssign for Value<$ty> { + fn shl_assign(&mut self, rhs: Self) { + *self = self.shl(rhs); + } + } + + impl ShrAssign for Value<$ty> { + fn shr_assign(&mut self, rhs: Self) { + *self = self.shr(rhs); + } + } + }; +} +macro_rules! impl_int { + ($ty:ident) => { + impl_bit_ops!($ty); + impl_wrapping_int_ops!($ty); + impl Int for Value<$ty> { + fn leading_zeros(self) -> Self { + Value(self.0.leading_zeros() as $ty) + } + fn leading_ones(self) -> Self { + Value(self.0.leading_ones() as $ty) + } + fn trailing_zeros(self) -> Self { + Value(self.0.trailing_zeros() as $ty) + } + fn trailing_ones(self) -> Self { + Value(self.0.trailing_ones() as $ty) + } + fn count_zeros(self) -> Self { + Value(self.0.count_zeros() as $ty) + } + fn count_ones(self) -> Self { + Value(self.0.count_ones() as $ty) + } + } + }; +} + +macro_rules! impl_uint { + ($($ty:ident),*) => { + $( + impl_int!($ty); + impl UInt for Value<$ty> {} + )* + }; +} + +impl_uint![u8, u16, u32, u64]; + +macro_rules! impl_sint { + ($($ty:ident),*) => { + $( + impl_int!($ty); + impl SInt for Value<$ty> {} + )* + }; +} + +impl_sint![i8, i16, i32, i64]; + +macro_rules! impl_float_ops { + ($ty:ident) => { + impl Add for Value<$ty> { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Value(self.0.add(rhs.0)) + } + } + + impl Sub for Value<$ty> { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Value(self.0.sub(rhs.0)) + } + } + + impl Mul for Value<$ty> { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + Value(self.0.mul(rhs.0)) + } + } + + impl Div for Value<$ty> { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + Value(self.0.div(rhs.0)) + } + } + + impl Rem for Value<$ty> { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + Value(self.0.rem(rhs.0)) + } + } + + impl Neg for Value<$ty> { + type Output = Self; + + fn neg(self) -> Self { + Value(self.0.neg()) + } + } + + impl AddAssign for Value<$ty> { + fn add_assign(&mut self, rhs: Self) { + *self = self.add(rhs); + } + } + + impl SubAssign for Value<$ty> { + fn sub_assign(&mut self, rhs: Self) { + *self = self.sub(rhs); + } + } + + impl MulAssign for Value<$ty> { + fn mul_assign(&mut self, rhs: Self) { + *self = self.mul(rhs); + } + } + + impl DivAssign for Value<$ty> { + fn div_assign(&mut self, rhs: Self) { + *self = self.div(rhs); + } + } + + impl RemAssign for Value<$ty> { + fn rem_assign(&mut self, rhs: Self) { + *self = self.rem(rhs); + } + } + }; +} + +impl_float_ops!(F16); + +macro_rules! impl_float { + ($ty:ident, $bits:ty, $signed_bits:ty) => { + impl_float_ops!($ty); + impl Float for Value<$ty> { + type FloatEncoding = $ty; + type BitsType = Value<$bits>; + type SignedBitsType = Value<$signed_bits>; + fn abs(self) -> Self { + #[cfg(feature = "std")] + return Value(self.0.abs()); + #[cfg(not(feature = "std"))] + todo!(); + } + fn trunc(self) -> Self { + #[cfg(feature = "std")] + return Value(self.0.trunc()); + #[cfg(not(feature = "std"))] + todo!(); + } + fn ceil(self) -> Self { + #[cfg(feature = "std")] + return Value(self.0.ceil()); + #[cfg(not(feature = "std"))] + todo!(); + } + fn floor(self) -> Self { + #[cfg(feature = "std")] + return Value(self.0.floor()); + #[cfg(not(feature = "std"))] + todo!(); + } + fn round(self) -> Self { + #[cfg(feature = "std")] + return Value(self.0.round()); + #[cfg(not(feature = "std"))] + todo!(); + } + #[cfg(feature = "fma")] + fn fma(self, a: Self, b: Self) -> Self { + Value(self.0.mul_add(a.0, b.0)) + } + fn is_nan(self) -> Self::Bool { + Value(self.0.is_nan()) + } + fn is_infinite(self) -> Self::Bool { + Value(self.0.is_infinite()) + } + fn is_finite(self) -> Self::Bool { + Value(self.0.is_finite()) + } + fn from_bits(v: Self::BitsType) -> Self { + Value(<$ty>::from_bits(v.0)) + } + fn to_bits(self) -> Self::BitsType { + Value(self.0.to_bits()) + } + } + }; +} + +impl_float!(f32, u32, i32); +impl_float!(f64, u64, i64); + +macro_rules! impl_compare_using_partial_cmp { + ($($ty:ty),*) => { + $( + impl Compare for Value<$ty> { + type Bool = Value; + fn eq(self, rhs: Self) -> Self::Bool { + Value(self == rhs) + } + fn ne(self, rhs: Self) -> Self::Bool { + Value(self != rhs) + } + fn lt(self, rhs: Self) -> Self::Bool { + Value(self < rhs) + } + fn gt(self, rhs: Self) -> Self::Bool { + Value(self > rhs) + } + fn le(self, rhs: Self) -> Self::Bool { + Value(self <= rhs) + } + fn ge(self, rhs: Self) -> Self::Bool { + Value(self >= rhs) + } + } + )* + }; +} + +impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64]; + +impl Bool for Value {} + +impl_bit_ops!(bool); + +impl Select> for Value { + fn select(self, true_v: Value, false_v: Value) -> Value { + if self.0 { + true_v + } else { + false_v + } + } +} + +macro_rules! impl_from { + ($src:ident => [$($dest:ident),*]) => { + $( + impl From> for Value<$dest> { + fn from(v: Value<$src>) -> Self { + Value(v.0.into()) + } + } + )* + }; +} + +impl_from!(u8 => [u16, i16, F16, u32, i32, f32, u64, i64, f64]); +impl_from!(u16 => [u32, i32, f32, u64, i64, f64]); +impl_from!(u32 => [u64, i64, f64]); +impl_from!(i8 => [i16, F16, i32, f32, i64, f64]); +impl_from!(i16 => [i32, f32, i64, f64]); +impl_from!(i32 => [i64, f64]); +impl_from!(F16 => [f32, f64]); +impl_from!(f32 => [f64]); + macro_rules! impl_context { ( impl Context for Scalar { - $(type $name:ident = $ty:ty;)* + $(type $name:ident = Value<$ty:ident>;)* #[vec] - $(type $vec_name:ident = $vec_ty:ty;)* + $(type $vec_name:ident = Value<$vec_ty:ident>;)* } ) => { impl Context for Scalar { - $(type $name = $ty;)* - $(type $vec_name = $vec_ty;)* + $(type $name = Value<$ty>;)* + $(type $vec_name = Value<$vec_ty>;)* } $( - impl Make for $ty { + impl Make for Value<$ty> { type Prim = $ty; type Context = Scalar; fn ctx(self) -> Self::Context { Scalar } fn make(_ctx: Self::Context, v: Self::Prim) -> Self { - v + Value(v) } } )* @@ -33,33 +490,33 @@ macro_rules! impl_context { impl_context! { impl Context for Scalar { - type Bool = bool; - type U8 = u8; - type I8 = i8; - type U16 = u16; - type I16 = i16; - type F16 = crate::f16::F16; - type U32 = u32; - type I32 = i32; - type F32 = f32; - type U64 = u64; - type I64 = i64; - type F64 = f64; + type Bool = Value; + type U8 = Value; + type I8 = Value; + type U16 = Value; + type I16 = Value; + type F16 = Value; + type U32 = Value; + type I32 = Value; + type F32 = Value; + type U64 = Value; + type I64 = Value; + type F64 = Value; #[vec] - type VecBool8 = bool; - type VecU8 = u8; - type VecI8 = i8; - type VecBool16 = bool; - type VecU16 = u16; - type VecI16 = i16; - type VecF16 = crate::f16::F16; - type VecBool32 = bool; - type VecU32 = u32; - type VecI32 = i32; - type VecF32 = f32; - type VecBool64 = bool; - type VecU64 = u64; - type VecI64 = i64; - type VecF64 = f64; + type VecBool8 = Value; + type VecU8 = Value; + type VecI8 = Value; + type VecBool16 = Value; + type VecU16 = Value; + type VecI16 = Value; + type VecF16 = Value; + type VecBool32 = Value; + type VecU32 = Value; + type VecI32 = Value; + type VecF32 = Value; + type VecBool64 = Value; + type VecU64 = Value; + type VecI64 = Value; + type VecF64 = Value; } } diff --git a/src/stdsimd.rs b/src/stdsimd.rs index b691f7c..e1a389e 100644 --- a/src/stdsimd.rs +++ b/src/stdsimd.rs @@ -2,7 +2,11 @@ use crate::f16::panic_f16_feature_disabled; use crate::{ f16::F16, - traits::{Bool, Compare, Context, ConvertTo, Float, Int, Make, SInt, Select, UInt}, + ieee754::FloatEncoding, + scalar, + traits::{ + Bool, Compare, Context, ConvertFrom, ConvertTo, Float, Int, Make, SInt, Select, UInt, + }, }; use core::{ marker::PhantomData, @@ -293,7 +297,11 @@ where Mask64: Mask, { fn select(self, true_v: V, false_v: V) -> V { - self.0.select(true_v, false_v) + if self.0 { + true_v + } else { + false_v + } } } @@ -319,27 +327,27 @@ macro_rules! impl_scalar_compare { type Bool = Wrapper; fn eq(self, rhs: Self) -> Self::Bool { - self.0.eq(rhs.0).into() + self.0.eq(&rhs.0).into() } fn ne(self, rhs: Self) -> Self::Bool { - self.0.ne(rhs.0).into() + self.0.ne(&rhs.0).into() } fn lt(self, rhs: Self) -> Self::Bool { - self.0.lt(rhs.0).into() + self.0.lt(&rhs.0).into() } fn gt(self, rhs: Self) -> Self::Bool { - self.0.gt(rhs.0).into() + self.0.gt(&rhs.0).into() } fn le(self, rhs: Self) -> Self::Bool { - self.0.le(rhs.0).into() + self.0.le(&rhs.0).into() } fn ge(self, rhs: Self) -> Self::Bool { - self.0.ge(rhs.0).into() + self.0.ge(&rhs.0).into() } } }; @@ -695,9 +703,9 @@ macro_rules! impl_float { { type FloatEncoding = $prim; - type BitsType = Wrapper<<$prim as Float>::BitsType, LANES>; + type BitsType = Wrapper<<$prim as FloatEncoding>::BitsType, LANES>; - type SignedBitsType = Wrapper<<$prim as Float>::SignedBitsType, LANES>; + type SignedBitsType = Wrapper<<$prim as FloatEncoding>::SignedBitsType, LANES>; fn abs(self) -> Self { self.0.abs().into() @@ -721,7 +729,9 @@ macro_rules! impl_float { #[cfg(feature = "fma")] fn fma(self, a: Self, b: Self) -> Self { - self.0.fma(a.0, b.0).into() + let a = scalar::Value(a.0); + let b = scalar::Value(b.0); + scalar::Value(self.0).fma(a, b).0.into() } fn is_finite(self) -> Self::Bool { @@ -806,9 +816,9 @@ impl_float!(SimdF16, F16, SimdU16, SimdI16); impl_float!(SimdF32, f32, SimdU32, SimdI32); impl_float!(SimdF64, f64, SimdU64, SimdI64); -macro_rules! impl_scalar_convert_to_helper { +macro_rules! impl_vector_convert_from_helper { ($src:ty => $dest:ty) => { - impl ConvertTo> for Wrapper<$src, LANES> + impl ConvertFrom> for Wrapper<$dest, LANES> where SimdI8: LanesAtMost32, SimdU8: LanesAtMost32, @@ -825,31 +835,31 @@ macro_rules! impl_scalar_convert_to_helper { SimdF64: LanesAtMost32, Mask64: Mask, { - fn to(self) -> Wrapper<$dest, LANES> { - let v: $dest = self.0.to(); + fn cvt_from(v: Wrapper<$src, LANES>) -> Self { + let v: $dest = v.0.to(); v.into() } } }; } -macro_rules! impl_scalar_convert_to { +macro_rules! impl_vector_convert_from { ($first:ty $(, $ty:ty)*) => { $( - impl_scalar_convert_to_helper!($first => $ty); - impl_scalar_convert_to_helper!($ty => $first); + impl_vector_convert_from_helper!($first => $ty); + impl_vector_convert_from_helper!($ty => $first); )* - impl_scalar_convert_to![$($ty),*]; + impl_vector_convert_from![$($ty),*]; }; () => {}; } -impl_scalar_convert_to![u8, i8, u16, i16, F16, u32, i32, u64, i64, f32, f64]; +impl_vector_convert_from![u8, i8, u16, i16, F16, u32, i32, u64, i64, f32, f64]; -macro_rules! impl_vector_convert_to_helper { +macro_rules! impl_vector_convert_from_helper { (($(#[From = $From:ident])? $src:ident, $src_prim:ident) => ($(#[From = $From2:ident])? $dest:ident, $dest_prim:ident)) => { - impl ConvertTo, LANES>> - for Wrapper<$src, LANES> + impl ConvertFrom, LANES>> + for Wrapper<$dest, LANES> where SimdI8: LanesAtMost32, SimdU8: LanesAtMost32, @@ -866,9 +876,9 @@ macro_rules! impl_vector_convert_to_helper { SimdF64: LanesAtMost32, Mask64: Mask, { - fn to(self) -> Wrapper<$dest, LANES> { + fn cvt_from(v: Wrapper<$src, LANES>) -> Self { // FIXME(programmerjake): workaround https://github.com/rust-lang/stdsimd/issues/116 - let src: [$src_prim; LANES] = self.0.into(); + let src: [$src_prim; LANES] = v.0.into(); let mut dest: [$dest_prim; LANES] = [Default::default(); LANES]; for i in 0..LANES { dest[i] = src[i].to(); @@ -901,18 +911,18 @@ macro_rules! impl_vector_convert_to_helper { }; } -macro_rules! impl_vector_convert_to { +macro_rules! impl_vector_convert_from { ($first:tt $(, $ty:tt)*) => { $( - impl_vector_convert_to_helper!($first => $ty); - impl_vector_convert_to_helper!($ty => $first); + impl_vector_convert_from_helper!($first => $ty); + impl_vector_convert_from_helper!($ty => $first); )* - impl_vector_convert_to![$($ty),*]; + impl_vector_convert_from![$($ty),*]; }; () => {}; } -impl_vector_convert_to![ +impl_vector_convert_from![ (SimdU8, u8), (SimdI8, i8), (SimdU16, u16), @@ -926,7 +936,7 @@ impl_vector_convert_to![ (SimdF64, f64) ]; -impl_vector_convert_to![ +impl_vector_convert_from![ ( #[From = From] Mask8, @@ -969,7 +979,7 @@ macro_rules! impl_from_helper { Mask64: Mask, { fn from(v: $src) -> Self { - <$src as ConvertTo<$dest>>::to(v) + <$dest>::cvt_from(v) } } }; diff --git a/src/traits.rs b/src/traits.rs index 6555016..1877c21 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,4 +1,4 @@ -use crate::{f16::F16, ieee754::FloatEncoding, scalar::Scalar}; +use crate::{f16::F16, ieee754::FloatEncoding}; use core::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, @@ -19,37 +19,47 @@ pub trait Make: Copy { fn make(ctx: Self::Context, v: Self::Prim) -> Self; } +pub trait ConvertFrom: Sized { + fn cvt_from(v: T) -> Self; +} + +impl ConvertFrom for T { + fn cvt_from(v: T) -> Self { + v + } +} + pub trait ConvertTo { fn to(self) -> T; } -impl ConvertTo for T { +impl> ConvertTo for F { fn to(self) -> T { - self + T::cvt_from(self) } } -macro_rules! impl_convert_to_using_as { +macro_rules! impl_convert_from_using_as { ($first:ident $(, $ty:ident)*) => { $( - impl ConvertTo<$first> for $ty { - fn to(self) -> $first { - self as $first + impl ConvertFrom<$first> for $ty { + fn cvt_from(v: $first) -> Self { + v as _ } } - impl ConvertTo<$ty> for $first { - fn to(self) -> $ty { - self as $ty + impl ConvertFrom<$ty> for $first { + fn cvt_from(v: $ty) -> Self { + v as _ } } )* - impl_convert_to_using_as![$($ty),*]; + impl_convert_from_using_as![$($ty),*]; }; () => { }; } -impl_convert_to_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; +impl_convert_from_using_as![u8, i8, u16, i16, u32, i32, u64, i64, f32, f64]; pub trait Number: Compare @@ -126,61 +136,14 @@ pub trait UInt: Int {} pub trait SInt: Int + Neg {} -macro_rules! impl_int { - ($ty:ident) => { - impl Int for $ty { - fn leading_zeros(self) -> Self { - self.leading_zeros() as Self - } - fn leading_ones(self) -> Self { - self.leading_ones() as Self - } - fn trailing_zeros(self) -> Self { - self.trailing_zeros() as Self - } - fn trailing_ones(self) -> Self { - self.trailing_ones() as Self - } - fn count_zeros(self) -> Self { - self.count_zeros() as Self - } - fn count_ones(self) -> Self { - self.count_ones() as Self - } - } - }; -} - -macro_rules! impl_uint { - ($($ty:ident),*) => { - $( - impl_int!($ty); - impl UInt for $ty {} - )* - }; -} - -impl_uint![u8, u16, u32, u64]; - -macro_rules! impl_sint { - ($($ty:ident),*) => { - $( - impl_int!($ty); - impl SInt for $ty {} - )* - }; -} - -impl_sint![i8, i16, i32, i64]; - pub trait Float: Number + Neg { - type FloatEncoding: FloatEncoding + Make::Prim>; + type FloatEncoding: FloatEncoding + From<::Prim> + Into<::Prim>; type BitsType: UInt - + Make::BitsType> + + Make::BitsType> + ConvertTo + Compare; type SignedBitsType: SInt - + Make::SignedBitsType> + + Make::SignedBitsType> + ConvertTo + Compare; fn abs(self) -> Self; @@ -246,85 +209,12 @@ pub trait Float: Number + Neg { } } -macro_rules! impl_float { - ($ty:ty, $bits:ty, $signed_bits:ty) => { - impl Float for $ty { - type FloatEncoding = $ty; - type BitsType = $bits; - type SignedBitsType = $signed_bits; - fn abs(self) -> Self { - #[cfg(feature = "std")] - return self.abs(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn trunc(self) -> Self { - #[cfg(feature = "std")] - return self.trunc(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn ceil(self) -> Self { - #[cfg(feature = "std")] - return self.ceil(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn floor(self) -> Self { - #[cfg(feature = "std")] - return self.floor(); - #[cfg(not(feature = "std"))] - todo!(); - } - fn round(self) -> Self { - #[cfg(feature = "std")] - return self.round(); - #[cfg(not(feature = "std"))] - todo!(); - } - #[cfg(feature = "fma")] - fn fma(self, a: Self, b: Self) -> Self { - self.mul_add(a, b) - } - fn is_nan(self) -> Self::Bool { - self.is_nan() - } - fn is_infinite(self) -> Self::Bool { - self.is_infinite() - } - fn is_finite(self) -> Self::Bool { - self.is_finite() - } - fn from_bits(v: Self::BitsType) -> Self { - <$ty>::from_bits(v) - } - fn to_bits(self) -> Self::BitsType { - self.to_bits() - } - } - }; -} - -impl_float!(f32, u32, i32); -impl_float!(f64, u64, i64); - pub trait Bool: Make + BitOps {} -impl Bool for bool {} - pub trait Select: Bool { fn select(self, true_v: T, false_v: T) -> T; } -impl Select for bool { - fn select(self, true_v: T, false_v: T) -> T { - if self { - true_v - } else { - false_v - } - } -} pub trait Compare: Make { type Bool: Bool + Select; fn eq(self, rhs: Self) -> Self::Bool; @@ -334,33 +224,3 @@ pub trait Compare: Make { fn le(self, rhs: Self) -> Self::Bool; fn ge(self, rhs: Self) -> Self::Bool; } - -macro_rules! impl_compare_using_partial_cmp { - ($($ty:ty),*) => { - $( - impl Compare for $ty { - type Bool = bool; - fn eq(self, rhs: Self) -> Self::Bool { - self == rhs - } - fn ne(self, rhs: Self) -> Self::Bool { - self != rhs - } - fn lt(self, rhs: Self) -> Self::Bool { - self < rhs - } - fn gt(self, rhs: Self) -> Self::Bool { - self > rhs - } - fn le(self, rhs: Self) -> Self::Bool { - self <= rhs - } - fn ge(self, rhs: Self) -> Self::Bool { - self >= rhs - } - } - )* - }; -} - -impl_compare_using_partial_cmp![bool, u8, i8, u16, i16, F16, u32, i32, f32, u64, i64, f64]; diff --git a/vector-math-proc-macro/src/lib.rs b/vector-math-proc-macro/src/lib.rs index 1f5797a..5c4de02 100644 --- a/vector-math-proc-macro/src/lib.rs +++ b/vector-math-proc-macro/src/lib.rs @@ -317,10 +317,10 @@ impl TraitSets { } if convertibility.possible() { self.add_trait( - other_type_kind, - other_bits, + type_kind, + bits, vector_scalar, - quote! { ConvertTo }, + quote! { ConvertFrom }, ); } } -- 2.30.2