working on adding CLDivRem
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 5 Apr 2022 03:52:21 +0000 (20:52 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 5 Apr 2022 03:52:21 +0000 (20:52 -0700)
src/nmigen_gf/hdl/cldivrem.py [new file with mode: 0644]
src/nmigen_gf/hdl/test/test_cldivrem.py [new file with mode: 0644]

diff --git a/src/nmigen_gf/hdl/cldivrem.py b/src/nmigen_gf/hdl/cldivrem.py
new file mode 100644 (file)
index 0000000..9a89c43
--- /dev/null
@@ -0,0 +1,99 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+""" Carry-less Division and Remainder.
+
+https://bugs.libre-soc.org/show_bug.cgi?id=784
+"""
+
+from nmigen.hdl.ir import Elaboratable
+from nmigen.hdl.ast import Signal, Cat
+from nmigen.hdl.dsl import Module
+
+
+def equal_leading_zero_count_reference(a, b, width):
+    """checks if `clz(a) == clz(b)`.
+    Reference code for algorithm used in `EqualLeadingZeroCount`.
+    """
+    assert isinstance(width, int) and 0 <= width
+    assert isinstance(a, int) and 0 <= a < (1 << width)
+    assert isinstance(b, int) and 0 <= b < (1 << width)
+    eq = True  # both have no leading zeros so far...
+    for i in range(width):
+        if (a >> i) & 1:
+            if (b >> i) & 1:
+                eq = True  # both have no leading zeros so far...
+            else:
+                eq = False  # different number of leading zeros
+        else:
+            if (b >> i) & 1:
+                eq = False  # different number of leading zeros
+            else:
+                pass  # propagate results from lower bits
+    return eq
+
+
+class EqualLeadingZeroCount(Elaboratable):
+    """checks if `clz(a) == clz(b)`.
+
+    Properties:
+    width: int
+        the width in bits of `a` and `b`.
+    a: Signal of width `width`
+        input
+    b: Signal of width `width`
+        input
+    out: Signal of width `1`
+        output, set if the number of leading zeros in `a` is the same as in
+        `b`.
+    """
+
+    def __init__(self, width):
+        assert isinstance(width, int)
+        self.width = width
+        self.a = Signal(width)
+        self.b = Signal(width)
+        self.out = Signal()
+
+    def elaborate(self, platform):
+        # the operation is converted into calculation of the carry-out of a
+        # binary addition, allowing FPGAs to re-use their specialized
+        # carry-propagation logic. This should be simplified by yosys to
+        # remove the extraneous xor gates from addition when targeting
+        # FPGAs/ASICs, so no efficiency is lost.
+        #
+        # see `equal_leading_zero_count_reference` for a Python version of
+        # the algorithm, but without conversion to carry-propagation.
+        m = Module()
+        addend1 = Signal(self.width)
+        addend2 = Signal(self.width)
+        for i in range(self.width):
+            with m.Switch(Cat(self.a[i], self.b[i])):
+                with m.Case('11'):
+                    # both have no leading zeros so far, so set carry
+                    m.d.comb += [
+                        addend1[i].eq(1),
+                        addend2[i].eq(1),
+                    ]
+                with m.Case('01', '10'):
+                    # different number of leading zeros, so clear carry
+                    m.d.comb += [
+                        addend1[i].eq(0),
+                        addend2[i].eq(0),
+                    ]
+                with m.Case('00'):
+                    # propagate results from lower bits
+                    m.d.comb += [
+                        addend1[i].eq(1),
+                        addend2[i].eq(0),
+                    ]
+        sum = Signal(self.width + 1)
+        carry_in = 1  # both have no leading zeros so far, so set carry
+        m.d.comb += sum.eq(addend1 + addend2 + carry_in)
+        m.d.comb += self.out.eq(sum[self.width])  # out is carry-out
+        return m
+
+# TODO: add CLDivRem
diff --git a/src/nmigen_gf/hdl/test/test_cldivrem.py b/src/nmigen_gf/hdl/test/test_cldivrem.py
new file mode 100644 (file)
index 0000000..4c90e41
--- /dev/null
@@ -0,0 +1,104 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# Copyright 2022 Jacob Lifshay
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
+
+import unittest
+from nmigen.hdl.ast import (AnyConst, Assert, Signal, Const, unsigned, Cat)
+from nmigen.hdl.dsl import Module
+from nmutil.formaltest import FHDLTestCase
+from nmigen_gf.hdl.cldivrem import (equal_leading_zero_count_reference,
+                                    EqualLeadingZeroCount)
+from nmigen.sim import Delay
+from nmutil.sim_util import do_sim, hash_256
+
+
+class TestEqualLeadingZeroCount(FHDLTestCase):
+    def tst(self, width, full):
+        dut = EqualLeadingZeroCount(width)
+        self.assertEqual(dut.a.shape(), unsigned(width))
+        self.assertEqual(dut.b.shape(), unsigned(width))
+        self.assertEqual(dut.out.shape(), unsigned(1))
+
+        def case(a, b):
+            assert isinstance(a, int)
+            assert isinstance(b, int)
+            expected = a.bit_length() == b.bit_length()
+            with self.subTest(a=hex(a), b=hex(b),
+                              expected=expected):
+                reference = equal_leading_zero_count_reference(a, b, width)
+                with self.subTest(reference=reference):
+                    self.assertEqual(expected, reference)
+
+            with self.subTest(a=hex(a), b=hex(b),
+                              expected=expected):
+                yield dut.a.eq(a)
+                yield dut.b.eq(b)
+                yield Delay(1e-6)
+                out = yield dut.out
+                with self.subTest(out=out):
+                    self.assertEqual(expected, out)
+
+        def process():
+            if full:
+                for a in range(1 << width):
+                    for b in range(1 << width):
+                        yield from case(a, b)
+            else:
+                for i in range(100):
+                    a = hash_256(f"eqlzc input a {i}")
+                    a = Const.normalize(a, dut.a.shape())
+                    b = hash_256(f"eqlzc input b {i}")
+                    b = Const.normalize(b, dut.b.shape())
+                    yield from case(a, b)
+
+        with do_sim(self, dut, [dut.a, dut.b, dut.out]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+    def tst_formal(self, width):
+        dut = EqualLeadingZeroCount(width)
+        m = Module()
+        m.submodules.dut = dut
+        m.d.comb += dut.a.eq(AnyConst(width))
+        m.d.comb += dut.b.eq(AnyConst(width))
+        expected = Signal()
+        with m.Switch(Cat(dut.a, dut.b)):
+            with m.Case('0' * (2 * width)):
+                # `width` leading zeros
+                m.d.comb += expected.eq(1)
+            for i in range(width):
+                # `i` leading zeros
+                pattern = '0' * i + '1' + '-' * (width - i - 1)
+                with m.Case(pattern * 2):
+                    m.d.comb += expected.eq(1)
+            with m.Default():
+                m.d.comb += expected.eq(0)
+        m.d.comb += Assert(dut.out == expected)
+        self.assertFormal(m)
+
+    def test_64(self):
+        self.tst(64, full=False)
+
+    def test_8(self):
+        self.tst(8, full=False)
+
+    def test_3(self):
+        self.tst(3, full=True)
+
+    def test_formal_16(self):
+        # yosys crashes with 32 or 64
+        self.tst_formal(16)
+
+    def test_formal_8(self):
+        self.tst_formal(8)
+
+    def test_formal_3(self):
+        self.tst_formal(3)
+
+# TODO: add TestCLDivRem
+
+
+if __name__ == "__main__":
+    unittest.main()