more experimenting with mmu READ_WAIT state
[soc.git] / src / soc / experiment / mmu.py
1 # MMU
2 #
3 # License for original copyright mmu.vhdl by microwatt authors: CC4
4 # License for copyrighted modifications made in mmu.py: LGPLv3+
5 #
6 # This derivative work although includes CC4 licensed material is
7 # covered by the LGPLv3+
8
9 """MMU
10
11 based on Anton Blanchard microwatt mmu.vhdl
12
13 """
14 from enum import Enum, unique
15 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
16 from nmigen.cli import main
17 from nmigen.cli import rtlil
18 from nmutil.iocontrol import RecordObject
19 from nmutil.byterev import byte_reverse
20 from nmutil.mask import Mask, masked
21 from nmutil.util import Display
22
23 if True:
24 from nmigen.back.pysim import Simulator, Delay, Settle
25 else:
26 from nmigen.sim.cxxsim import Simulator, Delay, Settle
27 from nmutil.util import wrap
28
29 from soc.experiment.mem_types import (LoadStore1ToMMUType,
30 MMUToLoadStore1Type,
31 MMUToDCacheType,
32 DCacheToMMUType,
33 MMUToICacheType)
34
35
36 @unique
37 class State(Enum):
38 IDLE = 0 # zero is default on reset for r.state
39 DO_TLBIE = 1
40 TLB_WAIT = 2
41 PROC_TBL_READ = 3
42 PROC_TBL_WAIT = 4
43 SEGMENT_CHECK = 5
44 RADIX_LOOKUP = 6
45 RADIX_READ_WAIT = 7
46 RADIX_LOAD_TLB = 8
47 RADIX_FINISH = 9
48
49
50 class RegStage(RecordObject):
51 def __init__(self, name=None):
52 super().__init__(name=name)
53 # latched request from loadstore1
54 self.valid = Signal()
55 self.iside = Signal()
56 self.store = Signal()
57 self.priv = Signal()
58 self.addr = Signal(64)
59 self.inval_all = Signal()
60 # config SPRs
61 self.prtbl = Signal(64)
62 self.pid = Signal(32)
63 # internal state
64 self.state = Signal(State) # resets to IDLE
65 self.done = Signal()
66 self.err = Signal()
67 self.pgtbl0 = Signal(64)
68 self.pt0_valid = Signal()
69 self.pgtbl3 = Signal(64)
70 self.pt3_valid = Signal()
71 self.shift = Signal(6)
72 self.mask_size = Signal(5)
73 self.pgbase = Signal(56)
74 self.pde = Signal(64)
75 self.invalid = Signal()
76 self.badtree = Signal()
77 self.segerror = Signal()
78 self.perm_err = Signal()
79 self.rc_error = Signal()
80
81
82 class MMU(Elaboratable):
83 """Radix MMU
84
85 Supports 4-level trees as in arch 3.0B, but not the
86 two-step translation for guests under a hypervisor
87 (i.e. there is no gRA -> hRA translation).
88 """
89 def __init__(self):
90 self.l_in = LoadStore1ToMMUType()
91 self.l_out = MMUToLoadStore1Type()
92 self.d_out = MMUToDCacheType()
93 self.d_in = DCacheToMMUType()
94 self.i_out = MMUToICacheType()
95
96 def radix_tree_idle(self, m, l_in, r, v):
97 comb = m.d.comb
98 sync = m.d.sync
99
100 pt_valid = Signal()
101 pgtbl = Signal(64)
102 rts = Signal(6)
103 mbits = Signal(6)
104
105 with m.If(~l_in.addr[63]):
106 comb += pgtbl.eq(r.pgtbl0)
107 comb += pt_valid.eq(r.pt0_valid)
108 with m.Else():
109 comb += pgtbl.eq(r.pgtbl3)
110 comb += pt_valid.eq(r.pt3_valid)
111
112 # rts == radix tree size, number of address bits
113 # being translated
114 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
115
116 # mbits == number of address bits to index top
117 # level of tree
118 comb += mbits.eq(pgtbl[0:5])
119
120 # set v.shift to rts so that we can use finalmask
121 # for the segment check
122 comb += v.shift.eq(rts)
123 comb += v.mask_size.eq(mbits[0:5])
124 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
125
126 with m.If(l_in.valid):
127 comb += v.addr.eq(l_in.addr)
128 comb += v.iside.eq(l_in.iside)
129 comb += v.store.eq(~(l_in.load | l_in.iside))
130 comb += v.priv.eq(l_in.priv)
131
132 comb += Display("l_in.valid addr %x iside %d store %d "
133 "rts %x mbits %x pt_valid %d",
134 v.addr, v.iside, v.store,
135 rts, mbits, pt_valid)
136
137 with m.If(l_in.tlbie):
138 # Invalidate all iTLB/dTLB entries for
139 # tlbie with RB[IS] != 0 or RB[AP] != 0,
140 # or for slbia
141 comb += v.inval_all.eq(l_in.slbia
142 | l_in.addr[11]
143 | l_in.addr[10]
144 | l_in.addr[7]
145 | l_in.addr[6]
146 | l_in.addr[5]
147 )
148 # The RIC field of the tlbie instruction
149 # comes across on the sprn bus as bits 2--3.
150 # RIC=2 flushes process table caches.
151 with m.If(l_in.sprn[3]):
152 comb += v.pt0_valid.eq(0)
153 comb += v.pt3_valid.eq(0)
154 comb += v.state.eq(State.DO_TLBIE)
155 with m.Else():
156 comb += v.valid.eq(1)
157 with m.If(~pt_valid):
158 # need to fetch process table entry
159 # set v.shift so we can use finalmask
160 # for generating the process table
161 # entry address
162 comb += v.shift.eq(r.prtbl[0:5])
163 comb += v.state.eq(State.PROC_TBL_READ)
164
165 with m.Elif(mbits == 0):
166 # Use RPDS = 0 to disable radix tree walks
167 comb += v.state.eq(State.RADIX_FINISH)
168 comb += v.invalid.eq(1)
169 with m.Else():
170 comb += v.state.eq(State.SEGMENT_CHECK)
171
172 with m.If(l_in.mtspr):
173 # Move to PID needs to invalidate L1 TLBs
174 # and cached pgtbl0 value. Move to PRTBL
175 # does that plus invalidating the cached
176 # pgtbl3 value as well.
177 with m.If(~l_in.sprn[9]):
178 comb += v.pid.eq(l_in.rs[0:32])
179 with m.Else():
180 comb += v.prtbl.eq(l_in.rs)
181 comb += v.pt3_valid.eq(0)
182
183 comb += v.pt0_valid.eq(0)
184 comb += v.inval_all.eq(1)
185 comb += v.state.eq(State.DO_TLBIE)
186
187 def proc_tbl_wait(self, m, v, r, data):
188 comb = m.d.comb
189 with m.If(r.addr[63]):
190 comb += v.pgtbl3.eq(data)
191 comb += v.pt3_valid.eq(1)
192 with m.Else():
193 comb += v.pgtbl0.eq(data)
194 comb += v.pt0_valid.eq(1)
195
196 rts = Signal(6)
197 mbits = Signal(6)
198
199 # rts == radix tree size, # address bits being translated
200 comb += rts.eq(Cat(data[5:8], data[61:63]))
201
202 # mbits == # address bits to index top level of tree
203 comb += mbits.eq(data[0:5])
204
205 # set v.shift to rts so that we can use finalmask for the segment check
206 comb += v.shift.eq(rts)
207 comb += v.mask_size.eq(mbits[0:5])
208 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
209
210 with m.If(mbits):
211 comb += v.state.eq(State.SEGMENT_CHECK)
212 with m.Else():
213 comb += v.state.eq(State.RADIX_FINISH)
214 comb += v.invalid.eq(1)
215
216 def radix_read_wait(self, m, v, r, d_in, data):
217 comb = m.d.comb
218 sync = m.d.sync
219
220 perm_ok = Signal()
221 rc_ok = Signal()
222 mbits = Signal(6)
223 valid = Signal()
224 leaf = Signal()
225 badtree = Signal()
226
227 with m.If(d_in.done & (r.state == State.RADIX_READ_WAIT)):
228 comb += Display("RDW %016x done %d "
229 "perm %d rc %d mbits %d rshift %d "
230 "valid %d leaf %d badtree %d",
231 data, d_in.done, perm_ok, rc_ok,
232 mbits, r.shift, valid, leaf, badtree)
233
234 # set pde
235 comb += v.pde.eq(data)
236
237 # test valid bit
238 comb += valid.eq(data[63]) # valid=data[63]
239 comb += leaf.eq(data[62]) # valid=data[63]
240
241 comb += v.pde.eq(data)
242 # valid & leaf
243 with m.If(valid):
244 with m.If(leaf):
245 # check permissions and RC bits
246 with m.If(r.priv | ~data[3]):
247 with m.If(~r.iside):
248 comb += perm_ok.eq(data[1] | (data[2] & ~r.store))
249 with m.Else():
250 # no IAMR, so no KUEP support for now
251 # deny execute permission if cache inhibited
252 comb += perm_ok.eq(data[0] & ~data[5])
253
254 comb += rc_ok.eq(data[8] & (data[7] | ~r.store))
255 with m.If(perm_ok & rc_ok):
256 comb += v.state.eq(State.RADIX_LOAD_TLB)
257 with m.Else():
258 comb += v.state.eq(State.RADIX_FINISH)
259 comb += v.perm_err.eq(~perm_ok)
260 # permission error takes precedence over RC error
261 comb += v.rc_error.eq(perm_ok)
262
263 # valid & !leaf
264 with m.Else():
265 comb += mbits.eq(data[0:5])
266 comb += badtree.eq((mbits < 5) |
267 (mbits > 16) |
268 (mbits > r.shift))
269 with m.If(badtree):
270 comb += v.state.eq(State.RADIX_FINISH)
271 comb += v.badtree.eq(1)
272 with m.Else():
273 comb += v.shift.eq(v.shift - mbits)
274 comb += v.mask_size.eq(mbits[0:5])
275 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
276 comb += v.state.eq(State.RADIX_LOOKUP)
277
278 with m.Else():
279 # non-present PTE, generate a DSI
280 comb += v.state.eq(State.RADIX_FINISH)
281 comb += v.invalid.eq(1)
282
283 def segment_check(self, m, v, r, data, finalmask):
284 comb = m.d.comb
285
286 mbits = Signal(6)
287 nonzero = Signal()
288 comb += mbits.eq(r.mask_size)
289 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
290 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
291 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
292 comb += v.state.eq(State.RADIX_FINISH)
293 comb += v.segerror.eq(1)
294 with m.Elif((mbits < 5) | (mbits > 16) |
295 (mbits > (r.shift + (31-12)))):
296 comb += v.state.eq(State.RADIX_FINISH)
297 comb += v.badtree.eq(1)
298 with m.Else():
299 comb += v.state.eq(State.RADIX_LOOKUP)
300
301 def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
302 comb = m.d.comb
303 sync = m.d.sync
304
305 # Multiplex internal SPR values back to loadstore1,
306 # selected by l_in.sprn.
307 with m.If(l_in.sprn[9]):
308 comb += l_out.sprval.eq(r.prtbl)
309 with m.Else():
310 comb += l_out.sprval.eq(r.pid)
311
312 with m.If(rin.valid):
313 sync += Display("MMU got tlb miss for %x", rin.addr)
314
315 with m.If(l_out.done):
316 sync += Display("MMU completing op without error")
317
318 with m.If(l_out.err):
319 sync += Display("MMU completing op with err invalid"
320 "%d badtree=%d", l_out.invalid, l_out.badtree)
321
322 with m.If(rin.state == State.RADIX_LOOKUP):
323 sync += Display ("radix lookup shift=%d msize=%d",
324 rin.shift, rin.mask_size)
325
326 with m.If(r.state == State.RADIX_LOOKUP):
327 sync += Display(f"send load addr=%x addrsh=%d mask=%x",
328 d_out.addr, addrsh, mask)
329 sync += r.eq(rin)
330
331 def elaborate(self, platform):
332 m = Module()
333
334 comb = m.d.comb
335 sync = m.d.sync
336
337 addrsh = Signal(16)
338 mask = Signal(16)
339 finalmask = Signal(44)
340
341 self.rin = rin = RegStage("r_in")
342 r = RegStage("r")
343
344 l_in = self.l_in
345 l_out = self.l_out
346 d_out = self.d_out
347 d_in = self.d_in
348 i_out = self.i_out
349
350 self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
351
352 v = RegStage()
353 dcreq = Signal()
354 tlb_load = Signal()
355 itlb_load = Signal()
356 tlbie_req = Signal()
357 prtbl_rd = Signal()
358 effpid = Signal(32)
359 prtb_adr = Signal(64)
360 pgtb_adr = Signal(64)
361 pte = Signal(64)
362 tlb_data = Signal(64)
363 addr = Signal(64)
364
365 comb += v.eq(r)
366 comb += v.valid.eq(0)
367 comb += dcreq.eq(0)
368 comb += v.done.eq(0)
369 comb += v.err.eq(0)
370 comb += v.invalid.eq(0)
371 comb += v.badtree.eq(0)
372 comb += v.segerror.eq(0)
373 comb += v.perm_err.eq(0)
374 comb += v.rc_error.eq(0)
375 comb += tlb_load.eq(0)
376 comb += itlb_load.eq(0)
377 comb += tlbie_req.eq(0)
378 comb += v.inval_all.eq(0)
379 comb += prtbl_rd.eq(0)
380
381 # Radix tree data structures in memory are
382 # big-endian, so we need to byte-swap them
383 data = byte_reverse(m, "data", d_in.data, 8)
384
385 # generate mask for extracting address fields for PTE addr generation
386 m.submodules.pte_mask = pte_mask = Mask(16-5)
387 comb += pte_mask.shift.eq(r.mask_size - 5)
388 comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
389
390 # generate mask for extracting address bits to go in
391 # TLB entry in order to support pages > 4kB
392 m.submodules.tlb_mask = tlb_mask = Mask(44)
393 comb += tlb_mask.shift.eq(r.shift)
394 comb += finalmask.eq(tlb_mask.mask)
395
396 with m.If(r.state != State.IDLE):
397 sync += Display("MMU state %d %016x", r.state, data)
398
399 with m.Switch(r.state):
400 with m.Case(State.IDLE):
401 self.radix_tree_idle(m, l_in, r, v)
402
403 with m.Case(State.DO_TLBIE):
404 comb += dcreq.eq(1)
405 comb += tlbie_req.eq(1)
406 comb += v.state.eq(State.TLB_WAIT)
407
408 with m.Case(State.TLB_WAIT):
409 with m.If(d_in.done):
410 comb += v.state.eq(State.RADIX_FINISH)
411
412 with m.Case(State.PROC_TBL_READ):
413 sync += Display(" TBL_READ %016x", prtb_adr)
414 comb += dcreq.eq(1)
415 comb += prtbl_rd.eq(1)
416 comb += v.state.eq(State.PROC_TBL_WAIT)
417
418 with m.Case(State.PROC_TBL_WAIT):
419 with m.If(d_in.done):
420 self.proc_tbl_wait(m, v, r, data)
421
422 with m.If(d_in.err):
423 comb += v.state.eq(State.RADIX_FINISH)
424 comb += v.badtree.eq(1)
425
426 with m.Case(State.SEGMENT_CHECK):
427 self.segment_check(m, v, r, data, finalmask)
428
429 with m.Case(State.RADIX_LOOKUP):
430 sync += Display(" RADIX_LOOKUP")
431 comb += dcreq.eq(1)
432 comb += v.state.eq(State.RADIX_READ_WAIT)
433
434 with m.Case(State.RADIX_READ_WAIT):
435 sync += Display(" READ_WAIT")
436 with m.If(d_in.done):
437 self.radix_read_wait(m, v, r, d_in, data)
438
439 with m.If(d_in.err):
440 comb += v.state.eq(State.RADIX_FINISH)
441 comb += v.badtree.eq(1)
442
443 with m.Case(State.RADIX_LOAD_TLB):
444 comb += tlb_load.eq(1)
445 with m.If(~r.iside):
446 comb += dcreq.eq(1)
447 comb += v.state.eq(State.TLB_WAIT)
448 with m.Else():
449 comb += itlb_load.eq(1)
450 comb += v.state.eq(State.IDLE)
451
452 with m.Case(State.RADIX_FINISH):
453 sync += Display(" RADIX_FINISH")
454 comb += v.state.eq(State.IDLE)
455
456 with m.If((v.state == State.RADIX_FINISH) |
457 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
458 comb += v.err.eq(v.invalid | v.badtree | v.segerror
459 | v.perm_err | v.rc_error)
460 comb += v.done.eq(~v.err)
461
462 with m.If(~r.addr[63]):
463 comb += effpid.eq(r.pid)
464
465 pr24 = Signal(24, reset_less=True)
466 comb += pr24.eq(masked(r.prtbl[12:36], effpid[8:32], finalmask))
467 comb += prtb_adr.eq(Cat(C(0, 4), effpid[0:8], pr24, r.prtbl[36:56]))
468
469 pg16 = Signal(16, reset_less=True)
470 comb += pg16.eq(masked(r.pgbase[3:19], addrsh, mask))
471 comb += pgtb_adr.eq(Cat(C(0, 3), pg16, r.pgbase[19:56]))
472
473 pd44 = Signal(44, reset_less=True)
474 comb += pd44.eq(masked(r.pde[12:56], r.addr[12:56], finalmask))
475 comb += pte.eq(Cat(r.pde[0:12], pd44))
476
477 # update registers
478 comb += rin.eq(v)
479
480 # drive outputs
481 with m.If(tlbie_req):
482 comb += addr.eq(r.addr)
483 with m.Elif(tlb_load):
484 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
485 comb += tlb_data.eq(pte)
486 with m.Elif(prtbl_rd):
487 comb += addr.eq(prtb_adr)
488 with m.Else():
489 comb += addr.eq(pgtb_adr)
490
491 comb += l_out.done.eq(r.done)
492 comb += l_out.err.eq(r.err)
493 comb += l_out.invalid.eq(r.invalid)
494 comb += l_out.badtree.eq(r.badtree)
495 comb += l_out.segerr.eq(r.segerror)
496 comb += l_out.perm_error.eq(r.perm_err)
497 comb += l_out.rc_error.eq(r.rc_error)
498
499 comb += d_out.valid.eq(dcreq)
500 comb += d_out.tlbie.eq(tlbie_req)
501 comb += d_out.doall.eq(r.inval_all)
502 comb += d_out.tlbld.eq(tlb_load)
503 comb += d_out.addr.eq(addr)
504 comb += d_out.pte.eq(tlb_data)
505
506 comb += i_out.tlbld.eq(itlb_load)
507 comb += i_out.tlbie.eq(tlbie_req)
508 comb += i_out.doall.eq(r.inval_all)
509 comb += i_out.addr.eq(addr)
510 comb += i_out.pte.eq(tlb_data)
511
512 return m
513
514 stop = False
515
516 def dcache_get(dut):
517 """simulator process for getting memory load requests
518 """
519
520 global stop
521
522 def b(x):
523 return int.from_bytes(x.to_bytes(8, byteorder='little'),
524 byteorder='big', signed=False)
525
526 mem = {0x10000: # PARTITION_TABLE_2
527 # PATB_GR=1 PRTB=0x1000 PRTS=0xb
528 b(0x800000000100000b),
529
530 0x30000: # RADIX_ROOT_PTE
531 # V = 1 L = 0 NLB = 0x400 NLS = 9
532 b(0x8000000000040009),
533
534 0x1000000: # PROCESS_TABLE_3
535 # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
536 b(0x40000000000300ad),
537 }
538
539 while not stop:
540 while True: # wait for dc_valid
541 if stop:
542 return
543 dc_valid = yield (dut.d_out.valid)
544 if dc_valid:
545 break
546 yield
547 addr = yield dut.d_out.addr
548 if addr not in mem:
549 print (" DCACHE LOOKUP FAIL %x" % (addr))
550 stop = True
551 return
552
553 data = mem[addr]
554 yield dut.d_in.data.eq(data)
555 print ("dcache get %x data %x" % (addr, data))
556 yield
557 yield dut.d_in.done.eq(1)
558 yield
559 yield dut.d_in.done.eq(0)
560
561
562 def mmu_sim(dut):
563 yield dut.rin.prtbl.eq(0x1000000) # set process table
564 yield
565
566 yield dut.l_in.load.eq(1)
567 yield dut.l_in.priv.eq(1)
568 yield dut.l_in.addr.eq(0x10000)
569 yield dut.l_in.valid.eq(1)
570 while True: # wait for dc_valid / err
571 l_done = yield (dut.l_out.done)
572 l_err = yield (dut.l_out.err)
573 l_badtree = yield (dut.l_out.badtree)
574 l_permerr = yield (dut.l_out.perm_error)
575 l_rc_err = yield (dut.l_out.rc_error)
576 l_segerr = yield (dut.l_out.segerr)
577 l_invalid = yield (dut.l_out.invalid)
578 if (l_done or l_err or l_badtree or
579 l_permerr or l_rc_err or l_segerr or l_invalid):
580 break
581 yield
582 addr = yield dut.d_out.addr
583 pte = yield dut.d_out.pte
584 print ("translated done %d err %d badtree %d addr %x pte %x" % \
585 (l_done, l_err, l_badtree, addr, pte))
586
587 global stop
588 stop = True
589
590
591 def test_mmu():
592 dut = MMU()
593 vl = rtlil.convert(dut, ports=[])#dut.ports())
594 with open("test_mmu.il", "w") as f:
595 f.write(vl)
596
597 m = Module()
598 m.submodules.mmu = dut
599
600 # nmigen Simulation
601 sim = Simulator(m)
602 sim.add_clock(1e-6)
603
604 sim.add_sync_process(wrap(mmu_sim(dut)))
605 sim.add_sync_process(wrap(dcache_get(dut)))
606 with sim.write_vcd('test_mmu.vcd'):
607 sim.run()
608
609 if __name__ == '__main__':
610 test_mmu()