use simple one-line mask-generation
[soc.git] / src / soc / experiment / mmu.py
1 """MMU
2
3 based on Anton Blanchard microwatt mmu.vhdl
4
5 """
6 from enum import Enum, unique
7 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
8 from nmigen.cli import main
9 from nmigen.cli import rtlil
10 from nmutil.iocontrol import RecordObject
11 from nmutil.byterev import byte_reverse
12
13 from soc.experiment.mem_types import (LoadStore1ToMmuType,
14 MmuToLoadStore1Type,
15 MmuToDcacheType,
16 DcacheToMmuType,
17 MmuToIcacheType)
18
19 # -- Radix MMU
20 # -- Supports 4-level trees as in arch 3.0B, but not the
21 # -- two-step translation
22 # -- for guests under a hypervisor (i.e. there is no gRA -> hRA translation).
23
24 @unique
25 class State(Enum):
26 IDLE = 0 # zero is default on reset for r.state
27 DO_TLBIE = 1
28 TLB_WAIT = 2
29 PROC_TBL_READ = 3
30 PROC_TBL_WAIT = 4
31 SEGMENT_CHECK = 5
32 RADIX_LOOKUP = 6
33 RADIX_READ_WAIT = 7
34 RADIX_LOAD_TLB = 8
35 RADIX_FINISH = 9
36
37
38 class RegStage(RecordObject):
39 def __init__(self, name=None):
40 super().__init__(name=name)
41 # latched request from loadstore1
42 self.valid = Signal()
43 self.iside = Signal()
44 self.store = Signal()
45 self.priv = Signal()
46 self.addr = Signal(64)
47 self.inval_all = Signal()
48 # config SPRs
49 self.prtbl = Signal(64)
50 self.pid = Signal(32)
51 # internal state
52 self.state = Signal(State) # resets to IDLE
53 self.done = Signal()
54 self.err = Signal()
55 self.pgtbl0 = Signal(64)
56 self.pt0_valid = Signal()
57 self.pgtbl3 = Signal(64)
58 self.pt3_valid = Signal()
59 self.shift = Signal(6)
60 self.mask_size = Signal(5)
61 self.pgbase = Signal(56)
62 self.pde = Signal(64)
63 self.invalid = Signal()
64 self.badtree = Signal()
65 self.segerror = Signal()
66 self.perm_err = Signal()
67 self.rc_error = Signal()
68
69
70 class MMU(Elaboratable):
71 """Radix MMU
72
73 Supports 4-level trees as in arch 3.0B, but not the
74 two-step translation for guests under a hypervisor
75 (i.e. there is no gRA -> hRA translation).
76 """
77 def __init__(self):
78 self.l_in = LoadStore1ToMmuType()
79 self.l_out = MmuToLoadStore1Type()
80 self.d_out = MmuToDcacheType()
81 self.d_in = DcacheToMmuType()
82 self.i_out = MmuToIcacheType()
83
84 def radix_tree_idle(self, m, l_in, r, v):
85 comb = m.d.comb
86 pt_valid = Signal()
87 pgtbl = Signal(64)
88 with m.If(~l_in.addr[63]):
89 comb += pgtbl.eq(r.pgtbl0)
90 comb += pt_valid.eq(r.pt0_valid)
91 with m.Else():
92 comb += pgtbl.eq(r.pt3_valid)
93 comb += pt_valid.eq(r.pt3_valid)
94
95 # rts == radix tree size, number of address bits
96 # being translated
97 rts = Signal(6)
98 comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
99
100 # mbits == number of address bits to index top
101 # level of tree
102 mbits = Signal(6)
103 comb += mbits.eq(pgtbl[0:5])
104
105 # set v.shift to rts so that we can use finalmask
106 # for the segment check
107 comb += v.shift.eq(rts)
108 comb += v.mask_size.eq(mbits[0:5])
109 comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
110
111 with m.If(l_in.valid):
112 comb += v.addr.eq(l_in.addr)
113 comb += v.iside.eq(l_in.iside)
114 comb += v.store.eq(~(l_in.load | l_in.iside))
115
116 with m.If(l_in.tlbie):
117 # Invalidate all iTLB/dTLB entries for
118 # tlbie with RB[IS] != 0 or RB[AP] != 0,
119 # or for slbia
120 comb += v.inval_all.eq(l_in.slbia
121 | l_in.addr[11]
122 | l_in.addr[10]
123 | l_in.addr[7]
124 | l_in.addr[6]
125 | l_in.addr[5]
126 )
127 # The RIC field of the tlbie instruction
128 # comes across on the sprn bus as bits 2--3.
129 # RIC=2 flushes process table caches.
130 with m.If(l_in.sprn[3]):
131 comb += v.pt0_valid.eq(0)
132 comb += v.pt3_valid.eq(0)
133 comb += v.state.eq(State.DO_TLBIE)
134 with m.Else():
135 comb += v.valid.eq(1)
136 with m.If(~pt_valid):
137 # need to fetch process table entry
138 # set v.shift so we can use finalmask
139 # for generating the process table
140 # entry address
141 comb += v.shift.eq(r.prtbl[0:5])
142 comb += v.state.eq(State.PROC_TBL_READ)
143
144 with m.If(~mbits):
145 # Use RPDS = 0 to disable radix tree walks
146 comb += v.state.eq(State.RADIX_FINISH)
147 comb += v.invalid.eq(1)
148 with m.Else():
149 comb += v.state.eq(State.SEGMENT_CHECK)
150
151 with m.If(l_in.mtspr):
152 # Move to PID needs to invalidate L1 TLBs
153 # and cached pgtbl0 value. Move to PRTBL
154 # does that plus invalidating the cached
155 # pgtbl3 value as well.
156 with m.If(~l_in.sprn[9]):
157 comb += v.pid.eq(l_in.rs[0:32])
158 with m.Else():
159 comb += v.prtbl.eq(l_in.rs)
160 comb += v.pt3_valid.eq(0)
161
162 comb += v.pt0_valid.eq(0)
163 comb += v.inval_all.eq(1)
164 comb += v.state.eq(State.DO_TLBIE)
165
166 def proc_tbl_wait(self, m, v, r, data):
167 comb = m.d.comb
168 with m.If(r.addr[63]):
169 comb += v.pgtbl3.eq(data)
170 comb += v.pt3_valid.eq(1)
171 with m.Else():
172 comb += v.pgtbl0.eq(data)
173 comb += v.pt0_valid.eq(1)
174 # rts == radix tree size, # address bits being translated
175 rts = Signal(6)
176 comb += rts.eq(Cat(data[5:8], data[61:63]))
177
178 # mbits == # address bits to index top level of tree
179 mbits = Signal(6)
180 comb += mbits.eq(data[0:5])
181 # set v.shift to rts so that we can use
182 # finalmask for the segment check
183 comb += v.shift.eq(rts)
184 comb += v.mask_size.eq(mbits[0:5])
185 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
186
187 with m.If(~mbits):
188 comb += v.state.eq(State.RADIX_FINISH)
189 comb += v.invalid.eq(1)
190 comb += v.state.eq(State.SEGMENT_CHECK)
191
192 def radix_read_wait(self, m, v, r, d_in, data):
193 comb = m.d.comb
194 comb += v.pde.eq(data)
195 # test valid bit
196 with m.If(data[63]):
197 with m.If(data[62]):
198 # check permissions and RC bits
199 perm_ok = Signal()
200 comb += perm_ok.eq(0)
201 with m.If(r.priv | ~data[3]):
202 with m.If(~r.iside):
203 comb += perm_ok.eq((data[1] | data[2]) & (~r.store))
204 with m.Else():
205 # no IAMR, so no KUEP support
206 # for now deny execute
207 # permission if cache inhibited
208 comb += perm_ok.eq(data[0] & ~data[5])
209
210 rc_ok = Signal()
211 comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
212 with m.If(perm_ok & rc_ok):
213 comb += v.state.eq(State.RADIX_LOAD_TLB)
214 with m.Else():
215 comb += v.state.eq(State.RADIX_FINISH)
216 comb += v.perm_err.eq(~perm_ok)
217 # permission error takes precedence
218 # over RC error
219 comb += v.rc_error.eq(perm_ok)
220 with m.Else():
221 mbits = Signal(6)
222 comb += mbits.eq(data[0:5])
223 with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
224 comb += v.state.eq(State.RADIX_FINISH)
225 comb += v.badtree.eq(1)
226 with m.Else():
227 comb += v.shift.eq(v.shift - mbits)
228 comb += v.mask_size.eq(mbits[0:5])
229 comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
230 comb += v.state.eq(State.RADIX_LOOKUP)
231
232 def segment_check(self, m, v, r, data, finalmask):
233 comb = m.d.comb
234 mbits = Signal(6)
235 nonzero = Signal()
236 comb += mbits.eq(r.mask_size)
237 comb += v.shift.eq(r.shift + (31 - 12) - mbits)
238 comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
239 with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
240 comb += v.state.eq(State.RADIX_FINISH)
241 comb += v.segerror.eq(1)
242 with m.Elif((mbits < 5) | (mbits > 16) |
243 (mbits > (r.shift + (31-12)))):
244 comb += v.state.eq(State.RADIX_FINISH)
245 comb += v.badtree.eq(1)
246 with m.Else():
247 comb += v.state.eq(State.RADIX_LOOKUP)
248
249 def elaborate(self, platform):
250 m = Module()
251
252 comb = m.d.comb
253 sync = m.d.sync
254
255 addrsh = Signal(16)
256 mask = Signal(16)
257 finalmask = Signal(44)
258
259 r = RegStage("r")
260 rin = RegStage("r_in")
261
262 l_in = self.l_in
263 l_out = self.l_out
264 d_out = self.d_out
265 d_in = self.d_in
266 i_out = self.i_out
267
268 # Multiplex internal SPR values back to loadstore1,
269 # selected by l_in.sprn.
270 with m.If(l_in.sprn[9]):
271 comb += l_out.sprval.eq(r.prtbl)
272 with m.Else():
273 comb += l_out.sprval.eq(r.pid)
274
275 with m.If(rin.valid):
276 pass
277 #sync += Display(f"MMU got tlb miss for {rin.addr}")
278
279 with m.If(l_out.done):
280 pass
281 # sync += Display("MMU completing op without error")
282
283 with m.If(l_out.err):
284 pass
285 # sync += Display(f"MMU completing op with err invalid"
286 # "{l_out.invalid} badtree={l_out.badtree}")
287
288 with m.If(rin.state == State.RADIX_LOOKUP):
289 pass
290 # sync += Display (f"radix lookup shift={rin.shift}"
291 # "msize={rin.mask_size}")
292
293 with m.If(r.state == State.RADIX_LOOKUP):
294 pass
295 # sync += Display(f"send load addr={d_out.addr}"
296 # "addrsh={addrsh} mask={mask}")
297
298 sync += r.eq(rin)
299
300 v = RegStage()
301 dcreq = Signal()
302 tlb_load = Signal()
303 itlb_load = Signal()
304 tlbie_req = Signal()
305 prtbl_rd = Signal()
306 effpid = Signal(32)
307 prtable_addr = Signal(64)
308 pgtable_addr = Signal(64)
309 pte = Signal(64)
310 tlb_data = Signal(64)
311 addr = Signal(64)
312
313 comb += v.eq(r)
314 comb += v.valid.eq(0)
315 comb += dcreq.eq(0)
316 comb += v.done.eq(0)
317 comb += v.err.eq(0)
318 comb += v.invalid.eq(0)
319 comb += v.badtree.eq(0)
320 comb += v.segerror.eq(0)
321 comb += v.perm_err.eq(0)
322 comb += v.rc_error.eq(0)
323 comb += tlb_load.eq(0)
324 comb += itlb_load.eq(0)
325 comb += tlbie_req.eq(0)
326 comb += v.inval_all.eq(0)
327 comb += prtbl_rd.eq(0)
328
329 # Radix tree data structures in memory are
330 # big-endian, so we need to byte-swap them
331 data = byte_reverse(m, "data", d_in.data, 8)
332
333 # generate mask for extracting address fields for PTE addr generation
334 comb += mask.eq(Cat(C(0x1f,5), ((1<<r.mask_size)-1)))
335
336 # generate mask for extracting address bits to go in
337 # TLB entry in order to support pages > 4kB
338 comb += finalmask.eq(((1<<r.shift)-1))
339
340 with m.Switch(r.state):
341 with m.Case(State.IDLE):
342 self.radix_tree_idle(m, l_in, r, v)
343
344 with m.Case(State.DO_TLBIE):
345 comb += dcreq.eq(1)
346 comb += tlbie_req.eq(1)
347 comb += v.state.eq(State.TLB_WAIT)
348
349 with m.Case(State.TLB_WAIT):
350 with m.If(d_in.done):
351 comb += v.state.eq(State.RADIX_FINISH)
352
353 with m.Case(State.PROC_TBL_READ):
354 comb += dcreq.eq(1)
355 comb += prtbl_rd.eq(1)
356 comb += v.state.eq(State.PROC_TBL_WAIT)
357
358 with m.Case(State.PROC_TBL_WAIT):
359 with m.If(d_in.done):
360 self.proc_tbl_wait(m, v, r, data)
361
362 with m.If(d_in.err):
363 comb += v.state.eq(State.RADIX_FINISH)
364 comb += v.badtree.eq(1)
365
366 with m.Case(State.SEGMENT_CHECK):
367 self.segment_check(m, v, r, data, finalmask)
368
369 with m.Case(State.RADIX_LOOKUP):
370 comb += dcreq.eq(1)
371 comb += v.state.eq(State.RADIX_READ_WAIT)
372
373 with m.Case(State.RADIX_READ_WAIT):
374 with m.If(d_in.done):
375 self.radix_read_wait(m, v, r, d_in, data)
376 with m.Else():
377 # non-present PTE, generate a DSI
378 comb += v.state.eq(State.RADIX_FINISH)
379 comb += v.invalid.eq(1)
380
381 with m.If(d_in.err):
382 comb += v.state.eq(State.RADIX_FINISH)
383 comb += v.badtree.eq(1)
384
385 with m.Case(State.RADIX_LOAD_TLB):
386 comb += tlb_load.eq(1)
387 with m.If(~r.iside):
388 comb += dcreq.eq(1)
389 comb += v.state.eq(State.TLB_WAIT)
390 with m.Else():
391 comb += itlb_load.eq(1)
392 comb += v.state.eq(State.IDLE)
393
394 with m.Case(State.RADIX_FINISH):
395 comb += v.state.eq(State.IDLE)
396
397 with m.If((v.state == State.RADIX_FINISH) |
398 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
399 comb += v.err.eq(v.invalid | v.badtree | v.segerror
400 | v.perm_err | v.rc_error)
401 comb += v.done.eq(~v.err)
402
403 with m.If(~r.addr[63]):
404 comb += effpid.eq(r.pid)
405
406 comb += prtable_addr.eq(Cat(
407 C(0b0000, 4),
408 effpid[0:8],
409 (r.prtbl[12:36] & ~finalmask[0:24]) |
410 (effpid[8:32] & finalmask[0:24]),
411 r.prtbl[36:56]
412 ))
413
414 comb += pgtable_addr.eq(Cat(
415 C(0b000, 3),
416 (r.pgbase[3:19] & ~mask) |
417 (addrsh & mask),
418 r.pgbase[19:56]
419 ))
420
421 comb += pte.eq(Cat(
422 r.pde[0:12],
423 (r.pde[12:56] & ~finalmask) |
424 (r.addr[12:56] & finalmask),
425 ))
426
427 # update registers
428 rin.eq(v)
429
430 # drive outputs
431 with m.If(tlbie_req):
432 comb += addr.eq(r.addr)
433 with m.Elif(tlb_load):
434 comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
435 comb += tlb_data.eq(pte)
436 with m.Elif(prtbl_rd):
437 comb += addr.eq(prtable_addr)
438 with m.Else():
439 comb += addr.eq(pgtable_addr)
440
441 comb += l_out.done.eq(r.done)
442 comb += l_out.err.eq(r.err)
443 comb += l_out.invalid.eq(r.invalid)
444 comb += l_out.badtree.eq(r.badtree)
445 comb += l_out.segerr.eq(r.segerror)
446 comb += l_out.perm_error.eq(r.perm_err)
447 comb += l_out.rc_error.eq(r.rc_error)
448
449 comb += d_out.valid.eq(dcreq)
450 comb += d_out.tlbie.eq(tlbie_req)
451 comb += d_out.doall.eq(r.inval_all)
452 comb += d_out.tlbld.eq(tlb_load)
453 comb += d_out.addr.eq(addr)
454 comb += d_out.pte.eq(tlb_data)
455
456 comb += i_out.tlbld.eq(itlb_load)
457 comb += i_out.tlbie.eq(tlbie_req)
458 comb += i_out.doall.eq(r.inval_all)
459 comb += i_out.addr.eq(addr)
460 comb += i_out.pte.eq(tlb_data)
461
462 return m
463
464
465 def mmu_sim():
466 yield wp.waddr.eq(1)
467 yield wp.data_i.eq(2)
468 yield wp.wen.eq(1)
469 yield
470 yield wp.wen.eq(0)
471 yield rp.ren.eq(1)
472 yield rp.raddr.eq(1)
473 yield Settle()
474 data = yield rp.data_o
475 print(data)
476 assert data == 2
477 yield
478
479 yield wp.waddr.eq(5)
480 yield rp.raddr.eq(5)
481 yield rp.ren.eq(1)
482 yield wp.wen.eq(1)
483 yield wp.data_i.eq(6)
484 yield Settle()
485 data = yield rp.data_o
486 print(data)
487 assert data == 6
488 yield
489 yield wp.wen.eq(0)
490 yield rp.ren.eq(0)
491 yield Settle()
492 data = yield rp.data_o
493 print(data)
494 assert data == 0
495 yield
496 data = yield rp.data_o
497 print(data)
498
499 def test_mmu():
500 dut = MMU()
501 vl = rtlil.convert(dut, ports=[])#dut.ports())
502 with open("test_mmu.il", "w") as f:
503 f.write(vl)
504
505 run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
506
507 if __name__ == '__main__':
508 test_mmu()