add BitwiseLut and tests
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 17 Nov 2021 03:08:35 +0000 (19:08 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 17 Nov 2021 03:08:35 +0000 (19:08 -0800)
src/nmutil/lut.py [new file with mode: 0644]
src/nmutil/test/test_lut.py [new file with mode: 0644]

diff --git a/src/nmutil/lut.py b/src/nmutil/lut.py
new file mode 100644 (file)
index 0000000..5705d3d
--- /dev/null
@@ -0,0 +1,75 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# See Notices.txt for copyright information
+
+from nmigen.hdl.ast import Repl, Signal
+from nmigen.hdl.dsl import Module
+from nmigen.hdl.ir import Elaboratable
+
+
+class BitwiseMux(Elaboratable):
+    """ Mux, but treating input/output Signals as bit vectors, rather than
+    integers. This means each bit in the output is independently multiplexed
+    based on the corresponding bit in each of the inputs.
+    """
+
+    def __init__(self, width):
+        self.sel = Signal(width)
+        self.t = Signal(width)
+        self.f = Signal(width)
+        self.output = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
+        return m
+
+
+class BitwiseLut(Elaboratable):
+    def __init__(self, input_count, width):
+        assert isinstance(input_count, int)
+        assert isinstance(width, int)
+        self.input_count = input_count
+        self.width = width
+
+        def inp(i):
+            return Signal(width, name=f"input{i}")
+        self.inputs = tuple(inp(i) for i in range(input_count))
+        self.output = Signal(width)
+        self.lut = Signal(2 ** input_count)
+        self._mux_inputs = {}
+        self._build_mux_inputs()
+
+    def _make_key_str(self, *sel_values):
+        k = ['x'] * self.input_count
+        for i, v in enumerate(sel_values):
+            k[i] = '1' if v else '0'
+        return '0b' + ''.join(reversed(k))
+
+    def _build_mux_inputs(self, *sel_values):
+        name = f"mux_input_{self._make_key_str(*sel_values)}"
+        self._mux_inputs[sel_values] = Signal(self.width, name=name)
+        if len(sel_values) < self.input_count:
+            self._build_mux_inputs(*sel_values, False)
+            self._build_mux_inputs(*sel_values, True)
+
+    def elaborate(self, platform):
+        m = Module()
+        m.d.comb += self.output.eq(self._mux_inputs[()])
+        for sel_values, v in self._mux_inputs.items():
+            if len(sel_values) < self.input_count:
+                mux_name = f"mux_{self._make_key_str(*sel_values)}"
+                mux = BitwiseMux(self.width)
+                setattr(m.submodules, mux_name, mux)
+                m.d.comb += [
+                    mux.f.eq(self._mux_inputs[(*sel_values, False)]),
+                    mux.t.eq(self._mux_inputs[(*sel_values, True)]),
+                    mux.sel.eq(self.inputs[len(sel_values)]),
+                    v.eq(mux.output),
+                ]
+            else:
+                lut_index = 0
+                for i in range(self.input_count):
+                    if sel_values[i]:
+                        lut_index |= 2 ** i
+                m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
+        return m
diff --git a/src/nmutil/test/test_lut.py b/src/nmutil/test/test_lut.py
new file mode 100644 (file)
index 0000000..2863e83
--- /dev/null
@@ -0,0 +1,126 @@
+# SPDX-License-Identifier: LGPL-3-or-later
+# See Notices.txt for copyright information
+
+from contextlib import contextmanager
+import unittest
+
+from hashlib import sha256
+from nmutil.get_test_path import get_test_path
+from nmutil.lut import BitwiseMux, BitwiseLut
+from nmigen.sim import Simulator, Delay
+
+
+@contextmanager
+def do_sim(test_case, dut, traces=()):
+    sim = Simulator(dut)
+    path = get_test_path(test_case, "sim_test_out")
+    path.parent.mkdir(parents=True, exist_ok=True)
+    vcd_path = path.with_suffix(".vcd")
+    gtkw_path = path.with_suffix(".gtkw")
+    with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
+                       gtkw_path.open("wt", encoding="utf-8"),
+                       traces=traces):
+        yield sim
+
+
+def hash_256(v):
+    return int.from_bytes(
+        sha256(bytes(v, encoding='utf-8')).digest(),
+        byteorder='little'
+    )
+
+
+class TestBitwiseMux(unittest.TestCase):
+    def test(self):
+        width = 2
+        dut = BitwiseMux(width)
+
+        def case(sel, t, f, expected):
+            with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
+                yield dut.sel.eq(sel)
+                yield dut.t.eq(t)
+                yield dut.f.eq(f)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=bin(output), expected=bin(expected)):
+                    self.assertEqual(expected, output)
+
+        def process():
+            for sel in range(2 ** width):
+                for t in range(2 ** width):
+                    for f in range(2**width):
+                        expected = 0
+                        for i in range(width):
+                            if sel & 2 ** i:
+                                if t & 2 ** i:
+                                    expected |= 2 ** i
+                            elif f & 2 ** i:
+                                expected |= 2 ** i
+                        yield from case(sel, t, f, expected)
+        with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+
+class TestBitwiseLut(unittest.TestCase):
+    def test(self):
+        dut = BitwiseLut(3, 16)
+        mask = 2 ** dut.width - 1
+        lut_mask = 2 ** dut.lut.width - 1
+        mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
+        self.assertEqual(mux_inputs, {
+            (): 'mux_input_0bxxx',
+            (False,): 'mux_input_0bxx0',
+            (False, False): 'mux_input_0bx00',
+            (False, False, False): 'mux_input_0b000',
+            (False, False, True): 'mux_input_0b100',
+            (False, True): 'mux_input_0bx10',
+            (False, True, False): 'mux_input_0b010',
+            (False, True, True): 'mux_input_0b110',
+            (True,): 'mux_input_0bxx1',
+            (True, False): 'mux_input_0bx01',
+            (True, False, False): 'mux_input_0b001',
+            (True, False, True): 'mux_input_0b101',
+            (True, True): 'mux_input_0bx11',
+            (True, True, False): 'mux_input_0b011',
+            (True, True, True): 'mux_input_0b111'
+        })
+
+        def case(in0, in1, in2, lut):
+            expected = 0
+            for i in range(dut.width):
+                lut_index = 0
+                if in0 & 2 ** i:
+                    lut_index |= 2 ** 0
+                if in1 & 2 ** i:
+                    lut_index |= 2 ** 1
+                if in2 & 2 ** i:
+                    lut_index |= 2 ** 2
+                if lut & 2 ** lut_index:
+                    expected |= 2 ** i
+            with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
+                              lut=bin(lut)):
+                yield dut.inputs[0].eq(in0)
+                yield dut.inputs[1].eq(in1)
+                yield dut.inputs[2].eq(in2)
+                yield dut.lut.eq(lut)
+                yield Delay(1e-6)
+                output = yield dut.output
+                with self.subTest(output=bin(output), expected=bin(expected)):
+                    self.assertEqual(expected, output)
+
+        def process():
+            for case_index in range(100):
+                with self.subTest(case_index=case_index):
+                    in0 = hash_256(f"{case_index} in0") & mask
+                    in1 = hash_256(f"{case_index} in1") & mask
+                    in2 = hash_256(f"{case_index} in2") & mask
+                    lut = hash_256(f"{case_index} lut") & lut_mask
+                    yield from case(in0, in1, in2, lut)
+        with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
+            sim.add_process(process)
+            sim.run()
+
+
+if __name__ == "__main__":
+    unittest.main()