From 49f1c51b676e692f1aa6964fa6e97f6af3464932 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 3 May 2022 23:30:59 -0700 Subject: [PATCH] implement CLDivRemFSMStage --- src/nmigen_gf/hdl/cldivrem.py | 243 +++++++++++++++++++++++- src/nmigen_gf/hdl/test/test_cldivrem.py | 177 ++++++++++++++++- 2 files changed, 415 insertions(+), 5 deletions(-) diff --git a/src/nmigen_gf/hdl/cldivrem.py b/src/nmigen_gf/hdl/cldivrem.py index f6ca4e8..bff3676 100644 --- a/src/nmigen_gf/hdl/cldivrem.py +++ b/src/nmigen_gf/hdl/cldivrem.py @@ -9,9 +9,11 @@ https://bugs.libre-soc.org/show_bug.cgi?id=784 """ +from dataclasses import dataclass, field, fields from nmigen.hdl.ir import Elaboratable -from nmigen.hdl.ast import Signal +from nmigen.hdl.ast import Signal, Value from nmigen.hdl.dsl import Module +from nmutil.singlepipe import ControlBase def equal_leading_zero_count_reference(a, b, width): @@ -104,4 +106,241 @@ class EqualLeadingZeroCount(Elaboratable): return m -# TODO: add CLDivRem + +@dataclass(frozen=True, unsafe_hash=True) +class CLDivRemShape: + width: int + n_width: int + + def __post_init__(self): + assert self.n_width >= self.width > 0 + + @property + def done_step(self): + return self.width + + @property + def step_range(self): + return range(self.done_step + 1) + + +@dataclass(frozen=True, eq=False) +class CLDivRemState: + shape: CLDivRemShape + name: str + d: Signal = field(init=False) + r: Signal = field(init=False) + q: Signal = field(init=False) + step: Signal = field(init=False) + + def __init__(self, shape, *, name=None, src_loc_at=0): + assert isinstance(shape, CLDivRemShape) + if name is None: + name = Signal(src_loc_at=1 + src_loc_at).name + assert isinstance(name, str) + d = Signal(2 * shape.width, name=f"{name}_d") + r = Signal(shape.n_width, name=f"{name}_r") + q = Signal(shape.width, name=f"{name}_q") + step = Signal(shape.width, name=f"{name}_step") + object.__setattr__(self, "shape", shape) + object.__setattr__(self, "name", name) + object.__setattr__(self, "d", d) + object.__setattr__(self, "r", r) + object.__setattr__(self, "q", q) + object.__setattr__(self, "step", step) + + def eq(self, rhs): + assert isinstance(rhs, CLDivRemState) + for f in fields(CLDivRemState): + if f.name in ("shape", "name"): + continue + l = getattr(self, f.name) + r = getattr(rhs, f.name) + yield l.eq(r) + + @staticmethod + def like(other, *, name=None, src_loc_at=0): + assert isinstance(other, CLDivRemState) + return CLDivRemState(other.shape, name=name, src_loc_at=1 + src_loc_at) + + @property + def done(self): + return self.will_be_done_after(steps=0) + + def will_be_done_after(self, steps): + """ Returns True if this state will be done after + another `steps` passes through `set_to_next`.""" + assert isinstance(steps, int) and steps >= 0 + return self.step >= max(0, self.shape.done_step - steps) + + def set_to_initial(self, m, n, d): + assert isinstance(m, Module) + m.d.comb += [ + self.d.eq(Value.cast(d) << self.shape.width), + self.r.eq(n), + self.q.eq(0), + self.step.eq(0), + ] + + def set_to_next(self, m, state_in): + assert isinstance(m, Module) + assert isinstance(state_in, CLDivRemState) + assert state_in.shape == self.shape + assert self is not state_in, "a.set_to_next(m, a) is not allowed" + + equal_leading_zero_count = EqualLeadingZeroCount(self.shape.n_width) + # can't name submodule since it would conflict if this function is + # called multiple times in a Module + m.submodules += equal_leading_zero_count + + with m.If(state_in.done): + m.d.comb += self.eq(state_in) + with m.Else(): + m.d.comb += [ + self.step.eq(state_in.step + 1), + self.d.eq(state_in.d >> 1), + equal_leading_zero_count.a.eq(self.d), + equal_leading_zero_count.b.eq(state_in.r), + ] + d_top = self.d[self.shape.n_width:] + with m.If(equal_leading_zero_count.out & (d_top == 0)): + m.d.comb += [ + self.r.eq(state_in.r ^ self.d), + self.q.eq((state_in.q << 1) | 1), + ] + with m.Else(): + m.d.comb += [ + self.r.eq(state_in.r), + self.q.eq(state_in.q << 1), + ] + + +class CLDivRemInputData: + def __init__(self, shape): + assert isinstance(shape, CLDivRemShape) + self.shape = shape + self.n = Signal(shape.n_width) + self.d = Signal(shape.width) + + def __iter__(self): + """ Get member signals. """ + yield self.n + yield self.d + + def eq(self, rhs): + """ Assign member signals. """ + return [ + self.n.eq(rhs.n), + self.d.eq(rhs.d), + ] + + +class CLDivRemOutputData: + def __init__(self, shape): + assert isinstance(shape, CLDivRemShape) + self.shape = shape + self.q = Signal(shape.width) + self.r = Signal(shape.width) + + def __iter__(self): + """ Get member signals. """ + yield self.q + yield self.r + + def eq(self, rhs): + """ Assign member signals. """ + return [ + self.q.eq(rhs.q), + self.r.eq(rhs.r), + ] + + +class CLDivRemFSMStage(ControlBase): + """carry-less div/rem + + Attributes: + shape: CLDivRemShape + the shape + steps_per_clock: int + number of steps that should be taken per clock cycle + in_valid: Signal() + input. true when the data inputs (`n` and `d`) are valid. + data transfer in occurs when `in_valid & in_ready`. + in_ready: Signal() + output. true when this FSM is ready to accept input. + data transfer in occurs when `in_valid & in_ready`. + n: Signal(shape.n_width) + numerator in, the value must be small enough that `q` and `r` don't + overflow. having `n_width == width` is sufficient. + d: Signal(shape.width) + denominator in, must be non-zero. + q: Signal(shape.width) + quotient out. + r: Signal(shape.width) + remainder out. + out_valid: Signal() + output. true when the data outputs (`q` and `r`) are valid + (or are junk because the inputs were out of range). + data transfer out occurs when `out_valid & out_ready`. + out_ready: Signal() + input. true when the output can be read. + data transfer out occurs when `out_valid & out_ready`. + """ + + def __init__(self, pspec, shape, *, steps_per_clock=4): + assert isinstance(shape, CLDivRemShape) + assert isinstance(steps_per_clock, int) and steps_per_clock >= 1 + self.shape = shape + self.steps_per_clock = steps_per_clock + self.pspec = pspec # store now: used in ispec and ospec + super().__init__(stage=self) + self.empty = Signal(reset=1) + self.saved_state = CLDivRemState(shape) + + def ispec(self): + return CLDivRemInputData(self.shape) + + def ospec(self): + return CLDivRemOutputData(self.shape) + + def setup(self, m, i): + pass + + def elaborate(self, platform): + m = super().elaborate(platform) + i_data: CLDivRemInputData = self.p.i_data + o_data: CLDivRemOutputData = self.n.o_data + + # TODO: handle cancellation + + state_will_be_done = self.saved_state.will_be_done_after( + self.steps_per_clock) + m.d.comb += self.n.o_valid.eq(~self.empty & state_will_be_done) + m.d.comb += self.p.o_ready.eq(self.empty) + + def make_nc(i): + return CLDivRemState(self.shape, name=f"next_chain_{i}") + next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)] + for i in range(self.steps_per_clock): + next_chain[i + 1].set_to_next(m, next_chain[i]) + m.d.sync += self.saved_state.eq(next_chain[-1]) + m.d.comb += o_data.q.eq(next_chain[-1].q) + m.d.comb += o_data.r.eq(next_chain[-1].r) + + with m.If(self.empty): + next_chain[0].set_to_initial(m, n=i_data.n, d=i_data.d) + with m.If(self.p.i_valid): + m.d.sync += self.empty.eq(0) + with m.Else(): + m.d.comb += next_chain[0].eq(self.saved_state) + with m.If(self.n.i_ready & self.n.o_valid): + m.d.sync += self.empty.eq(1) + + return m + + def __iter__(self): + yield from self.p + yield from self.n + + def ports(self): + return list(self) diff --git a/src/nmigen_gf/hdl/test/test_cldivrem.py b/src/nmigen_gf/hdl/test/test_cldivrem.py index b6e4cfb..fa93812 100644 --- a/src/nmigen_gf/hdl/test/test_cldivrem.py +++ b/src/nmigen_gf/hdl/test/test_cldivrem.py @@ -8,10 +8,13 @@ import unittest from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned from nmigen.hdl.dsl import Module from nmutil.formaltest import FHDLTestCase -from nmigen_gf.hdl.cldivrem import (equal_leading_zero_count_reference, +from nmigen_gf.hdl.cldivrem import (CLDivRemFSMStage, CLDivRemInputData, + CLDivRemOutputData, CLDivRemShape, CLDivRemState, + equal_leading_zero_count_reference, EqualLeadingZeroCount) -from nmigen.sim import Delay +from nmigen.sim import Delay, Tick from nmutil.sim_util import do_sim, hash_256 +from nmigen_gf.reference.cldivrem import cldivrem class TestEqualLeadingZeroCount(FHDLTestCase): @@ -100,7 +103,175 @@ class TestEqualLeadingZeroCount(FHDLTestCase): def test_formal_3(self): self.tst_formal(3) -# TODO: add TestCLDivRem + +class TestCLDivRemComb(FHDLTestCase): + def tst(self, shape, full): + assert isinstance(shape, CLDivRemShape) + m = Module() + n_in = Signal(shape.n_width) + d_in = Signal(shape.width) + states: "list[CLDivRemState]" = [] + for i in shape.step_range: + states.append(CLDivRemState(shape, name=f"state_{i}")) + if i == 0: + states[i].set_to_initial(m, n=n_in, d=d_in) + else: + states[i].set_to_next(m, states[i - 1]) + + def case(n, d): + assert isinstance(n, int) + assert isinstance(d, int) + max_width = max(shape.width, shape.n_width) + if d != 0: + expected_q, expected_r = cldivrem(n, d, width=max_width) + else: + expected_q = expected_r = 0 + with self.subTest(n=hex(n), d=hex(d), + expected_q=hex(expected_q), + expected_r=hex(expected_r)): + yield n_in.eq(n) + yield d_in.eq(d) + yield Delay(1e-6) + for i in shape.step_range: + with self.subTest(i=i): + done = yield states[i].done + step = yield states[i].step + self.assertEqual(done, i >= shape.done_step) + self.assertEqual(step, i) + q = yield states[-1].q + r = yield states[-1].r + with self.subTest(q=hex(q), r=hex(r)): + # only check results when inputs are valid + if d != 0 and (expected_q >> shape.width) == 0: + self.assertEqual(q, expected_q) + self.assertEqual(r, expected_r) + + def process(): + if full: + for n in range(1 << shape.n_width): + for d in range(1 << shape.width): + yield from case(n, d) + else: + for i in range(100): + n = hash_256(f"cldivrem comb n {i}") + n = Const.normalize(n, unsigned(shape.n_width)) + d = hash_256(f"cldivrem comb d {i}") + d = Const.normalize(d, unsigned(shape.width)) + yield from case(n, d) + with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim: + sim.add_process(process) + sim.run() + + def test_4(self): + self.tst(CLDivRemShape(width=4, n_width=4), full=True) + + def test_8_by_4(self): + self.tst(CLDivRemShape(width=4, n_width=8), full=True) + + +class TestCLDivRemFSM(FHDLTestCase): + def tst(self, shape, full, steps_per_clock): + assert isinstance(shape, CLDivRemShape) + assert isinstance(steps_per_clock, int) and steps_per_clock >= 1 + pspec = {} + dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock) + i_data: CLDivRemInputData = dut.p.i_data + o_data: CLDivRemOutputData = dut.n.o_data + self.assertEqual(i_data.n.shape(), unsigned(shape.n_width)) + self.assertEqual(i_data.d.shape(), unsigned(shape.width)) + self.assertEqual(o_data.q.shape(), unsigned(shape.width)) + self.assertEqual(o_data.r.shape(), unsigned(shape.width)) + + def case(n, d): + assert isinstance(n, int) + assert isinstance(d, int) + max_width = max(shape.width, shape.n_width) + if d != 0: + expected_q, expected_r = cldivrem(n, d, width=max_width) + else: + expected_q = expected_r = 0 + with self.subTest(n=hex(n), d=hex(d), + expected_q=hex(expected_q), + expected_r=hex(expected_r)): + yield dut.p.i_valid.eq(0) + yield Tick() + yield i_data.n.eq(n) + yield i_data.d.eq(d) + yield dut.p.i_valid.eq(1) + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertFalse(valid) + self.assertTrue(ready) + yield Tick() + yield i_data.n.eq(-1) + yield i_data.d.eq(-1) + yield dut.p.i_valid.eq(0) + for i in range(steps_per_clock * 2, shape.done_step, + steps_per_clock): + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertFalse(valid) + self.assertFalse(ready) + yield Tick() + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertTrue(valid) + self.assertFalse(ready) + q = yield o_data.q + r = yield o_data.r + with self.subTest(q=hex(q), r=hex(r)): + # only check results when inputs are valid + if d != 0 and (expected_q >> shape.width) == 0: + self.assertEqual(q, expected_q) + self.assertEqual(r, expected_r) + yield dut.n.i_ready.eq(1) + yield Tick() + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertFalse(valid) + self.assertTrue(ready) + yield dut.n.i_ready.eq(0) + + def process(): + if full: + for n in range(1 << shape.n_width): + for d in range(1 << shape.width): + yield from case(n, d) + else: + for i in range(100): + n = hash_256(f"cldivrem fsm n {i}") + n = Const.normalize(n, unsigned(shape.n_width)) + d = hash_256(f"cldivrem fsm d {i}") + d = Const.normalize(d, unsigned(shape.width)) + yield from case(n, d) + + with do_sim(self, dut, list(dut.ports())) as sim: + sim.add_process(process) + sim.add_clock(1e-6) + sim.run() + + def test_4_step_1(self): + self.tst(CLDivRemShape(width=4, n_width=4), + full=True, + steps_per_clock=1) + + def test_4_step_2(self): + self.tst(CLDivRemShape(width=4, n_width=4), + full=True, + steps_per_clock=2) + + def test_4_step_3(self): + self.tst(CLDivRemShape(width=4, n_width=4), + full=True, + steps_per_clock=3) if __name__ == "__main__": -- 2.30.2