switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_core.py
old mode 100644 (file)
new mode 100755 (executable)
index 0ee3935..d7aeded
@@ -1,17 +1,22 @@
+#!/usr/bin/env python3
 # SPDX-License-Identifier: LGPL-2.1-or-later
 # See Notices.txt for copyright information
 
-from .core import (DivPipeCoreConfig, DivPipeCoreSetupStage,
-                   DivPipeCoreCalculateStage, DivPipeCoreFinalStage,
-                   DivPipeCoreOperation, DivPipeCoreInputData,
-                   DivPipeCoreInterstageData, DivPipeCoreOutputData)
-from .algorithm import (FixedUDivRemSqrtRSqrt, Fixed, Operation, div_rem,
+from ieee754.div_rem_sqrt_rsqrt.core import (DivPipeCoreConfig,
+                    DivPipeCoreSetupStage,
+                    DivPipeCoreCalculateStage, DivPipeCoreFinalStage,
+                    DivPipeCoreOperation, DivPipeCoreInputData,
+                    DivPipeCoreInterstageData, DivPipeCoreOutputData)
+from ieee754.div_rem_sqrt_rsqrt.algorithm import (FixedUDivRemSqrtRSqrt,
+                        Fixed, Operation, div_rem,
                         fixed_sqrt, fixed_rsqrt)
 import unittest
-from nmigen import Module, Elaboratable
+from nmigen import Module, Elaboratable, Signal
 from nmigen.hdl.ir import Fragment
 from nmigen.back import rtlil
 from nmigen.back.pysim import Simulator, Delay, Tick
+from itertools import chain
+import inspect
 
 
 def show_fixed(bits, fract_width, bit_width):
@@ -29,6 +34,8 @@ def get_core_op(alg_op):
 
 
 class TestCaseData:
+    __test__ = False  # make pytest ignore class
+
     def __init__(self,
                  dividend,
                  divisor_radicand,
@@ -50,18 +57,18 @@ class TestCaseData:
     def __str__(self):
         bit_width = self.core_config.bit_width
         fract_width = self.core_config.fract_width
-        dividend_str = show_fixed(dividend,
+        dividend_str = show_fixed(self.dividend,
                                   fract_width * 2,
                                   bit_width + fract_width)
-        divisor_radicand_str = show_fixed(divisor_radicand,
+        divisor_radicand_str = show_fixed(self.divisor_radicand,
                                           fract_width,
                                           bit_width)
-        quotient_root_str = self.show_fixed(quotient_root,
-                                            fract_width,
-                                            bit_width)
-        remainder_str = self.show_fixed(remainder,
-                                        fract_width * 3,
-                                        bit_width * 3)
+        quotient_root_str = show_fixed(self.quotient_root,
+                                       fract_width,
+                                       bit_width)
+        remainder_str = show_fixed(self.remainder,
+                                   fract_width * 3,
+                                   bit_width * 3)
         return f"{{dividend={dividend_str}, " \
             + f"divisor_radicand={divisor_radicand_str}, " \
             + f"op={self.alg_op.name}, " \
@@ -73,81 +80,119 @@ class TestCaseData:
 def generate_test_case(core_config, dividend, divisor_radicand, alg_op):
     bit_width = core_config.bit_width
     fract_width = core_config.fract_width
-    if alg_op is Operation.UDivRem:
-        if divisor_radicand == 0:
-            return
-        quotient_root, remainder = div_rem(dividend,
-                                           divisor_radicand,
-                                           bit_width * 3,
-                                           False)
-        remainder <<= fract_width
-    elif alg_op is Operation.SqrtRem:
-        root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
-                                                    fract_width,
-                                                    bit_width,
-                                                    False))
-        quotient_root = root_remainder.root.bits
-        remainder = root_remainder.remainder.bits << fract_width
-    else:
-        assert alg_op is Operation.RSqrtRem
-        if divisor_radicand == 0:
-            return
-        root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
-                                                     fract_width,
-                                                     bit_width,
-                                                     False))
-        quotient_root = root_remainder.root.bits
-        remainder = root_remainder.remainder.bits
-    if quotient_root >= (1 << bit_width):
-        return
+    obj = FixedUDivRemSqrtRSqrt(dividend,
+                                divisor_radicand,
+                                alg_op,
+                                bit_width,
+                                fract_width,
+                                core_config.log2_radix)
+    obj.calculate()
     yield TestCaseData(dividend,
                        divisor_radicand,
                        alg_op,
-                       quotient_root,
-                       remainder,
+                       obj.quotient_root,
+                       obj.remainder,
                        core_config)
 
 
