Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / regfile / regfile.py
1 """Specialist Regfiles
2
3 These are not "normal" binary-indexed regfiles (although that is included).
4 They include *unary* indexed regfiles as well as Dependency-tracked ones
5 (SPR files with 1024 registers, only around 4-5 of which need to be active)
6 and special "split" regfiles that have 8R8W for 8 4-bit quantities and a
7 1R1W to read/write *all* 8 4-bit registers in a single one-off 32-bit way.
8
9 Due to the way that the Dependency Matrices are set up (bit-vectors), the
10 primary focus here is on *unary* indexing.
11
12 Links:
13
14 * https://libre-soc.org/3d_gpu/architecture/regfile
15 * https://bugs.libre-soc.org/show_bug.cgi?id=345
16 * https://bugs.libre-soc.org/show_bug.cgi?id=351
17 * https://bugs.libre-soc.org/show_bug.cgi?id=352
18 """
19
20 from nmigen.compat.sim import run_simulation
21 from nmigen.back.pysim import Settle
22 from nmigen.cli import verilog, rtlil
23
24 from nmigen import Cat, Const, Array, Signal, Elaboratable, Module
25 from nmutil.iocontrol import RecordObject
26 from nmutil.util import treereduce
27 from nmigen.utils import log2_int
28 from nmigen import Memory
29
30 from math import log
31 import operator
32
33
34 class Register(Elaboratable):
35 def __init__(self, width, writethru=True, synced=True, resetval=0):
36 self.width = width
37 self.reset = resetval
38 self.writethru = writethru
39 self.synced = synced
40 self._rdports = []
41 self._wrports = []
42
43 def read_port(self, name=None):
44 port = RecordObject([("ren", 1),
45 ("o_data", self.width)],
46 name=name)
47 self._rdports.append(port)
48 return port
49
50 def write_port(self, name=None):
51 port = RecordObject([("wen", 1),
52 ("i_data", self.width)],
53 name=name)
54 self._wrports.append(port)
55 return port
56
57 def elaborate(self, platform):
58 m = Module()
59 self.reg = reg = Signal(self.width, name="reg", reset=self.reset)
60
61 if self.synced:
62 domain = m.d.sync
63 else:
64 domain = m.d.comb
65
66 # read ports. has write-through detection (returns data written)
67 for rp in self._rdports:
68 domain += rp.o_data.eq(0)
69 with m.If(rp.ren):
70 if self.writethru:
71 wr_detect = Signal(reset_less=False)
72 m.d.comb += wr_detect.eq(0)
73 for wp in self._wrports:
74 with m.If(wp.wen):
75 domain += rp.o_data.eq(wp.i_data)
76 m.d.comb += wr_detect.eq(1)
77 with m.If(~wr_detect):
78 domain += rp.o_data.eq(reg)
79 else:
80 domain += rp.o_data.eq(reg)
81
82 # write ports, delayed by 1 cycle
83 for wp in self._wrports:
84 with m.If(wp.wen):
85 m.d.sync += reg.eq(wp.i_data)
86
87 return m
88
89 def __iter__(self):
90 for p in self._rdports:
91 yield from p
92 for p in self._wrports:
93 yield from p
94
95 def ports(self):
96 res = list(self)
97
98
99 def ortreereduce(tree, attr="o_data"):
100 return treereduce(tree, operator.or_, lambda x: getattr(x, attr))
101
102
103 class RegFileArray(Elaboratable):
104 unary = True
105 """ an array-based register file (register having write-through capability)
106 that has no "address" decoder, instead it has individual write-en
107 and read-en signals (per port).
108 """
109
110 def __init__(self, width, depth, synced=True, fwd_bus_mode=True,
111 resets=None):
112 if resets is None:
113 resets = [0] * depth
114 self.synced = synced
115 self.width = width
116 self.depth = depth
117 self.regs = Array(Register(width, synced=synced,
118 writethru=fwd_bus_mode,
119 resetval=rst) \
120 for rst in resets)
121 self._rdports = []
122 self._wrports = []
123
124 def read_reg_port(self, name=None):
125 regs = []
126 for i in range(self.depth):
127 port = self.regs[i].read_port("%s%d" % (name, i))
128 regs.append(port)
129 return regs
130
131 def write_reg_port(self, name=None):
132 regs = []
133 for i in range(self.depth):
134 port = self.regs[i].write_port("%s%d" % (name, i))
135 regs.append(port)
136 return regs
137
138 def read_port(self, name=None):
139 regs = self.read_reg_port(name)
140 regs = Array(regs)
141 port = RecordObject([("ren", self.depth),
142 ("o_data", self.width)], name)
143 self._rdports.append((regs, port))
144 return port
145
146 def write_port(self, name=None):
147 regs = self.write_reg_port(name)
148 regs = Array(regs)
149 port = RecordObject([("wen", self.depth),
150 ("i_data", self.width)])
151 self._wrports.append((regs, port))
152 return port
153
154 def _get_en_sig(self, port, typ):
155 wen = []
156 for p in port:
157 wen.append(p[typ])
158 return Cat(*wen)
159
160 def elaborate(self, platform):
161 m = Module()
162 for i, reg in enumerate(self.regs):
163 setattr(m.submodules, "reg_%d" % i, reg)
164
165 if self.synced:
166 domain = m.d.sync
167 else:
168 domain = m.d.comb
169
170 for (regs, p) in self._rdports:
171 #print (p)
172 m.d.comb += self._get_en_sig(regs, 'ren').eq(p.ren)
173 ror = ortreereduce(list(regs))
174 if self.synced:
175 ren_delay = Signal.like(p.ren)
176 m.d.sync += ren_delay.eq(p.ren)
177 with m.If(ren_delay):
178 m.d.comb += p.o_data.eq(ror)
179 else:
180 m.d.comb += p.o_data.eq(ror)
181 for (regs, p) in self._wrports:
182 m.d.comb += self._get_en_sig(regs, 'wen').eq(p.wen)
183 for r in regs:
184 m.d.comb += r.i_data.eq(p.i_data)
185
186 return m
187
188 def __iter__(self):
189 for r in self.regs:
190 yield from r
191
192 def ports(self):
193 return list(self)
194
195
196 class RegFileMem(Elaboratable):
197 unary = False
198 def __init__(self, width, depth, fwd_bus_mode=False, synced=True):
199 self.fwd_bus_mode = fwd_bus_mode
200 self.synced = synced
201 self.width, self.depth = width, depth
202 self.memory = Memory(width=width, depth=depth)
203 self._rdports = {}
204 self._wrports = {}
205
206 def read_port(self, name=None):
207 bsz = log2_int(self.depth, False)
208 port = RecordObject([("addr", bsz),
209 ("ren", 1),
210 ("o_data", self.width)], name=name)
211 if self.synced:
212 domain = "sync"
213 else:
214 domain = "comb"
215 self._rdports[name] = (port, self.memory.read_port(domain=domain))
216 return port
217
218 def write_port(self, name=None):
219 bsz = log2_int(self.depth, False)
220 port = RecordObject([("addr", bsz),
221 ("wen", 1),
222 ("i_data", self.width)], name=name)
223 self._wrports[name] = (port, self.memory.write_port())
224 return port
225
226 def elaborate(self, platform):
227 m = Module()
228 comb = m.d.comb
229
230 # read ports. has write-through detection (returns data written)
231 for name, (rp, rport) in self._rdports.items():
232 setattr(m.submodules, "rp_"+name, rport)
233 wr_detect = Signal(reset_less=False)
234 comb += rport.addr.eq(rp.addr)
235 if self.fwd_bus_mode:
236 with m.If(rp.ren):
237 m.d.comb += wr_detect.eq(0)
238 for _, (wp, wport) in self._wrports.items():
239 addrmatch = Signal(reset_less=False)
240 m.d.comb += addrmatch.eq(wp.addr == rp.addr)
241 with m.If(wp.wen & addrmatch):
242 m.d.comb += rp.o_data.eq(wp.i_data)
243 m.d.comb += wr_detect.eq(1)
244 with m.If(~wr_detect):
245 m.d.comb += rp.o_data.eq(rport.data)
246 else:
247 if self.synced:
248 ren_delay = Signal.like(rp.ren)
249 m.d.sync += ren_delay.eq(rp.ren)
250 with m.If(ren_delay):
251 m.d.comb += rp.o_data.eq(rport.data)
252 else:
253 m.d.comb += rp.o_data.eq(rport.data)
254
255 # write ports, delayed by one cycle (in the memory itself)
256 for name, (port, wp) in self._wrports.items():
257 setattr(m.submodules, "wp_"+name, wp)
258 comb += wp.addr.eq(port.addr)
259 comb += wp.en.eq(port.wen)
260 comb += wp.data.eq(port.i_data)
261
262 return m
263
264
265 class RegFile(Elaboratable):
266 unary = False
267 def __init__(self, width, depth):
268 self.width = width
269 self.depth = depth
270 self._rdports = []
271 self._wrports = []
272
273 def read_port(self, name=None):
274 bsz = int(log(self.width) / log(2))
275 port = RecordObject([("addr", bsz),
276 ("ren", 1),
277 ("o_data", self.width)], name=name)
278 self._rdports.append(port)
279 return port
280
281 def write_port(self, name=None):
282 bsz = int(log(self.width) / log(2))
283 port = RecordObject([("addr", bsz),
284 ("wen", 1),
285 ("i_data", self.width)], name=name)
286 self._wrports.append(port)
287 return port
288
289 def elaborate(self, platform):
290 m = Module()
291 bsz = int(log(self.width) / log(2))
292 regs = Array(Signal(self.width, name="reg") for _ in range(self.depth))
293
294 # read ports. has write-through detection (returns data written)
295 for rp in self._rdports:
296 wr_detect = Signal(reset_less=False)
297 with m.If(rp.ren):
298 m.d.comb += wr_detect.eq(0)
299 for wp in self._wrports:
300 addrmatch = Signal(reset_less=False)
301 m.d.comb += addrmatch.eq(wp.addr == rp.addr)
302 with m.If(wp.wen & addrmatch):
303 m.d.comb += rp.o_data.eq(wp.i_data)
304 m.d.comb += wr_detect.eq(1)
305 with m.If(~wr_detect):
306 m.d.comb += rp.o_data.eq(regs[rp.addr])
307
308 # write ports, delayed by one cycle
309 for wp in self._wrports:
310 with m.If(wp.wen):
311 m.d.sync += regs[wp.addr].eq(wp.i_data)
312
313 return m
314
315 def __iter__(self):
316 yield from self._rdports
317 yield from self._wrports
318
319 def ports(self):
320 res = list(self)
321 for r in res:
322 if isinstance(r, RecordObject):
323 yield from r
324 else:
325 yield r
326
327
328 def regfile_sim(dut, rp, wp):
329 yield wp.addr.eq(1)
330 yield wp.i_data.eq(2)
331 yield wp.wen.eq(1)
332 yield
333 yield wp.wen.eq(0)
334 yield wp.addr.eq(0)
335 yield
336 yield
337 yield rp.ren.eq(1)
338 yield rp.addr.eq(1)
339 yield Settle()
340 data = yield rp.o_data
341 print(data)
342 yield
343 data = yield rp.o_data
344 print(data)
345 yield
346 data2 = yield rp.o_data
347 print(data2)
348 assert data == 2
349 yield
350
351 yield wp.addr.eq(5)
352 yield rp.addr.eq(5)
353 yield rp.ren.eq(1)
354 yield wp.wen.eq(1)
355 yield wp.i_data.eq(6)
356 yield
357 data = yield rp.o_data
358 print(data)
359 assert data == 6
360 yield
361 yield wp.wen.eq(0)
362 yield rp.ren.eq(0)
363 yield
364 data = yield rp.o_data
365 print(data)
366 assert data == 0
367 yield
368 data = yield rp.o_data
369 print(data)
370
371
372 def regfile_array_sim(dut, rp1, rp2, wp, wp2):
373 print("regfile_array_sim")
374 yield wp.i_data.eq(2)
375 yield wp.wen.eq(1 << 1)
376 yield
377 yield wp.wen.eq(0)
378 yield rp1.ren.eq(1 << 1)
379 yield Settle()
380 data = yield rp1.o_data
381 print(data)
382 assert data == 2
383 yield
384
385 yield rp1.ren.eq(1 << 5)
386 yield rp2.ren.eq(1 << 1)
387 yield wp.wen.eq(1 << 5)
388 yield wp.i_data.eq(6)
389 yield Settle()
390 data = yield rp1.o_data
391 assert data == 6
392 print(data)
393 yield
394 yield wp.wen.eq(0)
395 yield rp1.ren.eq(0)
396 yield rp2.ren.eq(0)
397 yield Settle()
398 data1 = yield rp1.o_data
399 print(data1)
400 assert data1 == 0
401 data2 = yield rp2.o_data
402 print(data2)
403 assert data2 == 0
404
405 yield
406 data = yield rp1.o_data
407 print(data)
408 assert data == 0
409
410
411 def test_regfile():
412 dut = RegFile(32, 8)
413 rp = dut.read_port()
414 wp = dut.write_port()
415 vl = rtlil.convert(dut)#, ports=dut.ports())
416 with open("test_regfile.il", "w") as f:
417 f.write(vl)
418
419 run_simulation(dut, regfile_sim(dut, rp, wp), vcd_name='test_regfile.vcd')
420
421 dut = RegFileMem(32, 8, True, False)
422 rp = dut.read_port("rp1")
423 wp = dut.write_port("wp1")
424 vl = rtlil.convert(dut)#, ports=dut.ports())
425 with open("test_regmem.il", "w") as f:
426 f.write(vl)
427
428 run_simulation(dut, regfile_sim(dut, rp, wp), vcd_name='test_regmem.vcd')
429
430 dut = RegFileArray(32, 8, False)
431 rp1 = dut.read_port("read1")
432 rp2 = dut.read_port("read2")
433 wp = dut.write_port("write")
434 wp2 = dut.write_port("write2")
435 ports = dut.ports()
436 print("ports", ports)
437 vl = rtlil.convert(dut, ports=ports)
438 with open("test_regfile_array.il", "w") as f:
439 f.write(vl)
440
441 run_simulation(dut, regfile_array_sim(dut, rp1, rp2, wp, wp2),
442 vcd_name='test_regfile_array.vcd')
443
444
445 if __name__ == '__main__':
446 test_regfile()