d38734091a90b41ebabcb6d24c4ebbe4e49212d6
[vector-math.git] / src / algorithms / base.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, Select, UInt},
4 };
5
6 pub fn abs<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
9 PrimF: PrimFloat<BitsType = PrimU>,
10 PrimU: PrimUInt,
11 >(
12 ctx: Ctx,
13 x: VecF,
14 ) -> VecF {
15 VecF::from_bits(x.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK))
16 }
17
18 pub fn copy_sign<
19 Ctx: Context,
20 VecF: Float<PrimFloat = PrimF> + Make<Context = Ctx>,
21 PrimF: PrimFloat<BitsType = PrimU>,
22 PrimU: PrimUInt,
23 >(
24 ctx: Ctx,
25 mag: VecF,
26 sign: VecF,
27 ) -> VecF {
28 let mag_bits = mag.to_bits() & ctx.make(!PrimF::SIGN_FIELD_MASK);
29 let sign_bit = sign.to_bits() & ctx.make(PrimF::SIGN_FIELD_MASK);
30 VecF::from_bits(mag_bits | sign_bit)
31 }
32
33 pub fn trunc<
34 Ctx: Context,
35 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
36 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
37 PrimF: PrimFloat<BitsType = PrimU>,
38 PrimU: PrimUInt,
39 >(
40 ctx: Ctx,
41 v: VecF,
42 ) -> VecF {
43 let big_limit: VecF = ctx.make(PrimF::IMPLICIT_MANTISSA_BIT.to());
44 let big = !v.abs().lt(big_limit); // use `lt` so nans are counted as big
45 let small = v.abs().lt(ctx.make(PrimF::cvt_from(1)));
46 let out_of_range = big | small;
47 let small_value = ctx.make::<VecF>(0.to()).copy_sign(v);
48 let out_of_range_value = small.select(small_value, v);
49 let exponent_field = v.extract_exponent_field();
50 let right_shift_amount: VecU = exponent_field - ctx.make(PrimF::EXPONENT_BIAS_UNSIGNED);
51 let mut mask: VecU = ctx.make(PrimF::MANTISSA_FIELD_MASK);
52 mask >>= right_shift_amount;
53 let in_range_value = VecF::from_bits(v.to_bits() & !mask);
54 out_of_range.select(out_of_range_value, in_range_value)
55 }
56
57 #[cfg(test)]
58 mod tests {
59 use super::*;
60 use crate::{
61 f16::F16,
62 scalar::{Scalar, Value},
63 };
64
65 #[test]
66 #[cfg_attr(
67 not(feature = "f16"),
68 should_panic(expected = "f16 feature is not enabled")
69 )]
70 fn test_abs_f16() {
71 for bits in 0..=u16::MAX {
72 let v = F16::from_bits(bits);
73 let expected = v.abs();
74 let result = abs(Scalar, Value(v)).0;
75 assert_eq!(expected.to_bits(), result.to_bits());
76 }
77 }
78
79 #[test]
80 fn test_abs_f32() {
81 for bits in (0..=u32::MAX).step_by(10001) {
82 let v = f32::from_bits(bits);
83 let expected = v.abs();
84 let result = abs(Scalar, Value(v)).0;
85 assert_eq!(expected.to_bits(), result.to_bits());
86 }
87 }
88
89 #[test]
90 fn test_abs_f64() {
91 for bits in (0..=u64::MAX).step_by(100_000_000_000_001) {
92 let v = f64::from_bits(bits);
93 let expected = v.abs();
94 let result = abs(Scalar, Value(v)).0;
95 assert_eq!(expected.to_bits(), result.to_bits());
96 }
97 }
98
99 #[test]
100 #[cfg_attr(
101 not(feature = "f16"),
102 should_panic(expected = "f16 feature is not enabled")
103 )]
104 fn test_copy_sign_f16() {
105 #[track_caller]
106 fn check(mag_bits: u16, sign_bits: u16) {
107 let mag = F16::from_bits(mag_bits);
108 let sign = F16::from_bits(sign_bits);
109 let expected = mag.copysign(sign);
110 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
111 assert_eq!(expected.to_bits(), result.to_bits());
112 }
113 for mag_low_bits in 0..16 {
114 for mag_high_bits in 0..16 {
115 for sign_low_bits in 0..16 {
116 for sign_high_bits in 0..16 {
117 check(
118 mag_low_bits | (mag_high_bits << (16 - 4)),
119 sign_low_bits | (sign_high_bits << (16 - 4)),
120 );
121 }
122 }
123 }
124 }
125 }
126
127 #[test]
128 fn test_copy_sign_f32() {
129 #[track_caller]
130 fn check(mag_bits: u32, sign_bits: u32) {
131 let mag = f32::from_bits(mag_bits);
132 let sign = f32::from_bits(sign_bits);
133 let expected = mag.copysign(sign);
134 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
135 assert_eq!(expected.to_bits(), result.to_bits());
136 }
137 for mag_low_bits in 0..16 {
138 for mag_high_bits in 0..16 {
139 for sign_low_bits in 0..16 {
140 for sign_high_bits in 0..16 {
141 check(
142 mag_low_bits | (mag_high_bits << (32 - 4)),
143 sign_low_bits | (sign_high_bits << (32 - 4)),
144 );
145 }
146 }
147 }
148 }
149 }
150
151 #[test]
152 fn test_copy_sign_f64() {
153 #[track_caller]
154 fn check(mag_bits: u64, sign_bits: u64) {
155 let mag = f64::from_bits(mag_bits);
156 let sign = f64::from_bits(sign_bits);
157 let expected = mag.copysign(sign);
158 let result = copy_sign(Scalar, Value(mag), Value(sign)).0;
159 assert_eq!(expected.to_bits(), result.to_bits());
160 }
161 for mag_low_bits in 0..16 {
162 for mag_high_bits in 0..16 {
163 for sign_low_bits in 0..16 {
164 for sign_high_bits in 0..16 {
165 check(
166 mag_low_bits | (mag_high_bits << (64 - 4)),
167 sign_low_bits | (sign_high_bits << (64 - 4)),
168 );
169 }
170 }
171 }
172 }
173 }
174
175 fn same<F: PrimFloat>(a: F, b: F) -> bool {
176 if a.is_finite() && b.is_finite() {
177 a == b
178 } else {
179 a == b || (a.is_nan() && b.is_nan())
180 }
181 }
182
183 #[test]
184 #[cfg_attr(
185 not(feature = "f16"),
186 should_panic(expected = "f16 feature is not enabled")
187 )]
188 fn test_trunc_f16() {
189 for bits in 0..=u16::MAX {
190 let v = F16::from_bits(bits);
191 let expected = v.trunc();
192 let result = trunc(Scalar, Value(v)).0;
193 assert!(
194 same(expected, result),
195 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
196 v=v,
197 v_bits=v.to_bits(),
198 expected=expected,
199 expected_bits=expected.to_bits(),
200 result=result,
201 result_bits=result.to_bits(),
202 );
203 }
204 }
205
206 #[test]
207 fn test_trunc_f32() {
208 for bits in (0..=u32::MAX).step_by(0x10000) {
209 let v = f32::from_bits(bits);
210 let expected = v.trunc();
211 let result = trunc(Scalar, Value(v)).0;
212 assert!(
213 same(expected, result),
214 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
215 v=v,
216 v_bits=v.to_bits(),
217 expected=expected,
218 expected_bits=expected.to_bits(),
219 result=result,
220 result_bits=result.to_bits(),
221 );
222 }
223 }
224
225 #[test]
226 fn test_trunc_f64() {
227 for bits in (0..=u64::MAX).step_by(1 << 48) {
228 let v = f64::from_bits(bits);
229 let expected = v.trunc();
230 let result = trunc(Scalar, Value(v)).0;
231 assert!(
232 same(expected, result),
233 "case failed: v={v}, v_bits={v_bits:#X}, expected={expected}, expected_bits={expected_bits:#X}, result={result}, result_bits={result_bits:#X}",
234 v=v,
235 v_bits=v.to_bits(),
236 expected=expected,
237 expected_bits=expected.to_bits(),
238 result=result,
239 result_bits=result.to_bits(),
240 );
241 }
242 }
243 }