+def shifted_ints(total_bits, int_bits):
+    """ Generate a sequence like a generalized binary version of A037124.
+
+        See https://oeis.org/A037124
+
+        Generates the sequence of all non-negative integers ``n`` in ascending
+        order with no repeats where ``n < (1 << total_bits) and n == (v << i)``
+        where ``i`` is a non-negative integer and ``v`` is a non-negative
+        integer less than ``1 << int_bits``.
+    """
+    n = 0
+    while n < (1 << total_bits):
+        yield n
+        if n < (1 << int_bits):
+            n += 1
+        else:
+            n += 1 << (n.bit_length() - int_bits)
+
+
+def partitioned_ints(bit_width):
+    """ Get ints with all 1s on one side and 0s on the other. """
+    for i in range(bit_width):
+        yield (-1 << i) & ((1 << bit_width) - 1)
+        yield (1 << (i + 1)) - 1
+
+
+class TestShiftedInts(unittest.TestCase):
+    def test(self):
+        expected = [0x000,
+                    0x001,
+                    0x002, 0x003,
+                    0x004, 0x005, 0x006, 0x007,
+                    0x008, 0x009, 0x00A, 0x00B, 0x00C, 0x00D, 0x00E, 0x00F,
+                    0x010, 0x012, 0x014, 0x016, 0x018, 0x01A, 0x01C, 0x01E,
+                    0x020, 0x024, 0x028, 0x02C, 0x030, 0x034, 0x038, 0x03C,
+                    0x040, 0x048, 0x050, 0x058, 0x060, 0x068, 0x070, 0x078,
+                    0x080, 0x090, 0x0A0, 0x0B0, 0x0C0, 0x0D0, 0x0E0, 0x0F0,
+                    0x100, 0x120, 0x140, 0x160, 0x180, 0x1A0, 0x1C0, 0x1E0,
+                    0x200, 0x240, 0x280, 0x2C0, 0x300, 0x340, 0x380, 0x3C0,
+                    0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780,
+                    0x800, 0x900, 0xA00, 0xB00, 0xC00, 0xD00, 0xE00, 0xF00]
+        self.assertEqual(list(shifted_ints(12, 4)), expected)
+
+
 def get_test_cases(core_config,
-                   dividend_range=None,
-                   divisor_range=None,
-                   radicand_range=None):
-    if dividend_range is None:
-        dividend_range = range(1 << (core_config.bit_width
-                                     + core_config.fract_width))
-    if divisor_range is None:
-        divisor_range = range(1 << core_config.bit_width)
-    if radicand_range is None:
-        radicand_range = range(1 << core_config.bit_width)
-
-    for alg_op in Operation:
+                   dividends=None,
+                   divisors=None,
+                   radicands=None):
+    if dividends is None:
+        dividend_width = core_config.bit_width + core_config.fract_width
+        dividends = [*shifted_ints(dividend_width,
+                                   max(3, core_config.log2_radix)),
+                     *partitioned_ints(dividend_width)]
+    else:
+        assert isinstance(dividends, list)
+    if divisors is None:
+        divisors = [*shifted_ints(core_config.bit_width,
+                                  max(3, core_config.log2_radix)),
+                    *partitioned_ints(core_config.bit_width)]
+    else:
+        assert isinstance(divisors, list)
+    if radicands is None:
+        radicands = [*shifted_ints(core_config.bit_width, 5),
+                     *partitioned_ints(core_config.bit_width)]
+    else:
+        assert isinstance(radicands, list)
+
+    for alg_op in reversed(Operation):  # put UDivRem at end
+        if get_core_op(alg_op) not in core_config.supported:
+            continue
         if alg_op is Operation.UDivRem:
-            for dividend in dividend_range:
-                for divisor in divisor_range:
+            for dividend in dividends:
+                for divisor in divisors:
                     yield from generate_test_case(core_config,
                                                   dividend,
                                                   divisor,
                                                   alg_op)
         else:
-            for radicand in radicand_range:
+            for radicand in radicands:
                 yield from generate_test_case(core_config,
-                                              dividend,
+                                              0,
                                               radicand,
                                               alg_op)
 
 
 class DivPipeCoreTestPipeline(Elaboratable):
