add Array-based version of BitwiseLut, renaming old version to TreeBitwiseLut in...
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 17 Nov 2021 18:58:22 +0000 (10:58 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 17 Nov 2021 18:58:22 +0000 (10:58 -0800)
src/nmutil/lut.py
src/nmutil/test/test_lut.py

index 5705d3d96e21c050b229e65b9b88d249dbdfe356..35b61868aa061d15dad4efbe82c020378fbed3d3 100644 (file)
@@ -1,7 +1,7 @@
 # SPDX-License-Identifier: LGPL-3-or-later
 # See Notices.txt for copyright information
 
-from nmigen.hdl.ast import Repl, Signal
+from nmigen.hdl.ast import Array, Cat, Repl, Signal
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Elaboratable
 
@@ -25,6 +25,35 @@ class BitwiseMux(Elaboratable):
 
 
 class BitwiseLut(Elaboratable):
+    def __init__(self, input_count, width):
+        assert isinstance(input_count, int)
+        assert isinstance(width, int)
+        self.input_count = input_count
+        self.width = width
+
+        def inp(i):
+            return Signal(width, name=f"input{i}")
+        self.inputs = tuple(inp(i) for i in range(input_count))
+        self.output = Signal(width)
+        self.lut = Signal(2 ** input_count)
+
+        def lut_index(i):
+            return Signal(input_count, name=f"lut_index_{i}")
+        self._lut_indexes = [lut_index(i) for i in range(width)]
+
+    def elaborate(self, platform):
+        m = Module()
+        lut = Array(self.lut[i] for i in range(self.lut.width))
+        for i in range(self.width):
+            for j in range(self.input_count):
+                m.d.comb += self._lut_indexes[i][j].eq(self.inputs[j][i])
+            m.d.comb += self.output[i].eq(lut[self._lut_indexes[i]])
+        return m
+
+
+class TreeBitwiseLut(Elaboratable):
+    """tree-based version of BitwiseLut"""
+
     def __init__(self, input_count, width):
         assert isinstance(input_count, int)
         assert isinstance(width, int)
index d36cb4e9af40b980ddf9dffb6684ed92668fd057..14896da28d3eff637048c746d1e7a2aa5865f133 100644 (file)
@@ -12,7 +12,7 @@ from nmigen.hdl.ast import AnyConst, Assert, Signal
 from nmigen.hdl.dsl import Module
 from nmigen.hdl.ir import Fragment
 from nmutil.get_test_path import get_test_path
-from nmutil.lut import BitwiseMux, BitwiseLut
+from nmutil.lut import BitwiseMux, BitwiseLut, TreeBitwiseLut
 from nmigen.sim import Simulator, Delay
 
 
@@ -121,28 +121,29 @@ class TestBitwiseMux(unittest.TestCase):
 
 
 class TestBitwiseLut(unittest.TestCase):
-    def test(self):
-        dut = BitwiseLut(3, 16)
+    def tst(self, cls):
+        dut = cls(3, 16)
         mask = 2 ** dut.width - 1
         lut_mask = 2 ** dut.lut.width - 1
-        mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
-        self.assertEqual(mux_inputs, {
-            (): 'mux_input_0bxxx',
-            (False,): 'mux_input_0bxx0',
-            (False, False): 'mux_input_0bx00',
-            (False, False, False): 'mux_input_0b000',
-            (False, False, True): 'mux_input_0b100',
-            (False, True): 'mux_input_0bx10',
-            (False, True, False): 'mux_input_0b010',
-            (False, True, True): 'mux_input_0b110',
-            (True,): 'mux_input_0bxx1',
-            (True, False): 'mux_input_0bx01',
-            (True, False, False): 'mux_input_0b001',
-            (True, False, True): 'mux_input_0b101',
-            (True, True): 'mux_input_0bx11',
-            (True, True, False): 'mux_input_0b011',
-            (True, True, True): 'mux_input_0b111'
-        })
+        if cls is TreeBitwiseLut:
+            mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
+            self.assertEqual(mux_inputs, {
+                (): 'mux_input_0bxxx',
+                (False,): 'mux_input_0bxx0',
+                (False, False): 'mux_input_0bx00',
+                (False, False, False): 'mux_input_0b000',
+                (False, False, True): 'mux_input_0b100',
+                (False, True): 'mux_input_0bx10',
+                (False, True, False): 'mux_input_0b010',
+                (False, True, True): 'mux_input_0b110',
+                (True,): 'mux_input_0bxx1',
+                (True, False): 'mux_input_0bx01',
+                (True, False, False): 'mux_input_0b001',
+                (True, False, True): 'mux_input_0b101',
+                (True, True): 'mux_input_0bx11',
+                (True, True, False): 'mux_input_0b011',
+                (True, True, True): 'mux_input_0b111'
+            })
 
         def case(in0, in1, in2, lut):
             expected = 0
@@ -179,8 +180,8 @@ class TestBitwiseLut(unittest.TestCase):
             sim.add_process(process)
             sim.run()
 
-    def test_formal(self):
-        dut = BitwiseLut(3, 16)
+    def tst_formal(self, cls):
+        dut = cls(3, 16)
         m = Module()
         m.submodules.dut = dut
         m.d.comb += dut.inputs[0].eq(AnyConst(dut.width))
@@ -196,6 +197,18 @@ class TestBitwiseLut(unittest.TestCase):
                     m.d.comb += Assert(dut.lut[j] == dut.output[i])
         formal(self, m)
 
+    def test(self):
+        self.tst(BitwiseLut)
+
+    def test_tree(self):
+        self.tst(TreeBitwiseLut)
+
+    def test_formal(self):
+        self.tst_formal(BitwiseLut)
+
+    def test_tree_formal(self):
+        self.tst_formal(TreeBitwiseLut)
+
 
 if __name__ == "__main__":
     unittest.main()