use shift module in mmu. to be moved to nmutil
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12 from nmigen.utils import log2_int
13
14 from soc.experiment.mem_types import (LoadStore1ToMmuType,
15 MmuToLoadStore1Type,
16 MmuToDcacheType,
17 DcacheToMmuType,
18 MmuToIcacheType)
19
20 # -- Radix MMU
21 # -- Supports 4-level trees as in arch 3.0B, but not the
22 # -- two-step translation
23 # -- for guests under a hypervisor (i.e. there is no gRA -> hRA translation).
24
25 @unique
26 class State(Enum):
27 IDLE = 0 # zero is default on reset for r.state
28 DO_TLBIE = 1
29 TLB_WAIT = 2
30 PROC_TBL_READ = 3
31 PROC_TBL_WAIT = 4
32 SEGMENT_CHECK = 5
33 RADIX_LOOKUP = 6
34 RADIX_READ_WAIT = 7
35 RADIX_LOAD_TLB = 8
36 RADIX_FINISH = 9
37
38
39 class RegStage(RecordObject):
40 def __init__(self, name=None):
41 super().__init__(name=name)
42 # latched request from loadstore1
43 self.valid = Signal()
44 self.iside = Signal()
45 self.store = Signal()
46 self.priv = Signal()
47 self.addr = Signal(64)
48 self.inval_all = Signal()
49 # config SPRs
50 self.prtbl = Signal(64)
51 self.pid = Signal(32)
52 # internal state
53 self.state = Signal(State) # resets to IDLE
54 self.done = Signal()
55 self.err = Signal()
56 self.pgtbl0 = Signal(64)
57 self.pt0_valid = Signal()
58 self.pgtbl3 = Signal(64)
59 self.pt3_valid = Signal()
60 self.shift = Signal(6)
61 self.mask_size = Signal(5)
62 self.pgbase = Signal(56)
63 self.pde = Signal(64)
64 self.invalid = Signal()
65 self.badtree = Signal()
66 self.segerror = Signal()
67 self.perm_err = Signal()
68 self.rc_error = Signal()
69
70
71 class Mask(Elaboratable):
72 def __init__(self, sz):
73 self.sz = sz
74 self.shift = Signal(log2_int(sz, False))
75 self.mask = Signal(sz)
76
77 def elaborate(self, platform):
78 m = Module()
79
80 comb = m.d.comb
81
82 for i in range(self.sz):
83 with m.If(self.shift > i):
84 comb += self.mask[i].eq(1)
85
86 return m
87
88
89 class MMU(Elaboratable):
90 """Radix MMU
91
92 Supports 4-level trees as in arch 3.0B, but not the
93 two-step translation for guests under a hypervisor
94 (i.e. there is no gRA -> hRA translation).
95 """
96 def __init__(self):
97 self.l_in = LoadStore1ToMmuType()
98 self.l_out = MmuToLoadStore1Type()
99 self.d_out = MmuToDcacheType()
100 self.d_in = DcacheToMmuType()
101 self.i_out = MmuToIcacheType()
102
103 def radix_tree_idle(self, m, l_in, r, v):
104 comb = m.d.comb
105 pt_valid = Signal()
106 pgtbl = Signal(64)
107 with m.If(~l_in.addr[63]):
108 comb += pgtbl.eq(r.pgtbl0)
109 comb += pt_valid.eq(r.pt0_valid)
110 with m.Else():
111 comb += pgtbl.eq(r.pt3_valid)
112 comb += pt_valid.eq(r.pt3_valid)
113
114 # rts == radix tree size, number of address bits
115 # being translated
116 rts = Signal(6)
117 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
118
119 # mbits == number of address bits to index top
120 # level of tree
121 mbits = Signal(6)
122 comb += mbits.eq(pgtbl[0:5])
123
124 # set v.shift to rts so that we can use finalmask
125 # for the segment check
126 comb += v.shift.eq(rts)
127 comb += v.mask_size.eq(mbits[0:5])
128 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
129
130 with m.If(l_in.valid):
131 comb += v.addr.eq(l_in.addr)
132 comb += v.iside.eq(l_in.iside)
133 comb += v.store.eq(~(l_in.load | l_in.iside))
134
135 with m.If(l_in.tlbie):
136 # Invalidate all iTLB/dTLB entries for
137 # tlbie with RB[IS] != 0 or RB[AP] != 0,
138 # or for slbia
139 comb += v.inval_all.eq(l_in.slbia
140 | l_in.addr[11]
141 | l_in.addr[10]
142 | l_in.addr[7]
143 | l_in.addr[6]
144 | l_in.addr[5]
145 )
146 # The RIC field of the tlbie instruction
147 # comes across on the sprn bus as bits 2--3.
148 # RIC=2 flushes process table caches.
149 with m.If(l_in.sprn[3]):
150 comb += v.pt0_valid.eq(0)
151 comb += v.pt3_valid.eq(0)
152 comb += v.state.eq(State.DO_TLBIE)
153 with m.Else():
154 comb += v.valid.eq(1)
155 with m.If(~pt_valid):
156 # need to fetch process table entry
157 # set v.shift so we can use finalmask
158 # for generating the process table
159 # entry address
160 comb += v.shift.eq(r.prtbl[0:5])
161 comb += v.state.eq(State.PROC_TBL_READ)
162
163 with m.If(~mbits):
164 # Use RPDS = 0 to disable radix tree walks
165 comb += v.state.eq(State.RADIX_FINISH)
166 comb += v.invalid.eq(1)
167 with m.Else():
168 comb += v.state.eq(State.SEGMENT_CHECK)
169
170 with m.If(l_in.mtspr):
171 # Move to PID needs to invalidate L1 TLBs
172 # and cached pgtbl0 value. Move to PRTBL
173 # does that plus invalidating the cached
174 # pgtbl3 value as well.
175 with m.If(~l_in.sprn[9]):
176 comb += v.pid.eq(l_in.rs[0:32])
177 with m.Else():
178 comb += v.prtbl.eq(l_in.rs)
179 comb += v.pt3_valid.eq(0)
180
181 comb += v.pt0_valid.eq(0)
182 comb += v.inval_all.eq(1)
183 comb += v.state.eq(State.DO_TLBIE)
184
185 def proc_tbl_wait(self, m, v, r, data):
186 comb = m.d.comb
187 with m.If(r.addr[63]):
188 comb += v.pgtbl3.eq(data)
189 comb += v.pt3_valid.eq(1)
190 with m.Else():
191 comb += v.pgtbl0.eq(data)
192 comb += v.pt0_valid.eq(1)
193 # rts == radix tree size, # address bits being translated
194 rts = Signal(6)
195 comb += rts.eq(Cat(data[5:8], data[61:63]))
196
197 # mbits == # address bits to index top level of tree
198 mbits = Signal(6)
199 comb += mbits.eq(data[0:5])
200 # set v.shift to rts so that we can use
201 # finalmask for the segment check
202 comb += v.shift.eq(rts)
203 comb += v.mask_size.eq(mbits[0:5])
204 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
205
206 with m.If(~mbits):
207 comb += v.state.eq(State.RADIX_FINISH)
208 comb += v.invalid.eq(1)
209 comb += v.state.eq(State.SEGMENT_CHECK)
210
211 def radix_read_wait(self, m, v, r, d_in, data):
212 comb = m.d.comb
213 comb += v.pde.eq(data)
214 # test valid bit
215 with m.If(data[63]):
216 with m.If(data[62]):
217 # check permissions and RC bits
218 perm_ok = Signal()
219 comb += perm_ok.eq(0)
220 with m.If(r.priv | ~data[3]):
221 with m.If(~r.iside):
222 comb += perm_ok.eq(
223 (data[1] | data[2])
224 & (~r.store)
225 )
226 with m.Else():
227 # no IAMR, so no KUEP support
228 # for now deny execute
229 # permission if cache inhibited
230 comb += perm_ok.eq(data[0] & ~data[5])
231
232 rc_ok = Signal()
233 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
234 with m.If(perm_ok & rc_ok):
235 comb += v.state.eq(State.RADIX_LOAD_TLB)
236 with m.Else():
237 comb += v.state.eq(State.RADIX_FINISH)
238 comb += v.perm_err.eq(~perm_ok)
239 # permission error takes precedence
240 # over RC error
241 comb += v.rc_error.eq(perm_ok)
242 with m.Else():
243 mbits = Signal(6)
244 comb += mbits.eq(data[0:5])
245 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
246 comb += v.state.eq(State.RADIX_FINISH)
247 comb += v.badtree.eq(1)
248 with m.Else():
249 comb += v.shift.eq(v.shift - mbits)
250 comb += v.mask_size.eq(mbits[0:5])
251 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
252 comb += v.state.eq(State.RADIX_LOOKUP)
253
254 def segment_check(self, m, v, r, data, finalmask):
255 comb = m.d.comb
256 mbits = Signal(6)
257 nonzero = Signal()
258 comb += mbits.eq(r.mask_size)
259 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
260 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
261 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
262 comb += v.state.eq(State.RADIX_FINISH)
263 comb += v.segerror.eq(1)
264 with m.Elif((mbits < 5) | (mbits > 16) |
265 (mbits > (r.shift + (31-12)))):
266 comb += v.state.eq(State.RADIX_FINISH)
267 comb += v.badtree.eq(1)
268 with m.Else():
269 comb += v.state.eq(State.RADIX_LOOKUP)
270
271 def elaborate(self, platform):
272 m = Module()
273
274 comb = m.d.comb
275 sync = m.d.sync
276
277 addrsh = Signal(16)
278 mask = Signal(16)
279 finalmask = Signal(44)
280
281 r = RegStage("r")
282 rin = RegStage("r_in")
283
284 l_in = self.l_in
285 l_out = self.l_out
286 d_out = self.d_out
287 d_in = self.d_in
288 i_out = self.i_out
289
290 # Multiplex internal SPR values back to loadstore1,
291 # selected by l_in.sprn.
292 with m.If(l_in.sprn[9]):
293 comb += l_out.sprval.eq(r.prtbl)
294 with m.Else():
295 comb += l_out.sprval.eq(r.pid)
296
297 with m.If(rin.valid):
298 pass
299 #sync += Display(f"MMU got tlb miss for {rin.addr}")
300
301 with m.If(l_out.done):
302 pass
303 # sync += Display("MMU completing op without error")
304
305 with m.If(l_out.err):
306 pass
307 # sync += Display(f"MMU completing op with err invalid"
308 # "{l_out.invalid} badtree={l_out.badtree}")
309
310 with m.If(rin.state == State.RADIX_LOOKUP):
311 pass
312 # sync += Display (f"radix lookup shift={rin.shift}"
313 # "msize={rin.mask_size}")
314
315 with m.If(r.state == State.RADIX_LOOKUP):
316 pass
317 # sync += Display(f"send load addr={d_out.addr}"
318 # "addrsh={addrsh} mask={mask}")
319
320 sync += r.eq(rin)
321
322 v = RegStage()
323 dcreq = Signal()
324 tlb_load = Signal()
325 itlb_load = Signal()
326 tlbie_req = Signal()
327 prtbl_rd = Signal()
328 effpid = Signal(32)
329 prtable_addr = Signal(64)
330 pgtable_addr = Signal(64)
331 pte = Signal(64)
332 tlb_data = Signal(64)
333 addr = Signal(64)
334
335 comb += v.eq(r)
336 comb += v.valid.eq(0)
337 comb += dcreq.eq(0)
338 comb += v.done.eq(0)
339 comb += v.err.eq(0)
340 comb += v.invalid.eq(0)
341 comb += v.badtree.eq(0)
342 comb += v.segerror.eq(0)
343 comb += v.perm_err.eq(0)
344 comb += v.rc_error.eq(0)
345 comb += tlb_load.eq(0)
346 comb += itlb_load.eq(0)
347 comb += tlbie_req.eq(0)
348 comb += v.inval_all.eq(0)
349 comb += prtbl_rd.eq(0)
350
351 # Radix tree data structures in memory are
352 # big-endian, so we need to byte-swap them
353 data = byte_reverse(m, "data", d_in.data, 8)
354
355 # generate mask for extracting address fields for PTE addr generation
356 m.submodules.pte_mask = pte_mask = Mask(16-5)
357 comb += pte_mask.shift.eq(r.mask_size - 5)
358 comb += mask.eq(Cat(C(0x1f,5), pte_mask.mask))
359
360 # generate mask for extracting address bits to go in
361 # TLB entry in order to support pages > 4kB
362 m.submodules.tlb_mask = tlb_mask = Mask(44)
363 comb += tlb_mask.shift.eq(r.shift)
364 comb += finalmask.eq(tlb_mask.mask)
365
366 with m.Switch(r.state):
367 with m.Case(State.IDLE):
368 self.radix_tree_idle(m, l_in, r, v)
369
370 with m.Case(State.DO_TLBIE):
371 comb += dcreq.eq(1)
372 comb += tlbie_req.eq(1)
373 comb += v.state.eq(State.TLB_WAIT)
374
375 with m.Case(State.TLB_WAIT):
376 with m.If(d_in.done):
377 comb += v.state.eq(State.RADIX_FINISH)
378
379 with m.Case(State.PROC_TBL_READ):
380 comb += dcreq.eq(1)
381 comb += prtbl_rd.eq(1)
382 comb += v.state.eq(State.PROC_TBL_WAIT)
383
384 with m.Case(State.PROC_TBL_WAIT):
385 with m.If(d_in.done):
386 self.proc_tbl_wait(m, v, r, data)
387
388 with m.If(d_in.err):
389 comb += v.state.eq(State.RADIX_FINISH)
390 comb += v.badtree.eq(1)
391
392 with m.Case(State.SEGMENT_CHECK):
393 self.segment_check(m, v, r, data, finalmask)
394
395 with m.Case(State.RADIX_LOOKUP):
396 comb += dcreq.eq(1)
397 comb += v.state.eq(State.RADIX_READ_WAIT)
398
399 with m.Case(State.RADIX_READ_WAIT):
400 with m.If(d_in.done):
401 self.radix_read_wait(m, v, r, d_in, data)
402 with m.Else():
403 # non-present PTE, generate a DSI
404 comb += v.state.eq(State.RADIX_FINISH)
405 comb += v.invalid.eq(1)
406
407 with m.If(d_in.err):
408 comb += v.state.eq(State.RADIX_FINISH)
409 comb += v.badtree.eq(1)
410
411 with m.Case(State.RADIX_LOAD_TLB):
412 comb += tlb_load.eq(1)
413 with m.If(~r.iside):
414 comb += dcreq.eq(1)
415 comb += v.state.eq(State.TLB_WAIT)
416 with m.Else():
417 comb += itlb_load.eq(1)
418 comb += v.state.eq(State.IDLE)
419
420 with m.Case(State.RADIX_FINISH):
421 comb += v.state.eq(State.IDLE)
422
423 with m.If((v.state == State.RADIX_FINISH) |
424 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
425 comb += v.err.eq(v.invalid | v.badtree | v.segerror
426 | v.perm_err | v.rc_error)
427 comb += v.done.eq(~v.err)
428
429 with m.If(~r.addr[63]):
430 comb += effpid.eq(r.pid)
431
432 comb += prtable_addr.eq(Cat(
433 C(0b0000, 4),
434 effpid[0:8],
435 (r.prtbl[12:36] & ~finalmask[0:24]) |
436 (effpid[8:32] & finalmask[0:24]),
437 r.prtbl[36:56]
438 ))
439
440 comb += pgtable_addr.eq(Cat(
441 C(0b000, 3),
442 (r.pgbase[3:19] & ~mask) |
443 (addrsh & mask),
444 r.pgbase[19:56]
445 ))
446
447 comb += pte.eq(Cat(
448 r.pde[0:12],
449 (r.pde[12:56] & ~finalmask) |
450 (r.addr[12:56] & finalmask),
451 ))
452
453 # update registers
454 rin.eq(v)
455
456 # drive outputs
457 with m.If(tlbie_req):
458 comb += addr.eq(r.addr)
459 with m.Elif(tlb_load):
460 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
461 comb += tlb_data.eq(pte)
462 with m.Elif(prtbl_rd):
463 comb += addr.eq(prtable_addr)
464 with m.Else():
465 comb += addr.eq(pgtable_addr)
466
467 comb += l_out.done.eq(r.done)
468 comb += l_out.err.eq(r.err)
469 comb += l_out.invalid.eq(r.invalid)
470 comb += l_out.badtree.eq(r.badtree)
471 comb += l_out.segerr.eq(r.segerror)
472 comb += l_out.perm_error.eq(r.perm_err)
473 comb += l_out.rc_error.eq(r.rc_error)
474
475 comb += d_out.valid.eq(dcreq)
476 comb += d_out.tlbie.eq(tlbie_req)
477 comb += d_out.doall.eq(r.inval_all)
478 comb += d_out.tlbld.eq(tlb_load)
479 comb += d_out.addr.eq(addr)
480 comb += d_out.pte.eq(tlb_data)
481
482 comb += i_out.tlbld.eq(itlb_load)
483 comb += i_out.tlbie.eq(tlbie_req)
484 comb += i_out.doall.eq(r.inval_all)
485 comb += i_out.addr.eq(addr)
486 comb += i_out.pte.eq(tlb_data)
487
488 return m
489
490
491 def mmu_sim():
492 yield wp.waddr.eq(1)
493 yield wp.data_i.eq(2)
494 yield wp.wen.eq(1)
495 yield
496 yield wp.wen.eq(0)
497 yield rp.ren.eq(1)
498 yield rp.raddr.eq(1)
499 yield Settle()
500 data = yield rp.data_o
501 print(data)
502 assert data == 2
503 yield
504
505 yield wp.waddr.eq(5)
506 yield rp.raddr.eq(5)
507 yield rp.ren.eq(1)
508 yield wp.wen.eq(1)
509 yield wp.data_i.eq(6)
510 yield Settle()
511 data = yield rp.data_o
512 print(data)
513 assert data == 6
514 yield
515 yield wp.wen.eq(0)
516 yield rp.ren.eq(0)
517 yield Settle()
518 data = yield rp.data_o
519 print(data)
520 assert data == 0
521 yield
522 data = yield rp.data_o
523 print(data)
524
525 def test_mmu():
526 dut = MMU()
527 vl = rtlil.convert(dut, ports=[])#dut.ports())
528 with open("test_mmu.il", "w") as f:
529 f.write(vl)
530
531 run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
532
533 if __name__ == '__main__':
534 test_mmu()