big reorg, shuffle code to functions, makes the FSM clearer
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12
13 from soc.experiment.mem_types import (LoadStore1ToMmuType,
14 MmuToLoadStore1Type,
15 MmuToDcacheType,
16 DcacheToMmuType,
17 MmuToIcacheType)
18
19 # -- Radix MMU
20 # -- Supports 4-level trees as in arch 3.0B, but not the
21 # -- two-step translation
22 # -- for guests under a hypervisor (i.e. there is no gRA -> hRA translation).
23
24 @unique
25 class State(Enum):
26 IDLE = 0 # zero is default on reset for r.state
27 DO_TLBIE = 1
28 TLB_WAIT = 2
29 PROC_TBL_READ = 3
30 PROC_TBL_WAIT = 4
31 SEGMENT_CHECK = 5
32 RADIX_LOOKUP = 6
33 RADIX_READ_WAIT = 7
34 RADIX_LOAD_TLB = 8
35 RADIX_FINISH = 9
36
37
38 # -- generate mask for extracting address fields for PTE address
39 # -- generation
40 # addrmaskgen: process(all)
41 # generate mask for extracting address fields for PTE address
42 # generation
43 class AddrMaskGen(Elaboratable):
44 def __init__(self):
45 # variable m : std_ulogic_vector(15 downto 0);
46 super().__init__()
47 self.msk = Signal(16)
48
49 # begin
50 def elaborate(self, platform):
51 m = Module()
52
53 comb = m.d.comb
54 sync = m.d.sync
55
56 rst = ResetSignal()
57
58 msk = self.msk
59
60 r = self.r
61 mask = self.mask
62
63 # -- mask_count has to be >= 5
64 # m := x"001f";
65 # mask_count has to be >= 5
66 comb += mask.eq(C(0x001F, 16))
67
68 # for i in 5 to 15 loop
69 for i in range(5,16):
70 # if i < to_integer(r.mask_size) then
71 with m.If(i < r.mask_size):
72 # m(i) := '1';
73 comb += msk[i].eq(1)
74 # end if;
75 # end loop;
76 # mask <= m;
77 comb += mask.eq(msk)
78 # end process;
79
80 # -- generate mask for extracting address bits to go in
81 # -- TLB entry in order to support pages > 4kB
82 # finalmaskgen: process(all)
83 # generate mask for extracting address bits to go in
84 # TLB entry in order to support pages > 4kB
85 class FinalMaskGen(Elaboratable):
86 def __init__(self):
87 # variable m : std_ulogic_vector(43 downto 0);
88 super().__init__()
89 self.msk = Signal(44)
90
91 # begin
92 def elaborate(self, platform):
93 m = Module()
94
95 comb = m.d.comb
96 sync = m.d.sync
97
98 rst = ResetSignal()
99
100 mask = self.mask
101 r = self.r
102
103 msk = self.msk
104
105 # for i in 0 to 43 loop
106 for i in range(44):
107 # if i < to_integer(r.shift) then
108 with m.If(i < r.shift):
109 # m(i) := '1';
110 comb += msk.eq(1)
111 # end if;
112 # end loop;
113 # finalmask <= m;
114 comb += self.finalmask(mask)
115 # end process;
116
117
118 class RegStage(RecordObject):
119 def __init__(self, name=None):
120 super().__init__(self, name=name)
121 # latched request from loadstore1
122 self.valid = Signal(reset_less=True)
123 self.iside = Signal(reset_less=True)
124 self.store = Signal(reset_less=True)
125 self.priv = Signal(reset_less=True)
126 self.addr = Signal(64, reset_less=True)
127 self.inval_all = Signal(reset_less=True)
128 # config SPRs
129 self.prtbl = Signal(64, reset_less=True)
130 self.pid = Signal(32, reset_less=True)
131 # internal state
132 self.state = State.IDLE
133 self.done = Signal(reset_less=True)
134 self.err = Signal(reset_less=True)
135 self.pgtbl0 = Signal(64, reset_less=True)
136 self.pt0_valid = Signal(reset_less=True)
137 self.pgtbl3 = Signal(64, reset_less=True)
138 self.pt3_valid = Signal(reset_less=True)
139 self.shift = Signal(6, reset_less=True)
140 self.mask_size = Signal(5, reset_less=True)
141 self.pgbase = Signal(56, reset_less=True)
142 self.pde = Signal(64, reset_less=True)
143 self.invalid = Signal(reset_less=True)
144 self.badtree = Signal(reset_less=True)
145 self.segerror = Signal(reset_less=True)
146 self.perm_err = Signal(reset_less=True)
147 self.rc_error = Signal(reset_less=True)
148
149
150 class MMU(Elaboratable):
151 """Radix MMU
152
153 Supports 4-level trees as in arch 3.0B, but not the
154 two-step translation for guests under a hypervisor
155 (i.e. there is no gRA -> hRA translation).
156 """
157 def __init__(self):
158 self.l_in = LoadStore1ToMmuType()
159 self.l_out = MmuToLoadStore1Type()
160 self.d_out = MmuToDcacheType()
161 self.d_in = DcacheToMmuType()
162 self.i_out = MmuToIcacheType()
163
164 def radix_tree_idle(self, m, l_in, v):
165 comb = m.d.comb
166 pt_valid = Signal()
167 pgtbl = Signal(64)
168 with m.If(~l_in.addr[63]):
169 comb += pgtbl.eq(r.pgtbl0)
170 comb += pt_valid.eq(r.pt0_valid)
171 with m.Else():
172 comb += pgtbl.eq(r.pt3_valid)
173 comb += pt_valid.eq(r.pt3_valid)
174
175 # rts == radix tree size, number of address bits
176 # being translated
177 rts = Signal(6)
178 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
179
180 # mbits == number of address bits to index top
181 # level of tree
182 mbits = Signal(6)
183 comb += mbits.eq(pgtbl[0:5])
184
185 # set v.shift to rts so that we can use finalmask
186 # for the segment check
187 comb += v.shift.eq(rts)
188 comb += v.mask_size.eq(mbits[0:5])
189 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
190
191 with m.If(l_in.valid):
192 comb += v.addr.eq(l_in.addr)
193 comb += v.iside.eq(l_in.iside)
194 comb += v.store.eq(~(l_in.load | l_in.iside))
195
196 with m.If(l_in.tlbie):
197 # Invalidate all iTLB/dTLB entries for
198 # tlbie with RB[IS] != 0 or RB[AP] != 0,
199 # or for slbia
200 comb += v.inval_all.eq(l_in.slbia
201 | l_in.addr[11]
202 | l_in.addr[10]
203 | l_in.addr[7]
204 | l_in.addr[6]
205 | l_in.addr[5]
206 )
207 # The RIC field of the tlbie instruction
208 # comes across on the sprn bus as bits 2--3.
209 # RIC=2 flushes process table caches.
210 with m.If(l_in.sprn[3]):
211 comb += v.pt0_valid.eq(0)
212 comb += v.pt3_valid.eq(0)
213 comb += v.state.eq(State.DO_TLBIE)
214 with m.Else():
215 comb += v.valid.eq(1)
216 with m.If(~pt_valid):
217 # need to fetch process table entry
218 # set v.shift so we can use finalmask
219 # for generating the process table
220 # entry address
221 comb += v.shift.eq(r.prtble[0:5])
222 comb += v.state.eq(State.PROC_TBL_READ)
223
224 with m.If(~mbits):
225 # Use RPDS = 0 to disable radix tree walks
226 comb += v.state.eq(State.RADIX_FINISH)
227 comb += v.invalid.eq(1)
228 with m.Else():
229 comb += v.state.eq(State.SEGMENT_CHECK)
230
231 with m.If(l_in.mtspr):
232 # Move to PID needs to invalidate L1 TLBs
233 # and cached pgtbl0 value. Move to PRTBL
234 # does that plus invalidating the cached
235 # pgtbl3 value as well.
236 with m.If(~l_in.sprn[9]):
237 comb += v.pid.eq(l_in.rs[0:32])
238 with m.Else():
239 comb += v.prtbl.eq(l_in.rs)
240 comb += v.pt3_valid.eq(0)
241
242 comb += v.pt0_valid.eq(0)
243 comb += v.inval_all.eq(1)
244 comb += v.state.eq(State.DO_TLBIE)
245
246 def proc_tbl_wait(self, m, v, r, data):
247 comb = m.d.comb
248 with m.If(r.addr[63]):
249 comb += v.pgtbl3.eq(data)
250 comb += v.pt3_valid.eq(1)
251 with m.Else():
252 comb += v.pgtbl0.eq(data)
253 comb += v.pt0_valid.eq(1)
254 # rts == radix tree size, # address bits being translated
255 rts = Signal(6)
256 comb += rts.eq(Cat(data[5:8], data[61:63]))
257
258 # mbits == # address bits to index top level of tree
259 mbits = Signal(6)
260 comb += mbits.eq(data[0:5])
261 # set v.shift to rts so that we can use
262 # finalmask for the segment check
263 comb += v.shift.eq(rts)
264 comb += v.mask_size.eq(mbits[0:5])
265 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
266
267 with m.If(~mbits):
268 comb += v.state.eq(State.RADIX_FINISH)
269 comb += v.invalid.eq(1)
270 comb += v.state.eq(State.SEGMENT_CHECK)
271
272 def radix_read_wait(self, m, v, r, d_in, data):
273 comb = m.d.comb
274 comb += v.pde.eq(data)
275 # test valid bit
276 with m.If(data[63]):
277 with m.If(data[62]):
278 # check permissions and RC bits
279 perm_ok = Signal()
280 comb += perm_ok.eq(0)
281 with m.If(r.priv | ~data[3]):
282 with m.If(~r.iside):
283 comb += perm_ok.eq((data[1] | data[2]) & (~r.store))
284 with m.Else():
285 # no IAMR, so no KUEP support
286 # for now deny execute
287 # permission if cache inhibited
288 comb += perm_ok.eq(data[0] & ~data[5])
289
290 rc_ok = Signal()
291 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
292 with m.If(perm_ok & rc_ok):
293 comb += v.state.eq(State.RADIX_LOAD_TLB)
294 with m.Else():
295 comb += v.state.eq(State.RADIX_ERROR)
296 comb += v.perm_err.eq(~perm_ok)
297 # permission error takes precedence
298 # over RC error
299 comb += v.rc_error.eq(perm_ok)
300 with m.Else():
301 mbits = Signal(6)
302 comb += mbits.eq(data[0:5])
303 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
304 comb += v.state.eq(State.RADIX_FINISH)
305 comb += v.badtree.eq(1)
306 with m.Else():
307 comb += v.shift.eq(v.shif - mbits)
308 comb += v.mask_size.eq(mbits[0:5])
309 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
310 comb += v.state.eq(State.RADIX_LOOKUP)
311
312 def segment_check(self, m, v, r, data):
313 comb = m.d.comb
314 mbits = Signal(6)
315 nonzero = Signal()
316 comb += mbits.eq(r.mask_size)
317 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
318 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
319 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
320 comb += v.state.eq(State.RADIX_FINISH)
321 comb += v.segerror.eq(1)
322 with m.Elif((mbits < 5) | (mbits > 16) |
323 (mbits > (r.shift + (31-12)))):
324 comb += v.state.eq(State.RADIX_FINISH)
325 comb += v.badtree.eq(1)
326 with m.Else():
327 comb += v.state.eq(State.RADIX_LOOKUP)
328
329 def elaborate(self, platform):
330 m = Module()
331
332 comb = m.d.comb
333 sync = m.d.sync
334
335 addrsh = Signal(16)
336 mask = Signal(16)
337 finalmask = Signal(44)
338
339 r = RegStage()
340 rin = RegStage()
341
342 l_in = self.l_in
343 l_out = self.l_out
344 d_out = self.d_out
345 d_in = self.d_in
346 i_out = self.i_out
347
348 # Multiplex internal SPR values back to loadstore1,
349 # selected by l_in.sprn.
350 with m.If(l_in.sprn[9]):
351 comb += l_out.sprval.eq(r.prtbl)
352 with m.Else():
353 comb += l_out.sprval.eq(r.pid)
354
355 with m.If(rin.valid):
356 pass
357 #sync += Display(f"MMU got tlb miss for {rin.addr}")
358
359 with m.If(l_out.done):
360 pass
361 # sync += Display("MMU completing op without error")
362
363 with m.If(l_out.err):
364 pass
365 # sync += Display(f"MMU completing op with err invalid"
366 # "{l_out.invalid} badtree={l_out.badtree}")
367
368 with m.If(rin.state == State.RADIX_LOOKUP):
369 pass
370 # sync += Display (f"radix lookup shift={rin.shift}"
371 # "msize={rin.mask_size}")
372
373 with m.If(r.state == State.RADIX_LOOKUP):
374 pass
375 # sync += Display(f"send load addr={d_out.addr}"
376 # "addrsh={addrsh} mask={mask}")
377
378 sync += r.eq(rin)
379
380 v = RegStage()
381 dcreq = Signal()
382 tlb_load = Signal()
383 itlb_load = Signal()
384 tlbie_req = Signal()
385 prtbl_rd = Signal()
386 effpid = Signal(32)
387 prtable_addr = Signal(64)
388 pgtable_addr = Signal(64)
389 pte = Signal(64)
390 tlb_data = Signal(64)
391 addr = Signal(64)
392
393 comb += v.eq(r)
394 comb += v.valid.eq(0)
395 comb += dcreq.eq(0)
396 comb += v.done.eq(0)
397 comb += v.err.eq(0)
398 comb += v.invalid.eq(0)
399 comb += v.badtree.eq(0)
400 comb += v.segerror.eq(0)
401 comb += v.perm_err.eq(0)
402 comb += v.rc_error.eq(0)
403 comb += tlb_load.eq(0)
404 comb += itlb_load.eq(0)
405 comb += tlbie_req.eq(0)
406 comb += v.inval_all.eq(0)
407 comb += prtbl_rd.eq(0)
408
409 # Radix tree data structures in memory are
410 # big-endian, so we need to byte-swap them
411 data = byte_reverse(m, "data", d_in.data, 8)
412
413 with m.Switch(r.state):
414 with m.Case(State.IDLE):
415 self.radix_tree_idle(m, l_in, v)
416
417 with m.Case(State.DO_TLBIE):
418 comb += dcreq.eq(1)
419 comb += tlbie_req.eq(1)
420 comb += v.state.eq(State.TLB_WAIT)
421
422 with m.Case(State.TLB_WAIT):
423 with m.If(d_in.done):
424 comb += v.state.eq(State.RADIX_FINISH)
425
426 with m.Case(State.PROC_TBL_READ):
427 comb += dcreq.eq(1)
428 comb += prtbl_rd.eq(1)
429 comb += v.state.eq(State.PROC_TBL_WAIT)
430
431 with m.Case(State.PROC_TBL_WAIT):
432 with m.If(d_in.done):
433 self.proc_tbl_wait(m, v, r, data)
434
435 with m.If(d_in.err):
436 comb += v.state.eq(State.RADIX_FINISH)
437 comb += v.badtree.eq(1)
438
439 with m.Case(State.SEGMENT_CHECK):
440 self.segment_check(m, v, r, data)
441
442 with m.Case(State.RADIX_LOOKUP):
443 comb += dcreq.eq(1)
444 comb += v.state.eq(State.RADIX_READ_WAIT)
445
446 with m.Case(State.RADIX_READ_WAIT):
447 with m.If(d_in.done):
448 self.radix_read_wait(m, v, r, d_in, data)
449 with m.Else():
450 # non-present PTE, generate a DSI
451 comb += v.state.eq(State.RADIX_FINISH)
452 comb += v.invalid.eq(1)
453
454 with m.If(d_in.err):
455 comb += v.state.eq(State.RADIX_FINISH)
456 comb += v.badtree.eq(1)
457
458 with m.Case(State.RADIX_LOAD_TLB):
459 comb += tlb_load.eq(1)
460 with m.If(~r.iside):
461 comb += dcreq.eq(1)
462 comb += v.state.eq(State.TLB_WAIT)
463 with m.Else():
464 comb += itlb_load.eq(1)
465 comb += v.state.eq(State.IDLE)
466
467 with m.Case(State.RADIX_FINISH):
468 comb += v.state.eq(State.IDLE)
469
470 with m.If((v.state == State.RADIX_FINISH) |
471 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
472 comb += v.err.eq(v.invalid | v.badtree | v.segerror
473 | v.perm_err | v.rc_error)
474 comb += v.done.eq(~v.err)
475
476 with m.If(~r.addr[63]):
477 comb += effpid.eq(r.pid)
478
479 comb += prtable_addr.eq(Cat(
480 C(0b0000, 4),
481 effpid[0:8],
482 (r.prtble[12:36] & ~finalmask[0:24]) |
483 (effpid[8:32] & finalmask[0:24]),
484 r.prtbl[36:56]
485 ))
486
487 comb += pgtable_addr.eq(Cat(
488 C(0b000, 3),
489 (r.pgbase[3:19] & ~mask) |
490 (addrsh & mask),
491 r.pgbase[19:56]
492 ))
493
494 comb += pte.eq(Cat(
495 r.pde[0:12],
496 (r.pde[12:56] & ~finalmask) |
497 (r.addr[12:56] & finalmask),
498 ))
499
500 # update registers
501 rin.eq(v)
502
503 # drive outputs
504 with m.If(tlbie_req):
505 comb += addr.eq(r.addr)
506 with m.Elif(tlb_load):
507 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
508 comb += tlb_data.eq(pte)
509 with m.Elif(prtbl_rd):
510 comb += addr.eq(prtable_addr)
511 with m.Else():
512 comb += addr.eq(pgtable_addr)
513
514 comb += l_out.done.eq(r.done)
515 comb += l_out.err.eq(r.err)
516 comb += l_out.invalid.eq(r.invalid)
517 comb += l_out.badtree.eq(r.badtree)
518 comb += l_out.segerr.eq(r.segerror)
519 comb += l_out.perm_error.eq(r.perm_err)
520 comb += l_out.rc_error.eq(r.rc_error)
521
522 comb += d_out.valid.eq(dcreq)
523 comb += d_out.tlbie.eq(tlbie_req)
524 comb += d_out.doall.eq(r.inval_all)
525 comb += d_out.tlbld.eq(tlb_load)
526 comb += d_out.addr.eq(addr)
527 comb += d_out.pte.eq(tlb_data)
528
529 comb += i_out.tlbld.eq(itlb_load)
530 comb += i_out.tblie.eq(tlbie_req)
531 comb += i_out.doall.eq(r.inval_all)
532 comb += i_out.addr.eq(addr)
533 comb += i_out.pte.eq(tlb_data)
534
535
536 def mmu_sim():
537 yield wp.waddr.eq(1)
538 yield wp.data_i.eq(2)
539 yield wp.wen.eq(1)
540 yield
541 yield wp.wen.eq(0)
542 yield rp.ren.eq(1)
543 yield rp.raddr.eq(1)
544 yield Settle()
545 data = yield rp.data_o
546 print(data)
547 assert data == 2
548 yield
549
550 yield wp.waddr.eq(5)
551 yield rp.raddr.eq(5)
552 yield rp.ren.eq(1)
553 yield wp.wen.eq(1)
554 yield wp.data_i.eq(6)
555 yield Settle()
556 data = yield rp.data_o
557 print(data)
558 assert data == 6
559 yield
560 yield wp.wen.eq(0)
561 yield rp.ren.eq(0)
562 yield Settle()
563 data = yield rp.data_o
564 print(data)
565 assert data == 0
566 yield
567 data = yield rp.data_o
568 print(data)
569
570 def test_mmu():
571 dut = MMU()
572 vl = rtlil.convert(dut, ports=[])#dut.ports())
573 with open("test_mmu.il", "w") as f:
574 f.write(vl)
575
576 run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
577
578 if __name__ == '__main__':
579 test_mmu()