move Mask to nmutil
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12 from nmutil.mask import Mask
13
14
15 from soc.experiment.mem_types import (LoadStore1ToMmuType,
16 MmuToLoadStore1Type,
17 MmuToDcacheType,
18 DcacheToMmuType,
19 MmuToIcacheType)
20
21 # -- Radix MMU
22 # -- Supports 4-level trees as in arch 3.0B, but not the
23 # -- two-step translation
24 # -- for guests under a hypervisor (i.e. there is no gRA -> hRA translation).
25
26 @unique
27 class State(Enum):
28 IDLE = 0 # zero is default on reset for r.state
29 DO_TLBIE = 1
30 TLB_WAIT = 2
31 PROC_TBL_READ = 3
32 PROC_TBL_WAIT = 4
33 SEGMENT_CHECK = 5
34 RADIX_LOOKUP = 6
35 RADIX_READ_WAIT = 7
36 RADIX_LOAD_TLB = 8
37 RADIX_FINISH = 9
38
39
40 class RegStage(RecordObject):
41 def __init__(self, name=None):
42 super().__init__(name=name)
43 # latched request from loadstore1
44 self.valid = Signal()
45 self.iside = Signal()
46 self.store = Signal()
47 self.priv = Signal()
48 self.addr = Signal(64)
49 self.inval_all = Signal()
50 # config SPRs
51 self.prtbl = Signal(64)
52 self.pid = Signal(32)
53 # internal state
54 self.state = Signal(State) # resets to IDLE
55 self.done = Signal()
56 self.err = Signal()
57 self.pgtbl0 = Signal(64)
58 self.pt0_valid = Signal()
59 self.pgtbl3 = Signal(64)
60 self.pt3_valid = Signal()
61 self.shift = Signal(6)
62 self.mask_size = Signal(5)
63 self.pgbase = Signal(56)
64 self.pde = Signal(64)
65 self.invalid = Signal()
66 self.badtree = Signal()
67 self.segerror = Signal()
68 self.perm_err = Signal()
69 self.rc_error = Signal()
70
71
72 class MMU(Elaboratable):
73 """Radix MMU
74
75 Supports 4-level trees as in arch 3.0B, but not the
76 two-step translation for guests under a hypervisor
77 (i.e. there is no gRA -> hRA translation).
78 """
79 def __init__(self):
80 self.l_in = LoadStore1ToMmuType()
81 self.l_out = MmuToLoadStore1Type()
82 self.d_out = MmuToDcacheType()
83 self.d_in = DcacheToMmuType()
84 self.i_out = MmuToIcacheType()
85
86 def radix_tree_idle(self, m, l_in, r, v):
87 comb = m.d.comb
88 pt_valid = Signal()
89 pgtbl = Signal(64)
90 with m.If(~l_in.addr[63]):
91 comb += pgtbl.eq(r.pgtbl0)
92 comb += pt_valid.eq(r.pt0_valid)
93 with m.Else():
94 comb += pgtbl.eq(r.pt3_valid)
95 comb += pt_valid.eq(r.pt3_valid)
96
97 # rts == radix tree size, number of address bits
98 # being translated
99 rts = Signal(6)
100 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
101
102 # mbits == number of address bits to index top
103 # level of tree
104 mbits = Signal(6)
105 comb += mbits.eq(pgtbl[0:5])
106
107 # set v.shift to rts so that we can use finalmask
108 # for the segment check
109 comb += v.shift.eq(rts)
110 comb += v.mask_size.eq(mbits[0:5])
111 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
112
113 with m.If(l_in.valid):
114 comb += v.addr.eq(l_in.addr)
115 comb += v.iside.eq(l_in.iside)
116 comb += v.store.eq(~(l_in.load | l_in.iside))
117
118 with m.If(l_in.tlbie):
119 # Invalidate all iTLB/dTLB entries for
120 # tlbie with RB[IS] != 0 or RB[AP] != 0,
121 # or for slbia
122 comb += v.inval_all.eq(l_in.slbia
123 | l_in.addr[11]
124 | l_in.addr[10]
125 | l_in.addr[7]
126 | l_in.addr[6]
127 | l_in.addr[5]
128 )
129 # The RIC field of the tlbie instruction
130 # comes across on the sprn bus as bits 2--3.
131 # RIC=2 flushes process table caches.
132 with m.If(l_in.sprn[3]):
133 comb += v.pt0_valid.eq(0)
134 comb += v.pt3_valid.eq(0)
135 comb += v.state.eq(State.DO_TLBIE)
136 with m.Else():
137 comb += v.valid.eq(1)
138 with m.If(~pt_valid):
139 # need to fetch process table entry
140 # set v.shift so we can use finalmask
141 # for generating the process table
142 # entry address
143 comb += v.shift.eq(r.prtbl[0:5])
144 comb += v.state.eq(State.PROC_TBL_READ)
145
146 with m.If(~mbits):
147 # Use RPDS = 0 to disable radix tree walks
148 comb += v.state.eq(State.RADIX_FINISH)
149 comb += v.invalid.eq(1)
150 with m.Else():
151 comb += v.state.eq(State.SEGMENT_CHECK)
152
153 with m.If(l_in.mtspr):
154 # Move to PID needs to invalidate L1 TLBs
155 # and cached pgtbl0 value. Move to PRTBL
156 # does that plus invalidating the cached
157 # pgtbl3 value as well.
158 with m.If(~l_in.sprn[9]):
159 comb += v.pid.eq(l_in.rs[0:32])
160 with m.Else():
161 comb += v.prtbl.eq(l_in.rs)
162 comb += v.pt3_valid.eq(0)
163
164 comb += v.pt0_valid.eq(0)
165 comb += v.inval_all.eq(1)
166 comb += v.state.eq(State.DO_TLBIE)
167
168 def proc_tbl_wait(self, m, v, r, data):
169 comb = m.d.comb
170 with m.If(r.addr[63]):
171 comb += v.pgtbl3.eq(data)
172 comb += v.pt3_valid.eq(1)
173 with m.Else():
174 comb += v.pgtbl0.eq(data)
175 comb += v.pt0_valid.eq(1)
176 # rts == radix tree size, # address bits being translated
177 rts = Signal(6)
178 comb += rts.eq(Cat(data[5:8], data[61:63]))
179
180 # mbits == # address bits to index top level of tree
181 mbits = Signal(6)
182 comb += mbits.eq(data[0:5])
183 # set v.shift to rts so that we can use
184 # finalmask for the segment check
185 comb += v.shift.eq(rts)
186 comb += v.mask_size.eq(mbits[0:5])
187 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
188
189 with m.If(~mbits):
190 comb += v.state.eq(State.RADIX_FINISH)
191 comb += v.invalid.eq(1)
192 comb += v.state.eq(State.SEGMENT_CHECK)
193
194 def radix_read_wait(self, m, v, r, d_in, data):
195 comb = m.d.comb
196 comb += v.pde.eq(data)
197 # test valid bit
198 with m.If(data[63]):
199 with m.If(data[62]):
200 # check permissions and RC bits
201 perm_ok = Signal()
202 comb += perm_ok.eq(0)
203 with m.If(r.priv | ~data[3]):
204 with m.If(~r.iside):
205 comb += perm_ok.eq(
206 (data[1] | data[2])
207 & (~r.store)
208 )
209 with m.Else():
210 # no IAMR, so no KUEP support
211 # for now deny execute
212 # permission if cache inhibited
213 comb += perm_ok.eq(data[0] & ~data[5])
214
215 rc_ok = Signal()
216 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
217 with m.If(perm_ok & rc_ok):
218 comb += v.state.eq(State.RADIX_LOAD_TLB)
219 with m.Else():
220 comb += v.state.eq(State.RADIX_FINISH)
221 comb += v.perm_err.eq(~perm_ok)
222 # permission error takes precedence
223 # over RC error
224 comb += v.rc_error.eq(perm_ok)
225 with m.Else():
226 mbits = Signal(6)
227 comb += mbits.eq(data[0:5])
228 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
229 comb += v.state.eq(State.RADIX_FINISH)
230 comb += v.badtree.eq(1)
231 with m.Else():
232 comb += v.shift.eq(v.shift - mbits)
233 comb += v.mask_size.eq(mbits[0:5])
234 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
235 comb += v.state.eq(State.RADIX_LOOKUP)
236
237 def segment_check(self, m, v, r, data, finalmask):
238 comb = m.d.comb
239 mbits = Signal(6)
240 nonzero = Signal()
241 comb += mbits.eq(r.mask_size)
242 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
243 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
244 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
245 comb += v.state.eq(State.RADIX_FINISH)
246 comb += v.segerror.eq(1)
247 with m.Elif((mbits < 5) | (mbits > 16) |
248 (mbits > (r.shift + (31-12)))):
249 comb += v.state.eq(State.RADIX_FINISH)
250 comb += v.badtree.eq(1)
251 with m.Else():
252 comb += v.state.eq(State.RADIX_LOOKUP)
253
254 def elaborate(self, platform):
255 m = Module()
256
257 comb = m.d.comb
258 sync = m.d.sync
259
260 addrsh = Signal(16)
261 mask = Signal(16)
262 finalmask = Signal(44)
263
264 r = RegStage("r")
265 rin = RegStage("r_in")
266
267 l_in = self.l_in
268 l_out = self.l_out
269 d_out = self.d_out
270 d_in = self.d_in
271 i_out = self.i_out
272
273 # Multiplex internal SPR values back to loadstore1,
274 # selected by l_in.sprn.
275 with m.If(l_in.sprn[9]):
276 comb += l_out.sprval.eq(r.prtbl)
277 with m.Else():
278 comb += l_out.sprval.eq(r.pid)
279
280 with m.If(rin.valid):
281 pass
282 #sync += Display(f"MMU got tlb miss for {rin.addr}")
283
284 with m.If(l_out.done):
285 pass
286 # sync += Display("MMU completing op without error")
287
288 with m.If(l_out.err):
289 pass
290 # sync += Display(f"MMU completing op with err invalid"
291 # "{l_out.invalid} badtree={l_out.badtree}")
292
293 with m.If(rin.state == State.RADIX_LOOKUP):
294 pass
295 # sync += Display (f"radix lookup shift={rin.shift}"
296 # "msize={rin.mask_size}")
297
298 with m.If(r.state == State.RADIX_LOOKUP):
299 pass
300 # sync += Display(f"send load addr={d_out.addr}"
301 # "addrsh={addrsh} mask={mask}")
302
303 sync += r.eq(rin)
304
305 v = RegStage()
306 dcreq = Signal()
307 tlb_load = Signal()
308 itlb_load = Signal()
309 tlbie_req = Signal()
310 prtbl_rd = Signal()
311 effpid = Signal(32)
312 prtable_addr = Signal(64)
313 pgtable_addr = Signal(64)
314 pte = Signal(64)
315 tlb_data = Signal(64)
316 addr = Signal(64)
317
318 comb += v.eq(r)
319 comb += v.valid.eq(0)
320 comb += dcreq.eq(0)
321 comb += v.done.eq(0)
322 comb += v.err.eq(0)
323 comb += v.invalid.eq(0)
324 comb += v.badtree.eq(0)
325 comb += v.segerror.eq(0)
326 comb += v.perm_err.eq(0)
327 comb += v.rc_error.eq(0)
328 comb += tlb_load.eq(0)
329 comb += itlb_load.eq(0)
330 comb += tlbie_req.eq(0)
331 comb += v.inval_all.eq(0)
332 comb += prtbl_rd.eq(0)
333
334 # Radix tree data structures in memory are
335 # big-endian, so we need to byte-swap them
336 data = byte_reverse(m, "data", d_in.data, 8)
337
338 # generate mask for extracting address fields for PTE addr generation
339 m.submodules.pte_mask = pte_mask = Mask(16-5)
340 comb += pte_mask.shift.eq(r.mask_size - 5)
341 comb += mask.eq(Cat(C(0x1f,5), pte_mask.mask))
342
343 # generate mask for extracting address bits to go in
344 # TLB entry in order to support pages > 4kB
345 m.submodules.tlb_mask = tlb_mask = Mask(44)
346 comb += tlb_mask.shift.eq(r.shift)
347 comb += finalmask.eq(tlb_mask.mask)
348
349 with m.Switch(r.state):
350 with m.Case(State.IDLE):
351 self.radix_tree_idle(m, l_in, r, v)
352
353 with m.Case(State.DO_TLBIE):
354 comb += dcreq.eq(1)
355 comb += tlbie_req.eq(1)
356 comb += v.state.eq(State.TLB_WAIT)
357
358 with m.Case(State.TLB_WAIT):
359 with m.If(d_in.done):
360 comb += v.state.eq(State.RADIX_FINISH)
361
362 with m.Case(State.PROC_TBL_READ):
363 comb += dcreq.eq(1)
364 comb += prtbl_rd.eq(1)
365 comb += v.state.eq(State.PROC_TBL_WAIT)
366
367 with m.Case(State.PROC_TBL_WAIT):
368 with m.If(d_in.done):
369 self.proc_tbl_wait(m, v, r, data)
370
371 with m.If(d_in.err):
372 comb += v.state.eq(State.RADIX_FINISH)
373 comb += v.badtree.eq(1)
374
375 with m.Case(State.SEGMENT_CHECK):
376 self.segment_check(m, v, r, data, finalmask)
377
378 with m.Case(State.RADIX_LOOKUP):
379 comb += dcreq.eq(1)
380 comb += v.state.eq(State.RADIX_READ_WAIT)
381
382 with m.Case(State.RADIX_READ_WAIT):
383 with m.If(d_in.done):
384 self.radix_read_wait(m, v, r, d_in, data)
385 with m.Else():
386 # non-present PTE, generate a DSI
387 comb += v.state.eq(State.RADIX_FINISH)
388 comb += v.invalid.eq(1)
389
390 with m.If(d_in.err):
391 comb += v.state.eq(State.RADIX_FINISH)
392 comb += v.badtree.eq(1)
393
394 with m.Case(State.RADIX_LOAD_TLB):
395 comb += tlb_load.eq(1)
396 with m.If(~r.iside):
397 comb += dcreq.eq(1)
398 comb += v.state.eq(State.TLB_WAIT)
399 with m.Else():
400 comb += itlb_load.eq(1)
401 comb += v.state.eq(State.IDLE)
402
403 with m.Case(State.RADIX_FINISH):
404 comb += v.state.eq(State.IDLE)
405
406 with m.If((v.state == State.RADIX_FINISH) |
407 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
408 comb += v.err.eq(v.invalid | v.badtree | v.segerror
409 | v.perm_err | v.rc_error)
410 comb += v.done.eq(~v.err)
411
412 with m.If(~r.addr[63]):
413 comb += effpid.eq(r.pid)
414
415 comb += prtable_addr.eq(Cat(
416 C(0b0000, 4),
417 effpid[0:8],
418 (r.prtbl[12:36] & ~finalmask[0:24]) |
419 (effpid[8:32] & finalmask[0:24]),
420 r.prtbl[36:56]
421 ))
422
423 comb += pgtable_addr.eq(Cat(
424 C(0b000, 3),
425 (r.pgbase[3:19] & ~mask) |
426 (addrsh & mask),
427 r.pgbase[19:56]
428 ))
429
430 comb += pte.eq(Cat(
431 r.pde[0:12],
432 (r.pde[12:56] & ~finalmask) |
433 (r.addr[12:56] & finalmask),
434 ))
435
436 # update registers
437 rin.eq(v)
438
439 # drive outputs
440 with m.If(tlbie_req):
441 comb += addr.eq(r.addr)
442 with m.Elif(tlb_load):
443 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
444 comb += tlb_data.eq(pte)
445 with m.Elif(prtbl_rd):
446 comb += addr.eq(prtable_addr)
447 with m.Else():
448 comb += addr.eq(pgtable_addr)
449
450 comb += l_out.done.eq(r.done)
451 comb += l_out.err.eq(r.err)
452 comb += l_out.invalid.eq(r.invalid)
453 comb += l_out.badtree.eq(r.badtree)
454 comb += l_out.segerr.eq(r.segerror)
455 comb += l_out.perm_error.eq(r.perm_err)
456 comb += l_out.rc_error.eq(r.rc_error)
457
458 comb += d_out.valid.eq(dcreq)
459 comb += d_out.tlbie.eq(tlbie_req)
460 comb += d_out.doall.eq(r.inval_all)
461 comb += d_out.tlbld.eq(tlb_load)
462 comb += d_out.addr.eq(addr)
463 comb += d_out.pte.eq(tlb_data)
464
465 comb += i_out.tlbld.eq(itlb_load)
466 comb += i_out.tlbie.eq(tlbie_req)
467 comb += i_out.doall.eq(r.inval_all)
468 comb += i_out.addr.eq(addr)
469 comb += i_out.pte.eq(tlb_data)
470
471 return m
472
473
474 def mmu_sim():
475 yield wp.waddr.eq(1)
476 yield wp.data_i.eq(2)
477 yield wp.wen.eq(1)
478 yield
479 yield wp.wen.eq(0)
480 yield rp.ren.eq(1)
481 yield rp.raddr.eq(1)
482 yield Settle()
483 data = yield rp.data_o
484 print(data)
485 assert data == 2
486 yield
487
488 yield wp.waddr.eq(5)
489 yield rp.raddr.eq(5)
490 yield rp.ren.eq(1)
491 yield wp.wen.eq(1)
492 yield wp.data_i.eq(6)
493 yield Settle()
494 data = yield rp.data_o
495 print(data)
496 assert data == 6
497 yield
498 yield wp.wen.eq(0)
499 yield rp.ren.eq(0)
500 yield Settle()
501 data = yield rp.data_o
502 print(data)
503 assert data == 0
504 yield
505 data = yield rp.data_o
506 print(data)
507
508 def test_mmu():
509 dut = MMU()
510 vl = rtlil.convert(dut, ports=[])#dut.ports())
511 with open("test_mmu.il", "w") as f:
512 f.write(vl)
513
514 run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
515
516 if __name__ == '__main__':
517 test_mmu()