from nmigen.hdl.dsl import Module
from nmigen.hdl.ir import Elaboratable
from nmigen.cli import rtlil
+from dataclasses import dataclass
class BitwiseMux(Elaboratable):
return list(self.inputs) + [self.lut, self.output]
+@dataclass
+class _TreeMuxNode:
+ """Mux in tree for `TreeBitwiseLut`."""
+ out: Signal
+ container: "TreeBitwiseLut"
+ parent: "_TreeMuxNode | None"
+ child0: "_TreeMuxNode | None"
+ child1: "_TreeMuxNode | None"
+ depth: int
+
+ @property
+ def child_index(self):
+ """index of this node, when looked up in this node's parent's children.
+ """
+ if self.parent is None:
+ return None
+ return int(self.parent.child1 is self)
+
+ def add_child(self, child_index):
+ node = _TreeMuxNode(
+ out=Signal(self.container.width),
+ container=self.container, parent=self,
+ child0=None, child1=None, depth=1 + self.depth)
+ if child_index:
+ assert self.child1 is None
+ self.child1 = node
+ else:
+ assert self.child0 is None
+ self.child0 = node
+ node.out.name = "node_out_" + node.key_str
+ return node
+
+ @property
+ def key(self):
+ retval = []
+ node = self
+ while node.parent is not None:
+ retval.append(node.child_index)
+ node = node.parent
+ retval.reverse()
+ return retval
+
+ @property
+ def key_str(self):
+ k = ['x'] * self.container.input_count
+ for i, v in enumerate(self.key):
+ k[i] = '1' if v else '0'
+ return '0b' + ''.join(reversed(k))
+
+
class TreeBitwiseLut(Elaboratable):
- """Tree-based version of BitwiseLut. See BitwiseLut for API documentation.
- (good enough reason to say "see bitwiselut", but mention that
- the API is identical and explain why the second implementation
- exists, despite it being identical)
+ """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
+ for API documentation. This version may produce more efficient hardware.
"""
def __init__(self, input_count, width):
self.inputs = tuple(inp(i) for i in range(input_count))
self.output = Signal(width)
self.lut = Signal(2 ** input_count)
- self._mux_inputs = {}
- self._build_mux_inputs()
-
- def _make_key_str(self, *sel_values):
- k = ['x'] * self.input_count
- for i, v in enumerate(sel_values):
- k[i] = '1' if v else '0'
- return '0b' + ''.join(reversed(k))
-
- def _build_mux_inputs(self, *sel_values):
- # XXX yyyeah using PHP-style functions-in-text... blech :)
- # XXX replace with name = mux_input_%s" % self._make_etcetc
- name = f"mux_input_{self._make_key_str(*sel_values)}"
- self._mux_inputs[sel_values] = Signal(self.width, name=name)
- if len(sel_values) < self.input_count:
- self._build_mux_inputs(*sel_values, False)
- self._build_mux_inputs(*sel_values, True)
+ self._tree_root = _TreeMuxNode(
+ out=self.output, container=self, parent=None,
+ child0=None, child1=None, depth=0)
+ self._build_tree(self._tree_root)
+
+ def _build_tree(self, node):
+ if node.depth < self.input_count:
+ self._build_tree(node.add_child(0))
+ self._build_tree(node.add_child(1))
+
+ def _elaborate_tree(self, m, node):
+ if node.depth < self.input_count:
+ mux = BitwiseMux(self.width)
+ setattr(m.submodules, "mux_" + node.key_str, mux)
+ m.d.comb += [
+ mux.f.eq(node.child0.out),
+ mux.t.eq(node.child1.out),
+ mux.sel.eq(self.inputs[node.depth]),
+ node.out.eq(mux.output),
+ ]
+ self._elaborate_tree(m, node.child0)
+ self._elaborate_tree(m, node.child1)
+ else:
+ index = int(node.key_str, base=2)
+ m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
def elaborate(self, platform):
m = Module()
- m.d.comb += self.output.eq(self._mux_inputs[()])
- for sel_values, v in self._mux_inputs.items():
- if len(sel_values) < self.input_count:
- # XXX yyyeah using PHP-style functions-in-text... blech :)
- # XXX replace with name = mux_input_%s" % self._make_etcetc
- mux_name = f"mux_{self._make_key_str(*sel_values)}"
- mux = BitwiseMux(self.width)
- setattr(m.submodules, mux_name, mux)
- m.d.comb += [
- mux.f.eq(self._mux_inputs[(*sel_values, False)]),
- mux.t.eq(self._mux_inputs[(*sel_values, True)]),
- mux.sel.eq(self.inputs[len(sel_values)]),
- v.eq(mux.output),
- ]
- else:
- lut_index = 0
- for i in range(self.input_count):
- if sel_values[i]:
- lut_index |= 2 ** i
- m.d.comb += v.eq(Repl(self.lut[lut_index], self.width))
+ self._elaborate_tree(m, self._tree_root)
return m
def ports(self):
- return [self.input, self.chunk_sizes, self.output]
+ return [*self.inputs, self.lut, self.output]
# useful to see what is going on:
dut = cls(3, 16)
mask = 2 ** dut.width - 1
lut_mask = 2 ** dut.lut.width - 1
- if cls is TreeBitwiseLut:
- mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
- self.assertEqual(mux_inputs, {
- (): 'mux_input_0bxxx',
- (False,): 'mux_input_0bxx0',
- (False, False): 'mux_input_0bx00',
- (False, False, False): 'mux_input_0b000',
- (False, False, True): 'mux_input_0b100',
- (False, True): 'mux_input_0bx10',
- (False, True, False): 'mux_input_0b010',
- (False, True, True): 'mux_input_0b110',
- (True,): 'mux_input_0bxx1',
- (True, False): 'mux_input_0bx01',
- (True, False, False): 'mux_input_0b001',
- (True, False, True): 'mux_input_0b101',
- (True, True): 'mux_input_0bx11',
- (True, True, False): 'mux_input_0b011',
- (True, True, True): 'mux_input_0b111'
- })
def case(in0, in1, in2, lut):
expected = 0
self.assertEqual(expected, output)
def process():
+ for shift in range(dut.lut.width):
+ with self.subTest(shift=shift):
+ yield from case(in0=0xAAAA, in1=0xCCCC, in2=0xF0F0,
+ lut=1 << shift)
for case_index in range(100):
with self.subTest(case_index=case_index):
in0 = hash_256(f"{case_index} in0") & mask