radix: reading first page table entry
[soc.git] / src / soc / experiment / mmu.py
1 # MMU
2 #
3 # License for original copyright mmu.vhdl by microwatt authors: CC4
4 # License for copyrighted modifications made in mmu.py: LGPLv3+
5 #
6 # This derivative work although includes CC4 licensed material is
7 # covered by the LGPLv3+
8
9 """MMU
10
11 based on Anton Blanchard microwatt mmu.vhdl
12
13 """
14 from enum import Enum, unique
15 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
16 from nmigen.cli import main
17 from nmigen.cli import rtlil
18 from nmutil.iocontrol import RecordObject
19 from nmutil.byterev import byte_reverse
20 from nmutil.mask import Mask, masked
21 from nmutil.util import Display
22
23 # NOTE: to use cxxsim, export NMIGEN_SIM_MODE=cxxsim from the shell
24 # Also, check out the cxxsim nmigen branch, and latest yosys from git
25 from nmutil.sim_tmp_alternative import Simulator, Settle
26
27 from nmutil.util import wrap
28
29 from soc.experiment.mem_types import (LoadStore1ToMMUType,
30 MMUToLoadStore1Type,
31 MMUToDCacheType,
32 DCacheToMMUType,
33 MMUToICacheType)
34
35
36 @unique
37 class State(Enum):
38 IDLE = 0 # zero is default on reset for r.state
39 DO_TLBIE = 1
40 TLB_WAIT = 2
41 PROC_TBL_READ = 3
42 PROC_TBL_WAIT = 4
43 SEGMENT_CHECK = 5
44 RADIX_LOOKUP = 6
45 RADIX_READ_WAIT = 7
46 RADIX_LOAD_TLB = 8
47 RADIX_FINISH = 9
48
49
50 class RegStage(RecordObject):
51 def __init__(self, name=None):
52 super().__init__(name=name)
53 # latched request from loadstore1
54 self.valid = Signal()
55 self.iside = Signal()
56 self.store = Signal()
57 self.priv = Signal()
58 self.addr = Signal(64)
59 self.inval_all = Signal()
60 # config SPRs
61 self.prtbl = Signal(64)
62 self.pid = Signal(32)
63 # internal state
64 self.state = Signal(State) # resets to IDLE
65 self.done = Signal()
66 self.err = Signal()
67 self.pgtbl0 = Signal(64)
68 self.pt0_valid = Signal()
69 self.pgtbl3 = Signal(64)
70 self.pt3_valid = Signal()
71 self.shift = Signal(6)
72 self.mask_size = Signal(5)
73 self.pgbase = Signal(56)
74 self.pde = Signal(64)
75 self.invalid = Signal()
76 self.badtree = Signal()
77 self.segerror = Signal()
78 self.perm_err = Signal()
79 self.rc_error = Signal()
80
81
82 class MMU(Elaboratable):
83 """Radix MMU
84
85 Supports 4-level trees as in arch 3.0B, but not the
86 two-step translation for guests under a hypervisor
87 (i.e. there is no gRA -> hRA translation).
88 """
89 def __init__(self):
90 self.l_in = LoadStore1ToMMUType()
91 self.l_out = MMUToLoadStore1Type()
92 self.d_out = MMUToDCacheType()
93 self.d_in = DCacheToMMUType()
94 self.i_out = MMUToICacheType()
95
96 def radix_tree_idle(self, m, l_in, r, v):
97 comb = m.d.comb
98 sync = m.d.sync
99
100 pt_valid = Signal()
101 pgtbl = Signal(64)
102 rts = Signal(6)
103 mbits = Signal(6)
104
105 with m.If(~l_in.addr[63]):
106 comb += pgtbl.eq(r.pgtbl0)
107 comb += pt_valid.eq(r.pt0_valid)
108 with m.Else():
109 comb += pgtbl.eq(r.pgtbl3)
110 comb += pt_valid.eq(r.pt3_valid)
111
112 # rts == radix tree size, number of address bits
113 # being translated
114 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
115
116 # mbits == number of address bits to index top
117 # level of tree
118 comb += mbits.eq(pgtbl[0:5])
119
120 # set v.shift to rts so that we can use finalmask
121 # for the segment check
122 comb += v.shift.eq(rts)
123 comb += v.mask_size.eq(mbits[0:5])
124 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
125
126 with m.If(l_in.valid):
127 comb += v.addr.eq(l_in.addr)
128 comb += v.iside.eq(l_in.iside)
129 comb += v.store.eq(~(l_in.load | l_in.iside))
130 comb += v.priv.eq(l_in.priv)
131
132 comb += Display("state %d l_in.valid addr %x iside %d store %d "
133 "rts %x mbits %x pt_valid %d",
134 v.state, v.addr, v.iside, v.store,
135 rts, mbits, pt_valid)
136
137 with m.If(l_in.tlbie):
138 # Invalidate all iTLB/dTLB entries for
139 # tlbie with RB[IS] != 0 or RB[AP] != 0,
140 # or for slbia
141 comb += v.inval_all.eq(l_in.slbia
142 | l_in.addr[11]
143 | l_in.addr[10]
144 | l_in.addr[7]
145 | l_in.addr[6]
146 | l_in.addr[5]
147 )
148 # The RIC field of the tlbie instruction
149 # comes across on the sprn bus as bits 2--3.
150 # RIC=2 flushes process table caches.
151 with m.If(l_in.sprn[3]):
152 comb += v.pt0_valid.eq(0)
153 comb += v.pt3_valid.eq(0)
154 comb += v.state.eq(State.DO_TLBIE)
155 with m.Else():
156 comb += v.valid.eq(1)
157 with m.If(~pt_valid):
158 # need to fetch process table entry
159 # set v.shift so we can use finalmask
160 # for generating the process table
161 # entry address
162 comb += v.shift.eq(r.prtbl[0:5])
163 comb += v.state.eq(State.PROC_TBL_READ)
164
165 with m.Elif(mbits == 0):
166 # Use RPDS = 0 to disable radix tree walks
167 comb += v.state.eq(State.RADIX_FINISH)
168 comb += v.invalid.eq(1)
169 with m.Else():
170 comb += v.state.eq(State.SEGMENT_CHECK)
171
172 with m.If(l_in.mtspr):
173 # Move to PID needs to invalidate L1 TLBs
174 # and cached pgtbl0 value. Move to PRTBL
175 # does that plus invalidating the cached
176 # pgtbl3 value as well.
177 with m.If(~l_in.sprn[9]):
178 comb += v.pid.eq(l_in.rs[0:32])
179 with m.Else():
180 comb += v.prtbl.eq(l_in.rs)
181 comb += v.pt3_valid.eq(0)
182
183 comb += v.pt0_valid.eq(0)
184 comb += v.inval_all.eq(1)
185 comb += v.state.eq(State.DO_TLBIE)
186
187 def proc_tbl_wait(self, m, v, r, data):
188 comb = m.d.comb
189 with m.If(r.addr[63]):
190 comb += v.pgtbl3.eq(data)
191 comb += v.pt3_valid.eq(1)
192 with m.Else():
193 comb += v.pgtbl0.eq(data)
194 comb += v.pt0_valid.eq(1)
195
196 rts = Signal(6)
197 mbits = Signal(6)
198
199 # rts == radix tree size, # address bits being translated
200 comb += rts.eq(Cat(data[5:8], data[61:63]))
201
202 # mbits == # address bits to index top level of tree
203 comb += mbits.eq(data[0:5])
204
205 # set v.shift to rts so that we can use finalmask for the segment check
206 comb += v.shift.eq(rts)
207 comb += v.mask_size.eq(mbits[0:5])
208 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
209
210 with m.If(mbits):
211 comb += v.state.eq(State.SEGMENT_CHECK)
212 with m.Else():
213 comb += v.state.eq(State.RADIX_FINISH)
214 comb += v.invalid.eq(1)
215
216 def radix_read_wait(self, m, v, r, d_in, data):
217 comb = m.d.comb
218 sync = m.d.sync
219
220 perm_ok = Signal()
221 rc_ok = Signal()
222 mbits = Signal(6)
223 valid = Signal()
224 leaf = Signal()
225 badtree = Signal()
226
227 comb += Display("RDW %016x done %d "
228 "perm %d rc %d mbits %d shf %d "
229 "valid %d leaf %d bad %d",
230 data, d_in.done, perm_ok, rc_ok,
231 mbits, r.shift, valid, leaf, badtree)
232
233 # set pde
234 comb += v.pde.eq(data)
235
236 # test valid bit
237 comb += valid.eq(data[63]) # valid=data[63]
238 comb += leaf.eq(data[62]) # valid=data[63]
239
240 comb += v.pde.eq(data)
241 # valid & leaf
242 with m.If(valid):
243 with m.If(leaf):
244 # check permissions and RC bits
245 with m.If(r.priv | ~data[3]):
246 with m.If(~r.iside):
247 comb += perm_ok.eq(data[1] | (data[2] & ~r.store))
248 with m.Else():
249 # no IAMR, so no KUEP support for now
250 # deny execute permission if cache inhibited
251 comb += perm_ok.eq(data[0] & ~data[5])
252
253 comb += rc_ok.eq(data[8] & (data[7] | ~r.store))
254 with m.If(perm_ok & rc_ok):
255 comb += v.state.eq(State.RADIX_LOAD_TLB)
256 with m.Else():
257 comb += v.state.eq(State.RADIX_FINISH)
258 comb += v.perm_err.eq(~perm_ok)
259 # permission error takes precedence over RC error
260 comb += v.rc_error.eq(perm_ok)
261
262 # valid & !leaf
263 with m.Else():
264 comb += mbits.eq(data[0:5])
265 comb += badtree.eq((mbits < 5) |
266 (mbits > 16) |
267 (mbits > r.shift))
268 with m.If(badtree):
269 comb += v.state.eq(State.RADIX_FINISH)
270 comb += v.badtree.eq(1)
271 with m.Else():
272 comb += v.shift.eq(r.shift - mbits)
273 comb += v.mask_size.eq(mbits[0:5])
274 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
275 comb += v.state.eq(State.RADIX_LOOKUP)
276
277 with m.Else():
278 # non-present PTE, generate a DSI
279 comb += v.state.eq(State.RADIX_FINISH)
280 comb += v.invalid.eq(1)
281
282 def segment_check(self, m, v, r, data, finalmask):
283 comb = m.d.comb
284
285 mbits = Signal(6)
286 nonzero = Signal()
287 comb += mbits.eq(r.mask_size)
288 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
289 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
290 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
291 comb += v.state.eq(State.RADIX_FINISH)
292 comb += v.segerror.eq(1)
293 with m.Elif((mbits < 5) | (mbits > 16) |
294 (mbits > (r.shift + (31-12)))):
295 comb += v.state.eq(State.RADIX_FINISH)
296 comb += v.badtree.eq(1)
297 with m.Else():
298 comb += v.state.eq(State.RADIX_LOOKUP)
299
300 def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
301 comb = m.d.comb
302 sync = m.d.sync
303
304 # Multiplex internal SPR values back to loadstore1,
305 # selected by l_in.sprn.
306 with m.If(l_in.sprn[9]):
307 comb += l_out.sprval.eq(r.prtbl)
308 with m.Else():
309 comb += l_out.sprval.eq(r.pid)
310
311 with m.If(rin.valid):
312 sync += Display("MMU got tlb miss for %x", rin.addr)
313
314 with m.If(l_out.done):
315 sync += Display("MMU completing op without error")
316
317 with m.If(l_out.err):
318 sync += Display("MMU completing op with err invalid"
319 "%d badtree=%d", l_out.invalid, l_out.badtree)
320
321 with m.If(rin.state == State.RADIX_LOOKUP):
322 sync += Display ("radix lookup shift=%d msize=%d",
323 rin.shift, rin.mask_size)
324
325 with m.If(r.state == State.RADIX_LOOKUP):
326 sync += Display(f"send load addr=%x addrsh=%d mask=%x",
327 d_out.addr, addrsh, mask)
328 sync += r.eq(rin)
329
330 def elaborate(self, platform):
331 m = Module()
332
333 comb = m.d.comb
334 sync = m.d.sync
335
336 addrsh = Signal(16)
337 mask = Signal(16)
338 finalmask = Signal(44)
339
340 self.rin = rin = RegStage("r_in")
341 r = RegStage("r")
342
343 l_in = self.l_in
344 l_out = self.l_out
345 d_out = self.d_out
346 d_in = self.d_in
347 i_out = self.i_out
348
349 self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
350
351 v = RegStage()
352 dcreq = Signal()
353 tlb_load = Signal()
354 itlb_load = Signal()
355 tlbie_req = Signal()
356 prtbl_rd = Signal()
357 effpid = Signal(32)
358 prtb_adr = Signal(64)
359 pgtb_adr = Signal(64)
360 pte = Signal(64)
361 tlb_data = Signal(64)
362 addr = Signal(64)
363
364 comb += v.eq(r)
365 comb += v.valid.eq(0)
366 comb += dcreq.eq(0)
367 comb += v.done.eq(0)
368 comb += v.err.eq(0)
369 comb += v.invalid.eq(0)
370 comb += v.badtree.eq(0)
371 comb += v.segerror.eq(0)
372 comb += v.perm_err.eq(0)
373 comb += v.rc_error.eq(0)
374 comb += tlb_load.eq(0)
375 comb += itlb_load.eq(0)
376 comb += tlbie_req.eq(0)
377 comb += v.inval_all.eq(0)
378 comb += prtbl_rd.eq(0)
379
380 # Radix tree data structures in memory are
381 # big-endian, so we need to byte-swap them
382 data = byte_reverse(m, "data", d_in.data, 8)
383
384 # generate mask for extracting address fields for PTE addr generation
385 m.submodules.pte_mask = pte_mask = Mask(16-5)
386 comb += pte_mask.shift.eq(r.mask_size - 5)
387 comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
388
389 # generate mask for extracting address bits to go in
390 # TLB entry in order to support pages > 4kB
391 m.submodules.tlb_mask = tlb_mask = Mask(44)
392 comb += tlb_mask.shift.eq(r.shift)
393 comb += finalmask.eq(tlb_mask.mask)
394
395 with m.If(r.state != State.IDLE):
396 sync += Display("MMU state %d %016x", r.state, data)
397
398 with m.Switch(r.state):
399 with m.Case(State.IDLE):
400 self.radix_tree_idle(m, l_in, r, v)
401
402 with m.Case(State.DO_TLBIE):
403 comb += dcreq.eq(1)
404 comb += tlbie_req.eq(1)
405 comb += v.state.eq(State.TLB_WAIT)
406
407 with m.Case(State.TLB_WAIT):
408 with m.If(d_in.done):
409 comb += v.state.eq(State.RADIX_FINISH)
410
411 with m.Case(State.PROC_TBL_READ):
412 sync += Display(" TBL_READ %016x", prtb_adr)
413 comb += dcreq.eq(1)
414 comb += prtbl_rd.eq(1)
415 comb += v.state.eq(State.PROC_TBL_WAIT)
416
417 with m.Case(State.PROC_TBL_WAIT):
418 with m.If(d_in.done):
419 self.proc_tbl_wait(m, v, r, data)
420
421 with m.If(d_in.err):
422 comb += v.state.eq(State.RADIX_FINISH)
423 comb += v.badtree.eq(1)
424
425 with m.Case(State.SEGMENT_CHECK):
426 self.segment_check(m, v, r, data, finalmask)
427
428 with m.Case(State.RADIX_LOOKUP):
429 sync += Display(" RADIX_LOOKUP")
430 comb += dcreq.eq(1)
431 comb += v.state.eq(State.RADIX_READ_WAIT)
432
433 with m.Case(State.RADIX_READ_WAIT):
434 sync += Display(" READ_WAIT")
435 with m.If(d_in.done):
436 self.radix_read_wait(m, v, r, d_in, data)
437 with m.If(d_in.err):
438 comb += v.state.eq(State.RADIX_FINISH)
439 comb += v.badtree.eq(1)
440
441 with m.Case(State.RADIX_LOAD_TLB):
442 comb += tlb_load.eq(1)
443 with m.If(~r.iside):
444 comb += dcreq.eq(1)
445 comb += v.state.eq(State.TLB_WAIT)
446 with m.Else():
447 comb += itlb_load.eq(1)
448 comb += v.state.eq(State.IDLE)
449
450 with m.Case(State.RADIX_FINISH):
451 sync += Display(" RADIX_FINISH")
452 comb += v.state.eq(State.IDLE)
453
454 with m.If((v.state == State.RADIX_FINISH) |
455 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
456 comb += v.err.eq(v.invalid | v.badtree | v.segerror
457 | v.perm_err | v.rc_error)
458 comb += v.done.eq(~v.err)
459
460 with m.If(~r.addr[63]):
461 comb += effpid.eq(r.pid)
462
463 pr24 = Signal(24, reset_less=True)
464 comb += pr24.eq(masked(r.prtbl[12:36], effpid[8:32], finalmask))
465 comb += prtb_adr.eq(Cat(C(0, 4), effpid[0:8], pr24, r.prtbl[36:56]))
466
467 pg16 = Signal(16, reset_less=True)
468 comb += pg16.eq(masked(r.pgbase[3:19], addrsh, mask))
469 comb += pgtb_adr.eq(Cat(C(0, 3), pg16, r.pgbase[19:56]))
470
471 pd44 = Signal(44, reset_less=True)
472 comb += pd44.eq(masked(r.pde[12:56], r.addr[12:56], finalmask))
473 comb += pte.eq(Cat(r.pde[0:12], pd44))
474
475 # update registers
476 comb += rin.eq(v)
477
478 # drive outputs
479 with m.If(tlbie_req):
480 comb += addr.eq(r.addr)
481 with m.Elif(tlb_load):
482 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
483 comb += tlb_data.eq(pte)
484 with m.Elif(prtbl_rd):
485 comb += addr.eq(prtb_adr)
486 with m.Else():
487 comb += addr.eq(pgtb_adr)
488
489 comb += l_out.done.eq(r.done)
490 comb += l_out.err.eq(r.err)
491 comb += l_out.invalid.eq(r.invalid)
492 comb += l_out.badtree.eq(r.badtree)
493 comb += l_out.segerr.eq(r.segerror)
494 comb += l_out.perm_error.eq(r.perm_err)
495 comb += l_out.rc_error.eq(r.rc_error)
496
497 comb += d_out.valid.eq(dcreq)
498 comb += d_out.tlbie.eq(tlbie_req)
499 comb += d_out.doall.eq(r.inval_all)
500 comb += d_out.tlbld.eq(tlb_load)
501 comb += d_out.addr.eq(addr)
502 comb += d_out.pte.eq(tlb_data)
503
504 comb += i_out.tlbld.eq(itlb_load)
505 comb += i_out.tlbie.eq(tlbie_req)
506 comb += i_out.doall.eq(r.inval_all)
507 comb += i_out.addr.eq(addr)
508 comb += i_out.pte.eq(tlb_data)
509
510 return m
511
512 stop = False
513
514 def dcache_get(dut):
515 """simulator process for getting memory load requests
516 """
517
518 global stop
519
520 def b(x):
521 return int.from_bytes(x.to_bytes(8, byteorder='little'),
522 byteorder='big', signed=False)
523
524 mem = {0x0: 0x000000, # to get mtspr prtbl working
525
526 0x10000: # PARTITION_TABLE_2
527 # PATB_GR=1 PRTB=0x1000 PRTS=0xb
528 b(0x800000000100000b),
529
530 0x30000: # RADIX_ROOT_PTE
531 # V = 1 L = 0 NLB = 0x400 NLS = 9
532 b(0x8000000000040009),
533
534 0x40000: # RADIX_SECOND_LEVEL
535 # V = 1 L = 1 SW = 0 RPN = 0
536 # R = 1 C = 1 ATT = 0 EAA 0x7
537 b(0xc000000000000187),
538
539 0x1000000: # PROCESS_TABLE_3
540 # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
541 b(0x40000000000300ad),
542 }
543
544 while not stop:
545 while True: # wait for dc_valid
546 if stop:
547 return
548 dc_valid = yield (dut.d_out.valid)
549 if dc_valid:
550 break
551 yield
552 addr = yield dut.d_out.addr
553 if addr not in mem:
554 print (" DCACHE LOOKUP FAIL %x" % (addr))
555 stop = True
556 return
557
558 yield
559 data = mem[addr]
560 yield dut.d_in.data.eq(data)
561 print (" DCACHE GET %x data %x" % (addr, data))
562 yield dut.d_in.done.eq(1)
563 yield
564 yield dut.d_in.done.eq(0)
565
566 def mmu_wait(dut):
567 global stop
568 while not stop: # wait for dc_valid / err
569 l_done = yield (dut.l_out.done)
570 l_err = yield (dut.l_out.err)
571 l_badtree = yield (dut.l_out.badtree)
572 l_permerr = yield (dut.l_out.perm_error)
573 l_rc_err = yield (dut.l_out.rc_error)
574 l_segerr = yield (dut.l_out.segerr)
575 l_invalid = yield (dut.l_out.invalid)
576 if (l_done or l_err or l_badtree or
577 l_permerr or l_rc_err or l_segerr or l_invalid):
578 break
579 yield
580 yield dut.l_in.valid.eq(0) # data already in MMU by now
581 yield dut.l_in.mtspr.eq(0) # captured by RegStage(s)
582 yield dut.l_in.load.eq(0) # can reset everything safely
583
584 def mmu_sim(dut):
585 global stop
586
587 # MMU MTSPR set prtbl
588 yield dut.l_in.mtspr.eq(1)
589 yield dut.l_in.sprn[9].eq(1) # totally fake way to set SPR=prtbl
590 yield dut.l_in.rs.eq(0x1000000) # set process table
591 yield dut.l_in.valid.eq(1)
592 yield from mmu_wait(dut)
593 yield
594 yield dut.l_in.sprn.eq(0)
595 yield dut.l_in.rs.eq(0)
596 yield
597
598 prtbl = yield (dut.rin.prtbl)
599 print ("prtbl after MTSPR %x" % prtbl)
600 assert prtbl == 0x1000000
601
602 #yield dut.rin.prtbl.eq(0x1000000) # manually set process table
603 #yield
604
605
606 # MMU PTE request
607 yield dut.l_in.load.eq(1)
608 yield dut.l_in.priv.eq(1)
609 yield dut.l_in.addr.eq(0x10000)
610 yield dut.l_in.valid.eq(1)
611 yield from mmu_wait(dut)
612
613 addr = yield dut.d_out.addr
614 pte = yield dut.d_out.pte
615 l_done = yield (dut.l_out.done)
616 l_err = yield (dut.l_out.err)
617 l_badtree = yield (dut.l_out.badtree)
618 print ("translated done %d err %d badtree %d addr %x pte %x" % \
619 (l_done, l_err, l_badtree, addr, pte))
620 yield
621 yield dut.l_in.priv.eq(0)
622 yield dut.l_in.addr.eq(0)
623
624
625 stop = True
626
627
628 def test_mmu():
629 dut = MMU()
630 vl = rtlil.convert(dut, ports=[])#dut.ports())
631 with open("test_mmu.il", "w") as f:
632 f.write(vl)
633
634 m = Module()
635 m.submodules.mmu = dut
636
637 # nmigen Simulation
638 sim = Simulator(m)
639 sim.add_clock(1e-6)
640
641 sim.add_sync_process(wrap(mmu_sim(dut)))
642 sim.add_sync_process(wrap(dcache_get(dut)))
643 with sim.write_vcd('test_mmu.vcd'):
644 sim.run()
645
646 if __name__ == '__main__':
647 test_mmu()