+# MMU
+#
+# License for original copyright mmu.vhdl by microwatt authors: CC4
+# License for copyrighted modifications made in mmu.py: LGPLv3+
+#
+# This derivative work although includes CC4 licensed material is
+# covered by the LGPLv3+
+
"""MMU
based on Anton Blanchard microwatt mmu.vhdl
from nmigen.cli import rtlil
from nmutil.iocontrol import RecordObject
from nmutil.byterev import byte_reverse
+from nmutil.mask import Mask, masked
+from nmutil.util import Display
+
+if True:
+ from nmigen.back.pysim import Simulator, Delay, Settle
+else:
+ from nmigen.sim.cxxsim import Simulator, Delay, Settle
+from nmutil.util import wrap
-from soc.experiment.mem_types import (LoadStore1ToMmuType,
- MmuToLoadStore1Type,
- MmuToDcacheType,
- DcacheToMmuType,
- MmuToIcacheType)
+from soc.experiment.mem_types import (LoadStore1ToMMUType,
+ MMUToLoadStore1Type,
+ MMUToDCacheType,
+ DCacheToMMUType,
+ MMUToICacheType)
-# -- Radix MMU
-# -- Supports 4-level trees as in arch 3.0B, but not the
-# -- two-step translation
-# -- for guests under a hypervisor (i.e. there is no gRA -> hRA translation).
@unique
class State(Enum):
(i.e. there is no gRA -> hRA translation).
"""
def __init__(self):
- self.l_in = LoadStore1ToMmuType()
- self.l_out = MmuToLoadStore1Type()
- self.d_out = MmuToDcacheType()
- self.d_in = DcacheToMmuType()
- self.i_out = MmuToIcacheType()
+ self.l_in = LoadStore1ToMMUType()
+ self.l_out = MMUToLoadStore1Type()
+ self.d_out = MMUToDCacheType()
+ self.d_in = DCacheToMMUType()
+ self.i_out = MMUToICacheType()
def radix_tree_idle(self, m, l_in, r, v):
comb = m.d.comb
+ sync = m.d.sync
+
pt_valid = Signal()
pgtbl = Signal(64)
+ rts = Signal(6)
+ mbits = Signal(6)
+
with m.If(~l_in.addr[63]):
comb += pgtbl.eq(r.pgtbl0)
comb += pt_valid.eq(r.pt0_valid)
with m.Else():
- comb += pgtbl.eq(r.pt3_valid)
+ comb += pgtbl.eq(r.pgtbl3)
comb += pt_valid.eq(r.pt3_valid)
# rts == radix tree size, number of address bits
# being translated
- rts = Signal(6)
comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
# mbits == number of address bits to index top
# level of tree
- mbits = Signal(6)
comb += mbits.eq(pgtbl[0:5])
# set v.shift to rts so that we can use finalmask
comb += v.addr.eq(l_in.addr)
comb += v.iside.eq(l_in.iside)
comb += v.store.eq(~(l_in.load | l_in.iside))
+ comb += v.priv.eq(l_in.priv)
+
+ comb += Display("state %d l_in.valid addr %x iside %d store %d "
+ "rts %x mbits %x pt_valid %d",
+ v.state, v.addr, v.iside, v.store,
+ rts, mbits, pt_valid)
with m.If(l_in.tlbie):
# Invalidate all iTLB/dTLB entries for
comb += v.shift.eq(r.prtbl[0:5])
comb += v.state.eq(State.PROC_TBL_READ)
- with m.If(~mbits):
+ with m.Elif(mbits == 0):
# Use RPDS = 0 to disable radix tree walks
comb += v.state.eq(State.RADIX_FINISH)
comb += v.invalid.eq(1)
with m.Else():
comb += v.pgtbl0.eq(data)
comb += v.pt0_valid.eq(1)
- # rts == radix tree size, # address bits being translated
+
rts = Signal(6)
+ mbits = Signal(6)
+
+ # rts == radix tree size, # address bits being translated
comb += rts.eq(Cat(data[5:8], data[61:63]))
# mbits == # address bits to index top level of tree
- mbits = Signal(6)
comb += mbits.eq(data[0:5])
- # set v.shift to rts so that we can use
- # finalmask for the segment check
+
+ # set v.shift to rts so that we can use finalmask for the segment check
comb += v.shift.eq(rts)
comb += v.mask_size.eq(mbits[0:5])
comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
- with m.If(~mbits):
+ with m.If(mbits):
+ comb += v.state.eq(State.SEGMENT_CHECK)
+ with m.Else():
comb += v.state.eq(State.RADIX_FINISH)
comb += v.invalid.eq(1)
- comb += v.state.eq(State.SEGMENT_CHECK)
def radix_read_wait(self, m, v, r, d_in, data):
comb = m.d.comb
+ sync = m.d.sync
+
+ perm_ok = Signal()
+ rc_ok = Signal()
+ mbits = Signal(6)
+ valid = Signal()
+ leaf = Signal()
+ badtree = Signal()
+
+ comb += Display("RDW %016x done %d "
+ "perm %d rc %d mbits %d shf %d "
+ "valid %d leaf %d bad %d",
+ data, d_in.done, perm_ok, rc_ok,
+ mbits, r.shift, valid, leaf, badtree)
+
+ # set pde
comb += v.pde.eq(data)
+
# test valid bit
- with m.If(data[63]):
- with m.If(data[62]):
+ comb += valid.eq(data[63]) # valid=data[63]
+ comb += leaf.eq(data[62]) # valid=data[63]
+
+ comb += v.pde.eq(data)
+ # valid & leaf
+ with m.If(valid):
+ with m.If(leaf):
# check permissions and RC bits
- perm_ok = Signal()
- comb += perm_ok.eq(0)
with m.If(r.priv | ~data[3]):
with m.If(~r.iside):
- comb += perm_ok.eq((data[1] | data[2]) & (~r.store))
+ comb += perm_ok.eq(data[1] | (data[2] & ~r.store))
with m.Else():
- # no IAMR, so no KUEP support
- # for now deny execute
- # permission if cache inhibited
+ # no IAMR, so no KUEP support for now
+ # deny execute permission if cache inhibited
comb += perm_ok.eq(data[0] & ~data[5])
- rc_ok = Signal()
- comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
+ comb += rc_ok.eq(data[8] & (data[7] | ~r.store))
with m.If(perm_ok & rc_ok):
comb += v.state.eq(State.RADIX_LOAD_TLB)
with m.Else():
comb += v.state.eq(State.RADIX_FINISH)
comb += v.perm_err.eq(~perm_ok)
- # permission error takes precedence
- # over RC error
+ # permission error takes precedence over RC error
comb += v.rc_error.eq(perm_ok)
+
+ # valid & !leaf
with m.Else():
- mbits = Signal(6)
comb += mbits.eq(data[0:5])
- with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
+ comb += badtree.eq((mbits < 5) |
+ (mbits > 16) |
+ (mbits > r.shift))
+ with m.If(badtree):
comb += v.state.eq(State.RADIX_FINISH)
comb += v.badtree.eq(1)
with m.Else():
- comb += v.shift.eq(v.shift - mbits)
+ comb += v.shift.eq(r.shift - mbits)
comb += v.mask_size.eq(mbits[0:5])
comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
comb += v.state.eq(State.RADIX_LOOKUP)
+ with m.Else():
+ # non-present PTE, generate a DSI
+ comb += v.state.eq(State.RADIX_FINISH)
+ comb += v.invalid.eq(1)
+
def segment_check(self, m, v, r, data, finalmask):
comb = m.d.comb
+
mbits = Signal(6)
nonzero = Signal()
comb += mbits.eq(r.mask_size)
with m.Else():
comb += v.state.eq(State.RADIX_LOOKUP)
- def elaborate(self, platform):
- m = Module()
-
+ def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
comb = m.d.comb
sync = m.d.sync
- addrsh = Signal(16)
- mask = Signal(16)
- finalmask = Signal(44)
-
- r = RegStage("r")
- rin = RegStage("r_in")
-
- l_in = self.l_in
- l_out = self.l_out
- d_out = self.d_out
- d_in = self.d_in
- i_out = self.i_out
-
# Multiplex internal SPR values back to loadstore1,
# selected by l_in.sprn.
with m.If(l_in.sprn[9]):
comb += l_out.sprval.eq(r.pid)
with m.If(rin.valid):
- pass
- #sync += Display(f"MMU got tlb miss for {rin.addr}")
+ sync += Display("MMU got tlb miss for %x", rin.addr)
with m.If(l_out.done):
- pass
- # sync += Display("MMU completing op without error")
+ sync += Display("MMU completing op without error")
with m.If(l_out.err):
- pass
- # sync += Display(f"MMU completing op with err invalid"
- # "{l_out.invalid} badtree={l_out.badtree}")
+ sync += Display("MMU completing op with err invalid"
+ "%d badtree=%d", l_out.invalid, l_out.badtree)
with m.If(rin.state == State.RADIX_LOOKUP):
- pass
- # sync += Display (f"radix lookup shift={rin.shift}"
- # "msize={rin.mask_size}")
+ sync += Display ("radix lookup shift=%d msize=%d",
+ rin.shift, rin.mask_size)
with m.If(r.state == State.RADIX_LOOKUP):
- pass
- # sync += Display(f"send load addr={d_out.addr}"
- # "addrsh={addrsh} mask={mask}")
-
+ sync += Display(f"send load addr=%x addrsh=%d mask=%x",
+ d_out.addr, addrsh, mask)
sync += r.eq(rin)
+ def elaborate(self, platform):
+ m = Module()
+
+ comb = m.d.comb
+ sync = m.d.sync
+
+ addrsh = Signal(16)
+ mask = Signal(16)
+ finalmask = Signal(44)
+
+ self.rin = rin = RegStage("r_in")
+ r = RegStage("r")
+
+ l_in = self.l_in
+ l_out = self.l_out
+ d_out = self.d_out
+ d_in = self.d_in
+ i_out = self.i_out
+
+ self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
+
v = RegStage()
dcreq = Signal()
tlb_load = Signal()
tlbie_req = Signal()
prtbl_rd = Signal()
effpid = Signal(32)
- prtable_addr = Signal(64)
- pgtable_addr = Signal(64)
+ prtb_adr = Signal(64)
+ pgtb_adr = Signal(64)
pte = Signal(64)
tlb_data = Signal(64)
addr = Signal(64)
data = byte_reverse(m, "data", d_in.data, 8)
# generate mask for extracting address fields for PTE addr generation
- comb += mask.eq(Cat(C(0x1f,5), ((1<<r.mask_size)-1)))
+ m.submodules.pte_mask = pte_mask = Mask(16-5)
+ comb += pte_mask.shift.eq(r.mask_size - 5)
+ comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
# generate mask for extracting address bits to go in
# TLB entry in order to support pages > 4kB
- comb += finalmask.eq(((1<<r.shift)-1))
+ m.submodules.tlb_mask = tlb_mask = Mask(44)
+ comb += tlb_mask.shift.eq(r.shift)
+ comb += finalmask.eq(tlb_mask.mask)
+
+ with m.If(r.state != State.IDLE):
+ sync += Display("MMU state %d %016x", r.state, data)
with m.Switch(r.state):
with m.Case(State.IDLE):
comb += v.state.eq(State.RADIX_FINISH)
with m.Case(State.PROC_TBL_READ):
+ sync += Display(" TBL_READ %016x", prtb_adr)
comb += dcreq.eq(1)
comb += prtbl_rd.eq(1)
comb += v.state.eq(State.PROC_TBL_WAIT)
self.segment_check(m, v, r, data, finalmask)
with m.Case(State.RADIX_LOOKUP):
+ sync += Display(" RADIX_LOOKUP")
comb += dcreq.eq(1)
comb += v.state.eq(State.RADIX_READ_WAIT)
with m.Case(State.RADIX_READ_WAIT):
+ sync += Display(" READ_WAIT")
with m.If(d_in.done):
self.radix_read_wait(m, v, r, d_in, data)
- with m.Else():
- # non-present PTE, generate a DSI
- comb += v.state.eq(State.RADIX_FINISH)
- comb += v.invalid.eq(1)
-
with m.If(d_in.err):
comb += v.state.eq(State.RADIX_FINISH)
comb += v.badtree.eq(1)
comb += v.state.eq(State.IDLE)
with m.Case(State.RADIX_FINISH):
+ sync += Display(" RADIX_FINISH")
comb += v.state.eq(State.IDLE)
with m.If((v.state == State.RADIX_FINISH) |
with m.If(~r.addr[63]):
comb += effpid.eq(r.pid)
- comb += prtable_addr.eq(Cat(
- C(0b0000, 4),
- effpid[0:8],
- (r.prtbl[12:36] & ~finalmask[0:24]) |
- (effpid[8:32] & finalmask[0:24]),
- r.prtbl[36:56]
- ))
-
- comb += pgtable_addr.eq(Cat(
- C(0b000, 3),
- (r.pgbase[3:19] & ~mask) |
- (addrsh & mask),
- r.pgbase[19:56]
- ))
-
- comb += pte.eq(Cat(
- r.pde[0:12],
- (r.pde[12:56] & ~finalmask) |
- (r.addr[12:56] & finalmask),
- ))
+ pr24 = Signal(24, reset_less=True)
+ comb += pr24.eq(masked(r.prtbl[12:36], effpid[8:32], finalmask))
+ comb += prtb_adr.eq(Cat(C(0, 4), effpid[0:8], pr24, r.prtbl[36:56]))
+
+ pg16 = Signal(16, reset_less=True)
+ comb += pg16.eq(masked(r.pgbase[3:19], addrsh, mask))
+ comb += pgtb_adr.eq(Cat(C(0, 3), pg16, r.pgbase[19:56]))
+
+ pd44 = Signal(44, reset_less=True)
+ comb += pd44.eq(masked(r.pde[12:56], r.addr[12:56], finalmask))
+ comb += pte.eq(Cat(r.pde[0:12], pd44))
# update registers
- rin.eq(v)
+ comb += rin.eq(v)
# drive outputs
with m.If(tlbie_req):
comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
comb += tlb_data.eq(pte)
with m.Elif(prtbl_rd):
- comb += addr.eq(prtable_addr)
+ comb += addr.eq(prtb_adr)
with m.Else():
- comb += addr.eq(pgtable_addr)
+ comb += addr.eq(pgtb_adr)
comb += l_out.done.eq(r.done)
comb += l_out.err.eq(r.err)
return m
+stop = False
+
+def dcache_get(dut):
+ """simulator process for getting memory load requests
+ """
-def mmu_sim():
- yield wp.waddr.eq(1)
- yield wp.data_i.eq(2)
- yield wp.wen.eq(1)
+ global stop
+
+ def b(x):
+ return int.from_bytes(x.to_bytes(8, byteorder='little'),
+ byteorder='big', signed=False)
+
+ mem = {0x0: 0x000000, # to get mtspr prtbl working
+
+ 0x10000: # PARTITION_TABLE_2
+ # PATB_GR=1 PRTB=0x1000 PRTS=0xb
+ b(0x800000000100000b),
+
+ 0x30000: # RADIX_ROOT_PTE
+ # V = 1 L = 0 NLB = 0x400 NLS = 9
+ b(0x8000000000040009),
+
+ 0x40000: # RADIX_SECOND_LEVEL
+ # V = 1 L = 1 SW = 0 RPN = 0
+ # R = 1 C = 1 ATT = 0 EAA 0x7
+ b(0xc000000000000187),
+
+ 0x1000000: # PROCESS_TABLE_3
+ # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
+ b(0x40000000000300ad),
+ }
+
+ while not stop:
+ while True: # wait for dc_valid
+ if stop:
+ return
+ dc_valid = yield (dut.d_out.valid)
+ if dc_valid:
+ break
+ yield
+ addr = yield dut.d_out.addr
+ if addr not in mem:
+ print (" DCACHE LOOKUP FAIL %x" % (addr))
+ stop = True
+ return
+
+ yield
+ data = mem[addr]
+ yield dut.d_in.data.eq(data)
+ print (" DCACHE GET %x data %x" % (addr, data))
+ yield dut.d_in.done.eq(1)
+ yield
+ yield dut.d_in.done.eq(0)
+
+def mmu_wait(dut):
+ global stop
+ while not stop: # wait for dc_valid / err
+ l_done = yield (dut.l_out.done)
+ l_err = yield (dut.l_out.err)
+ l_badtree = yield (dut.l_out.badtree)
+ l_permerr = yield (dut.l_out.perm_error)
+ l_rc_err = yield (dut.l_out.rc_error)
+ l_segerr = yield (dut.l_out.segerr)
+ l_invalid = yield (dut.l_out.invalid)
+ if (l_done or l_err or l_badtree or
+ l_permerr or l_rc_err or l_segerr or l_invalid):
+ break
+ yield
+ yield dut.l_in.valid.eq(0) # data already in MMU by now
+ yield dut.l_in.mtspr.eq(0) # captured by RegStage(s)
+ yield dut.l_in.load.eq(0) # can reset everything safely
+
+def mmu_sim(dut):
+ global stop
+
+ # MMU MTSPR set prtbl
+ yield dut.l_in.mtspr.eq(1)
+ yield dut.l_in.sprn[9].eq(1) # totally fake way to set SPR=prtbl
+ yield dut.l_in.rs.eq(0x1000000) # set process table
+ yield dut.l_in.valid.eq(1)
+ yield from mmu_wait(dut)
yield
- yield wp.wen.eq(0)
- yield rp.ren.eq(1)
- yield rp.raddr.eq(1)
- yield Settle()
- data = yield rp.data_o
- print(data)
- assert data == 2
+ yield dut.l_in.sprn.eq(0)
+ yield dut.l_in.rs.eq(0)
yield
- yield wp.waddr.eq(5)
- yield rp.raddr.eq(5)
- yield rp.ren.eq(1)
- yield wp.wen.eq(1)
- yield wp.data_i.eq(6)
- yield Settle()
- data = yield rp.data_o
- print(data)
- assert data == 6
- yield
- yield wp.wen.eq(0)
- yield rp.ren.eq(0)
- yield Settle()
- data = yield rp.data_o
- print(data)
- assert data == 0
+ prtbl = yield (dut.rin.prtbl)
+ print ("prtbl after MTSPR %x" % prtbl)
+ assert prtbl == 0x1000000
+
+ #yield dut.rin.prtbl.eq(0x1000000) # manually set process table
+ #yield
+
+
+ # MMU PTE request
+ yield dut.l_in.load.eq(1)
+ yield dut.l_in.priv.eq(1)
+ yield dut.l_in.addr.eq(0x10000)
+ yield dut.l_in.valid.eq(1)
+ yield from mmu_wait(dut)
+
+ addr = yield dut.d_out.addr
+ pte = yield dut.d_out.pte
+ l_done = yield (dut.l_out.done)
+ l_err = yield (dut.l_out.err)
+ l_badtree = yield (dut.l_out.badtree)
+ print ("translated done %d err %d badtree %d addr %x pte %x" % \
+ (l_done, l_err, l_badtree, addr, pte))
yield
- data = yield rp.data_o
- print(data)
+ yield dut.l_in.priv.eq(0)
+ yield dut.l_in.addr.eq(0)
+
+
+ stop = True
+
def test_mmu():
dut = MMU()
with open("test_mmu.il", "w") as f:
f.write(vl)
- run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
+ m = Module()
+ m.submodules.mmu = dut
+
+ # nmigen Simulation
+ sim = Simulator(m)
+ sim.add_clock(1e-6)
+
+ sim.add_sync_process(wrap(mmu_sim(dut)))
+ sim.add_sync_process(wrap(dcache_get(dut)))
+ with sim.write_vcd('test_mmu.vcd'):
+ sim.run()
if __name__ == '__main__':
test_mmu()