add BitwiseLut and tests
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from nmigen.hdl.ast import Repl, Signal
5 from nmigen.hdl.dsl import Module
6 from nmigen.hdl.ir import Elaboratable
7
8
9 class BitwiseMux(Elaboratable):
10 """ Mux, but treating input/output Signals as bit vectors, rather than
11 integers. This means each bit in the output is independently multiplexed
12 based on the corresponding bit in each of the inputs.
13 """
14
15 def __init__(self, width):
16 self.sel = Signal(width)
17 self.t = Signal(width)
18 self.f = Signal(width)
19 self.output = Signal(width)
20
21 def elaborate(self, platform):
22 m = Module()
23 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
24 return m
25
26
27 class BitwiseLut(Elaboratable):
28 def __init__(self, input_count, width):
29 assert isinstance(input_count, int)
30 assert isinstance(width, int)
31 self.input_count = input_count
32 self.width = width
33
34 def inp(i):
35 return Signal(width, name=f"input{i}")
36 self.inputs = tuple(inp(i) for i in range(input_count))
37 self.output = Signal(width)
38 self.lut = Signal(2 ** input_count)
39 self._mux_inputs = {}
40 self._build_mux_inputs()
41
42 def _make_key_str(self, *sel_values):
43 k = ['x'] * self.input_count
44 for i, v in enumerate(sel_values):
45 k[i] = '1' if v else '0'
46 return '0b' + ''.join(reversed(k))
47
48 def _build_mux_inputs(self, *sel_values):
49 name = f"mux_input_{self._make_key_str(*sel_values)}"
50 self._mux_inputs[sel_values] = Signal(self.width, name=name)
51 if len(sel_values) < self.input_count:
52 self._build_mux_inputs(*sel_values, False)
53 self._build_mux_inputs(*sel_values, True)
54
55 def elaborate(self, platform):
56 m = Module()
57 m.d.comb += self.output.eq(self._mux_inputs[()])
58 for sel_values, v in self._mux_inputs.items():
59 if len(sel_values) < self.input_count:
60 mux_name = f"mux_{self._make_key_str(*sel_values)}"
61 mux = BitwiseMux(self.width)
62 setattr(m.submodules, mux_name, mux)
63 m.d.comb += [
64 mux.f.eq(self._mux_inputs[(*sel_values, False)]),
65 mux.t.eq(self._mux_inputs[(*sel_values, True)]),
66 mux.sel.eq(self.inputs[len(sel_values)]),
67 v.eq(mux.output),
68 ]
69 else:
70 lut_index = 0
71 for i in range(self.input_count):
72 if sel_values[i]:
73 lut_index |= 2 ** i
74 m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
75 return m