mmu.py remove duplicate comment left over from mmu.vhdl
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12 from nmutil.mask import Mask
13
14
15 from soc.experiment.mem_types import (LoadStore1ToMmuType,
16 MmuToLoadStore1Type,
17 MmuToDcacheType,
18 DcacheToMmuType,
19 MmuToIcacheType)
20
21
22 @unique
23 class State(Enum):
24 IDLE = 0 # zero is default on reset for r.state
25 DO_TLBIE = 1
26 TLB_WAIT = 2
27 PROC_TBL_READ = 3
28 PROC_TBL_WAIT = 4
29 SEGMENT_CHECK = 5
30 RADIX_LOOKUP = 6
31 RADIX_READ_WAIT = 7
32 RADIX_LOAD_TLB = 8
33 RADIX_FINISH = 9
34
35
36 class RegStage(RecordObject):
37 def __init__(self, name=None):
38 super().__init__(name=name)
39 # latched request from loadstore1
40 self.valid = Signal()
41 self.iside = Signal()
42 self.store = Signal()
43 self.priv = Signal()
44 self.addr = Signal(64)
45 self.inval_all = Signal()
46 # config SPRs
47 self.prtbl = Signal(64)
48 self.pid = Signal(32)
49 # internal state
50 self.state = Signal(State) # resets to IDLE
51 self.done = Signal()
52 self.err = Signal()
53 self.pgtbl0 = Signal(64)
54 self.pt0_valid = Signal()
55 self.pgtbl3 = Signal(64)
56 self.pt3_valid = Signal()
57 self.shift = Signal(6)
58 self.mask_size = Signal(5)
59 self.pgbase = Signal(56)
60 self.pde = Signal(64)
61 self.invalid = Signal()
62 self.badtree = Signal()
63 self.segerror = Signal()
64 self.perm_err = Signal()
65 self.rc_error = Signal()
66
67
68 class MMU(Elaboratable):
69 """Radix MMU
70
71 Supports 4-level trees as in arch 3.0B, but not the
72 two-step translation for guests under a hypervisor
73 (i.e. there is no gRA -> hRA translation).
74 """
75 def __init__(self):
76 self.l_in = LoadStore1ToMmuType()
77 self.l_out = MmuToLoadStore1Type()
78 self.d_out = MmuToDcacheType()
79 self.d_in = DcacheToMmuType()
80 self.i_out = MmuToIcacheType()
81
82 def radix_tree_idle(self, m, l_in, r, v):
83 comb = m.d.comb
84 pt_valid = Signal()
85 pgtbl = Signal(64)
86 with m.If(~l_in.addr[63]):
87 comb += pgtbl.eq(r.pgtbl0)
88 comb += pt_valid.eq(r.pt0_valid)
89 with m.Else():
90 comb += pgtbl.eq(r.pt3_valid)
91 comb += pt_valid.eq(r.pt3_valid)
92
93 # rts == radix tree size, number of address bits
94 # being translated
95 rts = Signal(6)
96 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
97
98 # mbits == number of address bits to index top
99 # level of tree
100 mbits = Signal(6)
101 comb += mbits.eq(pgtbl[0:5])
102
103 # set v.shift to rts so that we can use finalmask
104 # for the segment check
105 comb += v.shift.eq(rts)
106 comb += v.mask_size.eq(mbits[0:5])
107 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
108
109 with m.If(l_in.valid):
110 comb += v.addr.eq(l_in.addr)
111 comb += v.iside.eq(l_in.iside)
112 comb += v.store.eq(~(l_in.load | l_in.iside))
113
114 with m.If(l_in.tlbie):
115 # Invalidate all iTLB/dTLB entries for
116 # tlbie with RB[IS] != 0 or RB[AP] != 0,
117 # or for slbia
118 comb += v.inval_all.eq(l_in.slbia
119 | l_in.addr[11]
120 | l_in.addr[10]
121 | l_in.addr[7]
122 | l_in.addr[6]
123 | l_in.addr[5]
124 )
125 # The RIC field of the tlbie instruction
126 # comes across on the sprn bus as bits 2--3.
127 # RIC=2 flushes process table caches.
128 with m.If(l_in.sprn[3]):
129 comb += v.pt0_valid.eq(0)
130 comb += v.pt3_valid.eq(0)
131 comb += v.state.eq(State.DO_TLBIE)
132 with m.Else():
133 comb += v.valid.eq(1)
134 with m.If(~pt_valid):
135 # need to fetch process table entry
136 # set v.shift so we can use finalmask
137 # for generating the process table
138 # entry address
139 comb += v.shift.eq(r.prtbl[0:5])
140 comb += v.state.eq(State.PROC_TBL_READ)
141
142 with m.If(~mbits):
143 # Use RPDS = 0 to disable radix tree walks
144 comb += v.state.eq(State.RADIX_FINISH)
145 comb += v.invalid.eq(1)
146 with m.Else():
147 comb += v.state.eq(State.SEGMENT_CHECK)
148
149 with m.If(l_in.mtspr):
150 # Move to PID needs to invalidate L1 TLBs
151 # and cached pgtbl0 value. Move to PRTBL
152 # does that plus invalidating the cached
153 # pgtbl3 value as well.
154 with m.If(~l_in.sprn[9]):
155 comb += v.pid.eq(l_in.rs[0:32])
156 with m.Else():
157 comb += v.prtbl.eq(l_in.rs)
158 comb += v.pt3_valid.eq(0)
159
160 comb += v.pt0_valid.eq(0)
161 comb += v.inval_all.eq(1)
162 comb += v.state.eq(State.DO_TLBIE)
163
164 def proc_tbl_wait(self, m, v, r, data):
165 comb = m.d.comb
166 with m.If(r.addr[63]):
167 comb += v.pgtbl3.eq(data)
168 comb += v.pt3_valid.eq(1)
169 with m.Else():
170 comb += v.pgtbl0.eq(data)
171 comb += v.pt0_valid.eq(1)
172 # rts == radix tree size, # address bits being translated
173 rts = Signal(6)
174 comb += rts.eq(Cat(data[5:8], data[61:63]))
175
176 # mbits == # address bits to index top level of tree
177 mbits = Signal(6)
178 comb += mbits.eq(data[0:5])
179 # set v.shift to rts so that we can use
180 # finalmask for the segment check
181 comb += v.shift.eq(rts)
182 comb += v.mask_size.eq(mbits[0:5])
183 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
184
185 with m.If(~mbits):
186 comb += v.state.eq(State.RADIX_FINISH)
187 comb += v.invalid.eq(1)
188 comb += v.state.eq(State.SEGMENT_CHECK)
189
190 def radix_read_wait(self, m, v, r, d_in, data):
191 comb = m.d.comb
192 comb += v.pde.eq(data)
193 # test valid bit
194 with m.If(data[63]):
195 with m.If(data[62]):
196 # check permissions and RC bits
197 perm_ok = Signal()
198 comb += perm_ok.eq(0)
199 with m.If(r.priv | ~data[3]):
200 with m.If(~r.iside):
201 comb += perm_ok.eq(
202 (data[1] | data[2])
203 & (~r.store)
204 )
205 with m.Else():
206 # no IAMR, so no KUEP support
207 # for now deny execute
208 # permission if cache inhibited
209 comb += perm_ok.eq(data[0] & ~data[5])
210
211 rc_ok = Signal()
212 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
213 with m.If(perm_ok & rc_ok):
214 comb += v.state.eq(State.RADIX_LOAD_TLB)
215 with m.Else():
216 comb += v.state.eq(State.RADIX_FINISH)
217 comb += v.perm_err.eq(~perm_ok)
218 # permission error takes precedence
219 # over RC error
220 comb += v.rc_error.eq(perm_ok)
221 with m.Else():
222 mbits = Signal(6)
223 comb += mbits.eq(data[0:5])
224 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
225 comb += v.state.eq(State.RADIX_FINISH)
226 comb += v.badtree.eq(1)
227 with m.Else():
228 comb += v.shift.eq(v.shift - mbits)
229 comb += v.mask_size.eq(mbits[0:5])
230 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
231 comb += v.state.eq(State.RADIX_LOOKUP)
232
233 def segment_check(self, m, v, r, data, finalmask):
234 comb = m.d.comb
235 mbits = Signal(6)
236 nonzero = Signal()
237 comb += mbits.eq(r.mask_size)
238 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
239 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
240 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
241 comb += v.state.eq(State.RADIX_FINISH)
242 comb += v.segerror.eq(1)
243 with m.Elif((mbits < 5) | (mbits > 16) |
244 (mbits > (r.shift + (31-12)))):
245 comb += v.state.eq(State.RADIX_FINISH)
246 comb += v.badtree.eq(1)
247 with m.Else():
248 comb += v.state.eq(State.RADIX_LOOKUP)
249
250 def elaborate(self, platform):
251 m = Module()
252
253 comb = m.d.comb
254 sync = m.d.sync
255
256 addrsh = Signal(16)
257 mask = Signal(16)
258 finalmask = Signal(44)
259
260 r = RegStage("r")
261 rin = RegStage("r_in")
262
263 l_in = self.l_in
264 l_out = self.l_out
265 d_out = self.d_out
266 d_in = self.d_in
267 i_out = self.i_out
268
269 # Multiplex internal SPR values back to loadstore1,
270 # selected by l_in.sprn.
271 with m.If(l_in.sprn[9]):
272 comb += l_out.sprval.eq(r.prtbl)
273 with m.Else():
274 comb += l_out.sprval.eq(r.pid)
275
276 with m.If(rin.valid):
277 pass
278 #sync += Display(f"MMU got tlb miss for {rin.addr}")
279
280 with m.If(l_out.done):
281 pass
282 # sync += Display("MMU completing op without error")
283
284 with m.If(l_out.err):
285 pass
286 # sync += Display(f"MMU completing op with err invalid"
287 # "{l_out.invalid} badtree={l_out.badtree}")
288
289 with m.If(rin.state == State.RADIX_LOOKUP):
290 pass
291 # sync += Display (f"radix lookup shift={rin.shift}"
292 # "msize={rin.mask_size}")
293
294 with m.If(r.state == State.RADIX_LOOKUP):
295 pass
296 # sync += Display(f"send load addr={d_out.addr}"
297 # "addrsh={addrsh} mask={mask}")
298
299 sync += r.eq(rin)
300
301 v = RegStage()
302 dcreq = Signal()
303 tlb_load = Signal()
304 itlb_load = Signal()
305 tlbie_req = Signal()
306 prtbl_rd = Signal()
307 effpid = Signal(32)
308 prtable_addr = Signal(64)
309 pgtable_addr = Signal(64)
310 pte = Signal(64)
311 tlb_data = Signal(64)
312 addr = Signal(64)
313
314 comb += v.eq(r)
315 comb += v.valid.eq(0)
316 comb += dcreq.eq(0)
317 comb += v.done.eq(0)
318 comb += v.err.eq(0)
319 comb += v.invalid.eq(0)
320 comb += v.badtree.eq(0)
321 comb += v.segerror.eq(0)
322 comb += v.perm_err.eq(0)
323 comb += v.rc_error.eq(0)
324 comb += tlb_load.eq(0)
325 comb += itlb_load.eq(0)
326 comb += tlbie_req.eq(0)
327 comb += v.inval_all.eq(0)
328 comb += prtbl_rd.eq(0)
329
330 # Radix tree data structures in memory are
331 # big-endian, so we need to byte-swap them
332 data = byte_reverse(m, "data", d_in.data, 8)
333
334 # generate mask for extracting address fields for PTE addr generation
335 m.submodules.pte_mask = pte_mask = Mask(16-5)
336 comb += pte_mask.shift.eq(r.mask_size - 5)
337 comb += mask.eq(Cat(C(0x1f,5), pte_mask.mask))
338
339 # generate mask for extracting address bits to go in
340 # TLB entry in order to support pages > 4kB
341 m.submodules.tlb_mask = tlb_mask = Mask(44)
342 comb += tlb_mask.shift.eq(r.shift)
343 comb += finalmask.eq(tlb_mask.mask)
344
345 with m.Switch(r.state):
346 with m.Case(State.IDLE):
347 self.radix_tree_idle(m, l_in, r, v)
348
349 with m.Case(State.DO_TLBIE):
350 comb += dcreq.eq(1)
351 comb += tlbie_req.eq(1)
352 comb += v.state.eq(State.TLB_WAIT)
353
354 with m.Case(State.TLB_WAIT):
355 with m.If(d_in.done):
356 comb += v.state.eq(State.RADIX_FINISH)
357
358 with m.Case(State.PROC_TBL_READ):
359 comb += dcreq.eq(1)
360 comb += prtbl_rd.eq(1)
361 comb += v.state.eq(State.PROC_TBL_WAIT)
362
363 with m.Case(State.PROC_TBL_WAIT):
364 with m.If(d_in.done):
365 self.proc_tbl_wait(m, v, r, data)
366
367 with m.If(d_in.err):
368 comb += v.state.eq(State.RADIX_FINISH)
369 comb += v.badtree.eq(1)
370
371 with m.Case(State.SEGMENT_CHECK):
372 self.segment_check(m, v, r, data, finalmask)
373
374 with m.Case(State.RADIX_LOOKUP):
375 comb += dcreq.eq(1)
376 comb += v.state.eq(State.RADIX_READ_WAIT)
377
378 with m.Case(State.RADIX_READ_WAIT):
379 with m.If(d_in.done):
380 self.radix_read_wait(m, v, r, d_in, data)
381 with m.Else():
382 # non-present PTE, generate a DSI
383 comb += v.state.eq(State.RADIX_FINISH)
384 comb += v.invalid.eq(1)
385
386 with m.If(d_in.err):
387 comb += v.state.eq(State.RADIX_FINISH)
388 comb += v.badtree.eq(1)
389
390 with m.Case(State.RADIX_LOAD_TLB):
391 comb += tlb_load.eq(1)
392 with m.If(~r.iside):
393 comb += dcreq.eq(1)
394 comb += v.state.eq(State.TLB_WAIT)
395 with m.Else():
396 comb += itlb_load.eq(1)
397 comb += v.state.eq(State.IDLE)
398
399 with m.Case(State.RADIX_FINISH):
400 comb += v.state.eq(State.IDLE)
401
402 with m.If((v.state == State.RADIX_FINISH) |
403 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
404 comb += v.err.eq(v.invalid | v.badtree | v.segerror
405 | v.perm_err | v.rc_error)
406 comb += v.done.eq(~v.err)
407
408 with m.If(~r.addr[63]):
409 comb += effpid.eq(r.pid)
410
411 comb += prtable_addr.eq(Cat(
412 C(0b0000, 4),
413 effpid[0:8],
414 (r.prtbl[12:36] & ~finalmask[0:24]) |
415 (effpid[8:32] & finalmask[0:24]),
416 r.prtbl[36:56]
417 ))
418
419 comb += pgtable_addr.eq(Cat(
420 C(0b000, 3),
421 (r.pgbase[3:19] & ~mask) |
422 (addrsh & mask),
423 r.pgbase[19:56]
424 ))
425
426 comb += pte.eq(Cat(
427 r.pde[0:12],
428 (r.pde[12:56] & ~finalmask) |
429 (r.addr[12:56] & finalmask),
430 ))
431
432 # update registers
433 rin.eq(v)
434
435 # drive outputs
436 with m.If(tlbie_req):
437 comb += addr.eq(r.addr)
438 with m.Elif(tlb_load):
439 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
440 comb += tlb_data.eq(pte)
441 with m.Elif(prtbl_rd):
442 comb += addr.eq(prtable_addr)
443 with m.Else():
444 comb += addr.eq(pgtable_addr)
445
446 comb += l_out.done.eq(r.done)
447 comb += l_out.err.eq(r.err)
448 comb += l_out.invalid.eq(r.invalid)
449 comb += l_out.badtree.eq(r.badtree)
450 comb += l_out.segerr.eq(r.segerror)
451 comb += l_out.perm_error.eq(r.perm_err)
452 comb += l_out.rc_error.eq(r.rc_error)
453
454 comb += d_out.valid.eq(dcreq)
455 comb += d_out.tlbie.eq(tlbie_req)
456 comb += d_out.doall.eq(r.inval_all)
457 comb += d_out.tlbld.eq(tlb_load)
458 comb += d_out.addr.eq(addr)
459 comb += d_out.pte.eq(tlb_data)
460
461 comb += i_out.tlbld.eq(itlb_load)
462 comb += i_out.tlbie.eq(tlbie_req)
463 comb += i_out.doall.eq(r.inval_all)
464 comb += i_out.addr.eq(addr)
465 comb += i_out.pte.eq(tlb_data)
466
467 return m
468
469
470 def mmu_sim():
471 yield wp.waddr.eq(1)
472 yield wp.data_i.eq(2)
473 yield wp.wen.eq(1)
474 yield
475 yield wp.wen.eq(0)
476 yield rp.ren.eq(1)
477 yield rp.raddr.eq(1)
478 yield Settle()
479 data = yield rp.data_o
480 print(data)
481 assert data == 2
482 yield
483
484 yield wp.waddr.eq(5)
485 yield rp.raddr.eq(5)
486 yield rp.ren.eq(1)
487 yield wp.wen.eq(1)
488 yield wp.data_i.eq(6)
489 yield Settle()
490 data = yield rp.data_o
491 print(data)
492 assert data == 6
493 yield
494 yield wp.wen.eq(0)
495 yield rp.ren.eq(0)
496 yield Settle()
497 data = yield rp.data_o
498 print(data)
499 assert data == 0
500 yield
501 data = yield rp.data_o
502 print(data)
503
504 def test_mmu():
505 dut = MMU()
506 vl = rtlil.convert(dut, ports=[])#dut.ports())
507 with open("test_mmu.il", "w") as f:
508 f.write(vl)
509
510 run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
511
512 if __name__ == '__main__':
513 test_mmu()