hdl/gfbinv: add gfbinv implementation, all tests pass!
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 16 May 2024 06:21:04 +0000 (23:21 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 16 May 2024 06:21:04 +0000 (23:21 -0700)
src/nmigen_gf/hdl/gfbinv.py [new file with mode: 0644]
src/nmigen_gf/hdl/test/test_gfbinv.py [new file with mode: 0644]

diff --git a/src/nmigen_gf/hdl/gfbinv.py b/src/nmigen_gf/hdl/gfbinv.py
new file mode 100644 (file)
index 0000000..712c2a6
--- /dev/null
@@ -0,0 +1,457 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2024 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+""" GF(2^n)
+
+https://bugs.libre-soc.org/show_bug.cgi?id=785
+"""
+
+from nmigen.hdl.ast import Shape, Signal, Value, unsigned, Mux, Const
+from nmutil.singlepipe import ControlBase
+from nmigen.hdl.dsl import Module
+from nmutil.plain_data import plain_data, fields
+from nmigen_gf.reference.state import ST
+from nmigen_gf.reference.decode_reducing_polynomial import \
+    decode_reducing_polynomial
+from nmigen_gf.hdl.decode_reducing_polynomial import DecodeReducingPolynomial
+from nmutil.clz import clz, CLZ
+from collections import defaultdict
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class GFBInvShape:
+    __slots__ = "width",
+
+    def __init__(self, width):
+        # type: (int) -> None
+        self.width = width
+
+    @property
+    def rpoly_max_degree(self):
+        return self.width
+
+    @property
+    def rpoly_width(self):
+        return self.width + 1
+
+    @property
+    def max_step_count(self):
+        return self.rpoly_max_degree * 2
+
+    @property
+    def step_shape(self):
+        return Shape.cast(range(self.max_step_count + 1))
+
+    @property
+    def rpoly_degree_shape(self):
+        return Shape.cast(range(self.rpoly_max_degree + 1))
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+class PyGFBInvState:
+    __slots__ = "shape", "a", "s", "m", "r", "v", "u", "delta", "step"
+
+    def __init__(self, shape, a, s, m, r, v, u, delta, step):
+        # type: (GFBInvShape, int, int, int, int, int, int, int, int) -> None
+        a = Const.normalize(a, unsigned(shape.width))
+        s = Const.normalize(s, unsigned(shape.rpoly_width))
+        m = Const.normalize(m, shape.rpoly_degree_shape)
+        r = Const.normalize(r, unsigned(shape.rpoly_width))
+        v = Const.normalize(v, unsigned(shape.rpoly_width))
+        u = Const.normalize(u, unsigned(shape.rpoly_width))
+        delta = Const.normalize(delta, shape.step_shape)
+        step = Const.normalize(step, shape.step_shape)
+        self.shape = shape
+        self.a = a
+        self.s = s
+        self.m = m
+        self.r = r
+        self.v = v
+        self.u = u
+        self.delta = delta
+        self.step = step
+
+    @property
+    def next_state(self):
+        # type: () -> PyGFBInvState
+        s = self.s
+        r = self.r
+        v = self.v
+        u = self.u
+        delta = self.delta
+        step = self.step
+        if step != self.m * 2:
+            if r >> self.m == 0:  # if the MSB of `r` is zero
+                r <<= 1
+                u <<= 1
+                delta += 1
+            else:
+                if s >> self.m != 0:  # if the MSB of `s` isn't zero
+                    s ^= r
+                    v ^= u
+                s <<= 1
+                if delta == 0:
+                    r, s = s, r  # swap r and s
+                    u, v = v << 1, u  # shift v and swap
+                    delta = 1
+                else:
+                    u >>= 1
+                    delta -= 1
+            step += 1
+            return PyGFBInvState(
+                shape=self.shape,
+                a=self.a,
+                s=s,
+                m=self.m,
+                r=r,
+                v=v,
+                u=u,
+                delta=delta,
+                step=step,
+            )
+        else:
+            return self
+
+    @property
+    def output(self):
+        # type: () -> None | int
+        if self.step == self.m * 2:
+            return 0 if self.a == 0 else self.u % 2 ** self.shape.width
+        else:
+            return None
+
+    @staticmethod
+    def initial_state(shape, rpoly, a):
+        # type: (GFBInvShape, int, int) -> PyGFBInvState
+        assert 0 < shape.width
+        assert 0 < rpoly < 2 ** shape.rpoly_width
+        assert 0 <= a < 2 ** shape.width
+        s = rpoly
+        m = shape.rpoly_width - 1 - clz(s, shape.rpoly_width)
+        return PyGFBInvState(
+            shape=shape,
+            a=a,
+            s=s,
+            m=m,
+            r=a,
+            v=0,
+            u=1,
+            delta=0,
+            step=0,
+        )
+
+    def __repr__(self):
+        field_strs = []
+        for name in fields(self):
+            v = getattr(self, name, None)
+            if v is None:
+                field_strs.append("%s=<not set>" % (name,))
+            elif name in ("shape", "m", "delta", "step"):
+                field_strs.append("%s=%s" % (name, v))
+            else:
+                field_strs.append("%s=0x%x" % (name, v))
+        return "PyGFBInvState(%s)" % (", ".join(field_strs),)
+
+
+def py_gfbinv_algorithm(XLEN, REDPOLY, a):
+    # type: (int, int, int) -> int
+    ST.reinit(XLEN=XLEN, GFBREDPOLY=REDPOLY)
+    rpoly = decode_reducing_polynomial()
+    shape = GFBInvShape(width=XLEN)
+    state = PyGFBInvState.initial_state(shape=shape, rpoly=rpoly, a=a)
+    while True:
+        output = state.output
+        if output is not None:
+            return output
+        state = state.next_state
+
+
+@plain_data(frozen=True)
+class TempSignalMaker:
+    __slots__ = "m", "base_name", "names_next_index"
+
+    def __init__(self, m, base_name, names_next_index=None):
+        # type: (Module, str, None | dict[str, int]) -> None
+        self.m = m
+        self.base_name = base_name
+        if names_next_index is None:
+            names_next_index = {}
+        self.names_next_index = names_next_index
+
+    def __call__(self, value, *, name=None, src_loc_at=0):
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        if name in self.names_next_index:
+            idx = self.names_next_index[name]
+            self.names_next_index[name] += 1
+        else:
+            self.names_next_index[name] = 2
+            idx = ""
+        name = "%s_%s_temp%s" % (self.base_name, name, idx)
+        retval = Signal.like(value, name=name)
+        self.m.d.comb += retval.eq(value)
+        return retval
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class GFBInvState:
+    __slots__ = "shape", "a", "s", "m", "r", "v", "u", "delta", "step", "name"
+
+    def __init__(
+        self, shape, a, s, m, r, v, u, delta, step, *, name=None, src_loc_at=0,
+    ):
+        assert isinstance(shape, GFBInvShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        a = Value.cast(a)
+        s = Value.cast(s)
+        m = Value.cast(m)
+        r = Value.cast(r)
+        v = Value.cast(v)
+        u = Value.cast(u)
+        delta = Value.cast(delta)
+        step = Value.cast(step)
+        assert a.shape() == unsigned(shape.width)
+        assert s.shape() == unsigned(shape.rpoly_width)
+        assert m.shape() == shape.rpoly_degree_shape
+        assert r.shape() == unsigned(shape.rpoly_width)
+        assert v.shape() == unsigned(shape.rpoly_width)
+        assert u.shape() == unsigned(shape.rpoly_width)
+        assert delta.shape() == shape.step_shape
+        assert step.shape() == shape.step_shape
+        self.shape = shape
+        self.a = a
+        self.s = s
+        self.m = m
+        self.r = r
+        self.v = v
+        self.u = u
+        self.delta = delta
+        self.step = step
+        self.name = name
+
+    @staticmethod
+    def signals(shape, *, name=None, src_loc_at=0):
+        assert isinstance(shape, GFBInvShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        return GFBInvState(
+            shape=shape,
+            a=Signal(shape.width, name=name + "_a"),
+            s=Signal(shape.rpoly_width, name=name + "_s"),
+            m=Signal(shape.rpoly_degree_shape, name=name + "_m"),
+            r=Signal(shape.rpoly_width, name=name + "_r"),
+            v=Signal(shape.rpoly_width, name=name + "_v"),
+            u=Signal(shape.rpoly_width, name=name + "_u"),
+            delta=Signal(shape.step_shape, name=name + "_delta"),
+            step=Signal(shape.step_shape, name=name + "_step"),
+            name=name,
+        )
+
+    def eq(self, rhs):
+        assert isinstance(rhs, GFBInvState)
+        for f in fields(GFBInvState):
+            if f in ("shape", "name"):
+                continue
+            l = getattr(self, f)
+            r = getattr(rhs, f)
+            yield l.eq(r)
+
+    def next_state(self, m):
+        # type: (Module) -> GFBInvState
+        next_state = GFBInvState.signals(self.shape,
+                                         name=self.name + "_next_state")
+        m.d.comb += next_state.eq(self)
+
+        tmp = TempSignalMaker(m, self.name + "_next_state")
+
+        with m.If(self.step != self.m << 1):
+            with m.If(self.r >> self.m == 0):  # if the MSB of `r` is zero
+                m.d.comb += [
+                    next_state.r.eq(self.r << 1),
+                    next_state.u.eq(self.u << 1),
+                    next_state.delta.eq(self.delta + 1),
+                ]
+            with m.Else():
+                s_msb_nonzero = self.s >> self.m != 0
+                s_msb_nonzero = tmp(s_msb_nonzero)
+                # if the MSB of `s` isn't zero
+                s = tmp(Mux(s_msb_nonzero, self.s ^ self.r, self.s))
+                v = tmp(Mux(s_msb_nonzero, self.v ^ self.u, self.v))
+                s = tmp(s << 1)
+                delta_z = tmp(self.delta == 0)
+                r, s = (Mux(delta_z, s, self.r),  # swap r and s
+                        Mux(delta_z, self.r, s))
+                r = tmp(r)
+                s = tmp(s)
+                u, v = (Mux(delta_z, v << 1, self.u >> 1),  # shift v and swap
+                        Mux(delta_z, self.u, v))
+                u = tmp(u)
+                v = tmp(v)
+                delta = tmp(Mux(delta_z, 1, self.delta - 1))
+                m.d.comb += [
+                    next_state.s.eq(s),
+                    next_state.r.eq(r),
+                    next_state.v.eq(v),
+                    next_state.u.eq(u),
+                    next_state.delta.eq(delta),
+                ]
+            m.d.comb += next_state.step.eq(self.step + 1)
+        return next_state
+
+    @property
+    def has_output(self):
+        return self.step == self.m << 1
+
+    @property
+    def output(self):
+        return Mux(self.a == 0, 0, self.u)
+
+    @staticmethod
+    def initial_state(shape, m, rpoly, a, *, name=None, src_loc_at=0):
+        assert isinstance(shape, GFBInvShape)
+        if name is None:
+            name = Signal(src_loc_at=1 + src_loc_at).name
+        assert isinstance(name, str)
+        rpoly = Value.cast(rpoly)
+        a = Value.cast(a)
+        assert rpoly.shape() == unsigned(shape.rpoly_width)
+        assert a.shape() == unsigned(shape.width)
+        s = rpoly
+        clz = CLZ(shape.rpoly_width)
+        setattr(m.submodules, name + "_clz", clz)
+        m.d.comb += clz.sig_in.eq(s)
+        state_m = Signal(shape.rpoly_degree_shape,
+                         name=name + "_initial_state_m")
+        m.d.comb += state_m.eq(
+            Const(shape.rpoly_width - 1, shape.rpoly_degree_shape) - clz.lz)
+        r = Signal(shape.rpoly_width,
+                   name=name + "_initial_state_r")
+        m.d.comb += r.eq(a)
+        return GFBInvState(
+            shape=shape,
+            a=a,
+            s=s,
+            m=state_m,
+            r=r,
+            v=Const(0, shape.rpoly_width),
+            u=Const(1, shape.rpoly_width),
+            delta=Const(0, shape.step_shape),
+            step=Const(0, shape.step_shape),
+            name=name,
+        )
+
+
+class GFBInvInputData:
+    def __init__(self, shape):
+        assert isinstance(shape, GFBInvShape)
+        self.shape = shape
+        self.REDPOLY = Signal(shape.width)
+        self.a = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.REDPOLY
+        yield self.a
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return [
+            self.REDPOLY.eq(rhs.REDPOLY),
+            self.a.eq(rhs.a),
+        ]
+
+
+class GFBInvOutputData:
+    def __init__(self, shape):
+        assert isinstance(shape, GFBInvShape)
+        self.shape = shape
+        self.output = Signal(shape.width)
+
+    def __iter__(self):
+        """ Get member signals. """
+        yield self.output
+
+    def eq(self, rhs):
+        """ Assign member signals. """
+        return self.output.eq(rhs.output)
+
+
+class GFBInvFSMStage(ControlBase):
+    """binary galois field inverse
+
+    Attributes:
+    shape: GFBInvShape
+        the shape
+    pspec:
+        pipe-spec
+    empty: Signal()
+        true if nothing is stored in `self.saved_state`
+    saved_state: GFBInvState()
+        the saved state that is currently being worked on.
+    """
+
+    def __init__(self, pspec, shape):
+        assert isinstance(shape, GFBInvShape)
+        self.shape = shape
+        self.pspec = pspec  # store now: used in ispec and ospec
+        super().__init__(stage=self)
+        self.empty = Signal(reset=1)
+        self.saved_state = GFBInvState.signals(shape)
+
+    def ispec(self):
+        return GFBInvInputData(self.shape)
+
+    def ospec(self):
+        return GFBInvOutputData(self.shape)
+
+    def setup(self, m, i):
+        pass
+
+    def elaborate(self, platform):
+        m = super().elaborate(platform)
+        i_data = self.p.i_data
+        o_data = self.n.o_data
+
+        # TODO: handle cancellation
+
+        m.d.comb += self.n.o_valid.eq(
+            ~self.empty & self.saved_state.has_output)
+        m.d.comb += self.p.o_ready.eq(self.empty)
+
+        rpoly_dec = DecodeReducingPolynomial(self.shape.width)
+        m.submodules.rpoly_dec = rpoly_dec
+        m.d.comb += rpoly_dec.REDPOLY.eq(i_data.REDPOLY)
+        rpoly = rpoly_dec.reducing_polynomial
+
+        m.d.comb += o_data.output.eq(self.saved_state.output)
+
+        initial_state = GFBInvState.initial_state(
+            shape=self.shape,
+            m=m,
+            rpoly=rpoly,
+            a=i_data.a,
+        )
+
+        with m.If(self.empty):
+            m.d.sync += self.saved_state.eq(initial_state)
+            with m.If(self.p.i_valid):
+                m.d.sync += self.empty.eq(0)
+        with m.Else():
+            m.d.sync += self.saved_state.eq(self.saved_state.next_state(m))
+            with m.If(self.n.i_ready & self.n.o_valid):
+                m.d.sync += self.empty.eq(1)
+        return m
+
+    def __iter__(self):
+        yield from self.p
+        yield from self.n
+
+    def ports(self):
+        return list(self)
diff --git a/src/nmigen_gf/hdl/test/test_gfbinv.py b/src/nmigen_gf/hdl/test/test_gfbinv.py
new file mode 100644 (file)
index 0000000..957ce35
--- /dev/null
@@ -0,0 +1,180 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2024 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+""" GF(2^n)
+
+https://bugs.libre-soc.org/show_bug.cgi?id=785
+"""
+
+import unittest
+from nmigen.hdl.ast import Const, unsigned
+from nmutil.formaltest import FHDLTestCase
+from nmigen_gf.reference.gfbinv import gfbinv
+from nmigen_gf.reference.cldivrem import degree, cldivrem
+from nmigen_gf.hdl.gfbinv import \
+    py_gfbinv_algorithm, GFBInvShape, GFBInvFSMStage
+from nmigen_gf.reference.decode_reducing_polynomial import \
+    decode_reducing_polynomial
+from nmigen.sim import Delay, Tick
+from nmutil.sim_util import do_sim, hash_256
+from nmigen_gf.reference.state import ST
+import itertools
+
+
+def is_irreducible(poly):
+    # type: (int) -> bool
+    assert poly < 2 ** 33, "too slow for testing"
+    if poly == 0b10:
+        return True
+    if poly & 1 == 0:
+        return False
+    d = degree(poly)
+    half_degree = d // 2
+    for trial_divisor in range(0b11, 1 << (1 + half_degree), 2):
+        q, r = cldivrem(poly, trial_divisor, width=d + 1)
+        if q != 1 and r == 0:
+            return False
+    return True
+
+
+class TestPyGFBInv(FHDLTestCase):
+    def tst(self, XLEN, full):
+        # type: (int, bool) -> None
+        def case(REDPOLY, a):
+            # type: (int, int) -> None
+            ST.reinit(XLEN=XLEN, GFBREDPOLY=REDPOLY)
+            if not is_irreducible(decode_reducing_polynomial()):
+                # algorithm expects irreducible reducing polynomials, it can
+                # misbehave (intermediate values are bigger than the hardware
+                # implementation uses) if that's not satisfied. this isn't a
+                # problem for the instruction definition since we don't care
+                # what junk output you get for a junk input.
+                return
+            try:
+                expected = gfbinv(a)
+            except AssertionError as e:
+                if e.args != ("`a` is out-of-range",):
+                    raise
+                expected = None
+            output = py_gfbinv_algorithm(XLEN, REDPOLY, a)
+            if expected is not None:
+                with self.subTest(expected=hex(expected),
+                                  output=hex(output)):
+                    self.assertEqual(expected, output)
+        if full:
+            itr = itertools.product(range(1 << XLEN), repeat=2)
+            for REDPOLY, a in itr:
+                with self.subTest(REDPOLY=hex(REDPOLY), a=hex(a)):
+                    case(REDPOLY, a)
+        else:
+            for i in range(100):
+                v = hash_256("py_gfbinv %i REDPOLY %i" % (XLEN, i))
+                shift = hash_256("py_gfbinv %i REDPOLY shift %i" % (XLEN, i))
+                v >>= shift % XLEN
+                REDPOLY = Const.normalize(v, unsigned(XLEN))
+                v = hash_256("py_gfbinv %i a %i" % (XLEN, i))
+                a = Const.normalize(v, unsigned(XLEN))
+                with self.subTest(REDPOLY=hex(REDPOLY), a=hex(a)):
+                    case(REDPOLY, a)
+
+    def test_4(self):
+        self.tst(XLEN=4, full=True)
+
+    def test_8(self):
+        self.tst(XLEN=8, full=True)
+
+    def test_32(self):
+        self.tst(XLEN=32, full=False)
+
+
+class TestGFBInv(FHDLTestCase):
+    def tst(self, XLEN, full):
+        shape = GFBInvShape(width=XLEN)
+        pspec = {}
+        dut = GFBInvFSMStage(pspec, shape)
+        i_data = dut.p.i_data
+        o_data = dut.n.o_data
+
+        def case(REDPOLY, a):
+            expected = py_gfbinv_algorithm(shape.width, REDPOLY, a)
+            with self.subTest(REDPOLY=hex(REDPOLY),
+                              a=hex(a),
+                              expected=hex(expected)):
+                yield dut.p.i_valid.eq(0)
+                yield Tick()
+                yield i_data.REDPOLY.eq(REDPOLY)
+                yield i_data.a.eq(a)
+                yield dut.p.i_valid.eq(1)
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield Tick()
+                yield i_data.REDPOLY.eq(-1)
+                yield i_data.a.eq(-1)
+                yield dut.p.i_valid.eq(0)
+                for step in range(shape.max_step_count):
+                    yield Delay(0.1e-6)
+                    valid = yield dut.n.o_valid
+                    ready = yield dut.p.o_ready
+                    with self.subTest():
+                        self.assertFalse(ready)
+                    yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertTrue(valid)
+                    self.assertFalse(ready)
+                output = yield o_data.output
+                with self.subTest(output=hex(output)):
+                    self.assertEqual(output, expected)
+                yield dut.n.i_ready.eq(1)
+                yield Tick()
+                yield Delay(0.1e-6)
+                valid = yield dut.n.o_valid
+                ready = yield dut.p.o_ready
+                with self.subTest():
+                    self.assertFalse(valid)
+                    self.assertTrue(ready)
+                yield dut.n.i_ready.eq(0)
+
+        def process():
+            if full:
+                itr = itertools.product(range(1 << XLEN), repeat=2)
+                for REDPOLY, a in itr:
+                    with self.subTest(REDPOLY=hex(REDPOLY), a=hex(a)):
+                        yield from case(REDPOLY, a)
+            else:
+                for i in range(100):
+                    v = hash_256("gfbinv %i REDPOLY %i" % (XLEN, i))
+                    shift = hash_256("gfbinv %i REDPOLY shift %i" % (XLEN, i))
+                    v >>= shift % XLEN
+                    REDPOLY = Const.normalize(v, unsigned(XLEN))
+                    v = hash_256("gfbinv %i a %i" % (XLEN, i))
+                    a = Const.normalize(v, unsigned(XLEN))
+                    with self.subTest(REDPOLY=hex(REDPOLY), a=hex(a)):
+                        yield from case(REDPOLY, a)
+
+        with do_sim(self, dut, list(dut.ports())) as sim:
+            sim.add_process(process)
+            sim.add_clock(1e-6)
+            sim.run()
+
+    def test_4(self):
+        self.tst(XLEN=4, full=True)
+
+    def test_8(self):
+        self.tst(XLEN=8, full=False)
+
+    def test_32(self):
+        self.tst(XLEN=32, full=False)
+
+
+if __name__ == "__main__":
+    unittest.main()