+ fn same<F: PrimFloat>(a: F, b: F) -> bool {
+ if a.is_finite() && b.is_finite() {
+ a.to_bits() == b.to_bits()
+ } else {
+ a == b || (a.is_nan() && b.is_nan())
+ }
+ }
+
+ struct DisplayValueAndBits<F: PrimFloat>(F);
+
+ impl<F: PrimFloat> fmt::Display for DisplayValueAndBits<F> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{} {:#X}", self.0, self.0.to_bits())
+ }
+ }
+
+ #[test]
+ #[cfg(not(test))] // TODO: broken for now -- rewrite sqrt_rsqrt_kernel to use float-float arithmetic
+ #[cfg_attr(
+ not(feature = "f16"),
+ should_panic(expected = "f16 feature is not enabled")
+ )]
+ fn test_sqrt_rsqrt_kernel_f16() {
+ let start = F16::to_bits(0.5.to());
+ let end = F16::to_bits(2.0.to());
+ for bits in start..=end {
+ let v = F16::from_bits(bits);
+ let expected_sqrt: F16 = f64::from(v).sqrt().to();
+ let expected_rsqrt: F16 = (1.0 / f64::from(v).sqrt()).to();
+ let (Value(result_sqrt), Value(result_rsqrt)) = sqrt_rsqrt_kernel(Scalar, Value(v), 3);
+ assert!(
+ same(expected_sqrt, result_sqrt) && same(expected_rsqrt, result_rsqrt),
+ "case failed: v={v}, expected_sqrt={expected_sqrt}, result_sqrt={result_sqrt}, expected_rsqrt={expected_rsqrt}, result_rsqrt={result_rsqrt}",
+ v = DisplayValueAndBits(v),
+ expected_sqrt = DisplayValueAndBits(expected_sqrt),
+ result_sqrt = DisplayValueAndBits(result_sqrt),
+ expected_rsqrt = DisplayValueAndBits(expected_rsqrt),
+ result_rsqrt = DisplayValueAndBits(result_rsqrt),
+ );
+ }
+ }