move sqrt/rsqrt code to roots.rs
[vector-math.git] / src / algorithms / roots.rs
1 use crate::{
2 prim::{PrimFloat, PrimUInt},
3 traits::{Context, ConvertTo, Float, Make, UInt},
4 };
5
6 pub fn initial_rsqrt_approximation<
7 Ctx: Context,
8 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
9 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
10 PrimF: PrimFloat<BitsType = PrimU>,
11 PrimU: PrimUInt,
12 >(
13 ctx: Ctx,
14 v: VecF,
15 ) -> VecF {
16 // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
17 // where `CONST` is optimized for use for Goldschmidt's algorithm.
18 // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
19 // but using different constants.
20 const FACTOR: f64 = -0.5;
21 const TERM: f64 = 1.6;
22 v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
23 }
24
25 /// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
26 pub fn sqrt_rsqrt_kernel<
27 Ctx: Context,
28 VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
29 VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
30 PrimF: PrimFloat<BitsType = PrimU>,
31 PrimU: PrimUInt,
32 >(
33 ctx: Ctx,
34 v: VecF,
35 iteration_count: usize,
36 ) -> (VecF, VecF) {
37 // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
38 let y = initial_rsqrt_approximation(ctx, v);
39 let mut x = v * y;
40 let one_half: VecF = ctx.make(0.5.to());
41 let mut neg_h = y * -one_half;
42 for _ in 0..iteration_count {
43 let r = x.mul_add_fast(neg_h, one_half);
44 x = x.mul_add_fast(r, x);
45 neg_h = neg_h.mul_add_fast(r, neg_h);
46 }
47 (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
48 }
49
50 #[cfg(test)]
51 mod tests {
52 use super::*;
53
54 // TODO: add tests for `sqrt_rsqrt_kernel`
55 }