move sqrt/rsqrt code to roots.rs
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 21 May 2021 02:28:29 +0000 (19:28 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 21 May 2021 02:28:29 +0000 (19:28 -0700)
src/algorithms.rs
src/algorithms/base.rs
src/algorithms/roots.rs [new file with mode: 0644]

index 4278ac2da94d1623c3e78f9f4dd81a51bad48821..4c2aa82803b16bc11c84784d49466cf4d551fd0f 100644 (file)
@@ -1,4 +1,5 @@
 pub mod base;
 pub mod ilogb;
 pub mod integer;
+pub mod roots;
 pub mod trig_pi;
index b6e342604962cf666677a7454a47e450fd8f2f29..4ebd8493ee3d7dc12e0753e469db896253e6f14f 100644 (file)
@@ -116,50 +116,6 @@ pub fn ceil<
     big.select(v, in_range_value)
 }
 
-pub fn initial_rsqrt_approximation<
-    Ctx: Context,
-    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
-    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
-    PrimF: PrimFloat<BitsType = PrimU>,
-    PrimU: PrimUInt,
->(
-    ctx: Ctx,
-    v: VecF,
-) -> VecF {
-    // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
-    // where `CONST` is optimized for use for Goldschmidt's algorithm.
-    // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
-    // but using different constants.
-    const FACTOR: f64 = -0.5;
-    const TERM: f64 = 1.6;
-    v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
-}
-
-/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
-pub fn sqrt_rsqrt_kernel<
-    Ctx: Context,
-    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
-    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
-    PrimF: PrimFloat<BitsType = PrimU>,
-    PrimU: PrimUInt,
->(
-    ctx: Ctx,
-    v: VecF,
-    iteration_count: usize,
-) -> (VecF, VecF) {
-    // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
-    let y = initial_rsqrt_approximation(ctx, v);
-    let mut x = v * y;
-    let one_half: VecF = ctx.make(0.5.to());
-    let mut neg_h = y * -one_half;
-    for _ in 0..iteration_count {
-        let r = x.mul_add_fast(neg_h, one_half);
-        x = x.mul_add_fast(r, x);
-        neg_h = neg_h.mul_add_fast(r, neg_h);
-    }
-    (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -612,6 +568,4 @@ mod tests {
             );
         }
     }
-
-    // TODO: add tests for `sqrt_rsqrt_kernel`
 }
diff --git a/src/algorithms/roots.rs b/src/algorithms/roots.rs
new file mode 100644 (file)
index 0000000..4318bc9
--- /dev/null
@@ -0,0 +1,55 @@
+use crate::{
+    prim::{PrimFloat, PrimUInt},
+    traits::{Context, ConvertTo, Float, Make, UInt},
+};
+
+pub fn initial_rsqrt_approximation<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+) -> VecF {
+    // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
+    // where `CONST` is optimized for use for Goldschmidt's algorithm.
+    // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
+    // but using different constants.
+    const FACTOR: f64 = -0.5;
+    const TERM: f64 = 1.6;
+    v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
+}
+
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+pub fn sqrt_rsqrt_kernel<
+    Ctx: Context,
+    VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+    VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+    PrimF: PrimFloat<BitsType = PrimU>,
+    PrimU: PrimUInt,
+>(
+    ctx: Ctx,
+    v: VecF,
+    iteration_count: usize,
+) -> (VecF, VecF) {
+    // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+    let y = initial_rsqrt_approximation(ctx, v);
+    let mut x = v * y;
+    let one_half: VecF = ctx.make(0.5.to());
+    let mut neg_h = y * -one_half;
+    for _ in 0..iteration_count {
+        let r = x.mul_add_fast(neg_h, one_half);
+        x = x.mul_add_fast(r, x);
+        neg_h = neg_h.mul_add_fast(r, neg_h);
+    }
+    (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    // TODO: add tests for `sqrt_rsqrt_kernel`
+}