-    def __init__(self, core_config):
+    def __init__(self, core_config, sync):
         self.setup_stage = DivPipeCoreSetupStage(core_config)
         self.calculate_stages = [
             DivPipeCoreCalculateStage(core_config, stage_index)
-            for stage_index in range(core_config.num_calculate_stages)]
+            for stage_index in range(core_config.n_stages)]
         self.final_stage = DivPipeCoreFinalStage(core_config)
         self.interstage_signals = [
             DivPipeCoreInterstageData(core_config, reset_less=True)
-            for i in range(core_config.num_calculate_stages + 1)]
+            for i in range(core_config.n_stages + 1)]
         self.i = DivPipeCoreInputData(core_config, reset_less=True)
         self.o = DivPipeCoreOutputData(core_config, reset_less=True)
+        self.sync = sync
 
     def elaborate(self, platform):
         m = Module()
@@ -156,74 +201,201 @@ class DivPipeCoreTestPipeline(Elaboratable):
         stage_outputs = [*self.interstage_signals, self.o]
         for stage, input, output in zip(stages, stage_inputs, stage_outputs):
             stage.setup(m, input)
-            m.d.sync += output.eq(stage.process(input))
-
+            assignments = output.eq(stage.process(input))
+            if self.sync:
+                m.d.sync += assignments
+            else:
+                m.d.comb += assignments
         return m
 
     def traces(self):
         yield from self.i
-        for interstage_signal in self.interstage_signals:
-            yield from interstage_signal
+        for interstage_signal in self.interstage_signals:
+            yield from interstage_signal
         yield from self.o
 
 
+def trace_process(process, prefix="trace:", silent=False):
+    def generator():
+        if inspect.isgeneratorfunction(process):
+            proc = process()
+        else:
+            proc = process
+        response = None
+        while True:
+            try:
+                command = proc.send(response)
+                if not silent:
+                    print(prefix, command)
+            except StopIteration:
+                return
+            except Exception as e:
+                if not silent:
+                    print(prefix, "raised:", e)
+                raise e
+            response = (yield command)
+            if not silent:
+                print(prefix, "->", response)
+    return generator
+
+
 class TestDivPipeCore(unittest.TestCase):
-    def handle_case(self,
-                    core_config,
-                    dividend_range=None,
-                    divisor_range=None,
-                    radicand_range=None):
-        def gen_test_cases():
-            yield from get_test_cases(core_config,
-                                      dividend_range,
-                                      divisor_range,
-                                      radicand_range)
-        base_name = f"div_pipe_core_bit_width_{core_config.bit_width}"
+    def handle_config(self,
+                      core_config,
+                      test_cases=None,
+                      sync=True):
+        if test_cases is None:
+            test_cases = get_test_cases(core_config)
+        test_cases = list(test_cases)
+        base_name = f"test_div_pipe_core_bit_width_{core_config.bit_width}"
         base_name += f"_fract_width_{core_config.fract_width}"
         base_name += f"_radix_{1 << core_config.log2_radix}"
+        if not sync:
+            base_name += "_comb"
+        if core_config.supported != frozenset(DivPipeCoreOperation):
+            name_map = {
+                DivPipeCoreOperation.UDivRem: "div",
+                DivPipeCoreOperation.SqrtRem: "sqrt",
+                DivPipeCoreOperation.RSqrtRem: "rsqrt",
+            }
+            # loop using iter(DivPipeCoreOperation) to maintain order
+            for op in DivPipeCoreOperation:
+                if op in core_config.supported:
+                    base_name += f"_{name_map[op]}"
+            base_name+="_only"
+
         with self.subTest(part="synthesize"):
-            dut = DivPipeCoreTestPipeline(core_config)
+            dut = DivPipeCoreTestPipeline(core_config, sync)
             vl = rtlil.convert(dut, ports=[*dut.i, *dut.o])
             with open(f"{base_name}.il", "w") as f:
                 f.write(vl)
-        dut = DivPipeCoreTestPipeline(core_config)
-        with Simulator(dut,
-                       vcd_file=f"{base_name}.vcd",
-                       gtkw_file=f"{base_name}.gtkw",
-                       traces=[*dut.traces()]) as sim:
+        dut = DivPipeCoreTestPipeline(core_config, sync)
+        sim = Simulator(dut)
+        with sim.write_vcd(vcd_file=open(f"{base_name}.vcd", "w"),
+                           gtkw_file=open(f"{base_name}.gtkw", "w"),
+                           traces=[*dut.traces()]):
             def generate_process():
-                for test_case in gen_test_cases():
+                if not sync:
+                    yield Delay(1e-6)
+                for test_case in test_cases:
+                    if sync:
+                        yield Tick()
                     yield dut.i.dividend.eq(test_case.dividend)
                     yield dut.i.divisor_radicand.eq(test_case.divisor_radicand)
