working on adding sqrt and rsqrt implementation
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 20 May 2021 03:40:43 +0000 (20:40 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 20 May 2021 03:40:43 +0000 (20:40 -0700)
src/algorithms/base.rs

index 4ebd8493ee3d7dc12e0753e469db896253e6f14f..255b7ebc0ee0732ecf5d73934725b896d560c5c1 100644 (file)
@@ -116,6 +116,50 @@ pub fn ceil<
     big.select(v, in_range_value)
 }
 
+pub fn initial_rsqrt_approximation<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
+    // where `CONST` is optimized for use for Goldschmidt's algorithm.
+    // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
+    // but using different constants.
+    const FACTOR: f64 = -0.5;
+    const TERM: f64 = 1.6;
+    v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
+}
+
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+pub fn sqrt_rsqrt_kernel<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+    iteration_count: usize,
+) -> (VecF, VecF) {
+    /// based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+    let y = initial_rsqrt_approximation(ctx, v);
+    let mut x = v * y;
+    let one_half: VecF = ctx.make(0.5.to());
+    let mut neg_h = y * -one_half;
+    for _ in 0..iteration_count {
+        let r = x.mul_add_fast(neg_h, one_half);
+        x = x.mul_add_fast(r, x);
+        neg_h = neg_h.mul_add_fast(r, neg_h);
+    }
+    (x, neg_h * ctx.make(-2.to()))
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -568,4 +612,6 @@ mod tests {
             );
         }
     }
+
+    // TODO: add tests for `sqrt_rsqrt_kernel`
 }