From: Jacob Lifshay Date: Wed, 22 Dec 2021 04:02:12 +0000 (-0800) Subject: rewrite TreeBitwiseLut to actually use a tree rather than a dict, hopefully making... X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=efdf9de343cdd2e7f928e8b98ecf204d219cb247;p=nmutil.git rewrite TreeBitwiseLut to actually use a tree rather than a dict, hopefully making the code much easier to follow --- diff --git a/src/nmutil/lut.py b/src/nmutil/lut.py index 4eb790d..49c19a8 100644 --- a/src/nmutil/lut.py +++ b/src/nmutil/lut.py @@ -16,6 +16,7 @@ from nmigen.hdl.ast import Array, Cat, Repl, Signal from nmigen.hdl.dsl import Module from nmigen.hdl.ir import Elaboratable from nmigen.cli import rtlil +from dataclasses import dataclass class BitwiseMux(Elaboratable): @@ -82,11 +83,59 @@ class BitwiseLut(Elaboratable): return list(self.inputs) + [self.lut, self.output] +@dataclass +class _TreeMuxNode: + """Mux in tree for `TreeBitwiseLut`.""" + out: Signal + container: "TreeBitwiseLut" + parent: "_TreeMuxNode | None" + child0: "_TreeMuxNode | None" + child1: "_TreeMuxNode | None" + depth: int + + @property + def child_index(self): + """index of this node, when looked up in this node's parent's children. + """ + if self.parent is None: + return None + return int(self.parent.child1 is self) + + def add_child(self, child_index): + node = _TreeMuxNode( + out=Signal(self.container.width), + container=self.container, parent=self, + child0=None, child1=None, depth=1 + self.depth) + if child_index: + assert self.child1 is None + self.child1 = node + else: + assert self.child0 is None + self.child0 = node + node.out.name = "node_out_" + node.key_str + return node + + @property + def key(self): + retval = [] + node = self + while node.parent is not None: + retval.append(node.child_index) + node = node.parent + retval.reverse() + return retval + + @property + def key_str(self): + k = ['x'] * self.container.input_count + for i, v in enumerate(self.key): + k[i] = '1' if v else '0' + return '0b' + ''.join(reversed(k)) + + class TreeBitwiseLut(Elaboratable): - """Tree-based version of BitwiseLut. See BitwiseLut for API documentation. - (good enough reason to say "see bitwiselut", but mention that - the API is identical and explain why the second implementation - exists, despite it being identical) + """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut` + for API documentation. This version may produce more efficient hardware. """ def __init__(self, input_count, width): @@ -98,50 +147,39 @@ class TreeBitwiseLut(Elaboratable): self.inputs = tuple(inp(i) for i in range(input_count)) self.output = Signal(width) self.lut = Signal(2 ** input_count) - self._mux_inputs = {} - self._build_mux_inputs() - - def _make_key_str(self, *sel_values): - k = ['x'] * self.input_count - for i, v in enumerate(sel_values): - k[i] = '1' if v else '0' - return '0b' + ''.join(reversed(k)) - - def _build_mux_inputs(self, *sel_values): - # XXX yyyeah using PHP-style functions-in-text... blech :) - # XXX replace with name = mux_input_%s" % self._make_etcetc - name = f"mux_input_{self._make_key_str(*sel_values)}" - self._mux_inputs[sel_values] = Signal(self.width, name=name) - if len(sel_values) < self.input_count: - self._build_mux_inputs(*sel_values, False) - self._build_mux_inputs(*sel_values, True) + self._tree_root = _TreeMuxNode( + out=self.output, container=self, parent=None, + child0=None, child1=None, depth=0) + self._build_tree(self._tree_root) + + def _build_tree(self, node): + if node.depth < self.input_count: + self._build_tree(node.add_child(0)) + self._build_tree(node.add_child(1)) + + def _elaborate_tree(self, m, node): + if node.depth < self.input_count: + mux = BitwiseMux(self.width) + setattr(m.submodules, "mux_" + node.key_str, mux) + m.d.comb += [ + mux.f.eq(node.child0.out), + mux.t.eq(node.child1.out), + mux.sel.eq(self.inputs[node.depth]), + node.out.eq(mux.output), + ] + self._elaborate_tree(m, node.child0) + self._elaborate_tree(m, node.child1) + else: + index = int(node.key_str, base=2) + m.d.comb += node.out.eq(Repl(self.lut[index], self.width)) def elaborate(self, platform): m = Module() - m.d.comb += self.output.eq(self._mux_inputs[()]) - for sel_values, v in self._mux_inputs.items(): - if len(sel_values) < self.input_count: - # XXX yyyeah using PHP-style functions-in-text... blech :) - # XXX replace with name = mux_input_%s" % self._make_etcetc - mux_name = f"mux_{self._make_key_str(*sel_values)}" - mux = BitwiseMux(self.width) - setattr(m.submodules, mux_name, mux) - m.d.comb += [ - mux.f.eq(self._mux_inputs[(*sel_values, False)]), - mux.t.eq(self._mux_inputs[(*sel_values, True)]), - mux.sel.eq(self.inputs[len(sel_values)]), - v.eq(mux.output), - ] - else: - lut_index = 0 - for i in range(self.input_count): - if sel_values[i]: - lut_index |= 2 ** i - m.d.comb += v.eq(Repl(self.lut[lut_index], self.width)) + self._elaborate_tree(m, self._tree_root) return m def ports(self): - return [self.input, self.chunk_sizes, self.output] + return [*self.inputs, self.lut, self.output] # useful to see what is going on: diff --git a/src/nmutil/test/test_lut.py b/src/nmutil/test/test_lut.py index cdfaa2d..e0a9809 100644 --- a/src/nmutil/test/test_lut.py +++ b/src/nmutil/test/test_lut.py @@ -65,25 +65,6 @@ class TestBitwiseLut(FHDLTestCase): dut = cls(3, 16) mask = 2 ** dut.width - 1 lut_mask = 2 ** dut.lut.width - 1 - if cls is TreeBitwiseLut: - mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()} - self.assertEqual(mux_inputs, { - (): 'mux_input_0bxxx', - (False,): 'mux_input_0bxx0', - (False, False): 'mux_input_0bx00', - (False, False, False): 'mux_input_0b000', - (False, False, True): 'mux_input_0b100', - (False, True): 'mux_input_0bx10', - (False, True, False): 'mux_input_0b010', - (False, True, True): 'mux_input_0b110', - (True,): 'mux_input_0bxx1', - (True, False): 'mux_input_0bx01', - (True, False, False): 'mux_input_0b001', - (True, False, True): 'mux_input_0b101', - (True, True): 'mux_input_0bx11', - (True, True, False): 'mux_input_0b011', - (True, True, True): 'mux_input_0b111' - }) def case(in0, in1, in2, lut): expected = 0 @@ -109,6 +90,10 @@ class TestBitwiseLut(FHDLTestCase): self.assertEqual(expected, output) def process(): + for shift in range(dut.lut.width): + with self.subTest(shift=shift): + yield from case(in0=0xAAAA, in1=0xCCCC, in2=0xF0F0, + lut=1 << shift) for case_index in range(100): with self.subTest(case_index=case_index): in0 = hash_256(f"{case_index} in0") & mask