big.select(v, in_range_value)
}
+pub fn initial_rsqrt_approximation<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ v: VecF,
+) -> VecF {
+ // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
+ // where `CONST` is optimized for use for Goldschmidt's algorithm.
+ // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
+ // but using different constants.
+ const FACTOR: f64 = -0.5;
+ const TERM: f64 = 1.6;
+ v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
+}
+
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+pub fn sqrt_rsqrt_kernel<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ v: VecF,
+ iteration_count: usize,
+) -> (VecF, VecF) {
+ /// based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+ let y = initial_rsqrt_approximation(ctx, v);
+ let mut x = v * y;
+ let one_half: VecF = ctx.make(0.5.to());
+ let mut neg_h = y * -one_half;
+ for _ in 0..iteration_count {
+ let r = x.mul_add_fast(neg_h, one_half);
+ x = x.mul_add_fast(r, x);
+ neg_h = neg_h.mul_add_fast(r, neg_h);
+ }
+ (x, neg_h * ctx.make(-2.to()))
+}
+
#[cfg(test)]
mod tests {
use super::*;
);
}
}
+
+ // TODO: add tests for `sqrt_rsqrt_kernel`
}