From 892a3cd2f7d76a842c9f838d808af70ea87bf7f6 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 17 Nov 2021 10:58:22 -0800 Subject: [PATCH] add Array-based version of BitwiseLut, renaming old version to TreeBitwiseLut in case we need it --- src/nmutil/lut.py | 31 ++++++++++++++++++- src/nmutil/test/test_lut.py | 59 ++++++++++++++++++++++--------------- 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/src/nmutil/lut.py b/src/nmutil/lut.py index 5705d3d..35b6186 100644 --- a/src/nmutil/lut.py +++ b/src/nmutil/lut.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3-or-later # See Notices.txt for copyright information -from nmigen.hdl.ast import Repl, Signal +from nmigen.hdl.ast import Array, Cat, Repl, Signal from nmigen.hdl.dsl import Module from nmigen.hdl.ir import Elaboratable @@ -25,6 +25,35 @@ class BitwiseMux(Elaboratable): class BitwiseLut(Elaboratable): + def __init__(self, input_count, width): + assert isinstance(input_count, int) + assert isinstance(width, int) + self.input_count = input_count + self.width = width + + def inp(i): + return Signal(width, name=f"input{i}") + self.inputs = tuple(inp(i) for i in range(input_count)) + self.output = Signal(width) + self.lut = Signal(2 ** input_count) + + def lut_index(i): + return Signal(input_count, name=f"lut_index_{i}") + self._lut_indexes = [lut_index(i) for i in range(width)] + + def elaborate(self, platform): + m = Module() + lut = Array(self.lut[i] for i in range(self.lut.width)) + for i in range(self.width): + for j in range(self.input_count): + m.d.comb += self._lut_indexes[i][j].eq(self.inputs[j][i]) + m.d.comb += self.output[i].eq(lut[self._lut_indexes[i]]) + return m + + +class TreeBitwiseLut(Elaboratable): + """tree-based version of BitwiseLut""" + def __init__(self, input_count, width): assert isinstance(input_count, int) assert isinstance(width, int) diff --git a/src/nmutil/test/test_lut.py b/src/nmutil/test/test_lut.py index d36cb4e..14896da 100644 --- a/src/nmutil/test/test_lut.py +++ b/src/nmutil/test/test_lut.py @@ -12,7 +12,7 @@ from nmigen.hdl.ast import AnyConst, Assert, Signal from nmigen.hdl.dsl import Module from nmigen.hdl.ir import Fragment from nmutil.get_test_path import get_test_path -from nmutil.lut import BitwiseMux, BitwiseLut +from nmutil.lut import BitwiseMux, BitwiseLut, TreeBitwiseLut from nmigen.sim import Simulator, Delay @@ -121,28 +121,29 @@ class TestBitwiseMux(unittest.TestCase): class TestBitwiseLut(unittest.TestCase): - def test(self): - dut = BitwiseLut(3, 16) + def tst(self, cls): + dut = cls(3, 16) mask = 2 ** dut.width - 1 lut_mask = 2 ** dut.lut.width - 1 - mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()} - self.assertEqual(mux_inputs, { - (): 'mux_input_0bxxx', - (False,): 'mux_input_0bxx0', - (False, False): 'mux_input_0bx00', - (False, False, False): 'mux_input_0b000', - (False, False, True): 'mux_input_0b100', - (False, True): 'mux_input_0bx10', - (False, True, False): 'mux_input_0b010', - (False, True, True): 'mux_input_0b110', - (True,): 'mux_input_0bxx1', - (True, False): 'mux_input_0bx01', - (True, False, False): 'mux_input_0b001', - (True, False, True): 'mux_input_0b101', - (True, True): 'mux_input_0bx11', - (True, True, False): 'mux_input_0b011', - (True, True, True): 'mux_input_0b111' - }) + if cls is TreeBitwiseLut: + mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()} + self.assertEqual(mux_inputs, { + (): 'mux_input_0bxxx', + (False,): 'mux_input_0bxx0', + (False, False): 'mux_input_0bx00', + (False, False, False): 'mux_input_0b000', + (False, False, True): 'mux_input_0b100', + (False, True): 'mux_input_0bx10', + (False, True, False): 'mux_input_0b010', + (False, True, True): 'mux_input_0b110', + (True,): 'mux_input_0bxx1', + (True, False): 'mux_input_0bx01', + (True, False, False): 'mux_input_0b001', + (True, False, True): 'mux_input_0b101', + (True, True): 'mux_input_0bx11', + (True, True, False): 'mux_input_0b011', + (True, True, True): 'mux_input_0b111' + }) def case(in0, in1, in2, lut): expected = 0 @@ -179,8 +180,8 @@ class TestBitwiseLut(unittest.TestCase): sim.add_process(process) sim.run() - def test_formal(self): - dut = BitwiseLut(3, 16) + def tst_formal(self, cls): + dut = cls(3, 16) m = Module() m.submodules.dut = dut m.d.comb += dut.inputs[0].eq(AnyConst(dut.width)) @@ -196,6 +197,18 @@ class TestBitwiseLut(unittest.TestCase): m.d.comb += Assert(dut.lut[j] == dut.output[i]) formal(self, m) + def test(self): + self.tst(BitwiseLut) + + def test_tree(self): + self.tst(TreeBitwiseLut) + + def test_formal(self): + self.tst_formal(BitwiseLut) + + def test_tree_formal(self): + self.tst_formal(TreeBitwiseLut) + if __name__ == "__main__": unittest.main() -- 2.30.2