hdl/gfbmadd: GFBMAddFSMStage works!
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 15 May 2024 08:03:12 +0000 (01:03 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 15 May 2024 08:06:41 +0000 (01:06 -0700)
src/nmigen_gf/hdl/gfbmadd.py
src/nmigen_gf/hdl/test/test_gfbmadd.py

index 7a2997465b208995a89e5ef756df8cb688acba9a..b1880062c3c55acd37937924b790b3933b81e08d 100644 (file)
@@ -9,13 +9,14 @@
 https://bugs.libre-soc.org/show_bug.cgi?id=785
 """
 
-from nmigen.hdl.ast import Signal
-from nmigen.hdl.ir import Elaboratable
+from nmigen.hdl.ast import Shape, Signal, Value, unsigned, Mux, Const
+from nmutil.singlepipe import ControlBase
 from nmigen.hdl.dsl import Module
 from nmutil.plain_data import plain_data, fields
 from nmigen_gf.reference.state import ST
 from nmigen_gf.reference.decode_reducing_polynomial import \
     decode_reducing_polynomial
+from nmigen_gf.hdl.decode_reducing_polynomial import DecodeReducingPolynomial
 from nmutil.clz import clz, CLZ
 
 
@@ -55,6 +56,18 @@ class GFBMAddShape:
     def acc_width(self):
         return self.width + self.reduce_input_width
 
+    @property
+    def step_count(self):
+        return self.reduce_step_end
+
+    @property
+    def step_shape(self):
+        return Shape.cast(range(self.step_count + 1))
+
+    @property
+    def sh_rpoly_clz_shape(self):
+        return Shape.cast(range(self.acc_width))
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 class PyGFBMAddState:
@@ -84,27 +97,39 @@ class PyGFBMAddState:
         return clz(self.sh_rpoly, self.shape.acc_width)
 
     @property
-    def next_state(self):
+    def _mul_next_state(self):
         # type: () -> PyGFBMAddState
         factor1 = self.factor1
         acc = self.acc
         step = self.step
-        if self.step < self.shape.mul_step_end:
-            if factor1 & 1:
-                acc ^= self.factor2 << self.shape.width
-            acc >>= 1
-            factor1 >>= 1
-            if step == self.shape.mul_step_last:
-                acc ^= self.term
-            step += 1
-        elif self.step < self.shape.reduce_step_end:
-            acc_clz = clz(acc, self.shape.acc_width)
-            if acc_clz == self.sh_rpoly_clz:
-                acc ^= self.sh_rpoly
-            acc <<= 1
-            step += 1
-        else:
-            return self
+        if factor1 & 1:
+            acc ^= self.factor2 << self.shape.width
+        acc >>= 1
+        factor1 >>= 1
+        if step == self.shape.mul_step_last:
+            acc ^= self.term
+        step += 1
+        return PyGFBMAddState(
+            shape=self.shape,
+            rpoly=self.rpoly,
+            factor1=factor1,
+            factor2=self.factor2,
+            term=self.term,
+            acc=acc,
+            step=step,
+        )
+
+    @property
+    def _reduce_next_state(self):
+        # type: () -> PyGFBMAddState
+        factor1 = self.factor1
+        acc = self.acc
+        step = self.step
+        acc_clz = clz(acc, self.shape.acc_width)
+        if acc_clz == self.sh_rpoly_clz:
+            acc ^= self.sh_rpoly
+        acc <<= 1
+        step += 1
         return PyGFBMAddState(
             shape=self.shape,
             rpoly=self.rpoly,
@@ -115,6 +140,16 @@ class PyGFBMAddState:
             step=step,
         )
 
+    @property
+    def next_state(self):
+        # type: () -> PyGFBMAddState
+        if self.step < self.shape.mul_step_end:
+            return self._mul_next_state
+        elif self.step < self.shape.reduce_step_end:
+            return self._reduce_next_state
+        else:
+            return self
+
     @property
     def output(self):
         # type: () -> None | int
@@ -158,3 +193,300 @@ def py_gfbmadd_algorithm(XLEN, REDPOLY, factor1, factor2, term):
         if output is not None:
             return output
         state = state.next_state
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class GFBMAddState:
+    __slots__ = (
+        "shape", "rpoly", "factor1", "factor2", "term", "acc", "step",
+        "sh_rpoly_clz", "name",
+    )
+
+    def __init__(
+        self, shape, rpoly, factor1, factor2, term, acc, step, sh_rpoly_clz, *,
+        name=None, src_loc_at=0,
+    ):
+        assert isinstance(shape, GFBMAddShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        rpoly = Value.cast(rpoly)
+        factor1 = Value.cast(factor1)
+        factor2 = Value.cast(factor2)
+        term = Value.cast(term)
+        acc = Value.cast(acc)
+        step = Value.cast(step)
+        sh_rpoly_clz = Value.cast(sh_rpoly_clz)
+        assert rpoly.shape() == unsigned(shape.rpoly_width)
+        assert factor1.shape() == unsigned(shape.width)
+        assert factor2.shape() == unsigned(shape.width)
+        assert term.shape() == unsigned(shape.width)
+        assert acc.shape() == unsigned(shape.acc_width)
+        assert step.shape() == shape.step_shape
+        assert sh_rpoly_clz.shape() == shape.sh_rpoly_clz_shape
+        self.shape = shape
+        self.rpoly = rpoly
+        self.factor1 = factor1
+        self.factor2 = factor2
+        self.term = term
+        self.acc = acc
+        self.step = step
+        self.sh_rpoly_clz = sh_rpoly_clz
+        self.name = name
+
+    @staticmethod
+    def signals(shape, *, name=None, src_loc_at=0):
+        assert isinstance(shape, GFBMAddShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        return GFBMAddState(
+            shape=shape,
+            rpoly=Signal(shape.rpoly_width, name=name + "_rpoly"),
+            factor1=Signal(shape.width, name=name + "_factor1"),
+            factor2=Signal(shape.width, name=name + "_factor2"),
+            term=Signal(shape.width, name=name + "_term"),
+            acc=Signal(shape.acc_width, name=name + "_acc"),
+            step=Signal(shape.step_shape, name=name + "_step"),
+            sh_rpoly_clz=Signal(shape.sh_rpoly_clz_shape,
+                                name=name + "_sh_rpoly_clz"),
+            name=name,
+        )
+
+    def eq(self, rhs):
+        assert isinstance(rhs, GFBMAddState)
+        for f in fields(GFBMAddState):
+            if f in ("shape", "name"):
+                continue
+            l = getattr(self, f)
+            r = getattr(rhs, f)
+            yield l.eq(r)
+
+    @property
+    def sh_rpoly(self):
+        return self.rpoly << (self.shape.reduce_input_width - 1)
+
+    def _mul_next_state(self, m):
+        # type: (Module) -> GFBMAddState
+        factor1 = self.factor1
+        acc = self.acc
+        step = self.step
+        acc = Mux(factor1[0], acc ^ (self.factor2 << self.shape.width), acc)
+        acc >>= 1
+        factor1 >>= 1
+        acc = Mux(step == self.shape.mul_step_last, acc ^ self.term, acc)
+        step += 1
+        step_sig = Signal(self.shape.step_shape,
+                          name=self.name + "_mul_next_state_step")
+        m.d.comb += step_sig.eq(step)
+        return GFBMAddState(
+            shape=self.shape,
+            rpoly=self.rpoly,
+            factor1=factor1,
+            factor2=self.factor2,
+            term=self.term,
+            acc=acc,
+            step=step_sig,
+            sh_rpoly_clz=self.sh_rpoly_clz,
+            name=self.name + "_mul_next_state",
+        )
+
+    def _reduce_next_state(self, m):
+        # type: (Module) -> GFBMAddState
+        factor1 = self.factor1
+        acc = self.acc
+        step = self.step
+        clz = CLZ(self.shape.acc_width)
+        setattr(m.submodules, self.name + "_clz", clz)
+        m.d.comb += clz.sig_in.eq(acc)
+        acc_clz = clz.lz
+        acc = Mux(acc_clz == self.sh_rpoly_clz, acc ^ self.sh_rpoly, acc)
+        acc <<= 1
+        step += 1
+        acc_sig = Signal(self.shape.acc_width,
+                         name=self.name + "_reduce_next_state_acc")
+        m.d.comb += acc_sig.eq(acc)
+        step_sig = Signal(self.shape.step_shape,
+                          name=self.name + "_reduce_next_state_step")
+        m.d.comb += step_sig.eq(step)
+        return GFBMAddState(
+            shape=self.shape,
+            rpoly=self.rpoly,
+            factor1=factor1,
+            factor2=self.factor2,
+            term=self.term,
+            acc=acc_sig,
+            step=step_sig,
+            sh_rpoly_clz=self.sh_rpoly_clz,
+            name=self.name + "_reduce_next_state",
+        )
+
+    def next_state(self, m):
+        # type: (Module) -> GFBMAddState
+        next_state = GFBMAddState.signals(self.shape,
+                                          name=self.name + "_next_state")
+        with m.If(self.step < self.shape.mul_step_end):
+            m.d.comb += next_state.eq(self._mul_next_state(m))
+        with m.Elif(self.step < self.shape.reduce_step_end):
+            m.d.comb += next_state.eq(self._reduce_next_state(m))
+        with m.Else():
+            m.d.comb += next_state.eq(self)
+        return next_state
+
+    @property
+    def has_output(self):
+        return self.step == self.shape.reduce_step_end
+
+    @property
+    def output(self):
+        return self.acc >> self.shape.reduce_input_width
+
+    @staticmethod
+    def initial_state(
+        shape, m, rpoly, factor1, factor2, term, *, name=None, src_loc_at=0,
+    ):
+        assert isinstance(shape, GFBMAddShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        rpoly = Value.cast(rpoly)
+        factor1 = Value.cast(factor1)
+        factor2 = Value.cast(factor2)
+        term = Value.cast(term)
+        assert rpoly.shape() == unsigned(shape.rpoly_width)
+        assert factor1.shape() == unsigned(shape.width)
+        assert factor2.shape() == unsigned(shape.width)
+        assert term.shape() == unsigned(shape.width)
+        clz = CLZ(shape.acc_width)
+        setattr(m.submodules, name + "_sh_rpoly_clz", clz)
+        sh_rpoly_clz = clz.lz
+        retval = GFBMAddState(
+            shape=shape,
+            rpoly=rpoly,
+            acc=Const(0, shape.acc_width),
+            factor1=factor1,
+            factor2=factor2,
+            term=term,
+            step=Const(0, shape.step_shape),
+            sh_rpoly_clz=sh_rpoly_clz,
+            name=name,
+        )
+        m.d.comb += clz.sig_in.eq(retval.sh_rpoly)
+        return retval
+
+
+class GFBMAddInputData:
+    def __init__(self, shape):
+        assert isinstance(shape, GFBMAddShape)
+        self.shape = shape
+        self.REDPOLY = Signal(shape.width)
+        self.factor1 = Signal(shape.width)
+        self.factor2 = Signal(shape.width)
+        self.term = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.REDPOLY
+        yield self.factor1
+        yield self.factor2
+        yield self.term
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [
+            self.REDPOLY.eq(rhs.REDPOLY),
+            self.factor1.eq(rhs.factor1),
+            self.factor2.eq(rhs.factor2),
+            self.term.eq(rhs.term),
+        ]
+
+
+class GFBMAddOutputData:
+    def __init__(self, shape):
+        assert isinstance(shape, GFBMAddShape)
+        self.shape = shape
+        self.output = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.output
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return self.output.eq(rhs.output)
+
+
+class GFBMAddFSMStage(ControlBase):
+    """carry-less div/rem
+
+    Attributes:
+    shape: GFBMAddShape
+        the shape
+    pspec:
+        pipe-spec
+    empty: Signal()
+        true if nothing is stored in `self.saved_state`
+    saved_state: GFBMAddState()
+        the saved state that is currently being worked on.
+    """
+
+    def __init__(self, pspec, shape):
+        assert isinstance(shape, GFBMAddShape)
+        self.shape = shape
+        self.pspec = pspec  # store now: used in ispec and ospec
+        super().__init__(stage=self)
+        self.empty = Signal(reset=1)
+        self.saved_state = GFBMAddState.signals(shape)
+
+    def ispec(self):
+        return GFBMAddInputData(self.shape)
+
+    def ospec(self):
+        return GFBMAddOutputData(self.shape)
+
+    def setup(self, m, i):
+        pass
+
+    def elaborate(self, platform):
+        m = super().elaborate(platform)
+        i_data = self.p.i_data
+        o_data = self.n.o_data
+
+        # TODO: handle cancellation
+
+        m.d.comb += self.n.o_valid.eq(
+            ~self.empty & self.saved_state.has_output)
+        m.d.comb += self.p.o_ready.eq(self.empty)
+
+        rpoly_dec = DecodeReducingPolynomial(self.shape.width)
+        m.submodules.rpoly_dec = rpoly_dec
+        m.d.comb += rpoly_dec.REDPOLY.eq(i_data.REDPOLY)
+        rpoly = rpoly_dec.reducing_polynomial
+
+        m.d.comb += o_data.output.eq(self.saved_state.output)
+
+        initial_state = GFBMAddState.initial_state(
+            shape=self.shape,
+            m=m,
+            rpoly=rpoly,
+            factor1=i_data.factor1,
+            factor2=i_data.factor2,
+            term=i_data.term,
+        )
+
+        with m.If(self.empty):
+            m.d.sync += self.saved_state.eq(initial_state)
+            with m.If(self.p.i_valid):
+                m.d.sync += self.empty.eq(0)
+        with m.Else():
+            m.d.sync += self.saved_state.eq(self.saved_state.next_state(m))
+            with m.If(self.n.i_ready & self.n.o_valid):
+                m.d.sync += self.empty.eq(1)
+        return m
+
+    def __iter__(self):
+        yield from self.p
+        yield from self.n
+
+    def ports(self):
+        return list(self)
index 4aaa94f2168142da178b208242eb51afcb401882..6b8da5744ab2bd175f068c33a626e3199db844f8 100644 (file)
@@ -13,8 +13,9 @@ import unittest
 from nmigen.hdl.ast import Const, unsigned
 from nmutil.formaltest import FHDLTestCase
 from nmigen_gf.reference.gfbmadd import gfbmadd
-from nmigen_gf.hdl.gfbmadd import py_gfbmadd_algorithm
-from nmigen.sim import Delay
+from nmigen_gf.hdl.gfbmadd import \
+    py_gfbmadd_algorithm, GFBMAddFSMStage, GFBMAddShape
+from nmigen.sim import Delay, Tick
 from nmutil.sim_util import do_sim, hash_256
 from nmigen_gf.reference.state import ST
 import itertools
@@ -65,5 +66,102 @@ class TestPyGFBMAdd(FHDLTestCase):
         self.tst(XLEN=64, full=False)
 
 
+class TestGFBMAdd(FHDLTestCase):
+    def tst(self, XLEN):
+        shape = GFBMAddShape(width=XLEN)
+        pspec = {}
+        dut = GFBMAddFSMStage(pspec, shape)
+        i_data = dut.p.i_data
+        o_data = dut.n.o_data
+
+        def case(REDPOLY, factor1, factor2, term):
+            expected = py_gfbmadd_algorithm(
+                shape.width, REDPOLY, factor1, factor2, term)
+            with self.subTest(REDPOLY=hex(REDPOLY),
+                              factor1=hex(factor1),
+                              factor2=hex(factor2),
+                              term=hex(term),
+                              expected=hex(expected)):
+                yield dut.p.i_valid.eq(0)
+                yield Tick()
+                yield i_data.REDPOLY.eq(REDPOLY)
+                yield i_data.factor1.eq(factor1)
+                yield i_data.factor2.eq(factor2)
+                yield i_data.term.eq(term)
+                yield dut.p.i_valid.eq(1)
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield Tick()
+                yield i_data.REDPOLY.eq(-1)
+                yield i_data.factor1.eq(-1)
+                yield i_data.factor2.eq(-1)
+                yield i_data.term.eq(-1)
+                yield dut.p.i_valid.eq(0)
+                for step in range(shape.step_count):
+                    yield Delay(0.1e-6)
+                    valid = yield dut.n.o_valid
+                    ready = yield dut.p.o_ready
+                    with self.subTest():
+                        self.assertFalse(valid)
+                        self.assertFalse(ready)
+                    yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertTrue(valid)
+                    self.assertFalse(ready)
+                output = yield o_data.output
+                with self.subTest(output=hex(output)):
+                    self.assertEqual(output, expected)
+                yield dut.n.i_ready.eq(1)
+                yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield dut.n.i_ready.eq(0)
+
+        def process():
+            for i in range(100):
+                v = hash_256("gfbmadd fsm %i REDPOLY %i" % (XLEN, i))
+                shift = hash_256("gfbmadd fsm %i REDPOLY shift %i" % (XLEN, i))
+                v >>= shift % XLEN
+                REDPOLY = Const.normalize(v, unsigned(XLEN))
+                v = hash_256("gfbmadd fsm %i factor1 %i" % (XLEN, i))
+                factor1 = Const.normalize(v, unsigned(XLEN))
+                v = hash_256("gfbmadd fsm %i factor2 %i" % (XLEN, i))
+                factor2 = Const.normalize(v, unsigned(XLEN))
+                v = hash_256("gfbmadd fsm %i term %i" % (XLEN, i))
+                term = Const.normalize(v, unsigned(XLEN))
+                yield from case(REDPOLY, factor1, factor2, term)
+
+        with do_sim(self, dut, list(dut.ports())) as sim:
+            sim.add_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
+
+    def test_4(self):
+        self.tst(4)
+
+    def test_8(self):
+        self.tst(8)
+
+    def test_16(self):
+        self.tst(16)
+
+    def test_32(self):
+        self.tst(32)
+
+    def test_64(self):
+        self.tst(64)
+
+
 if __name__ == "__main__":
     unittest.main()