add the goldschmidt sqrt/rsqrt algorithm, still need code to calculate good parameters
[soc.git] / src / soc / fu / div / experiment / test / test_goldschmidt_div_sqrt.py
index 9e2763410f65e20aa25888b8094bbcab3b80104b..e2984dc16db684cc31ea20973d235ba72e9aa01d 100644 (file)
@@ -4,10 +4,12 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
+import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint)
+    GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div,
+    FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
 
 class TestFixedPoint(FHDLTestCase):
@@ -19,6 +21,41 @@ class TestFixedPoint(FHDLTestCase):
                     round_trip_value = FixedPoint.cast(str(value))
                     self.assertEqual(value, round_trip_value)
 
+    @staticmethod
+    def trap(f):
+        try:
+            return f(), None
+        except (ValueError, ZeroDivisionError) as e:
+            return None, e.__class__.__name__
+
+    def test_sqrt(self):
+        for frac_wid in range(8):
+            for bits in range(1 << 9):
+                for round_dir in RoundDir:
+                    radicand = FixedPoint(bits, frac_wid)
+                    expected_f = math.sqrt(float(radicand))
+                    expected = self.trap(lambda: FixedPoint.with_frac_wid(
+                        expected_f, frac_wid, round_dir))
+                    with self.subTest(radicand=repr(radicand),
+                                      round_dir=str(round_dir),
+                                      expected=repr(expected)):
+                        result = self.trap(lambda: radicand.sqrt(round_dir))
+                        self.assertEqual(result, expected)
+
+    def test_rsqrt(self):
+        for frac_wid in range(8):
+            for bits in range(1, 1 << 9):
+                for round_dir in RoundDir:
+                    radicand = FixedPoint(bits, frac_wid)
+                    expected_f = 1 / math.sqrt(float(radicand))
+                    expected = self.trap(lambda: FixedPoint.with_frac_wid(
+                        expected_f, frac_wid, round_dir))
+                    with self.subTest(radicand=repr(radicand),
+                                      round_dir=str(round_dir),
+                                      expected=repr(expected)):
+                        result = self.trap(lambda: radicand.rsqrt(round_dir))
+                        self.assertEqual(result, expected)
+
 
 class TestGoldschmidtDiv(FHDLTestCase):
     def test_case1(self):
@@ -257,5 +294,44 @@ class TestGoldschmidtDiv(FHDLTestCase):
         self.tst_params(64)
 
 
+class TestGoldschmidtSqrtRSqrt(FHDLTestCase):
+    def tst(self, io_width, frac_wid, extra_precision,
+            table_addr_bits, table_data_bits, iter_count):
+        assert isinstance(io_width, int)
+        assert isinstance(frac_wid, int)
+        assert isinstance(extra_precision, int)
+        assert isinstance(table_addr_bits, int)
+        assert isinstance(table_data_bits, int)
+        assert isinstance(iter_count, int)
+        with self.subTest(io_width=io_width, frac_wid=frac_wid,
+                          extra_precision=extra_precision,
+                          table_addr_bits=table_addr_bits,
+                          table_data_bits=table_data_bits,
+                          iter_count=iter_count):
+            for bits in range(1 << io_width):
+                radicand = FixedPoint(bits, frac_wid)
+                expected_sqrt = radicand.sqrt(RoundDir.DOWN)
+                expected_rsqrt = FixedPoint(0, frac_wid)
+                if radicand > 0:
+                    expected_rsqrt = radicand.rsqrt(RoundDir.DOWN)
+                with self.subTest(radicand=repr(radicand),
+                                  expected_sqrt=repr(expected_sqrt),
+                                  expected_rsqrt=repr(expected_rsqrt)):
+                    sqrt, rsqrt = goldschmidt_sqrt_rsqrt(
+                        radicand=radicand, io_width=io_width,
+                        frac_wid=frac_wid,
+                        extra_precision=extra_precision,
+                        table_addr_bits=table_addr_bits,
+                        table_data_bits=table_data_bits,
+                        iter_count=iter_count)
+                    with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)):
+                        self.assertEqual((sqrt, rsqrt),
+                                         (expected_sqrt, expected_rsqrt))
+
+    def test1(self):
+        self.tst(io_width=16, frac_wid=8, extra_precision=20,
+                 table_addr_bits=4, table_data_bits=28, iter_count=4)
+
+
 if __name__ == "__main__":
     unittest.main()