From 4ee201e6cae8475621647f7b3b9839292ed0b46f Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 5 May 2022 20:10:32 -0700 Subject: [PATCH] split step counter into clock and substep this allows substep to be completely optimized away by yosys for CLDivRemFSMStage --- src/nmigen_gf/hdl/cldivrem.py | 189 +++++++++++++++++------- src/nmigen_gf/hdl/test/test_cldivrem.py | 147 +++++++++++------- 2 files changed, 229 insertions(+), 107 deletions(-) diff --git a/src/nmigen_gf/hdl/cldivrem.py b/src/nmigen_gf/hdl/cldivrem.py index 48105fa..31650e5 100644 --- a/src/nmigen_gf/hdl/cldivrem.py +++ b/src/nmigen_gf/hdl/cldivrem.py @@ -10,13 +10,13 @@ https://bugs.libre-soc.org/show_bug.cgi?id=784 """ from dataclasses import dataclass, field, fields -from nmigen.hdl.ast import Signal, Value +from nmigen.hdl.ast import Signal, Value, Assert from nmigen.hdl.dsl import Module from nmutil.singlepipe import ControlBase from nmutil.clz import CLZ, clz -def cldivrem_shifting(n, d, width): +def cldivrem_shifting(n, d, shape): """ Carry-less Division and Remainder based on shifting at start and end allowing us to get away with checking a single bit each iteration rather than checking for equal degrees every iteration. @@ -24,57 +24,104 @@ def cldivrem_shifting(n, d, width): each input/output. Returns a tuple `q, r` of the quotient and remainder. """ - assert isinstance(width, int) and width >= 1 - assert isinstance(n, int) and 0 <= n < 1 << width - assert isinstance(d, int) and 0 <= d < 1 << width + assert isinstance(shape, CLDivRemShape) + assert isinstance(n, int) and 0 <= n < 1 << shape.width + assert isinstance(d, int) and 0 <= d < 1 << shape.width assert d != 0, "TODO: decide what happens on division by zero" - shape = CLDivRemShape(width) - - # `clz(d, width)`, but maxes out at `width - 1` instead of `width` in - # order to both fit in `shape.shift_width` bits and to not shift by more - # than needed. - shift = clz(d >> 1, width - 1) - assert 0 <= shift < 1 << shape.shift_width, "shift overflow" - d <<= shift - assert 0 <= d < 1 << shape.d_width, "d overflow" - r = n << shift - assert 0 <= r < 1 << shape.r_width, "r overflow" - q = 0 - for step in range(width): + # declare locals so nonlocal works + r = q = shift = clock = substep = NotImplemented + + # functions match up to HDL parts: + + def set_to_initial(): + nonlocal d, r, q, clock, substep, shift + # `clz(d, shape.width)`, but maxes out at `shape.width - 1` instead of + # `shape.width` in order to both fit in `shape.shift_width` bits and + # to not shift by more than needed. + shift = clz(d >> 1, shape.width - 1) + assert 0 <= shift < 1 << shape.shift_width, "shift overflow" + d <<= shift + assert 0 <= d < 1 << shape.d_width, "d overflow" + r = n << shift + assert 0 <= r < 1 << shape.r_width, "r overflow" + q = 0 + clock = 0 + substep = 0 + + def done(): + return clock == shape.done_clock + + def set_to_next(): + nonlocal r, q, clock, substep + substep += 1 + substep %= shape.steps_per_clock + if done(): + return + elif substep == 0: + clock += 1 + if clock == shape.width // shape.steps_per_clock \ + and substep >= shape.width % shape.steps_per_clock: + clock = shape.done_clock q <<= 1 r <<= 1 - if r >> (width * 2 - 1) != 0: - r ^= d << width + if r >> (shape.width * 2 - 1) != 0: + r ^= d << shape.width q |= 1 assert 0 <= q < 1 << shape.q_width, "q overflow" assert 0 <= r < 1 << shape.r_width, "r overflow" - r >>= width - r >>= shift - return q, r + + def get_output(): + return q, (r >> shape.width) >> shift + + set_to_initial() + + # one clock-cycle per outer loop + while not done(): + for expected_substep in range(shape.steps_per_clock): + assert substep == expected_substep + set_to_next() + + return get_output() @dataclass(frozen=True, unsafe_hash=True) class CLDivRemShape: width: int + """bit-width of each of the carry-less div/rem inputs and outputs""" + + steps_per_clock: int = 8 + """number of steps that should be taken per clock cycle""" def __post_init__(self): assert isinstance(self.width, int) and self.width >= 1, "invalid width" + assert (isinstance(self.steps_per_clock, int) + and self.steps_per_clock >= 1), "invalid steps_per_clock" @property - def done_step(self): - """the step number when iteration is finished - -- the largest `CLDivRemState.step` will get + def done_clock(self): + """the clock tick number when iteration is finished + -- the largest `CLDivRemState.clock` will get """ - return self.width + if self.width % self.steps_per_clock == 0: + return self.width // self.steps_per_clock + return self.width // self.steps_per_clock + 1 @property - def step_range(self): - """the range that `CLDivRemState.step` will fall in. + def clock_range(self): + """the range that `CLDivRemState.clock` will fall in. returns: range """ - return range(self.done_step + 1) + return range(self.done_clock + 1) + + @property + def substep_range(self): + """the range that `CLDivRemState.substep` will fall in. + + returns: range + """ + return range(self.steps_per_clock) @property def d_width(self): @@ -101,7 +148,8 @@ class CLDivRemShape: class CLDivRemState: shape: CLDivRemShape name: str - step: Signal = field(init=False) + clock: Signal = field(init=False) + substep: Signal = field(init=False) d: Signal = field(init=False) r: Signal = field(init=False) q: Signal = field(init=False) @@ -112,14 +160,16 @@ class CLDivRemState: if name is None: name = Signal(src_loc_at=1 + src_loc_at).name assert isinstance(name, str) - step = Signal(shape.step_range, name=f"{name}_step") + clock = Signal(shape.clock_range, name=f"{name}_clock") + substep = Signal(shape.substep_range, name=f"{name}_substep", reset=0) d = Signal(shape.d_width, name=f"{name}_d") r = Signal(shape.r_width, name=f"{name}_r") q = Signal(shape.q_width, name=f"{name}_q") shift = Signal(shape.shift_width, name=f"{name}_shift") object.__setattr__(self, "shape", shape) object.__setattr__(self, "name", name) - object.__setattr__(self, "step", step) + object.__setattr__(self, "clock", clock) + object.__setattr__(self, "substep", substep) object.__setattr__(self, "d", d) object.__setattr__(self, "r", r) object.__setattr__(self, "q", q) @@ -141,13 +191,7 @@ class CLDivRemState: @property def done(self): - return self.will_be_done_after(steps=0) - - def will_be_done_after(self, steps): - """ Returns True if this state will be done after - another `steps` passes through `set_to_next`.""" - assert isinstance(steps, int) and steps >= 0 - return self.step >= max(0, self.shape.done_step - steps) + return self.clock == self.shape.done_clock def get_output(self): return self.q, (self.r >> self.shape.width) >> self.shift @@ -168,21 +212,51 @@ class CLDivRemState: self.d.eq(d << self.shift), self.r.eq(n << self.shift), self.q.eq(0), - self.step.eq(0), + self.clock.eq(0), + self.substep.eq(0), ] + def eq_but_zero_substep(self, rhs, do_assert): + assert isinstance(rhs, CLDivRemState) + for f in fields(CLDivRemState): + if f.name in ("shape", "name"): + continue + l = getattr(self, f.name) + r = getattr(rhs, f.name) + if f.name == "substep": + if do_assert: + yield Assert(r == 0) + r = 0 + yield l.eq(r) + def set_to_next(self, m, state_in): assert isinstance(m, Module) assert isinstance(state_in, CLDivRemState) assert state_in.shape == self.shape assert self is not state_in, "a.set_to_next(m, a) is not allowed" width = self.shape.width + substep_wraps = state_in.substep >= self.shape.steps_per_clock - 1 + with m.If(substep_wraps): + m.d.comb += self.substep.eq(0) + with m.Else(): + m.d.comb += self.substep.eq(state_in.substep + 1) with m.If(state_in.done): - m.d.comb += self.eq(state_in) + m.d.comb += [ + self.clock.eq(state_in.clock), + self.d.eq(state_in.d), + self.r.eq(state_in.r), + self.q.eq(state_in.q), + self.shift.eq(state_in.shift), + ] with m.Else(): + clock = state_in.clock + substep_wraps + with m.If((clock == width // self.shape.steps_per_clock) + & (self.substep >= width % self.shape.steps_per_clock)): + m.d.comb += self.clock.eq(self.shape.done_clock) + with m.Else(): + m.d.comb += self.clock.eq(clock) m.d.comb += [ - self.step.eq(state_in.step + 1), self.d.eq(state_in.d), self.shift.eq(state_in.shift), ] @@ -239,6 +313,12 @@ class CLDivRemOutputData: self.r.eq(rhs.r), ] + def eq_output(self, state): + assert isinstance(state, CLDivRemState) + assert state.shape == self.shape + q, r = state.get_output() + return [self.q.eq(q), self.r.eq(r)] + class CLDivRemFSMStage(ControlBase): """carry-less div/rem @@ -246,8 +326,6 @@ class CLDivRemFSMStage(ControlBase): Attributes: shape: CLDivRemShape the shape - steps_per_clock: int - number of steps that should be taken per clock cycle pspec: pipe-spec empty: Signal() @@ -256,11 +334,9 @@ class CLDivRemFSMStage(ControlBase): the saved state that is currently being worked on. """ - def __init__(self, pspec, shape, *, steps_per_clock=8): + def __init__(self, pspec, shape): assert isinstance(shape, CLDivRemShape) - assert isinstance(steps_per_clock, int) and steps_per_clock >= 1 self.shape = shape - self.steps_per_clock = steps_per_clock self.pspec = pspec # store now: used in ispec and ospec super().__init__(stage=self) self.empty = Signal(reset=1) @@ -279,6 +355,7 @@ class CLDivRemFSMStage(ControlBase): m = super().elaborate(platform) i_data: CLDivRemInputData = self.p.i_data o_data: CLDivRemOutputData = self.n.o_data + steps_per_clock = self.shape.steps_per_clock # TODO: handle cancellation @@ -287,22 +364,24 @@ class CLDivRemFSMStage(ControlBase): def make_nc(i): return CLDivRemState(self.shape, name=f"next_chain_{i}") - next_chain = [make_nc(i) for i in range(self.steps_per_clock + 1)] - for i in range(self.steps_per_clock): + next_chain = [make_nc(i) for i in range(steps_per_clock + 1)] + for i in range(steps_per_clock): next_chain[i + 1].set_to_next(m, next_chain[i]) m.d.comb += next_chain[0].eq(self.saved_state) - out_q, out_r = self.saved_state.get_output() - m.d.comb += o_data.q.eq(out_q) - m.d.comb += o_data.r.eq(out_r) + m.d.comb += o_data.eq_output(self.saved_state) initial_state = CLDivRemState(self.shape) initial_state.set_to_initial(m, n=i_data.n, d=i_data.d) + do_assert = platform == "formal" + with m.If(self.empty): - m.d.sync += self.saved_state.eq(initial_state) + m.d.sync += self.saved_state.eq_but_zero_substep(initial_state, + do_assert) with m.If(self.p.i_valid): m.d.sync += self.empty.eq(0) with m.Else(): - m.d.sync += self.saved_state.eq(next_chain[-1]) + m.d.sync += self.saved_state.eq_but_zero_substep(next_chain[-1], + do_assert) with m.If(self.n.i_ready & self.n.o_valid): m.d.sync += self.empty.eq(1) return m diff --git a/src/nmigen_gf/hdl/test/test_cldivrem.py b/src/nmigen_gf/hdl/test/test_cldivrem.py index 438b547..efcd23e 100644 --- a/src/nmigen_gf/hdl/test/test_cldivrem.py +++ b/src/nmigen_gf/hdl/test/test_cldivrem.py @@ -17,39 +17,68 @@ from nmigen_gf.reference.cldivrem import cldivrem class TestCLDivRemShifting(FHDLTestCase): - def tst(self, width, full): + def tst(self, shape, full): + assert isinstance(shape, CLDivRemShape) + def case(n, d): assert isinstance(n, int) assert isinstance(d, int) if d != 0: - expected_q, expected_r = cldivrem(n, d, width=width) - q, r = cldivrem_shifting(n, d, width=width) + expected_q, expected_r = cldivrem(n, d, width=shape.width) + q, r = cldivrem_shifting(n, d, shape) else: expected_q = expected_r = 0 q = r = 0 - with self.subTest(n=hex(n), d=hex(d), - expected_q=hex(expected_q), + with self.subTest(expected_q=hex(expected_q), expected_r=hex(expected_r), q=hex(q), r=hex(r)): self.assertEqual(expected_q, q) self.assertEqual(expected_r, r) if full: - for n in range(1 << width): - for d in range(1 << width): - case(n, d) + for n in range(1 << shape.width): + for d in range(1 << shape.width): + with self.subTest(n=hex(n), d=hex(d)): + case(n, d) else: for i in range(100): n = hash_256(f"cldivrem comb n {i}") - n = Const.normalize(n, unsigned(width)) + n = Const.normalize(n, unsigned(shape.width)) d = hash_256(f"cldivrem comb d {i}") - d = Const.normalize(d, unsigned(width)) + d = Const.normalize(d, unsigned(shape.width)) case(n, d) - def test_6(self): - self.tst(6, full=True) + def test_6_step_1(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=True) + + def test_6_step_2(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=True) + + def test_6_step_3(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=3), full=True) - def test_64(self): - self.tst(64, full=False) + def test_6_step_4(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=4), full=True) + + def test_6_step_6(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=True) + + def test_6_step_10(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=10), full=True) + + def test_64_step_1(self): + self.tst(CLDivRemShape(width=64, steps_per_clock=1), full=False) + + def test_64_step_2(self): + self.tst(CLDivRemShape(width=64, steps_per_clock=2), full=False) + + def test_64_step_3(self): + self.tst(CLDivRemShape(width=64, steps_per_clock=3), full=False) + + def test_64_step_4(self): + self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False) + + def test_64_step_8(self): + self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False) class TestCLDivRemComb(FHDLTestCase): @@ -62,7 +91,7 @@ class TestCLDivRemComb(FHDLTestCase): q_out = Signal(width) r_out = Signal(width) states: "list[CLDivRemState]" = [] - for i in shape.step_range: + for i in range(shape.width + 1): states.append(CLDivRemState(shape, name=f"state_{i}")) if i == 0: states[i].set_to_initial(m, n=n_in, d=d_in) @@ -75,7 +104,7 @@ class TestCLDivRemComb(FHDLTestCase): assert isinstance(n, int) assert isinstance(d, int) if d != 0: - expected_q, expected_r = cldivrem_shifting(n, d, width) + expected_q, expected_r = cldivrem_shifting(n, d, shape) else: expected_q = expected_r = 0 with self.subTest(n=hex(n), d=hex(d), @@ -84,12 +113,17 @@ class TestCLDivRemComb(FHDLTestCase): yield n_in.eq(n) yield d_in.eq(d) yield Delay(1e-6) - for i in shape.step_range: + for i, state in enumerate(states): with self.subTest(i=i): - done = yield states[i].done - step = yield states[i].step - self.assertEqual(done, i >= shape.done_step) - self.assertEqual(step, i) + done = yield state.done + substep = yield state.substep + clock = yield state.clock + self.assertEqual(done, i >= shape.width) + if i % shape.steps_per_clock == 0: + self.assertEqual(substep, 0) + if i < shape.width: + self.assertEqual(substep + + clock * shape.steps_per_clock, i) q = yield q_out r = yield r_out with self.subTest(q=hex(q), r=hex(r)): @@ -114,22 +148,45 @@ class TestCLDivRemComb(FHDLTestCase): sim.add_process(process) sim.run() - def test_4(self): - self.tst(CLDivRemShape(width=4), full=True) + def test_4_step_1(self): + self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True) + + def test_4_step_2(self): + self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True) + + def test_4_step_3(self): + self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True) + + def test_4_step_4(self): + self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True) - def test_6(self): - self.tst(CLDivRemShape(width=6), full=True) + def test_6_step_1(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=1), full=False) - def test_8(self): - self.tst(CLDivRemShape(width=8), full=False) + def test_6_step_2(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=2), full=False) + + def test_6_step_6(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=6), full=False) + + def test_6_step_8(self): + self.tst(CLDivRemShape(width=6, steps_per_clock=8), full=False) + + def test_8_step_1(self): + self.tst(CLDivRemShape(width=8, steps_per_clock=1), full=False) + + def test_8_step_4(self): + self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False) + + def test_8_step_8(self): + self.tst(CLDivRemShape(width=8, steps_per_clock=8), full=False) class TestCLDivRemFSM(FHDLTestCase): - def tst(self, shape, full, steps_per_clock): + def tst(self, shape, full): assert isinstance(shape, CLDivRemShape) - assert isinstance(steps_per_clock, int) and steps_per_clock >= 1 pspec = {} - dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock) + dut = CLDivRemFSMStage(pspec, shape) i_data: CLDivRemInputData = dut.p.i_data o_data: CLDivRemOutputData = dut.n.o_data self.assertEqual(i_data.n.shape(), unsigned(shape.width)) @@ -162,7 +219,7 @@ class TestCLDivRemFSM(FHDLTestCase): yield i_data.n.eq(-1) yield i_data.d.eq(-1) yield dut.p.i_valid.eq(0) - for step in range(0, shape.done_step, steps_per_clock): + for step in range(shape.done_clock): yield Delay(0.1e-6) valid = yield dut.n.o_valid ready = yield dut.p.o_ready @@ -212,39 +269,25 @@ class TestCLDivRemFSM(FHDLTestCase): sim.run() def test_4_step_1(self): - self.tst(CLDivRemShape(width=4), - full=True, - steps_per_clock=1) + self.tst(CLDivRemShape(width=4, steps_per_clock=1), full=True) def test_4_step_2(self): - self.tst(CLDivRemShape(width=4), - full=True, - steps_per_clock=2) + self.tst(CLDivRemShape(width=4, steps_per_clock=2), full=True) def test_4_step_3(self): - self.tst(CLDivRemShape(width=4), - full=True, - steps_per_clock=3) + self.tst(CLDivRemShape(width=4, steps_per_clock=3), full=True) def test_4_step_4(self): - self.tst(CLDivRemShape(width=4), - full=True, - steps_per_clock=4) + self.tst(CLDivRemShape(width=4, steps_per_clock=4), full=True) def test_8_step_4(self): - self.tst(CLDivRemShape(width=8), - full=False, - steps_per_clock=4) + self.tst(CLDivRemShape(width=8, steps_per_clock=4), full=False) def test_64_step_4(self): - self.tst(CLDivRemShape(width=64), - full=False, - steps_per_clock=4) + self.tst(CLDivRemShape(width=64, steps_per_clock=4), full=False) def test_64_step_8(self): - self.tst(CLDivRemShape(width=64), - full=False, - steps_per_clock=8) + self.tst(CLDivRemShape(width=64, steps_per_clock=8), full=False) if __name__ == "__main__": -- 2.30.2