whitespace
[soc.git] / src / soc / experiment / mmu.py
1 # MMU
2 #
3 # License for original copyright mmu.vhdl by microwatt authors: CC4
4 # License for copyrighted modifications made in mmu.py: LGPLv3+
5 #
6 # This derivative work although includes CC4 licensed material is
7 # covered by the LGPLv3+
8
9 """MMU
10
11 based on Anton Blanchard microwatt mmu.vhdl
12
13 """
14 from enum import Enum, unique
15 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
16 from nmigen.cli import main
17 from nmigen.cli import rtlil
18 from nmutil.iocontrol import RecordObject
19 from nmutil.byterev import byte_reverse
20 from nmutil.mask import Mask, masked
21 from nmutil.util import Display
22
23 if True:
24 from nmigen.back.pysim import Simulator, Delay, Settle
25 else:
26 from nmigen.sim.cxxsim import Simulator, Delay, Settle
27 from nmutil.util import wrap
28
29 from soc.experiment.mem_types import (LoadStore1ToMMUType,
30 MMUToLoadStore1Type,
31 MMUToDCacheType,
32 DCacheToMMUType,
33 MMUToICacheType)
34
35
36 @unique
37 class State(Enum):
38 IDLE = 0 # zero is default on reset for r.state
39 DO_TLBIE = 1
40 TLB_WAIT = 2
41 PROC_TBL_READ = 3
42 PROC_TBL_WAIT = 4
43 SEGMENT_CHECK = 5
44 RADIX_LOOKUP = 6
45 RADIX_READ_WAIT = 7
46 RADIX_LOAD_TLB = 8
47 RADIX_FINISH = 9
48
49
50 class RegStage(RecordObject):
51 def __init__(self, name=None):
52 super().__init__(name=name)
53 # latched request from loadstore1
54 self.valid = Signal()
55 self.iside = Signal()
56 self.store = Signal()
57 self.priv = Signal()
58 self.addr = Signal(64)
59 self.inval_all = Signal()
60 # config SPRs
61 self.prtbl = Signal(64)
62 self.pid = Signal(32)
63 # internal state
64 self.state = Signal(State) # resets to IDLE
65 self.done = Signal()
66 self.err = Signal()
67 self.pgtbl0 = Signal(64)
68 self.pt0_valid = Signal()
69 self.pgtbl3 = Signal(64)
70 self.pt3_valid = Signal()
71 self.shift = Signal(6)
72 self.mask_size = Signal(5)
73 self.pgbase = Signal(56)
74 self.pde = Signal(64)
75 self.invalid = Signal()
76 self.badtree = Signal()
77 self.segerror = Signal()
78 self.perm_err = Signal()
79 self.rc_error = Signal()
80
81
82 class MMU(Elaboratable):
83 """Radix MMU
84
85 Supports 4-level trees as in arch 3.0B, but not the
86 two-step translation for guests under a hypervisor
87 (i.e. there is no gRA -> hRA translation).
88 """
89 def __init__(self):
90 self.l_in = LoadStore1ToMMUType()
91 self.l_out = MMUToLoadStore1Type()
92 self.d_out = MMUToDCacheType()
93 self.d_in = DCacheToMMUType()
94 self.i_out = MMUToICacheType()
95
96 def radix_tree_idle(self, m, l_in, r, v):
97 comb = m.d.comb
98 sync = m.d.sync
99
100 pt_valid = Signal()
101 pgtbl = Signal(64)
102 rts = Signal(6)
103 mbits = Signal(6)
104
105 with m.If(~l_in.addr[63]):
106 comb += pgtbl.eq(r.pgtbl0)
107 comb += pt_valid.eq(r.pt0_valid)
108 with m.Else():
109 comb += pgtbl.eq(r.pgtbl3)
110 comb += pt_valid.eq(r.pt3_valid)
111
112 # rts == radix tree size, number of address bits
113 # being translated
114 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
115
116 # mbits == number of address bits to index top
117 # level of tree
118 comb += mbits.eq(pgtbl[0:5])
119
120 # set v.shift to rts so that we can use finalmask
121 # for the segment check
122 comb += v.shift.eq(rts)
123 comb += v.mask_size.eq(mbits[0:5])
124 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
125
126 with m.If(l_in.valid):
127 comb += v.addr.eq(l_in.addr)
128 comb += v.iside.eq(l_in.iside)
129 comb += v.store.eq(~(l_in.load | l_in.iside))
130 comb += v.priv.eq(l_in.priv)
131
132 comb += Display("l_in.valid addr %x iside %d store %d "
133 "rts %x mbits %x pt_valid %d",
134 v.addr, v.iside, v.store,
135 rts, mbits, pt_valid)
136
137 with m.If(l_in.tlbie):
138 # Invalidate all iTLB/dTLB entries for
139 # tlbie with RB[IS] != 0 or RB[AP] != 0,
140 # or for slbia
141 comb += v.inval_all.eq(l_in.slbia
142 | l_in.addr[11]
143 | l_in.addr[10]
144 | l_in.addr[7]
145 | l_in.addr[6]
146 | l_in.addr[5]
147 )
148 # The RIC field of the tlbie instruction
149 # comes across on the sprn bus as bits 2--3.
150 # RIC=2 flushes process table caches.
151 with m.If(l_in.sprn[3]):
152 comb += v.pt0_valid.eq(0)
153 comb += v.pt3_valid.eq(0)
154 comb += v.state.eq(State.DO_TLBIE)
155 with m.Else():
156 comb += v.valid.eq(1)
157 with m.If(~pt_valid):
158 # need to fetch process table entry
159 # set v.shift so we can use finalmask
160 # for generating the process table
161 # entry address
162 comb += v.shift.eq(r.prtbl[0:5])
163 comb += v.state.eq(State.PROC_TBL_READ)
164
165 with m.Elif(mbits == 0):
166 # Use RPDS = 0 to disable radix tree walks
167 comb += v.state.eq(State.RADIX_FINISH)
168 comb += v.invalid.eq(1)
169 with m.Else():
170 comb += v.state.eq(State.SEGMENT_CHECK)
171
172 with m.If(l_in.mtspr):
173 # Move to PID needs to invalidate L1 TLBs
174 # and cached pgtbl0 value. Move to PRTBL
175 # does that plus invalidating the cached
176 # pgtbl3 value as well.
177 with m.If(~l_in.sprn[9]):
178 comb += v.pid.eq(l_in.rs[0:32])
179 with m.Else():
180 comb += v.prtbl.eq(l_in.rs)
181 comb += v.pt3_valid.eq(0)
182
183 comb += v.pt0_valid.eq(0)
184 comb += v.inval_all.eq(1)
185 comb += v.state.eq(State.DO_TLBIE)
186
187 def proc_tbl_wait(self, m, v, r, data):
188 comb = m.d.comb
189 with m.If(r.addr[63]):
190 comb += v.pgtbl3.eq(data)
191 comb += v.pt3_valid.eq(1)
192 with m.Else():
193 comb += v.pgtbl0.eq(data)
194 comb += v.pt0_valid.eq(1)
195
196 rts = Signal(6)
197 mbits = Signal(6)
198
199 # rts == radix tree size, # address bits being translated
200 comb += rts.eq(Cat(data[5:8], data[61:63]))
201
202 # mbits == # address bits to index top level of tree
203 comb += mbits.eq(data[0:5])
204
205 # set v.shift to rts so that we can use finalmask for the segment check
206 comb += v.shift.eq(rts)
207 comb += v.mask_size.eq(mbits[0:5])
208 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
209
210 with m.If(mbits):
211 comb += v.state.eq(State.SEGMENT_CHECK)
212 with m.Else():
213 comb += v.state.eq(State.RADIX_FINISH)
214 comb += v.invalid.eq(1)
215
216 def radix_read_wait(self, m, v, r, d_in, data):
217 comb = m.d.comb
218 sync = m.d.sync
219
220 perm_ok = Signal()
221 rc_ok = Signal()
222 mbits = Signal(6)
223 valid = Signal()
224 leaf = Signal()
225 badtree = Signal()
226
227 comb += Display("RDW %016x done %d "
228 "perm %d rc %d mbits %d shf %d "
229 "valid %d leaf %d bad %d",
230 data, d_in.done, perm_ok, rc_ok,
231 mbits, r.shift, valid, leaf, badtree)
232
233 # set pde
234 comb += v.pde.eq(data)
235
236 # test valid bit
237 comb += valid.eq(data[63]) # valid=data[63]
238 comb += leaf.eq(data[62]) # valid=data[63]
239
240 comb += v.pde.eq(data)
241 # valid & leaf
242 with m.If(valid):
243 with m.If(leaf):
244 # check permissions and RC bits
245 with m.If(r.priv | ~data[3]):
246 with m.If(~r.iside):
247 comb += perm_ok.eq(data[1] | (data[2] & ~r.store))
248 with m.Else():
249 # no IAMR, so no KUEP support for now
250 # deny execute permission if cache inhibited
251 comb += perm_ok.eq(data[0] & ~data[5])
252
253 comb += rc_ok.eq(data[8] & (data[7] | ~r.store))
254 with m.If(perm_ok & rc_ok):
255 comb += v.state.eq(State.RADIX_LOAD_TLB)
256 with m.Else():
257 comb += v.state.eq(State.RADIX_FINISH)
258 comb += v.perm_err.eq(~perm_ok)
259 # permission error takes precedence over RC error
260 comb += v.rc_error.eq(perm_ok)
261
262 # valid & !leaf
263 with m.Else():
264 comb += mbits.eq(data[0:5])
265 comb += badtree.eq((mbits < 5) |
266 (mbits > 16) |
267 (mbits > r.shift))
268 with m.If(badtree):
269 comb += v.state.eq(State.RADIX_FINISH)
270 comb += v.badtree.eq(1)
271 with m.Else():
272 comb += v.shift.eq(r.shift - mbits)
273 comb += v.mask_size.eq(mbits[0:5])
274 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
275 comb += v.state.eq(State.RADIX_LOOKUP)
276
277 with m.Else():
278 # non-present PTE, generate a DSI
279 comb += v.state.eq(State.RADIX_FINISH)
280 comb += v.invalid.eq(1)
281
282 def segment_check(self, m, v, r, data, finalmask):
283 comb = m.d.comb
284
285 mbits = Signal(6)
286 nonzero = Signal()
287 comb += mbits.eq(r.mask_size)
288 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
289 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
290 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
291 comb += v.state.eq(State.RADIX_FINISH)
292 comb += v.segerror.eq(1)
293 with m.Elif((mbits < 5) | (mbits > 16) |
294 (mbits > (r.shift + (31-12)))):
295 comb += v.state.eq(State.RADIX_FINISH)
296 comb += v.badtree.eq(1)
297 with m.Else():
298 comb += v.state.eq(State.RADIX_LOOKUP)
299
300 def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
301 comb = m.d.comb
302 sync = m.d.sync
303
304 # Multiplex internal SPR values back to loadstore1,
305 # selected by l_in.sprn.
306 with m.If(l_in.sprn[9]):
307 comb += l_out.sprval.eq(r.prtbl)
308 with m.Else():
309 comb += l_out.sprval.eq(r.pid)
310
311 with m.If(rin.valid):
312 sync += Display("MMU got tlb miss for %x", rin.addr)
313
314 with m.If(l_out.done):
315 sync += Display("MMU completing op without error")
316
317 with m.If(l_out.err):
318 sync += Display("MMU completing op with err invalid"
319 "%d badtree=%d", l_out.invalid, l_out.badtree)
320
321 with m.If(rin.state == State.RADIX_LOOKUP):
322 sync += Display ("radix lookup shift=%d msize=%d",
323 rin.shift, rin.mask_size)
324
325 with m.If(r.state == State.RADIX_LOOKUP):
326 sync += Display(f"send load addr=%x addrsh=%d mask=%x",
327 d_out.addr, addrsh, mask)
328 sync += r.eq(rin)
329
330 def elaborate(self, platform):
331 m = Module()
332
333 comb = m.d.comb
334 sync = m.d.sync
335
336 addrsh = Signal(16)
337 mask = Signal(16)
338 finalmask = Signal(44)
339
340 self.rin = rin = RegStage("r_in")
341 r = RegStage("r")
342
343 l_in = self.l_in
344 l_out = self.l_out
345 d_out = self.d_out
346 d_in = self.d_in
347 i_out = self.i_out
348
349 self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
350
351 v = RegStage()
352 dcreq = Signal()
353 tlb_load = Signal()
354 itlb_load = Signal()
355 tlbie_req = Signal()
356 prtbl_rd = Signal()
357 effpid = Signal(32)
358 prtb_adr = Signal(64)
359 pgtb_adr = Signal(64)
360 pte = Signal(64)
361 tlb_data = Signal(64)
362 addr = Signal(64)
363
364 comb += v.eq(r)
365 comb += v.valid.eq(0)
366 comb += dcreq.eq(0)
367 comb += v.done.eq(0)
368 comb += v.err.eq(0)
369 comb += v.invalid.eq(0)
370 comb += v.badtree.eq(0)
371 comb += v.segerror.eq(0)
372 comb += v.perm_err.eq(0)
373 comb += v.rc_error.eq(0)
374 comb += tlb_load.eq(0)
375 comb += itlb_load.eq(0)
376 comb += tlbie_req.eq(0)
377 comb += v.inval_all.eq(0)
378 comb += prtbl_rd.eq(0)
379
380 # Radix tree data structures in memory are
381 # big-endian, so we need to byte-swap them
382 data = byte_reverse(m, "data", d_in.data, 8)
383
384 # generate mask for extracting address fields for PTE addr generation
385 m.submodules.pte_mask = pte_mask = Mask(16-5)
386 comb += pte_mask.shift.eq(r.mask_size - 5)
387 comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
388
389 # generate mask for extracting address bits to go in
390 # TLB entry in order to support pages > 4kB
391 m.submodules.tlb_mask = tlb_mask = Mask(44)
392 comb += tlb_mask.shift.eq(r.shift)
393 comb += finalmask.eq(tlb_mask.mask)
394
395 with m.If(r.state != State.IDLE):
396 sync += Display("MMU state %d %016x", r.state, data)
397
398 with m.Switch(r.state):
399 with m.Case(State.IDLE):
400 self.radix_tree_idle(m, l_in, r, v)
401
402 with m.Case(State.DO_TLBIE):
403 comb += dcreq.eq(1)
404 comb += tlbie_req.eq(1)
405 comb += v.state.eq(State.TLB_WAIT)
406
407 with m.Case(State.TLB_WAIT):
408 with m.If(d_in.done):
409 comb += v.state.eq(State.RADIX_FINISH)
410
411 with m.Case(State.PROC_TBL_READ):
412 sync += Display(" TBL_READ %016x", prtb_adr)
413 comb += dcreq.eq(1)
414 comb += prtbl_rd.eq(1)
415 comb += v.state.eq(State.PROC_TBL_WAIT)
416
417 with m.Case(State.PROC_TBL_WAIT):
418 with m.If(d_in.done):
419 self.proc_tbl_wait(m, v, r, data)
420
421 with m.If(d_in.err):
422 comb += v.state.eq(State.RADIX_FINISH)
423 comb += v.badtree.eq(1)
424
425 with m.Case(State.SEGMENT_CHECK):
426 self.segment_check(m, v, r, data, finalmask)
427
428 with m.Case(State.RADIX_LOOKUP):
429 sync += Display(" RADIX_LOOKUP")
430 comb += dcreq.eq(1)
431 comb += v.state.eq(State.RADIX_READ_WAIT)
432
433 with m.Case(State.RADIX_READ_WAIT):
434 sync += Display(" READ_WAIT")
435 with m.If(d_in.done):
436 self.radix_read_wait(m, v, r, d_in, data)
437 with m.If(d_in.err):
438 comb += v.state.eq(State.RADIX_FINISH)
439 comb += v.badtree.eq(1)
440
441 with m.Case(State.RADIX_LOAD_TLB):
442 comb += tlb_load.eq(1)
443 with m.If(~r.iside):
444 comb += dcreq.eq(1)
445 comb += v.state.eq(State.TLB_WAIT)
446 with m.Else():
447 comb += itlb_load.eq(1)
448 comb += v.state.eq(State.IDLE)
449
450 with m.Case(State.RADIX_FINISH):
451 sync += Display(" RADIX_FINISH")
452 comb += v.state.eq(State.IDLE)
453
454 with m.If((v.state == State.RADIX_FINISH) |
455 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
456 comb += v.err.eq(v.invalid | v.badtree | v.segerror
457 | v.perm_err | v.rc_error)
458 comb += v.done.eq(~v.err)
459
460 with m.If(~r.addr[63]):
461 comb += effpid.eq(r.pid)
462
463 pr24 = Signal(24, reset_less=True)
464 comb += pr24.eq(masked(r.prtbl[12:36], effpid[8:32], finalmask))
465 comb += prtb_adr.eq(Cat(C(0, 4), effpid[0:8], pr24, r.prtbl[36:56]))
466
467 pg16 = Signal(16, reset_less=True)
468 comb += pg16.eq(masked(r.pgbase[3:19], addrsh, mask))
469 comb += pgtb_adr.eq(Cat(C(0, 3), pg16, r.pgbase[19:56]))
470
471 pd44 = Signal(44, reset_less=True)
472 comb += pd44.eq(masked(r.pde[12:56], r.addr[12:56], finalmask))
473 comb += pte.eq(Cat(r.pde[0:12], pd44))
474
475 # update registers
476 comb += rin.eq(v)
477
478 # drive outputs
479 with m.If(tlbie_req):
480 comb += addr.eq(r.addr)
481 with m.Elif(tlb_load):
482 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
483 comb += tlb_data.eq(pte)
484 with m.Elif(prtbl_rd):
485 comb += addr.eq(prtb_adr)
486 with m.Else():
487 comb += addr.eq(pgtb_adr)
488
489 comb += l_out.done.eq(r.done)
490 comb += l_out.err.eq(r.err)
491 comb += l_out.invalid.eq(r.invalid)
492 comb += l_out.badtree.eq(r.badtree)
493 comb += l_out.segerr.eq(r.segerror)
494 comb += l_out.perm_error.eq(r.perm_err)
495 comb += l_out.rc_error.eq(r.rc_error)
496
497 comb += d_out.valid.eq(dcreq)
498 comb += d_out.tlbie.eq(tlbie_req)
499 comb += d_out.doall.eq(r.inval_all)
500 comb += d_out.tlbld.eq(tlb_load)
501 comb += d_out.addr.eq(addr)
502 comb += d_out.pte.eq(tlb_data)
503
504 comb += i_out.tlbld.eq(itlb_load)
505 comb += i_out.tlbie.eq(tlbie_req)
506 comb += i_out.doall.eq(r.inval_all)
507 comb += i_out.addr.eq(addr)
508 comb += i_out.pte.eq(tlb_data)
509
510 return m
511
512 stop = False
513
514 def dcache_get(dut):
515 """simulator process for getting memory load requests
516 """
517
518 global stop
519
520 def b(x):
521 return int.from_bytes(x.to_bytes(8, byteorder='little'),
522 byteorder='big', signed=False)
523
524 mem = {0x10000: # PARTITION_TABLE_2
525 # PATB_GR=1 PRTB=0x1000 PRTS=0xb
526 b(0x800000000100000b),
527
528 0x30000: # RADIX_ROOT_PTE
529 # V = 1 L = 0 NLB = 0x400 NLS = 9
530 b(0x8000000000040009),
531
532 0x40000: # RADIX_SECOND_LEVEL
533 # V = 1 L = 1 SW = 0 RPN = 0
534 # R = 1 C = 1 ATT = 0 EAA 0x7
535 b(0xc000000000000187),
536
537 0x1000000: # PROCESS_TABLE_3
538 # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
539 b(0x40000000000300ad),
540 }
541
542 while not stop:
543 while True: # wait for dc_valid
544 if stop:
545 return
546 dc_valid = yield (dut.d_out.valid)
547 if dc_valid:
548 break
549 yield
550 addr = yield dut.d_out.addr
551 if addr not in mem:
552 print (" DCACHE LOOKUP FAIL %x" % (addr))
553 stop = True
554 return
555
556 yield
557 data = mem[addr]
558 yield dut.d_in.data.eq(data)
559 print ("dcache get %x data %x" % (addr, data))
560 yield dut.d_in.done.eq(1)
561 yield
562 yield dut.d_in.done.eq(0)
563
564
565 def mmu_sim(dut):
566 global stop
567 yield dut.rin.prtbl.eq(0x1000000) # set process table
568 yield
569
570 yield dut.l_in.load.eq(1)
571 yield dut.l_in.priv.eq(1)
572 yield dut.l_in.addr.eq(0x10000)
573 yield dut.l_in.valid.eq(1)
574 while not stop: # wait for dc_valid / err
575 l_done = yield (dut.l_out.done)
576 l_err = yield (dut.l_out.err)
577 l_badtree = yield (dut.l_out.badtree)
578 l_permerr = yield (dut.l_out.perm_error)
579 l_rc_err = yield (dut.l_out.rc_error)
580 l_segerr = yield (dut.l_out.segerr)
581 l_invalid = yield (dut.l_out.invalid)
582 if (l_done or l_err or l_badtree or
583 l_permerr or l_rc_err or l_segerr or l_invalid):
584 break
585 yield
586 addr = yield dut.d_out.addr
587 pte = yield dut.d_out.pte
588 print ("translated done %d err %d badtree %d addr %x pte %x" % \
589 (l_done, l_err, l_badtree, addr, pte))
590
591 stop = True
592
593
594 def test_mmu():
595 dut = MMU()
596 vl = rtlil.convert(dut, ports=[])#dut.ports())
597 with open("test_mmu.il", "w") as f:
598 f.write(vl)
599
600 m = Module()
601 m.submodules.mmu = dut
602
603 # nmigen Simulation
604 sim = Simulator(m)
605 sim.add_clock(1e-6)
606
607 sim.add_sync_process(wrap(mmu_sim(dut)))
608 sim.add_sync_process(wrap(dcache_get(dut)))
609 with sim.write_vcd('test_mmu.vcd'):
610 sim.run()
611
612 if __name__ == '__main__':
613 test_mmu()