00044760b16ffe794cf3c3ddbd5499fe1b6849cc
[vector-math.git] / src / algorithms / roots.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, UInt},
4 };
5
6 pub fn initial_rsqrt_approximation<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
9 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
10 PrimF: PrimFloat<BitsType = PrimU>,
11 PrimU: PrimUInt,
12 >(
13 ctx: Ctx,
14 v: VecF,
15 ) -> VecF {
16 // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
17 // where `CONST` is optimized for use for Goldschmidt's algorithm.
18 // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
19 // but using different constants.
20 const FACTOR: f64 = -0.5;
21 const TERM: f64 = 1.6;
22 v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
23 }
24
25 /// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm followed by one step of Newton's method
26 /// TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
27 pub fn sqrt_rsqrt_kernel<
28 Ctx: Context,
29 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
30 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
31 PrimF: PrimFloat<BitsType = PrimU>,
32 PrimU: PrimUInt,
33 >(
34 ctx: Ctx,
35 v: VecF,
36 iteration_count: usize,
37 ) -> (VecF, VecF) {
38 // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
39 let y = initial_rsqrt_approximation(ctx, v);
40 let mut x = v * y;
41 let one_half: VecF = ctx.make(0.5.to());
42 let three_halves: VecF = ctx.make((3.0 / 2.0).to());
43 let mut neg_h = y * -one_half;
44 for _ in 0..iteration_count {
45 let r = x.mul_add_fast(neg_h, one_half);
46 x = x.mul_add_fast(r, x);
47 neg_h = neg_h.mul_add_fast(r, neg_h);
48 }
49 let sqrt = x;
50 let rsqrt = neg_h * ctx.make(PrimF::cvt_from(-2));
51 // do one step of Newton's method
52 let rsqrt = rsqrt * (three_halves - (v * one_half) * (rsqrt * rsqrt));
53 (sqrt, rsqrt)
54 }
55
56 #[cfg(test)]
57 mod tests {
58 use super::*;
59 use crate::{
60 f16::F16,
61 scalar::{Scalar, Value},
62 };
63 use std::fmt;
64
65 fn same<F: PrimFloat>(a: F, b: F) -> bool {
66 if a.is_finite() && b.is_finite() {
67 a.to_bits() == b.to_bits()
68 } else {
69 a == b || (a.is_nan() && b.is_nan())
70 }
71 }
72
73 struct DisplayValueAndBits<F: PrimFloat>(F);
74
75 impl<F: PrimFloat> fmt::Display for DisplayValueAndBits<F> {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 write!(f, "{} {:#X}", self.0, self.0.to_bits())
78 }
79 }
80
81 #[test]
82 #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
83 #[cfg_attr(
84 not(feature = "f16"),
85 should_panic(expected = "f16 feature is not enabled")
86 )]
87 fn test_sqrt_rsqrt_kernel_f16() {
88 let start = F16::to_bits(0.5.to());
89 let end = F16::to_bits(2.0.to());
90 for bits in start..=end {
91 let v = F16::from_bits(bits);
92 let expected_sqrt: F16 = f64::from(v).sqrt().to();
93 let expected_rsqrt: F16 = (1.0 / f64::from(v).sqrt()).to();
94 let (Value(result_sqrt), Value(result_rsqrt)) = sqrt_rsqrt_kernel(Scalar, Value(v), 3);
95 assert!(
96 same(expected_sqrt, result_sqrt) && same(expected_rsqrt, result_rsqrt),
97 "case failed: v={v}, expected_sqrt={expected_sqrt}, result_sqrt={result_sqrt}, expected_rsqrt={expected_rsqrt}, result_rsqrt={result_rsqrt}",
98 v = DisplayValueAndBits(v),
99 expected_sqrt = DisplayValueAndBits(expected_sqrt),
100 result_sqrt = DisplayValueAndBits(result_sqrt),
101 expected_rsqrt = DisplayValueAndBits(expected_rsqrt),
102 result_rsqrt = DisplayValueAndBits(result_rsqrt),
103 );
104 }
105 }
106 }