move sqrt/rsqrt code to roots.rs
[vector-math.git] / src / algorithms / roots.rs
diff --git a/src/algorithms/roots.rs b/src/algorithms/roots.rs
new file mode 100644 (file)
index 0000000..4318bc9
--- /dev/null
@@ -0,0 +1,55 @@
+use crate::{
+    prim::{PrimFloat, PrimUInt},
+    traits::{Context, ConvertTo, Float, Make, UInt},
+};
+
+pub fn initial_rsqrt_approximation<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
+    // where `CONST` is optimized for use for Goldschmidt's algorithm.
+    // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
+    // but using different constants.
+    const FACTOR: f64 = -0.5;
+    const TERM: f64 = 1.6;
+    v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
+}
+
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+pub fn sqrt_rsqrt_kernel<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+    iteration_count: usize,
+) -> (VecF, VecF) {
+    // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+    let y = initial_rsqrt_approximation(ctx, v);
+    let mut x = v * y;
+    let one_half: VecF = ctx.make(0.5.to());
+    let mut neg_h = y * -one_half;
+    for _ in 0..iteration_count {
+        let r = x.mul_add_fast(neg_h, one_half);
+        x = x.mul_add_fast(r, x);
+        neg_h = neg_h.mul_add_fast(r, neg_h);
+    }
+    (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    // TODO: add tests for `sqrt_rsqrt_kernel`
+}