implement CLDivRemFSMStage
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 4 May 2022 06:30:59 +0000 (23:30 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 4 May 2022 06:30:59 +0000 (23:30 -0700)
src/nmigen_gf/hdl/cldivrem.py
src/nmigen_gf/hdl/test/test_cldivrem.py

index f6ca4e83fe85887aae425ad4b523f9801c41500b..bff3676a2a069eeb6db3a6d8d6815bb71c659831 100644 (file)
@@ -9,9 +9,11 @@
 https://bugs.libre-soc.org/show_bug.cgi?id=784
 """
 
+from dataclasses import dataclass, field, fields
 from nmigen.hdl.ir import Elaboratable
-from nmigen.hdl.ast import Signal
+from nmigen.hdl.ast import Signal, Value
 from nmigen.hdl.dsl import Module
+from nmutil.singlepipe import ControlBase
 
 
 def equal_leading_zero_count_reference(a, b, width):
@@ -104,4 +106,241 @@ class EqualLeadingZeroCount(Elaboratable):
 
         return m
 
-# TODO: add CLDivRem
+
+@dataclass(frozen=True, unsafe_hash=True)
+class CLDivRemShape:
+    width: int
+    n_width: int
+
+    def __post_init__(self):
+        assert self.n_width >= self.width > 0
+
+    @property
+    def done_step(self):
+        return self.width
+
+    @property
+    def step_range(self):
+        return range(self.done_step + 1)
+
+
+@dataclass(frozen=True, eq=False)
+class CLDivRemState:
+    shape: CLDivRemShape
+    name: str
+    d: Signal = field(init=False)
+    r: Signal = field(init=False)
+    q: Signal = field(init=False)
+    step: Signal = field(init=False)
+
+    def __init__(self, shape, *, name=None, src_loc_at=0):
+        assert isinstance(shape, CLDivRemShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        d = Signal(2 * shape.width, name=f"{name}_d")
+        r = Signal(shape.n_width, name=f"{name}_r")
+        q = Signal(shape.width, name=f"{name}_q")
+        step = Signal(shape.width, name=f"{name}_step")
+        object.__setattr__(self, "shape", shape)
+        object.__setattr__(self, "name", name)
+        object.__setattr__(self, "d", d)
+        object.__setattr__(self, "r", r)
+        object.__setattr__(self, "q", q)
+        object.__setattr__(self, "step", step)
+
+    def eq(self, rhs):
+        assert isinstance(rhs, CLDivRemState)
+        for f in fields(CLDivRemState):
+            if f.name in ("shape", "name"):
+                continue
+            l = getattr(self, f.name)
+            r = getattr(rhs, f.name)
+            yield l.eq(r)
+
+    @staticmethod
+    def like(other, *, name=None, src_loc_at=0):
+        assert isinstance(other, CLDivRemState)
+        return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at)
+
+    @property
+    def done(self):
+        return self.will_be_done_after(steps=0)
+
+    def will_be_done_after(self, steps):
+        """ Returns True if this state will be done after
+            another `steps` passes through `set_to_next`."""
+        assert isinstance(steps, int) and steps >= 0
+        return self.step >= max(0, self.shape.done_step - steps)
+
+    def set_to_initial(self, m, n, d):
+        assert isinstance(m, Module)
+        m.d.comb += [
+            self.d.eq(Value.cast(d) << self.shape.width),
+            self.r.eq(n),
+            self.q.eq(0),
+            self.step.eq(0),
+        ]
+
+    def set_to_next(self, m, state_in):
+        assert isinstance(m, Module)
+        assert isinstance(state_in, CLDivRemState)
+        assert state_in.shape == self.shape
+        assert self is not state_in, "a.set_to_next(m, a) is not allowed"
+
+        equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width)
+        # can't name submodule since it would conflict if this function is
+        # called multiple times in a Module
+        m.submodules += equal_leading_zero_count
+
+        with m.If(state_in.done):
+            m.d.comb += self.eq(state_in)
+        with m.Else():
+            m.d.comb += [
+                self.step.eq(state_in.step + 1),
+                self.d.eq(state_in.d >> 1),
+                equal_leading_zero_count.a.eq(self.d),
+                equal_leading_zero_count.b.eq(state_in.r),
+            ]
+            d_top = self.d[self.shape.n_width:]
+            with m.If(equal_leading_zero_count.out & (d_top == 0)):
+                m.d.comb += [
+                    self.r.eq(state_in.r ^ self.d),
+                    self.q.eq((state_in.q << 1) | 1),
+                ]
+            with m.Else():
+                m.d.comb += [
+                    self.r.eq(state_in.r),
+                    self.q.eq(state_in.q << 1),
+                ]
+
+
+class CLDivRemInputData:
+    def __init__(self, shape):
+        assert isinstance(shape, CLDivRemShape)
+        self.shape = shape
+        self.n = Signal(shape.n_width)
+        self.d = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.n
+        yield self.d
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [
+            self.n.eq(rhs.n),
+            self.d.eq(rhs.d),
+        ]
+
+
+class CLDivRemOutputData:
+    def __init__(self, shape):
+        assert isinstance(shape, CLDivRemShape)
+        self.shape = shape
+        self.q = Signal(shape.width)
+        self.r = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.q
+        yield self.r
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [
+            self.q.eq(rhs.q),
+            self.r.eq(rhs.r),
+        ]
+
+
+class CLDivRemFSMStage(ControlBase):
+    """carry-less div/rem
+
+    Attributes:
+    shape: CLDivRemShape
+        the shape
+    steps_per_clock: int
+        number of steps that should be taken per clock cycle
+    in_valid: Signal()
+        input. true when the data inputs (`n` and `d`) are valid.
+        data transfer in occurs when `in_valid & in_ready`.
+    in_ready: Signal()
+        output. true when this FSM is ready to accept input.
+        data transfer in occurs when `in_valid & in_ready`.
+    n: Signal(shape.n_width)
+        numerator in, the value must be small enough that `q` and `r` don't
+        overflow. having `n_width == width` is sufficient.
+    d: Signal(shape.width)
+        denominator in, must be non-zero.
+    q: Signal(shape.width)
+        quotient out.
+    r: Signal(shape.width)
+        remainder out.
+    out_valid: Signal()
+        output. true when the data outputs (`q` and `r`) are valid
+        (or are junk because the inputs were out of range).
+        data transfer out occurs when `out_valid & out_ready`.
+    out_ready: Signal()
+        input. true when the output can be read.
+        data transfer out occurs when `out_valid & out_ready`.
+    """
+
+    def __init__(self, pspec, shape, *, steps_per_clock=4):
+        assert isinstance(shape, CLDivRemShape)
+        assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
+        self.shape = shape
+        self.steps_per_clock = steps_per_clock
+        self.pspec = pspec  # store now: used in ispec and ospec
+        super().__init__(stage=self)
+        self.empty = Signal(reset=1)
+        self.saved_state = CLDivRemState(shape)
+
+    def ispec(self):
+        return CLDivRemInputData(self.shape)
+
+    def ospec(self):
+        return CLDivRemOutputData(self.shape)
+
+    def setup(self, m, i):
+        pass
+
+    def elaborate(self, platform):
+        m = super().elaborate(platform)
+        i_data: CLDivRemInputData = self.p.i_data
+        o_data: CLDivRemOutputData = self.n.o_data
+
+        # TODO: handle cancellation
+
+        state_will_be_done = self.saved_state.will_be_done_after(
+            self.steps_per_clock)
+        m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done)
+        m.d.comb += self.p.o_ready.eq(self.empty)
+
+        def make_nc(i):
+            return CLDivRemState(self.shape, name=f"next_chain_{i}")
+        next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)]
+        for i in range(self.steps_per_clock):
+            next_chain[i + 1].set_to_next(m, next_chain[i])
+        m.d.sync += self.saved_state.eq(next_chain[-1])
+        m.d.comb += o_data.q.eq(next_chain[-1].q)
+        m.d.comb += o_data.r.eq(next_chain[-1].r)
+
+        with m.If(self.empty):
+            next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d)
+            with m.If(self.p.i_valid):
+                m.d.sync += self.empty.eq(0)
+        with m.Else():
+            m.d.comb += next_chain[0].eq(self.saved_state)
+            with m.If(self.n.i_ready & self.n.o_valid):
+                m.d.sync += self.empty.eq(1)
+
+        return m
+
+    def __iter__(self):
+        yield from self.p
+        yield from self.n
+
+    def ports(self):
+        return list(self)
index b6e4cfb7b19d023fccc02e14a1e03dca28655e08..fa93812df04b301cc21e9d7dd0fdb5e988f41555 100644 (file)
@@ -8,10 +8,13 @@ import unittest
 from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned
 from nmigen.hdl.dsl import Module
 from nmutil.formaltest import FHDLTestCase
-from nmigen_gf.hdl.cldivrem import (equal_leading_zero_count_reference,
+from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData,
+                                    CLDivRemOutputData, CLDivRemShape, CLDivRemState,
+                                    equal_leading_zero_count_reference,
                                     EqualLeadingZeroCount)
-from nmigen.sim import Delay
+from nmigen.sim import Delay, Tick
 from nmutil.sim_util import do_sim, hash_256
+from nmigen_gf.reference.cldivrem import cldivrem
 
 
 class TestEqualLeadingZeroCount(FHDLTestCase):
@@ -100,7 +103,175 @@ class TestEqualLeadingZeroCount(FHDLTestCase):
     def test_formal_3(self):
         self.tst_formal(3)
 
-# TODO: add TestCLDivRem
+
+class TestCLDivRemComb(FHDLTestCase):
+    def tst(self, shape, full):
+        assert isinstance(shape, CLDivRemShape)
+        m = Module()
+        n_in = Signal(shape.n_width)
+        d_in = Signal(shape.width)
+        states: "list[CLDivRemState]" = []
+        for i in shape.step_range:
+            states.append(CLDivRemState(shape, name=f"state_{i}"))
+            if i == 0:
+                states[i].set_to_initial(m, n=n_in, d=d_in)
+            else:
+                states[i].set_to_next(m, states[i - 1])
+
+        def case(n, d):
+            assert isinstance(n, int)
+            assert isinstance(d, int)
+            max_width = max(shape.width, shape.n_width)
+            if d != 0:
+                expected_q, expected_r = cldivrem(n, d, width=max_width)
+            else:
+                expected_q = expected_r = 0
+            with self.subTest(n=hex(n), d=hex(d),
+                              expected_q=hex(expected_q),
+                              expected_r=hex(expected_r)):
+                yield n_in.eq(n)
+                yield d_in.eq(d)
+                yield Delay(1e-6)
+                for i in shape.step_range:
+                    with self.subTest(i=i):
+                        done = yield states[i].done
+                        step = yield states[i].step
+                        self.assertEqual(done, i >= shape.done_step)
+                        self.assertEqual(step, i)
+                q = yield states[-1].q
+                r = yield states[-1].r
+                with self.subTest(q=hex(q), r=hex(r)):
+                    # only check results when inputs are valid
+                    if d != 0 and (expected_q >> shape.width) == 0:
+                        self.assertEqual(q, expected_q)
+                        self.assertEqual(r, expected_r)
+
+        def process():
+            if full:
+                for n in range(1 << shape.n_width):
+                    for d in range(1 << shape.width):
+                        yield from case(n, d)
+            else:
+                for i in range(100):
+                    n = hash_256(f"cldivrem comb n {i}")
+                    n = Const.normalize(n, unsigned(shape.n_width))
+                    d = hash_256(f"cldivrem comb d {i}")
+                    d = Const.normalize(d, unsigned(shape.width))
+                    yield from case(n, d)
+        with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def test_4(self):
+        self.tst(CLDivRemShape(width=4, n_width=4), full=True)
+
+    def test_8_by_4(self):
+        self.tst(CLDivRemShape(width=4, n_width=8), full=True)
+
+
+class TestCLDivRemFSM(FHDLTestCase):
+    def tst(self, shape, full, steps_per_clock):
+        assert isinstance(shape, CLDivRemShape)
+        assert isinstance(steps_per_clock, int) and steps_per_clock >= 1
+        pspec = {}
+        dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
+        i_data: CLDivRemInputData = dut.p.i_data
+        o_data: CLDivRemOutputData = dut.n.o_data
+        self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
+        self.assertEqual(i_data.d.shape(), unsigned(shape.width))
+        self.assertEqual(o_data.q.shape(), unsigned(shape.width))
+        self.assertEqual(o_data.r.shape(), unsigned(shape.width))
+
+        def case(n, d):
+            assert isinstance(n, int)
+            assert isinstance(d, int)
+            max_width = max(shape.width, shape.n_width)
+            if d != 0:
+                expected_q, expected_r = cldivrem(n, d, width=max_width)
+            else:
+                expected_q = expected_r = 0
+            with self.subTest(n=hex(n), d=hex(d),
+                              expected_q=hex(expected_q),
+                              expected_r=hex(expected_r)):
+                yield dut.p.i_valid.eq(0)
+                yield Tick()
+                yield i_data.n.eq(n)
+                yield i_data.d.eq(d)
+                yield dut.p.i_valid.eq(1)
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield Tick()
+                yield i_data.n.eq(-1)
+                yield i_data.d.eq(-1)
+                yield dut.p.i_valid.eq(0)
+                for i in range(steps_per_clock * 2, shape.done_step,
+                               steps_per_clock):
+                    yield Delay(0.1e-6)
+                    valid = yield dut.n.o_valid
+                    ready = yield dut.p.o_ready
+                    with self.subTest():
+                        self.assertFalse(valid)
+                        self.assertFalse(ready)
+                    yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertTrue(valid)
+                    self.assertFalse(ready)
+                q = yield o_data.q
+                r = yield o_data.r
+                with self.subTest(q=hex(q), r=hex(r)):
+                    # only check results when inputs are valid
+                    if d != 0 and (expected_q >> shape.width) == 0:
+                        self.assertEqual(q, expected_q)
+                        self.assertEqual(r, expected_r)
+                yield dut.n.i_ready.eq(1)
+                yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield dut.n.i_ready.eq(0)
+
+        def process():
+            if full:
+                for n in range(1 << shape.n_width):
+                    for d in range(1 << shape.width):
+                        yield from case(n, d)
+            else:
+                for i in range(100):
+                    n = hash_256(f"cldivrem fsm n {i}")
+                    n = Const.normalize(n, unsigned(shape.n_width))
+                    d = hash_256(f"cldivrem fsm d {i}")
+                    d = Const.normalize(d, unsigned(shape.width))
+                    yield from case(n, d)
+
+        with do_sim(self, dut, list(dut.ports())) as sim:
+            sim.add_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
+
+    def test_4_step_1(self):
+        self.tst(CLDivRemShape(width=4, n_width=4),
+                 full=True,
+                 steps_per_clock=1)
+
+    def test_4_step_2(self):
+        self.tst(CLDivRemShape(width=4, n_width=4),
+                 full=True,
+                 steps_per_clock=2)
+
+    def test_4_step_3(self):
+        self.tst(CLDivRemShape(width=4, n_width=4),
+                 full=True,
+                 steps_per_clock=3)
 
 
 if __name__ == "__main__":