From fa3842779ffec3ebed85d9118ec56cffe17fe707 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 15 May 2024 01:03:12 -0700 Subject: [PATCH] hdl/gfbmadd: GFBMAddFSMStage works! --- src/nmigen_gf/hdl/gfbmadd.py | 370 +++++++++++++++++++++++-- src/nmigen_gf/hdl/test/test_gfbmadd.py | 102 ++++++- 2 files changed, 451 insertions(+), 21 deletions(-) diff --git a/src/nmigen_gf/hdl/gfbmadd.py b/src/nmigen_gf/hdl/gfbmadd.py index 7a29974..b188006 100644 --- a/src/nmigen_gf/hdl/gfbmadd.py +++ b/src/nmigen_gf/hdl/gfbmadd.py @@ -9,13 +9,14 @@ https://bugs.libre-soc.org/show_bug.cgi?id=785 """ -from nmigen.hdl.ast import Signal -from nmigen.hdl.ir import Elaboratable +from nmigen.hdl.ast import Shape, Signal, Value, unsigned, Mux, Const +from nmutil.singlepipe import ControlBase from nmigen.hdl.dsl import Module from nmutil.plain_data import plain_data, fields from nmigen_gf.reference.state import ST from nmigen_gf.reference.decode_reducing_polynomial import \ decode_reducing_polynomial +from nmigen_gf.hdl.decode_reducing_polynomial import DecodeReducingPolynomial from nmutil.clz import clz, CLZ @@ -55,6 +56,18 @@ class GFBMAddShape: def acc_width(self): return self.width + self.reduce_input_width + @property + def step_count(self): + return self.reduce_step_end + + @property + def step_shape(self): + return Shape.cast(range(self.step_count + 1)) + + @property + def sh_rpoly_clz_shape(self): + return Shape.cast(range(self.acc_width)) + @plain_data(frozen=True, unsafe_hash=True) class PyGFBMAddState: @@ -84,27 +97,39 @@ class PyGFBMAddState: return clz(self.sh_rpoly, self.shape.acc_width) @property - def next_state(self): + def _mul_next_state(self): # type: () -> PyGFBMAddState factor1 = self.factor1 acc = self.acc step = self.step - if self.step < self.shape.mul_step_end: - if factor1 & 1: - acc ^= self.factor2 << self.shape.width - acc >>= 1 - factor1 >>= 1 - if step == self.shape.mul_step_last: - acc ^= self.term - step += 1 - elif self.step < self.shape.reduce_step_end: - acc_clz = clz(acc, self.shape.acc_width) - if acc_clz == self.sh_rpoly_clz: - acc ^= self.sh_rpoly - acc <<= 1 - step += 1 - else: - return self + if factor1 & 1: + acc ^= self.factor2 << self.shape.width + acc >>= 1 + factor1 >>= 1 + if step == self.shape.mul_step_last: + acc ^= self.term + step += 1 + return PyGFBMAddState( + shape=self.shape, + rpoly=self.rpoly, + factor1=factor1, + factor2=self.factor2, + term=self.term, + acc=acc, + step=step, + ) + + @property + def _reduce_next_state(self): + # type: () -> PyGFBMAddState + factor1 = self.factor1 + acc = self.acc + step = self.step + acc_clz = clz(acc, self.shape.acc_width) + if acc_clz == self.sh_rpoly_clz: + acc ^= self.sh_rpoly + acc <<= 1 + step += 1 return PyGFBMAddState( shape=self.shape, rpoly=self.rpoly, @@ -115,6 +140,16 @@ class PyGFBMAddState: step=step, ) + @property + def next_state(self): + # type: () -> PyGFBMAddState + if self.step < self.shape.mul_step_end: + return self._mul_next_state + elif self.step < self.shape.reduce_step_end: + return self._reduce_next_state + else: + return self + @property def output(self): # type: () -> None | int @@ -158,3 +193,300 @@ def py_gfbmadd_algorithm(XLEN, REDPOLY, factor1, factor2, term): if output is not None: return output state = state.next_state + + +@plain_data(frozen=True, unsafe_hash=True) +class GFBMAddState: + __slots__ = ( + "shape", "rpoly", "factor1", "factor2", "term", "acc", "step", + "sh_rpoly_clz", "name", + ) + + def __init__( + self, shape, rpoly, factor1, factor2, term, acc, step, sh_rpoly_clz, *, + name=None, src_loc_at=0, + ): + assert isinstance(shape, GFBMAddShape) + if name is None: + name = Signal(src_loc_at=1 + src_loc_at).name + assert isinstance(name, str) + rpoly = Value.cast(rpoly) + factor1 = Value.cast(factor1) + factor2 = Value.cast(factor2) + term = Value.cast(term) + acc = Value.cast(acc) + step = Value.cast(step) + sh_rpoly_clz = Value.cast(sh_rpoly_clz) + assert rpoly.shape() == unsigned(shape.rpoly_width) + assert factor1.shape() == unsigned(shape.width) + assert factor2.shape() == unsigned(shape.width) + assert term.shape() == unsigned(shape.width) + assert acc.shape() == unsigned(shape.acc_width) + assert step.shape() == shape.step_shape + assert sh_rpoly_clz.shape() == shape.sh_rpoly_clz_shape + self.shape = shape + self.rpoly = rpoly + self.factor1 = factor1 + self.factor2 = factor2 + self.term = term + self.acc = acc + self.step = step + self.sh_rpoly_clz = sh_rpoly_clz + self.name = name + + @staticmethod + def signals(shape, *, name=None, src_loc_at=0): + assert isinstance(shape, GFBMAddShape) + if name is None: + name = Signal(src_loc_at=1 + src_loc_at).name + assert isinstance(name, str) + return GFBMAddState( + shape=shape, + rpoly=Signal(shape.rpoly_width, name=name + "_rpoly"), + factor1=Signal(shape.width, name=name + "_factor1"), + factor2=Signal(shape.width, name=name + "_factor2"), + term=Signal(shape.width, name=name + "_term"), + acc=Signal(shape.acc_width, name=name + "_acc"), + step=Signal(shape.step_shape, name=name + "_step"), + sh_rpoly_clz=Signal(shape.sh_rpoly_clz_shape, + name=name + "_sh_rpoly_clz"), + name=name, + ) + + def eq(self, rhs): + assert isinstance(rhs, GFBMAddState) + for f in fields(GFBMAddState): + if f in ("shape", "name"): + continue + l = getattr(self, f) + r = getattr(rhs, f) + yield l.eq(r) + + @property + def sh_rpoly(self): + return self.rpoly << (self.shape.reduce_input_width - 1) + + def _mul_next_state(self, m): + # type: (Module) -> GFBMAddState + factor1 = self.factor1 + acc = self.acc + step = self.step + acc = Mux(factor1[0], acc ^ (self.factor2 << self.shape.width), acc) + acc >>= 1 + factor1 >>= 1 + acc = Mux(step == self.shape.mul_step_last, acc ^ self.term, acc) + step += 1 + step_sig = Signal(self.shape.step_shape, + name=self.name + "_mul_next_state_step") + m.d.comb += step_sig.eq(step) + return GFBMAddState( + shape=self.shape, + rpoly=self.rpoly, + factor1=factor1, + factor2=self.factor2, + term=self.term, + acc=acc, + step=step_sig, + sh_rpoly_clz=self.sh_rpoly_clz, + name=self.name + "_mul_next_state", + ) + + def _reduce_next_state(self, m): + # type: (Module) -> GFBMAddState + factor1 = self.factor1 + acc = self.acc + step = self.step + clz = CLZ(self.shape.acc_width) + setattr(m.submodules, self.name + "_clz", clz) + m.d.comb += clz.sig_in.eq(acc) + acc_clz = clz.lz + acc = Mux(acc_clz == self.sh_rpoly_clz, acc ^ self.sh_rpoly, acc) + acc <<= 1 + step += 1 + acc_sig = Signal(self.shape.acc_width, + name=self.name + "_reduce_next_state_acc") + m.d.comb += acc_sig.eq(acc) + step_sig = Signal(self.shape.step_shape, + name=self.name + "_reduce_next_state_step") + m.d.comb += step_sig.eq(step) + return GFBMAddState( + shape=self.shape, + rpoly=self.rpoly, + factor1=factor1, + factor2=self.factor2, + term=self.term, + acc=acc_sig, + step=step_sig, + sh_rpoly_clz=self.sh_rpoly_clz, + name=self.name + "_reduce_next_state", + ) + + def next_state(self, m): + # type: (Module) -> GFBMAddState + next_state = GFBMAddState.signals(self.shape, + name=self.name + "_next_state") + with m.If(self.step < self.shape.mul_step_end): + m.d.comb += next_state.eq(self._mul_next_state(m)) + with m.Elif(self.step < self.shape.reduce_step_end): + m.d.comb += next_state.eq(self._reduce_next_state(m)) + with m.Else(): + m.d.comb += next_state.eq(self) + return next_state + + @property + def has_output(self): + return self.step == self.shape.reduce_step_end + + @property + def output(self): + return self.acc >> self.shape.reduce_input_width + + @staticmethod + def initial_state( + shape, m, rpoly, factor1, factor2, term, *, name=None, src_loc_at=0, + ): + assert isinstance(shape, GFBMAddShape) + if name is None: + name = Signal(src_loc_at=1 + src_loc_at).name + assert isinstance(name, str) + rpoly = Value.cast(rpoly) + factor1 = Value.cast(factor1) + factor2 = Value.cast(factor2) + term = Value.cast(term) + assert rpoly.shape() == unsigned(shape.rpoly_width) + assert factor1.shape() == unsigned(shape.width) + assert factor2.shape() == unsigned(shape.width) + assert term.shape() == unsigned(shape.width) + clz = CLZ(shape.acc_width) + setattr(m.submodules, name + "_sh_rpoly_clz", clz) + sh_rpoly_clz = clz.lz + retval = GFBMAddState( + shape=shape, + rpoly=rpoly, + acc=Const(0, shape.acc_width), + factor1=factor1, + factor2=factor2, + term=term, + step=Const(0, shape.step_shape), + sh_rpoly_clz=sh_rpoly_clz, + name=name, + ) + m.d.comb += clz.sig_in.eq(retval.sh_rpoly) + return retval + + +class GFBMAddInputData: + def __init__(self, shape): + assert isinstance(shape, GFBMAddShape) + self.shape = shape + self.REDPOLY = Signal(shape.width) + self.factor1 = Signal(shape.width) + self.factor2 = Signal(shape.width) + self.term = Signal(shape.width) + + def __iter__(self): + """ Get member signals. """ + yield self.REDPOLY + yield self.factor1 + yield self.factor2 + yield self.term + + def eq(self, rhs): + """ Assign member signals. """ + return [ + self.REDPOLY.eq(rhs.REDPOLY), + self.factor1.eq(rhs.factor1), + self.factor2.eq(rhs.factor2), + self.term.eq(rhs.term), + ] + + +class GFBMAddOutputData: + def __init__(self, shape): + assert isinstance(shape, GFBMAddShape) + self.shape = shape + self.output = Signal(shape.width) + + def __iter__(self): + """ Get member signals. """ + yield self.output + + def eq(self, rhs): + """ Assign member signals. """ + return self.output.eq(rhs.output) + + +class GFBMAddFSMStage(ControlBase): + """carry-less div/rem + + Attributes: + shape: GFBMAddShape + the shape + pspec: + pipe-spec + empty: Signal() + true if nothing is stored in `self.saved_state` + saved_state: GFBMAddState() + the saved state that is currently being worked on. + """ + + def __init__(self, pspec, shape): + assert isinstance(shape, GFBMAddShape) + self.shape = shape + self.pspec = pspec # store now: used in ispec and ospec + super().__init__(stage=self) + self.empty = Signal(reset=1) + self.saved_state = GFBMAddState.signals(shape) + + def ispec(self): + return GFBMAddInputData(self.shape) + + def ospec(self): + return GFBMAddOutputData(self.shape) + + def setup(self, m, i): + pass + + def elaborate(self, platform): + m = super().elaborate(platform) + i_data = self.p.i_data + o_data = self.n.o_data + + # TODO: handle cancellation + + m.d.comb += self.n.o_valid.eq( + ~self.empty & self.saved_state.has_output) + m.d.comb += self.p.o_ready.eq(self.empty) + + rpoly_dec = DecodeReducingPolynomial(self.shape.width) + m.submodules.rpoly_dec = rpoly_dec + m.d.comb += rpoly_dec.REDPOLY.eq(i_data.REDPOLY) + rpoly = rpoly_dec.reducing_polynomial + + m.d.comb += o_data.output.eq(self.saved_state.output) + + initial_state = GFBMAddState.initial_state( + shape=self.shape, + m=m, + rpoly=rpoly, + factor1=i_data.factor1, + factor2=i_data.factor2, + term=i_data.term, + ) + + with m.If(self.empty): + m.d.sync += self.saved_state.eq(initial_state) + with m.If(self.p.i_valid): + m.d.sync += self.empty.eq(0) + with m.Else(): + m.d.sync += self.saved_state.eq(self.saved_state.next_state(m)) + with m.If(self.n.i_ready & self.n.o_valid): + m.d.sync += self.empty.eq(1) + return m + + def __iter__(self): + yield from self.p + yield from self.n + + def ports(self): + return list(self) diff --git a/src/nmigen_gf/hdl/test/test_gfbmadd.py b/src/nmigen_gf/hdl/test/test_gfbmadd.py index 4aaa94f..6b8da57 100644 --- a/src/nmigen_gf/hdl/test/test_gfbmadd.py +++ b/src/nmigen_gf/hdl/test/test_gfbmadd.py @@ -13,8 +13,9 @@ import unittest from nmigen.hdl.ast import Const, unsigned from nmutil.formaltest import FHDLTestCase from nmigen_gf.reference.gfbmadd import gfbmadd -from nmigen_gf.hdl.gfbmadd import py_gfbmadd_algorithm -from nmigen.sim import Delay +from nmigen_gf.hdl.gfbmadd import \ + py_gfbmadd_algorithm, GFBMAddFSMStage, GFBMAddShape +from nmigen.sim import Delay, Tick from nmutil.sim_util import do_sim, hash_256 from nmigen_gf.reference.state import ST import itertools @@ -65,5 +66,102 @@ class TestPyGFBMAdd(FHDLTestCase): self.tst(XLEN=64, full=False) +class TestGFBMAdd(FHDLTestCase): + def tst(self, XLEN): + shape = GFBMAddShape(width=XLEN) + pspec = {} + dut = GFBMAddFSMStage(pspec, shape) + i_data = dut.p.i_data + o_data = dut.n.o_data + + def case(REDPOLY, factor1, factor2, term): + expected = py_gfbmadd_algorithm( + shape.width, REDPOLY, factor1, factor2, term) + with self.subTest(REDPOLY=hex(REDPOLY), + factor1=hex(factor1), + factor2=hex(factor2), + term=hex(term), + expected=hex(expected)): + yield dut.p.i_valid.eq(0) + yield Tick() + yield i_data.REDPOLY.eq(REDPOLY) + yield i_data.factor1.eq(factor1) + yield i_data.factor2.eq(factor2) + yield i_data.term.eq(term) + yield dut.p.i_valid.eq(1) + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertFalse(valid) + self.assertTrue(ready) + yield Tick() + yield i_data.REDPOLY.eq(-1) + yield i_data.factor1.eq(-1) + yield i_data.factor2.eq(-1) + yield i_data.term.eq(-1) + yield dut.p.i_valid.eq(0) + for step in range(shape.step_count): + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertFalse(valid) + self.assertFalse(ready) + yield Tick() + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertTrue(valid) + self.assertFalse(ready) + output = yield o_data.output + with self.subTest(output=hex(output)): + self.assertEqual(output, expected) + yield dut.n.i_ready.eq(1) + yield Tick() + yield Delay(0.1e-6) + valid = yield dut.n.o_valid + ready = yield dut.p.o_ready + with self.subTest(): + self.assertFalse(valid) + self.assertTrue(ready) + yield dut.n.i_ready.eq(0) + + def process(): + for i in range(100): + v = hash_256("gfbmadd fsm %i REDPOLY %i" % (XLEN, i)) + shift = hash_256("gfbmadd fsm %i REDPOLY shift %i" % (XLEN, i)) + v >>= shift % XLEN + REDPOLY = Const.normalize(v, unsigned(XLEN)) + v = hash_256("gfbmadd fsm %i factor1 %i" % (XLEN, i)) + factor1 = Const.normalize(v, unsigned(XLEN)) + v = hash_256("gfbmadd fsm %i factor2 %i" % (XLEN, i)) + factor2 = Const.normalize(v, unsigned(XLEN)) + v = hash_256("gfbmadd fsm %i term %i" % (XLEN, i)) + term = Const.normalize(v, unsigned(XLEN)) + yield from case(REDPOLY, factor1, factor2, term) + + with do_sim(self, dut, list(dut.ports())) as sim: + sim.add_process(process) + sim.add_clock(1e-6) + sim.run() + + def test_4(self): + self.tst(4) + + def test_8(self): + self.tst(8) + + def test_16(self): + self.tst(16) + + def test_32(self): + self.tst(32) + + def test_64(self): + self.tst(64) + + if __name__ == "__main__": unittest.main() -- 2.30.2