big.select(v, in_range_value)
}
-pub fn initial_rsqrt_approximation<
- Ctx: Context,
- VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
- VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
- PrimF: PrimFloat<BitsType = PrimU>,
- PrimU: PrimUInt,
->(
- ctx: Ctx,
- v: VecF,
-) -> VecF {
- // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
- // where `CONST` is optimized for use for Goldschmidt's algorithm.
- // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
- // but using different constants.
- const FACTOR: f64 = -0.5;
- const TERM: f64 = 1.6;
- v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
-}
-
-/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
-pub fn sqrt_rsqrt_kernel<
- Ctx: Context,
- VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
- VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
- PrimF: PrimFloat<BitsType = PrimU>,
- PrimU: PrimUInt,
->(
- ctx: Ctx,
- v: VecF,
- iteration_count: usize,
-) -> (VecF, VecF) {
- // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
- let y = initial_rsqrt_approximation(ctx, v);
- let mut x = v * y;
- let one_half: VecF = ctx.make(0.5.to());
- let mut neg_h = y * -one_half;
- for _ in 0..iteration_count {
- let r = x.mul_add_fast(neg_h, one_half);
- x = x.mul_add_fast(r, x);
- neg_h = neg_h.mul_add_fast(r, neg_h);
- }
- (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
-}
-
#[cfg(test)]
mod tests {
use super::*;
);
}
}
-
- // TODO: add tests for `sqrt_rsqrt_kernel`
}
--- /dev/null
+use crate::{
+ prim::{PrimFloat, PrimUInt},
+ traits::{Context, ConvertTo, Float, Make, UInt},
+};
+
+pub fn initial_rsqrt_approximation<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ v: VecF,
+) -> VecF {
+ // TODO: change to using `from_bits(CONST - v.to_bits() >> 1)` approximation
+ // where `CONST` is optimized for use for Goldschmidt's algorithm.
+ // Similar to https://en.wikipedia.org/wiki/Fast_inverse_square_root
+ // but using different constants.
+ const FACTOR: f64 = -0.5;
+ const TERM: f64 = 1.6;
+ v.mul_add_fast(ctx.make(FACTOR.to()), ctx.make(TERM.to()))
+}
+
+/// calculate `(sqrt(v), 1 / sqrt(v))` using Goldschmidt's algorithm
+pub fn sqrt_rsqrt_kernel<
+ Ctx: Context,
+ VecF: Float<PrimFloat = PrimF, BitsType = VecU> + Make<Context = Ctx>,
+ VecU: UInt<PrimUInt = PrimU> + Make<Context = Ctx>,
+ PrimF: PrimFloat<BitsType = PrimU>,
+ PrimU: PrimUInt,
+>(
+ ctx: Ctx,
+ v: VecF,
+ iteration_count: usize,
+) -> (VecF, VecF) {
+ // based on second algorithm of https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+ let y = initial_rsqrt_approximation(ctx, v);
+ let mut x = v * y;
+ let one_half: VecF = ctx.make(0.5.to());
+ let mut neg_h = y * -one_half;
+ for _ in 0..iteration_count {
+ let r = x.mul_add_fast(neg_h, one_half);
+ x = x.mul_add_fast(r, x);
+ neg_h = neg_h.mul_add_fast(r, neg_h);
+ }
+ (x, neg_h * ctx.make(PrimF::cvt_from(-2)))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ // TODO: add tests for `sqrt_rsqrt_kernel`
+}