mmu.py fix formatting 80 char limit
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12
13 from soc.experiment.mem_types import (LoadStore1ToMmuType,
14 MmuToLoadStore1Type,
15 MmuToDcacheType,
16 DcacheToMmuType,
17 MmuToIcacheType)
18
19 # -- Radix MMU
20 # -- Supports 4-level trees as in arch 3.0B, but not the
21 # -- two-step translation
22 # -- for guests under a hypervisor (i.e. there is no gRA -> hRA translation).
23
24 @unique
25 class State(Enum):
26 IDLE = 0 # zero is default on reset for r.state
27 DO_TLBIE = 1
28 TLB_WAIT = 2
29 PROC_TBL_READ = 3
30 PROC_TBL_WAIT = 4
31 SEGMENT_CHECK = 5
32 RADIX_LOOKUP = 6
33 RADIX_READ_WAIT = 7
34 RADIX_LOAD_TLB = 8
35 RADIX_FINISH = 9
36
37
38 class RegStage(RecordObject):
39 def __init__(self, name=None):
40 super().__init__(name=name)
41 # latched request from loadstore1
42 self.valid = Signal()
43 self.iside = Signal()
44 self.store = Signal()
45 self.priv = Signal()
46 self.addr = Signal(64)
47 self.inval_all = Signal()
48 # config SPRs
49 self.prtbl = Signal(64)
50 self.pid = Signal(32)
51 # internal state
52 self.state = Signal(State) # resets to IDLE
53 self.done = Signal()
54 self.err = Signal()
55 self.pgtbl0 = Signal(64)
56 self.pt0_valid = Signal()
57 self.pgtbl3 = Signal(64)
58 self.pt3_valid = Signal()
59 self.shift = Signal(6)
60 self.mask_size = Signal(5)
61 self.pgbase = Signal(56)
62 self.pde = Signal(64)
63 self.invalid = Signal()
64 self.badtree = Signal()
65 self.segerror = Signal()
66 self.perm_err = Signal()
67 self.rc_error = Signal()
68
69
70 class MMU(Elaboratable):
71 """Radix MMU
72
73 Supports 4-level trees as in arch 3.0B, but not the
74 two-step translation for guests under a hypervisor
75 (i.e. there is no gRA -> hRA translation).
76 """
77 def __init__(self):
78 self.l_in = LoadStore1ToMmuType()
79 self.l_out = MmuToLoadStore1Type()
80 self.d_out = MmuToDcacheType()
81 self.d_in = DcacheToMmuType()
82 self.i_out = MmuToIcacheType()
83
84 def radix_tree_idle(self, m, l_in, r, v):
85 comb = m.d.comb
86 pt_valid = Signal()
87 pgtbl = Signal(64)
88 with m.If(~l_in.addr[63]):
89 comb += pgtbl.eq(r.pgtbl0)
90 comb += pt_valid.eq(r.pt0_valid)
91 with m.Else():
92 comb += pgtbl.eq(r.pt3_valid)
93 comb += pt_valid.eq(r.pt3_valid)
94
95 # rts == radix tree size, number of address bits
96 # being translated
97 rts = Signal(6)
98 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
99
100 # mbits == number of address bits to index top
101 # level of tree
102 mbits = Signal(6)
103 comb += mbits.eq(pgtbl[0:5])
104
105 # set v.shift to rts so that we can use finalmask
106 # for the segment check
107 comb += v.shift.eq(rts)
108 comb += v.mask_size.eq(mbits[0:5])
109 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
110
111 with m.If(l_in.valid):
112 comb += v.addr.eq(l_in.addr)
113 comb += v.iside.eq(l_in.iside)
114 comb += v.store.eq(~(l_in.load | l_in.iside))
115
116 with m.If(l_in.tlbie):
117 # Invalidate all iTLB/dTLB entries for
118 # tlbie with RB[IS] != 0 or RB[AP] != 0,
119 # or for slbia
120 comb += v.inval_all.eq(l_in.slbia
121 | l_in.addr[11]
122 | l_in.addr[10]
123 | l_in.addr[7]
124 | l_in.addr[6]
125 | l_in.addr[5]
126 )
127 # The RIC field of the tlbie instruction
128 # comes across on the sprn bus as bits 2--3.
129 # RIC=2 flushes process table caches.
130 with m.If(l_in.sprn[3]):
131 comb += v.pt0_valid.eq(0)
132 comb += v.pt3_valid.eq(0)
133 comb += v.state.eq(State.DO_TLBIE)
134 with m.Else():
135 comb += v.valid.eq(1)
136 with m.If(~pt_valid):
137 # need to fetch process table entry
138 # set v.shift so we can use finalmask
139 # for generating the process table
140 # entry address
141 comb += v.shift.eq(r.prtbl[0:5])
142 comb += v.state.eq(State.PROC_TBL_READ)
143
144 with m.If(~mbits):
145 # Use RPDS = 0 to disable radix tree walks
146 comb += v.state.eq(State.RADIX_FINISH)
147 comb += v.invalid.eq(1)
148 with m.Else():
149 comb += v.state.eq(State.SEGMENT_CHECK)
150
151 with m.If(l_in.mtspr):
152 # Move to PID needs to invalidate L1 TLBs
153 # and cached pgtbl0 value. Move to PRTBL
154 # does that plus invalidating the cached
155 # pgtbl3 value as well.
156 with m.If(~l_in.sprn[9]):
157 comb += v.pid.eq(l_in.rs[0:32])
158 with m.Else():
159 comb += v.prtbl.eq(l_in.rs)
160 comb += v.pt3_valid.eq(0)
161
162 comb += v.pt0_valid.eq(0)
163 comb += v.inval_all.eq(1)
164 comb += v.state.eq(State.DO_TLBIE)
165
166 def proc_tbl_wait(self, m, v, r, data):
167 comb = m.d.comb
168 with m.If(r.addr[63]):
169 comb += v.pgtbl3.eq(data)
170 comb += v.pt3_valid.eq(1)
171 with m.Else():
172 comb += v.pgtbl0.eq(data)
173 comb += v.pt0_valid.eq(1)
174 # rts == radix tree size, # address bits being translated
175 rts = Signal(6)
176 comb += rts.eq(Cat(data[5:8], data[61:63]))
177
178 # mbits == # address bits to index top level of tree
179 mbits = Signal(6)
180 comb += mbits.eq(data[0:5])
181 # set v.shift to rts so that we can use
182 # finalmask for the segment check
183 comb += v.shift.eq(rts)
184 comb += v.mask_size.eq(mbits[0:5])
185 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
186
187 with m.If(~mbits):
188 comb += v.state.eq(State.RADIX_FINISH)
189 comb += v.invalid.eq(1)
190 comb += v.state.eq(State.SEGMENT_CHECK)
191
192 def radix_read_wait(self, m, v, r, d_in, data):
193 comb = m.d.comb
194 comb += v.pde.eq(data)
195 # test valid bit
196 with m.If(data[63]):
197 with m.If(data[62]):
198 # check permissions and RC bits
199 perm_ok = Signal()
200 comb += perm_ok.eq(0)
201 with m.If(r.priv | ~data[3]):
202 with m.If(~r.iside):
203 comb += perm_ok.eq(
204 (data[1] | data[2])
205 & (~r.store)
206 )
207 with m.Else():
208 # no IAMR, so no KUEP support
209 # for now deny execute
210 # permission if cache inhibited
211 comb += perm_ok.eq(data[0] & ~data[5])
212
213 rc_ok = Signal()
214 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
215 with m.If(perm_ok & rc_ok):
216 comb += v.state.eq(State.RADIX_LOAD_TLB)
217 with m.Else():
218 comb += v.state.eq(State.RADIX_FINISH)
219 comb += v.perm_err.eq(~perm_ok)
220 # permission error takes precedence
221 # over RC error
222 comb += v.rc_error.eq(perm_ok)
223 with m.Else():
224 mbits = Signal(6)
225 comb += mbits.eq(data[0:5])
226 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
227 comb += v.state.eq(State.RADIX_FINISH)
228 comb += v.badtree.eq(1)
229 with m.Else():
230 comb += v.shift.eq(v.shift - mbits)
231 comb += v.mask_size.eq(mbits[0:5])
232 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
233 comb += v.state.eq(State.RADIX_LOOKUP)
234
235 def segment_check(self, m, v, r, data, finalmask):
236 comb = m.d.comb
237 mbits = Signal(6)
238 nonzero = Signal()
239 comb += mbits.eq(r.mask_size)
240 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
241 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
242 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
243 comb += v.state.eq(State.RADIX_FINISH)
244 comb += v.segerror.eq(1)
245 with m.Elif((mbits < 5) | (mbits > 16) |
246 (mbits > (r.shift + (31-12)))):
247 comb += v.state.eq(State.RADIX_FINISH)
248 comb += v.badtree.eq(1)
249 with m.Else():
250 comb += v.state.eq(State.RADIX_LOOKUP)
251
252 def elaborate(self, platform):
253 m = Module()
254
255 comb = m.d.comb
256 sync = m.d.sync
257
258 addrsh = Signal(16)
259 mask = Signal(16)
260 finalmask = Signal(44)
261
262 r = RegStage("r")
263 rin = RegStage("r_in")
264
265 l_in = self.l_in
266 l_out = self.l_out
267 d_out = self.d_out
268 d_in = self.d_in
269 i_out = self.i_out
270
271 # Multiplex internal SPR values back to loadstore1,
272 # selected by l_in.sprn.
273 with m.If(l_in.sprn[9]):
274 comb += l_out.sprval.eq(r.prtbl)
275 with m.Else():
276 comb += l_out.sprval.eq(r.pid)
277
278 with m.If(rin.valid):
279 pass
280 #sync += Display(f"MMU got tlb miss for {rin.addr}")
281
282 with m.If(l_out.done):
283 pass
284 # sync += Display("MMU completing op without error")
285
286 with m.If(l_out.err):
287 pass
288 # sync += Display(f"MMU completing op with err invalid"
289 # "{l_out.invalid} badtree={l_out.badtree}")
290
291 with m.If(rin.state == State.RADIX_LOOKUP):
292 pass
293 # sync += Display (f"radix lookup shift={rin.shift}"
294 # "msize={rin.mask_size}")
295
296 with m.If(r.state == State.RADIX_LOOKUP):
297 pass
298 # sync += Display(f"send load addr={d_out.addr}"
299 # "addrsh={addrsh} mask={mask}")
300
301 sync += r.eq(rin)
302
303 v = RegStage()
304 dcreq = Signal()
305 tlb_load = Signal()
306 itlb_load = Signal()
307 tlbie_req = Signal()
308 prtbl_rd = Signal()
309 effpid = Signal(32)
310 prtable_addr = Signal(64)
311 pgtable_addr = Signal(64)
312 pte = Signal(64)
313 tlb_data = Signal(64)
314 addr = Signal(64)
315
316 comb += v.eq(r)
317 comb += v.valid.eq(0)
318 comb += dcreq.eq(0)
319 comb += v.done.eq(0)
320 comb += v.err.eq(0)
321 comb += v.invalid.eq(0)
322 comb += v.badtree.eq(0)
323 comb += v.segerror.eq(0)
324 comb += v.perm_err.eq(0)
325 comb += v.rc_error.eq(0)
326 comb += tlb_load.eq(0)
327 comb += itlb_load.eq(0)
328 comb += tlbie_req.eq(0)
329 comb += v.inval_all.eq(0)
330 comb += prtbl_rd.eq(0)
331
332 # Radix tree data structures in memory are
333 # big-endian, so we need to byte-swap them
334 data = byte_reverse(m, "data", d_in.data, 8)
335
336 # generate mask for extracting address fields for PTE addr generation
337 comb += mask.eq(Cat(C(0x1f,5), ((1<<r.mask_size)-1)))
338
339 # generate mask for extracting address bits to go in
340 # TLB entry in order to support pages > 4kB
341 comb += finalmask.eq(((1<<r.shift)-1))
342
343 with m.Switch(r.state):
344 with m.Case(State.IDLE):
345 self.radix_tree_idle(m, l_in, r, v)
346
347 with m.Case(State.DO_TLBIE):
348 comb += dcreq.eq(1)
349 comb += tlbie_req.eq(1)
350 comb += v.state.eq(State.TLB_WAIT)
351
352 with m.Case(State.TLB_WAIT):
353 with m.If(d_in.done):
354 comb += v.state.eq(State.RADIX_FINISH)
355
356 with m.Case(State.PROC_TBL_READ):
357 comb += dcreq.eq(1)
358 comb += prtbl_rd.eq(1)
359 comb += v.state.eq(State.PROC_TBL_WAIT)
360
361 with m.Case(State.PROC_TBL_WAIT):
362 with m.If(d_in.done):
363 self.proc_tbl_wait(m, v, r, data)
364
365 with m.If(d_in.err):
366 comb += v.state.eq(State.RADIX_FINISH)
367 comb += v.badtree.eq(1)
368
369 with m.Case(State.SEGMENT_CHECK):
370 self.segment_check(m, v, r, data, finalmask)
371
372 with m.Case(State.RADIX_LOOKUP):
373 comb += dcreq.eq(1)
374 comb += v.state.eq(State.RADIX_READ_WAIT)
375
376 with m.Case(State.RADIX_READ_WAIT):
377 with m.If(d_in.done):
378 self.radix_read_wait(m, v, r, d_in, data)
379 with m.Else():
380 # non-present PTE, generate a DSI
381 comb += v.state.eq(State.RADIX_FINISH)
382 comb += v.invalid.eq(1)
383
384 with m.If(d_in.err):
385 comb += v.state.eq(State.RADIX_FINISH)
386 comb += v.badtree.eq(1)
387
388 with m.Case(State.RADIX_LOAD_TLB):
389 comb += tlb_load.eq(1)
390 with m.If(~r.iside):
391 comb += dcreq.eq(1)
392 comb += v.state.eq(State.TLB_WAIT)
393 with m.Else():
394 comb += itlb_load.eq(1)
395 comb += v.state.eq(State.IDLE)
396
397 with m.Case(State.RADIX_FINISH):
398 comb += v.state.eq(State.IDLE)
399
400 with m.If((v.state == State.RADIX_FINISH) |
401 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
402 comb += v.err.eq(v.invalid | v.badtree | v.segerror
403 | v.perm_err | v.rc_error)
404 comb += v.done.eq(~v.err)
405
406 with m.If(~r.addr[63]):
407 comb += effpid.eq(r.pid)
408
409 comb += prtable_addr.eq(Cat(
410 C(0b0000, 4),
411 effpid[0:8],
412 (r.prtbl[12:36] & ~finalmask[0:24]) |
413 (effpid[8:32] & finalmask[0:24]),
414 r.prtbl[36:56]
415 ))
416
417 comb += pgtable_addr.eq(Cat(
418 C(0b000, 3),
419 (r.pgbase[3:19] & ~mask) |
420 (addrsh & mask),
421 r.pgbase[19:56]
422 ))
423
424 comb += pte.eq(Cat(
425 r.pde[0:12],
426 (r.pde[12:56] & ~finalmask) |
427 (r.addr[12:56] & finalmask),
428 ))
429
430 # update registers
431 rin.eq(v)
432
433 # drive outputs
434 with m.If(tlbie_req):
435 comb += addr.eq(r.addr)
436 with m.Elif(tlb_load):
437 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
438 comb += tlb_data.eq(pte)
439 with m.Elif(prtbl_rd):
440 comb += addr.eq(prtable_addr)
441 with m.Else():
442 comb += addr.eq(pgtable_addr)
443
444 comb += l_out.done.eq(r.done)
445 comb += l_out.err.eq(r.err)
446 comb += l_out.invalid.eq(r.invalid)
447 comb += l_out.badtree.eq(r.badtree)
448 comb += l_out.segerr.eq(r.segerror)
449 comb += l_out.perm_error.eq(r.perm_err)
450 comb += l_out.rc_error.eq(r.rc_error)
451
452 comb += d_out.valid.eq(dcreq)
453 comb += d_out.tlbie.eq(tlbie_req)
454 comb += d_out.doall.eq(r.inval_all)
455 comb += d_out.tlbld.eq(tlb_load)
456 comb += d_out.addr.eq(addr)
457 comb += d_out.pte.eq(tlb_data)
458
459 comb += i_out.tlbld.eq(itlb_load)
460 comb += i_out.tlbie.eq(tlbie_req)
461 comb += i_out.doall.eq(r.inval_all)
462 comb += i_out.addr.eq(addr)
463 comb += i_out.pte.eq(tlb_data)
464
465 return m
466
467
468 def mmu_sim():
469 yield wp.waddr.eq(1)
470 yield wp.data_i.eq(2)
471 yield wp.wen.eq(1)
472 yield
473 yield wp.wen.eq(0)
474 yield rp.ren.eq(1)
475 yield rp.raddr.eq(1)
476 yield Settle()
477 data = yield rp.data_o
478 print(data)
479 assert data == 2
480 yield
481
482 yield wp.waddr.eq(5)
483 yield rp.raddr.eq(5)
484 yield rp.ren.eq(1)
485 yield wp.wen.eq(1)
486 yield wp.data_i.eq(6)
487 yield Settle()
488 data = yield rp.data_o
489 print(data)
490 assert data == 6
491 yield
492 yield wp.wen.eq(0)
493 yield rp.ren.eq(0)
494 yield Settle()
495 data = yield rp.data_o
496 print(data)
497 assert data == 0
498 yield
499 data = yield rp.data_o
500 print(data)
501
502 def test_mmu():
503 dut = MMU()
504 vl = rtlil.convert(dut, ports=[])#dut.ports())
505 with open("test_mmu.il", "w") as f:
506 f.write(vl)
507
508 run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
509
510 if __name__ == '__main__':
511 test_mmu()