speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2021 Jacob Lifshay
3 # Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7
8 """Bitwise logic operators implemented using a look-up table, like LUTs in
9 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
10
11 https://bugs.libre-soc.org/show_bug.cgi?id=745
12 https://www.felixcloutier.com/x86/vpternlogd:vpternlogq
13 """
14
15 from nmigen.hdl.ast import Array, Cat, Repl, Signal
16 from nmigen.hdl.dsl import Module
17 from nmigen.hdl.ir import Elaboratable
18 from dataclasses import dataclass
19
20
21 class BitwiseMux(Elaboratable):
22 """Mux, but treating input/output Signals as bit vectors, rather than
23 integers. This means each bit in the output is independently multiplexed
24 based on the corresponding bit in each of the inputs.
25 """
26
27 def __init__(self, width):
28 self.sel = Signal(width)
29 self.t = Signal(width)
30 self.f = Signal(width)
31 self.output = Signal(width)
32
33 def elaborate(self, platform):
34 m = Module()
35 m.d.comb += self.output.eq((~self.sel & self.f) | (self.sel & self.t))
36 return m
37
38
39 class BitwiseLut(Elaboratable):
40 """Bitwise logic operators implemented using a look-up table, like LUTs in
41 FPGAs. Inspired by x86's `vpternlog[dq]` instructions.
42
43 Each output bit `i` is set to `lut[Cat(inp[i] for inp in self.inputs)]`
44 """
45
46 def __init__(self, input_count, width):
47 """
48 input_count: int
49 the number of inputs. ternlog-style instructions have 3 inputs.
50 width: int
51 the number of bits in each input/output.
52 """
53 self.input_count = input_count
54 self.width = width
55
56 def inp(i):
57 return Signal(width, name=f"input{i}")
58 self.inputs = tuple(inp(i) for i in range(input_count)) # inputs
59 self.lut = Signal(2 ** input_count) # lookup input
60 self.output = Signal(width) # output
61
62 def elaborate(self, platform):
63 m = Module()
64 comb = m.d.comb
65 lut_array = Array(self.lut) # create dynamic-indexable LUT array
66 out = []
67
68 for bit in range(self.width):
69 # take the bit'th bit of every input, create a LUT index from it
70 index = Signal(self.input_count, name="index%d" % bit)
71 comb += index.eq(Cat(inp[bit] for inp in self.inputs))
72 # store output bit in a list - Cat() it after (simplifies graphviz)
73 outbit = Signal(name="out%d" % bit)
74 comb += outbit.eq(lut_array[index])
75 out.append(outbit)
76
77 # finally Cat() all the output bits together
78 comb += self.output.eq(Cat(*out))
79 return m
80
81 def ports(self):
82 return list(self.inputs) + [self.lut, self.output]
83
84
85 @dataclass
86 class _TreeMuxNode:
87 """Mux in tree for `TreeBitwiseLut`."""
88 out: Signal
89 container: "TreeBitwiseLut"
90 parent: "_TreeMuxNode | None"
91 child0: "_TreeMuxNode | None"
92 child1: "_TreeMuxNode | None"
93 depth: int
94
95 @property
96 def child_index(self):
97 """index of this node, when looked up in this node's parent's children.
98 """
99 if self.parent is None:
100 return None
101 return int(self.parent.child1 is self)
102
103 def add_child(self, child_index):
104 node = _TreeMuxNode(
105 out=Signal(self.container.width),
106 container=self.container, parent=self,
107 child0=None, child1=None, depth=1 + self.depth)
108 if child_index:
109 assert self.child1 is None
110 self.child1 = node
111 else:
112 assert self.child0 is None
113 self.child0 = node
114 node.out.name = "node_out_" + node.key_str
115 return node
116
117 @property
118 def key(self):
119 retval = []
120 node = self
121 while node.parent is not None:
122 retval.append(node.child_index)
123 node = node.parent
124 retval.reverse()
125 return retval
126
127 @property
128 def key_str(self):
129 k = ['x'] * self.container.input_count
130 for i, v in enumerate(self.key):
131 k[i] = '1' if v else '0'
132 return '0b' + ''.join(reversed(k))
133
134
135 class TreeBitwiseLut(Elaboratable):
136 """Tree-based version of BitwiseLut. Has identical API, so see `BitwiseLut`
137 for API documentation. This version may produce more efficient hardware.
138 """
139
140 def __init__(self, input_count, width):
141 self.input_count = input_count
142 self.width = width
143
144 def inp(i):
145 return Signal(width, name=f"input{i}")
146 self.inputs = tuple(inp(i) for i in range(input_count))
147 self.output = Signal(width)
148 self.lut = Signal(2 ** input_count)
149 self._tree_root = _TreeMuxNode(
150 out=self.output, container=self, parent=None,
151 child0=None, child1=None, depth=0)
152 self._build_tree(self._tree_root)
153
154 def _build_tree(self, node):
155 if node.depth < self.input_count:
156 self._build_tree(node.add_child(0))
157 self._build_tree(node.add_child(1))
158
159 def _elaborate_tree(self, m, node):
160 if node.depth < self.input_count:
161 mux = BitwiseMux(self.width)
162 setattr(m.submodules, "mux_" + node.key_str, mux)
163 m.d.comb += [
164 mux.f.eq(node.child0.out),
165 mux.t.eq(node.child1.out),
166 mux.sel.eq(self.inputs[node.depth]),
167 node.out.eq(mux.output),
168 ]
169 self._elaborate_tree(m, node.child0)
170 self._elaborate_tree(m, node.child1)
171 else:
172 index = int(node.key_str, base=2)
173 m.d.comb += node.out.eq(Repl(self.lut[index], self.width))
174
175 def elaborate(self, platform):
176 m = Module()
177 self._elaborate_tree(m, self._tree_root)
178 return m
179
180 def ports(self):
181 return [*self.inputs, self.lut, self.output]
182
183
184 # useful to see what is going on:
185 # python3 src/nmutil/test/test_lut.py
186 # yosys <<<"read_ilang sim_test_out/__main__.TestBitwiseLut.test_tree/0.il; proc;;; show top"