-                    yield dut.i.operation.eq(test_case.core_op)
-                    yield Delay(1e-6)
-                    yield Tick()
+                    yield dut.i.operation.eq(int(test_case.core_op))
+                    if sync:
+                        yield Delay(0.9e-6)
+                    else:
+                        yield Delay(1e-6)
 
             def check_process():
                 # sync with generator
-                yield
-                for _ in core_config.num_calculate_stages:
-                    yield
-                yield
+                if sync:
+                    yield Tick()
+                    for _ in range(core_config.n_stages):
+                        yield Tick()
+                    yield Tick()
+                else:
+                    yield Delay(0.5e-6)
 
                 # now synched with generator
-                for test_case in gen_test_cases():
-                    yield Delay(1e-6)
+                for test_case in test_cases:
+                    if sync:
+                        yield Tick()
+                        yield Delay(0.9e-6)
+                    else:
+                        yield Delay(1e-6)
                     quotient_root = (yield dut.o.quotient_root)
                     remainder = (yield dut.o.remainder)
                     with self.subTest(test_case=str(test_case)):
                         self.assertEqual(quotient_root,
-                                         test_case.quotient_root)
-                        self.assertEqual(remainder, test_case.remainder)
-                    yield Tick()
-            sim.add_clock(2e-6)
-            sim.add_sync_process(generate_process)
-            sim.add_sync_process(check_process)
+                                         test_case.quotient_root,
+                                         str(test_case))
+                        self.assertEqual(remainder, test_case.remainder,
+                                         str(test_case))
+            if sync:
+                sim.add_clock(2e-6)
+            silent = True
+            sim.add_process(trace_process(generate_process, "generate:", silent=silent))
+            sim.add_process(trace_process(check_process, "check:", silent=silent))
             sim.run()
 
+    def test_bit_width_2_fract_width_1_radix_2_comb(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=2,
+                                             fract_width=1,
+                                             log2_radix=1),
+                           sync=False)
+
+    def test_bit_width_2_fract_width_1_radix_2(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=2,
+                                             fract_width=1,
+                                             log2_radix=1))
+
+    def test_bit_width_8_fract_width_4_radix_2_comb(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=8,
+                                             fract_width=4,
+                                             log2_radix=1),
+                           sync=False)
+
     def test_bit_width_8_fract_width_4_radix_2(self):
-        self.handle_case(DivPipeCoreConfig(bit_width=8,
-                                           fract_width=4,
-                                           log2_radix=1))
+        self.handle_config(DivPipeCoreConfig(bit_width=8,
+                                             fract_width=4,
+                                             log2_radix=1))
+
+    def test_bit_width_8_fract_width_4_radix_4_comb(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=8,
+                                             fract_width=4,
+                                             log2_radix=2),
+                           sync=False)
+
+    def test_bit_width_8_fract_width_4_radix_4(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=8,
+                                             fract_width=4,
+                                             log2_radix=2))
+
+    def test_bit_width_8_fract_width_4_radix_4_div_only(self):
+        supported = (DivPipeCoreOperation.UDivRem,)
+        self.handle_config(DivPipeCoreConfig(bit_width=8,
+                                             fract_width=4,
+                                             log2_radix=2,
+                                             supported=supported))
+
+    def test_bit_width_8_fract_width_4_radix_4_comb_div_only(self):
+        supported = (DivPipeCoreOperation.UDivRem,)
+        self.handle_config(DivPipeCoreConfig(bit_width=8,
+                                             fract_width=4,
+                                             log2_radix=2,
+                                             supported=supported),
+                           sync=False)
+
+    @unittest.skip("really slow")
+    def test_bit_width_32_fract_width_24_radix_8_comb(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=32,
+                                             fract_width=24,
+                                             log2_radix=3),
+                           sync=False)
+
+    @unittest.skip("really slow")
+    def test_bit_width_32_fract_width_24_radix_8(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=32,
+                                             fract_width=24,
+                                             log2_radix=3))
+
+    @unittest.skip("really slow")
+    def test_bit_width_32_fract_width_28_radix_8_comb(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=32,
+                                             fract_width=28,
+                                             log2_radix=3),
+                           sync=False)
+
+    @unittest.skip("really slow")
+    def test_bit_width_32_fract_width_28_radix_8(self):
+        self.handle_config(DivPipeCoreConfig(bit_width=32,
+                                             fract_width=28,
+                                             log2_radix=3))
 
     # FIXME: add more test_* functions
+
+
+if __name__ == '__main__':
+    unittest.main()