mmu code-morph
[soc.git] / src / soc / experiment / mmu.py
1 # MMU
2 #
3 # License for original copyright mmu.vhdl by microwatt authors: CC4
4 # License for copyrighted modifications made in mmu.py: LGPLv3+
5 #
6 # This derivative work although includes CC4 licensed material is
7 # covered by the LGPLv3+
8
9 """MMU
10
11 based on Anton Blanchard microwatt mmu.vhdl
12
13 """
14 from enum import Enum, unique
15 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
16 from nmigen.cli import main
17 from nmigen.cli import rtlil
18 from nmutil.iocontrol import RecordObject
19 from nmutil.byterev import byte_reverse
20 from nmutil.mask import Mask, masked
21 from nmutil.util import Display
22
23 if True:
24 from nmigen.back.pysim import Simulator, Delay, Settle
25 else:
26 from nmigen.sim.cxxsim import Simulator, Delay, Settle
27 from nmutil.util import wrap
28
29 from soc.experiment.mem_types import (LoadStore1ToMMUType,
30 MMUToLoadStore1Type,
31 MMUToDCacheType,
32 DCacheToMMUType,
33 MMUToICacheType)
34
35
36 @unique
37 class State(Enum):
38 IDLE = 0 # zero is default on reset for r.state
39 DO_TLBIE = 1
40 TLB_WAIT = 2
41 PROC_TBL_READ = 3
42 PROC_TBL_WAIT = 4
43 SEGMENT_CHECK = 5
44 RADIX_LOOKUP = 6
45 RADIX_READ_WAIT = 7
46 RADIX_LOAD_TLB = 8
47 RADIX_FINISH = 9
48
49
50 class RegStage(RecordObject):
51 def __init__(self, name=None):
52 super().__init__(name=name)
53 # latched request from loadstore1
54 self.valid = Signal()
55 self.iside = Signal()
56 self.store = Signal()
57 self.priv = Signal()
58 self.addr = Signal(64)
59 self.inval_all = Signal()
60 # config SPRs
61 self.prtbl = Signal(64)
62 self.pid = Signal(32)
63 # internal state
64 self.state = Signal(State) # resets to IDLE
65 self.done = Signal()
66 self.err = Signal()
67 self.pgtbl0 = Signal(64)
68 self.pt0_valid = Signal()
69 self.pgtbl3 = Signal(64)
70 self.pt3_valid = Signal()
71 self.shift = Signal(6)
72 self.mask_size = Signal(5)
73 self.pgbase = Signal(56)
74 self.pde = Signal(64)
75 self.invalid = Signal()
76 self.badtree = Signal()
77 self.segerror = Signal()
78 self.perm_err = Signal()
79 self.rc_error = Signal()
80
81
82 class MMU(Elaboratable):
83 """Radix MMU
84
85 Supports 4-level trees as in arch 3.0B, but not the
86 two-step translation for guests under a hypervisor
87 (i.e. there is no gRA -> hRA translation).
88 """
89 def __init__(self):
90 self.l_in = LoadStore1ToMMUType()
91 self.l_out = MMUToLoadStore1Type()
92 self.d_out = MMUToDCacheType()
93 self.d_in = DCacheToMMUType()
94 self.i_out = MMUToICacheType()
95
96 def radix_tree_idle(self, m, l_in, r, v):
97 comb = m.d.comb
98 pt_valid = Signal()
99 pgtbl = Signal(64)
100 rts = Signal(6)
101 mbits = Signal(6)
102
103 with m.If(~l_in.addr[63]):
104 comb += pgtbl.eq(r.pgtbl0)
105 comb += pt_valid.eq(r.pt0_valid)
106 with m.Else():
107 comb += pgtbl.eq(r.pt3_valid)
108 comb += pt_valid.eq(r.pt3_valid)
109
110 # rts == radix tree size, number of address bits
111 # being translated
112 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
113
114 # mbits == number of address bits to index top
115 # level of tree
116 comb += mbits.eq(pgtbl[0:5])
117
118 # set v.shift to rts so that we can use finalmask
119 # for the segment check
120 comb += v.shift.eq(rts)
121 comb += v.mask_size.eq(mbits[0:5])
122 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
123
124 with m.If(l_in.valid):
125 comb += v.addr.eq(l_in.addr)
126 comb += v.iside.eq(l_in.iside)
127 comb += v.store.eq(~(l_in.load | l_in.iside))
128
129 with m.If(l_in.tlbie):
130 # Invalidate all iTLB/dTLB entries for
131 # tlbie with RB[IS] != 0 or RB[AP] != 0,
132 # or for slbia
133 comb += v.inval_all.eq(l_in.slbia
134 | l_in.addr[11]
135 | l_in.addr[10]
136 | l_in.addr[7]
137 | l_in.addr[6]
138 | l_in.addr[5]
139 )
140 # The RIC field of the tlbie instruction
141 # comes across on the sprn bus as bits 2--3.
142 # RIC=2 flushes process table caches.
143 with m.If(l_in.sprn[3]):
144 comb += v.pt0_valid.eq(0)
145 comb += v.pt3_valid.eq(0)
146 comb += v.state.eq(State.DO_TLBIE)
147 with m.Else():
148 comb += v.valid.eq(1)
149 with m.If(~pt_valid):
150 # need to fetch process table entry
151 # set v.shift so we can use finalmask
152 # for generating the process table
153 # entry address
154 comb += v.shift.eq(r.prtbl[0:5])
155 comb += v.state.eq(State.PROC_TBL_READ)
156
157 with m.If(~mbits):
158 # Use RPDS = 0 to disable radix tree walks
159 comb += v.state.eq(State.RADIX_FINISH)
160 comb += v.invalid.eq(1)
161 with m.Else():
162 comb += v.state.eq(State.SEGMENT_CHECK)
163
164 with m.If(l_in.mtspr):
165 # Move to PID needs to invalidate L1 TLBs
166 # and cached pgtbl0 value. Move to PRTBL
167 # does that plus invalidating the cached
168 # pgtbl3 value as well.
169 with m.If(~l_in.sprn[9]):
170 comb += v.pid.eq(l_in.rs[0:32])
171 with m.Else():
172 comb += v.prtbl.eq(l_in.rs)
173 comb += v.pt3_valid.eq(0)
174
175 comb += v.pt0_valid.eq(0)
176 comb += v.inval_all.eq(1)
177 comb += v.state.eq(State.DO_TLBIE)
178
179 def proc_tbl_wait(self, m, v, r, data):
180 comb = m.d.comb
181 with m.If(r.addr[63]):
182 comb += v.pgtbl3.eq(data)
183 comb += v.pt3_valid.eq(1)
184 with m.Else():
185 comb += v.pgtbl0.eq(data)
186 comb += v.pt0_valid.eq(1)
187
188 rts = Signal(6)
189 mbits = Signal(6)
190
191 # rts == radix tree size, # address bits being translated
192 comb += rts.eq(Cat(data[5:8], data[61:63]))
193
194 # mbits == # address bits to index top level of tree
195 comb += mbits.eq(data[0:5])
196
197 # set v.shift to rts so that we can use finalmask for the segment check
198 comb += v.shift.eq(rts)
199 comb += v.mask_size.eq(mbits[0:5])
200 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
201
202 with m.If(mbits):
203 comb += v.state.eq(State.SEGMENT_CHECK)
204 with m.Else():
205 comb += v.state.eq(State.RADIX_FINISH)
206 comb += v.invalid.eq(1)
207
208 def radix_read_wait(self, m, v, r, d_in, data):
209 comb = m.d.comb
210 comb += v.pde.eq(data)
211
212 perm_ok = Signal()
213 rc_ok = Signal()
214 mbits = Signal(6)
215 vbit = Signal(2)
216
217 # test valid bit
218 comb += vbit.eq(data[62:]) # leaf=data[62], valid=data[63]
219
220 # valid & leaf
221 with m.If(vbit == 0b11):
222 # check permissions and RC bits
223 with m.If(r.priv | ~data[3]):
224 with m.If(~r.iside):
225 comb += perm_ok.eq(data[1:3].bool() & ~r.store)
226 with m.Else():
227 # no IAMR, so no KUEP support for now
228 # deny execute permission if cache inhibited
229 comb += perm_ok.eq(data[0] & ~data[5])
230
231 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
232 with m.If(perm_ok & rc_ok):
233 comb += v.state.eq(State.RADIX_LOAD_TLB)
234 with m.Else():
235 comb += v.state.eq(State.RADIX_FINISH)
236 comb += v.perm_err.eq(~perm_ok)
237 # permission error takes precedence over RC error
238 comb += v.rc_error.eq(perm_ok)
239
240 # valid & !leaf
241 with m.Elif(vbit == 0b10):
242 comb += mbits.eq(data[0:5])
243 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
244 comb += v.state.eq(State.RADIX_FINISH)
245 comb += v.badtree.eq(1)
246 with m.Else():
247 comb += v.shift.eq(v.shift - mbits)
248 comb += v.mask_size.eq(mbits[0:5])
249 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
250 comb += v.state.eq(State.RADIX_LOOKUP)
251
252 def segment_check(self, m, v, r, data, finalmask):
253 comb = m.d.comb
254
255 mbits = Signal(6)
256 nonzero = Signal()
257 comb += mbits.eq(r.mask_size)
258 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
259 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
260 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
261 comb += v.state.eq(State.RADIX_FINISH)
262 comb += v.segerror.eq(1)
263 with m.Elif((mbits < 5) | (mbits > 16) |
264 (mbits > (r.shift + (31-12)))):
265 comb += v.state.eq(State.RADIX_FINISH)
266 comb += v.badtree.eq(1)
267 with m.Else():
268 comb += v.state.eq(State.RADIX_LOOKUP)
269
270 def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
271 comb = m.d.comb
272 sync = m.d.sync
273
274 # Multiplex internal SPR values back to loadstore1,
275 # selected by l_in.sprn.
276 with m.If(l_in.sprn[9]):
277 comb += l_out.sprval.eq(r.prtbl)
278 with m.Else():
279 comb += l_out.sprval.eq(r.pid)
280
281 with m.If(rin.valid):
282 sync += Display("MMU got tlb miss for %x", rin.addr)
283
284 with m.If(l_out.done):
285 sync += Display("MMU completing op without error")
286
287 with m.If(l_out.err):
288 sync += Display("MMU completing op with err invalid"
289 "%d badtree=%d", l_out.invalid, l_out.badtree)
290
291 with m.If(rin.state == State.RADIX_LOOKUP):
292 sync += Display ("radix lookup shift=%d msize=%d",
293 rin.shift, rin.mask_size)
294
295 with m.If(r.state == State.RADIX_LOOKUP):
296 sync += Display(f"send load addr=%x addrsh=%d mask=%d",
297 d_out.addr, addrsh, mask)
298 sync += r.eq(rin)
299
300 def elaborate(self, platform):
301 m = Module()
302
303 comb = m.d.comb
304 sync = m.d.sync
305
306 addrsh = Signal(16)
307 mask = Signal(16)
308 finalmask = Signal(44)
309
310 r = RegStage("r")
311 rin = RegStage("r_in")
312
313 l_in = self.l_in
314 l_out = self.l_out
315 d_out = self.d_out
316 d_in = self.d_in
317 i_out = self.i_out
318
319 self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
320
321 v = RegStage()
322 dcreq = Signal()
323 tlb_load = Signal()
324 itlb_load = Signal()
325 tlbie_req = Signal()
326 prtbl_rd = Signal()
327 effpid = Signal(32)
328 prtb_adr = Signal(64)
329 pgtb_addr = Signal(64)
330 pte = Signal(64)
331 tlb_data = Signal(64)
332 addr = Signal(64)
333
334 comb += v.eq(r)
335 comb += v.valid.eq(0)
336 comb += dcreq.eq(0)
337 comb += v.done.eq(0)
338 comb += v.err.eq(0)
339 comb += v.invalid.eq(0)
340 comb += v.badtree.eq(0)
341 comb += v.segerror.eq(0)
342 comb += v.perm_err.eq(0)
343 comb += v.rc_error.eq(0)
344 comb += tlb_load.eq(0)
345 comb += itlb_load.eq(0)
346 comb += tlbie_req.eq(0)
347 comb += v.inval_all.eq(0)
348 comb += prtbl_rd.eq(0)
349
350 # Radix tree data structures in memory are
351 # big-endian, so we need to byte-swap them
352 data = byte_reverse(m, "data", d_in.data, 8)
353
354 # generate mask for extracting address fields for PTE addr generation
355 m.submodules.pte_mask = pte_mask = Mask(16-5)
356 comb += pte_mask.shift.eq(r.mask_size - 5)
357 comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
358
359 # generate mask for extracting address bits to go in
360 # TLB entry in order to support pages > 4kB
361 m.submodules.tlb_mask = tlb_mask = Mask(44)
362 comb += tlb_mask.shift.eq(r.shift)
363 comb += finalmask.eq(tlb_mask.mask)
364
365 with m.Switch(r.state):
366 with m.Case(State.IDLE):
367 self.radix_tree_idle(m, l_in, r, v)
368
369 with m.Case(State.DO_TLBIE):
370 comb += dcreq.eq(1)
371 comb += tlbie_req.eq(1)
372 comb += v.state.eq(State.TLB_WAIT)
373
374 with m.Case(State.TLB_WAIT):
375 with m.If(d_in.done):
376 comb += v.state.eq(State.RADIX_FINISH)
377
378 with m.Case(State.PROC_TBL_READ):
379 comb += dcreq.eq(1)
380 comb += prtbl_rd.eq(1)
381 comb += v.state.eq(State.PROC_TBL_WAIT)
382
383 with m.Case(State.PROC_TBL_WAIT):
384 with m.If(d_in.done):
385 self.proc_tbl_wait(m, v, r, data)
386
387 with m.If(d_in.err):
388 comb += v.state.eq(State.RADIX_FINISH)
389 comb += v.badtree.eq(1)
390
391 with m.Case(State.SEGMENT_CHECK):
392 self.segment_check(m, v, r, data, finalmask)
393
394 with m.Case(State.RADIX_LOOKUP):
395 comb += dcreq.eq(1)
396 comb += v.state.eq(State.RADIX_READ_WAIT)
397
398 with m.Case(State.RADIX_READ_WAIT):
399 with m.If(d_in.done):
400 self.radix_read_wait(m, v, r, d_in, data)
401 with m.Else():
402 # non-present PTE, generate a DSI
403 comb += v.state.eq(State.RADIX_FINISH)
404 comb += v.invalid.eq(1)
405
406 with m.If(d_in.err):
407 comb += v.state.eq(State.RADIX_FINISH)
408 comb += v.badtree.eq(1)
409
410 with m.Case(State.RADIX_LOAD_TLB):
411 comb += tlb_load.eq(1)
412 with m.If(~r.iside):
413 comb += dcreq.eq(1)
414 comb += v.state.eq(State.TLB_WAIT)
415 with m.Else():
416 comb += itlb_load.eq(1)
417 comb += v.state.eq(State.IDLE)
418
419 with m.Case(State.RADIX_FINISH):
420 comb += v.state.eq(State.IDLE)
421
422 with m.If((v.state == State.RADIX_FINISH) |
423 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
424 comb += v.err.eq(v.invalid | v.badtree | v.segerror
425 | v.perm_err | v.rc_error)
426 comb += v.done.eq(~v.err)
427
428 with m.If(~r.addr[63]):
429 comb += effpid.eq(r.pid)
430
431 pr24 = Signal(24, reset_less=True)
432 comb += pr24.eq(masked(r.prtbl[12:36], effpid[8:32], finalmask))
433 comb += prtb_adr.eq(Cat(C(0, 4), effpid[0:8], pr24, r.prtbl[36:56]))
434
435 pg16 = Signal(16, reset_less=True)
436 comb += pg16.eq(masked(r.pgbase[3:19], addrsh, mask))
437 comb += pgtb_addr.eq(Cat(C(0, 3), pg16, r.pgbase[19:56]))
438
439 pd44 = Signal(44, reset_less=True)
440 comb += pd44.eq(masked(r.pde[12:56], r.addr[12:56], finalmask))
441 comb += pte.eq(Cat(r.pde[0:12], pd44))
442
443 # update registers
444 comb += rin.eq(v)
445
446 # drive outputs
447 with m.If(tlbie_req):
448 comb += addr.eq(r.addr)
449 with m.Elif(tlb_load):
450 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
451 comb += tlb_data.eq(pte)
452 with m.Elif(prtbl_rd):
453 comb += addr.eq(prtb_adr)
454 with m.Else():
455 comb += addr.eq(pgtb_addr)
456
457 comb += l_out.done.eq(r.done)
458 comb += l_out.err.eq(r.err)
459 comb += l_out.invalid.eq(r.invalid)
460 comb += l_out.badtree.eq(r.badtree)
461 comb += l_out.segerr.eq(r.segerror)
462 comb += l_out.perm_error.eq(r.perm_err)
463 comb += l_out.rc_error.eq(r.rc_error)
464
465 comb += d_out.valid.eq(dcreq)
466 comb += d_out.tlbie.eq(tlbie_req)
467 comb += d_out.doall.eq(r.inval_all)
468 comb += d_out.tlbld.eq(tlb_load)
469 comb += d_out.addr.eq(addr)
470 comb += d_out.pte.eq(tlb_data)
471
472 comb += i_out.tlbld.eq(itlb_load)
473 comb += i_out.tlbie.eq(tlbie_req)
474 comb += i_out.doall.eq(r.inval_all)
475 comb += i_out.addr.eq(addr)
476 comb += i_out.pte.eq(tlb_data)
477
478 return m
479
480 stop = False
481
482 def dcache_get(dut):
483 """simulator process for getting memory load requests
484 """
485
486 mem = {0x10000: 0x12345678}
487
488 while not stop:
489 while True: # wait for dc_valid
490 if stop:
491 return
492 dc_valid = yield (dut.d_out.valid)
493 if dc_valid:
494 break
495 yield
496 addr = yield dut.d_out.addr
497 yield dut.d_in.data.eq(mem[addr])
498 yield dut.d_in.done.eq(1)
499 yield
500 yield dut.d_in.done.eq(0)
501
502
503 def mmu_sim(dut):
504 global stop
505 yield
506 yield
507 yield
508 stop = True
509
510 def test_mmu():
511 dut = MMU()
512 vl = rtlil.convert(dut, ports=[])#dut.ports())
513 with open("test_mmu.il", "w") as f:
514 f.write(vl)
515
516 m = Module()
517 m.submodules.mmu = dut
518
519 # nmigen Simulation
520 sim = Simulator(m)
521 sim.add_clock(1e-6)
522
523 sim.add_sync_process(wrap(mmu_sim(dut)))
524 sim.add_sync_process(wrap(dcache_get(dut)))
525 with sim.write_vcd('test_mmu.vcd'):
526 sim.run()
527
528 if __name__ == '__main__':
529 test_mmu()