From 839dedf481ac19ab57bd2ca816bc444772cf3ca3 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Sun, 16 Aug 2020 11:11:48 +0100 Subject: [PATCH] big reorg, shuffle code to functions, makes the FSM clearer --- src/soc/experiment/mmu.py | 340 ++++++++++++++++++++------------------ 1 file changed, 179 insertions(+), 161 deletions(-) diff --git a/src/soc/experiment/mmu.py b/src/soc/experiment/mmu.py index 28e2675d..46d9867c 100644 --- a/src/soc/experiment/mmu.py +++ b/src/soc/experiment/mmu.py @@ -6,6 +6,7 @@ based on Anton Blanchard microwatt mmu.vhdl from enum import Enum, unique from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal) from nmigen.cli import main +from nmigen.cli import rtlil from nmutil.iocontrol import RecordObject from nmutil.byterev import byte_reverse @@ -160,6 +161,171 @@ class MMU(Elaboratable): self.d_in = DcacheToMmuType() self.i_out = MmuToIcacheType() + def radix_tree_idle(self, m, l_in, v): + comb = m.d.comb + pt_valid = Signal() + pgtbl = Signal(64) + with m.If(~l_in.addr[63]): + comb += pgtbl.eq(r.pgtbl0) + comb += pt_valid.eq(r.pt0_valid) + with m.Else(): + comb += pgtbl.eq(r.pt3_valid) + comb += pt_valid.eq(r.pt3_valid) + + # rts == radix tree size, number of address bits + # being translated + rts = Signal(6) + comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63])) + + # mbits == number of address bits to index top + # level of tree + mbits = Signal(6) + comb += mbits.eq(pgtbl[0:5]) + + # set v.shift to rts so that we can use finalmask + # for the segment check + comb += v.shift.eq(rts) + comb += v.mask_size.eq(mbits[0:5]) + comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56])) + + with m.If(l_in.valid): + comb += v.addr.eq(l_in.addr) + comb += v.iside.eq(l_in.iside) + comb += v.store.eq(~(l_in.load | l_in.iside)) + + with m.If(l_in.tlbie): + # Invalidate all iTLB/dTLB entries for + # tlbie with RB[IS] != 0 or RB[AP] != 0, + # or for slbia + comb += v.inval_all.eq(l_in.slbia + | l_in.addr[11] + | l_in.addr[10] + | l_in.addr[7] + | l_in.addr[6] + | l_in.addr[5] + ) + # The RIC field of the tlbie instruction + # comes across on the sprn bus as bits 2--3. + # RIC=2 flushes process table caches. + with m.If(l_in.sprn[3]): + comb += v.pt0_valid.eq(0) + comb += v.pt3_valid.eq(0) + comb += v.state.eq(State.DO_TLBIE) + with m.Else(): + comb += v.valid.eq(1) + with m.If(~pt_valid): + # need to fetch process table entry + # set v.shift so we can use finalmask + # for generating the process table + # entry address + comb += v.shift.eq(r.prtble[0:5]) + comb += v.state.eq(State.PROC_TBL_READ) + + with m.If(~mbits): + # Use RPDS = 0 to disable radix tree walks + comb += v.state.eq(State.RADIX_FINISH) + comb += v.invalid.eq(1) + with m.Else(): + comb += v.state.eq(State.SEGMENT_CHECK) + + with m.If(l_in.mtspr): + # Move to PID needs to invalidate L1 TLBs + # and cached pgtbl0 value. Move to PRTBL + # does that plus invalidating the cached + # pgtbl3 value as well. + with m.If(~l_in.sprn[9]): + comb += v.pid.eq(l_in.rs[0:32]) + with m.Else(): + comb += v.prtbl.eq(l_in.rs) + comb += v.pt3_valid.eq(0) + + comb += v.pt0_valid.eq(0) + comb += v.inval_all.eq(1) + comb += v.state.eq(State.DO_TLBIE) + + def proc_tbl_wait(self, m, v, r, data): + comb = m.d.comb + with m.If(r.addr[63]): + comb += v.pgtbl3.eq(data) + comb += v.pt3_valid.eq(1) + with m.Else(): + comb += v.pgtbl0.eq(data) + comb += v.pt0_valid.eq(1) + # rts == radix tree size, # address bits being translated + rts = Signal(6) + comb += rts.eq(Cat(data[5:8], data[61:63])) + + # mbits == # address bits to index top level of tree + mbits = Signal(6) + comb += mbits.eq(data[0:5]) + # set v.shift to rts so that we can use + # finalmask for the segment check + comb += v.shift.eq(rts) + comb += v.mask_size.eq(mbits[0:5]) + comb += v.pgbase.eq(Cat(C(0, 8), data[8:56])) + + with m.If(~mbits): + comb += v.state.eq(State.RADIX_FINISH) + comb += v.invalid.eq(1) + comb += v.state.eq(State.SEGMENT_CHECK) + + def radix_read_wait(self, m, v, r, d_in, data): + comb = m.d.comb + comb += v.pde.eq(data) + # test valid bit + with m.If(data[63]): + with m.If(data[62]): + # check permissions and RC bits + perm_ok = Signal() + comb += perm_ok.eq(0) + with m.If(r.priv | ~data[3]): + with m.If(~r.iside): + comb += perm_ok.eq((data[1] | data[2]) & (~r.store)) + with m.Else(): + # no IAMR, so no KUEP support + # for now deny execute + # permission if cache inhibited + comb += perm_ok.eq(data[0] & ~data[5]) + + rc_ok = Signal() + comb += rc_ok.eq(data[8] & (data[7] | (~r.store))) + with m.If(perm_ok & rc_ok): + comb += v.state.eq(State.RADIX_LOAD_TLB) + with m.Else(): + comb += v.state.eq(State.RADIX_ERROR) + comb += v.perm_err.eq(~perm_ok) + # permission error takes precedence + # over RC error + comb += v.rc_error.eq(perm_ok) + with m.Else(): + mbits = Signal(6) + comb += mbits.eq(data[0:5]) + with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)): + comb += v.state.eq(State.RADIX_FINISH) + comb += v.badtree.eq(1) + with m.Else(): + comb += v.shift.eq(v.shif - mbits) + comb += v.mask_size.eq(mbits[0:5]) + comb += v.pgbase.eq(Cat(C(0, 8), data[8:56])) + comb += v.state.eq(State.RADIX_LOOKUP) + + def segment_check(self, m, v, r, data): + comb = m.d.comb + mbits = Signal(6) + nonzero = Signal() + comb += mbits.eq(r.mask_size) + comb += v.shift.eq(r.shift + (31 - 12) - mbits) + comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool()) + with m.If((r.addr[63] ^ r.addr[62]) | nonzero): + comb += v.state.eq(State.RADIX_FINISH) + comb += v.segerror.eq(1) + with m.Elif((mbits < 5) | (mbits > 16) | + (mbits > (r.shift + (31-12)))): + comb += v.state.eq(State.RADIX_FINISH) + comb += v.badtree.eq(1) + with m.Else(): + comb += v.state.eq(State.RADIX_LOOKUP) + def elaborate(self, platform): m = Module() @@ -212,22 +378,16 @@ class MMU(Elaboratable): sync += r.eq(rin) v = RegStage() - dcrq = Signal() + dcreq = Signal() tlb_load = Signal() itlb_load = Signal() tlbie_req = Signal() prtbl_rd = Signal() - pt_valid = Signal() effpid = Signal(32) prtable_addr = Signal(64) - rts = Signal(6) - mbits = Signal(6) pgtable_addr = Signal(64) pte = Signal(64) tlb_data = Signal(64) - nonzero = Signal() - pgtbl = Signal(64) - rc_ok = Signal() addr = Signal(64) comb += v.eq(r) @@ -252,80 +412,7 @@ class MMU(Elaboratable): with m.Switch(r.state): with m.Case(State.IDLE): - with m.If(~l_in.addr[63]): - comb += pgtbl.eq(r.pgtbl0) - comb += pt_valid.eq(r.pt0_valid) - with m.Else(): - comb += pgtbl.eq(r.pt3_valid) - comb += pt_valid.eq(r.pt3_valid) - - # rts == radix tree size, number of address bits - # being translated - comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63])) - - # mbits == number of address bits to index top - # level of tree - comb += mbits.eq(pgtbl[0:5]) - - # set v.shift to rts so that we can use finalmask - # for the segment check - comb += v.shift.eq(rts) - comb += v.mask_size.eq(mbits[0:5]) - comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56])) - - with m.If(l_in.valid): - comb += v.addr.eq(l_in.addr) - comb += v.iside.eq(l_in.iside) - comb += v.store.eq(~(l_in.load | l_in.iside)) - with m.If(l_in.tlbie): - # Invalidate all iTLB/dTLB entries for - # tlbie with RB[IS] != 0 or RB[AP] != 0, - # or for slbia - comb += v.inval_all.eq(l_in.slbia - | l_in.addr[11] - | l_in.addr[10] - | l_in.addr[7] - | l_in.addr[6] - | l_in.addr[5] - ) - # The RIC field of the tlbie instruction - # comes across on the sprn bus as bits 2--3. - # RIC=2 flushes process table caches. - with m.If(l_in.sprn[3]): - comb += v.pt0_valid.eq(0) - comb += v.pt3_valid.eq(0) - comb += v.state.eq(State.DO_TLBIE) - with m.Else(): - comb += v.valid.eq(1) - with m.If(~pt_valid): - # need to fetch process table entry - # set v.shift so we can use finalmask - # for generating the process table - # entry address - comb += v.shift.eq(r.prtble[0:5]) - comb += v.state.eq(State.PROC_TBL_READ) - - with m.If(~mbits): - # Use RPDS = 0 to disable radix tree walks - comb += v.state.eq(State.RADIX_FINISH) - comb += v.invalid.eq(1) - with m.Else(): - comb += v.state.eq(State.SEGMENT_CHECK) - - with m.If(l_in.mtspr): - # Move to PID needs to invalidate L1 TLBs - # and cached pgtbl0 value. Move to PRTBL - # does that plus invalidating the cached - # pgtbl3 value as well. - with m.If(~l_in.sprn[9]): - comb += v.pid.eq(l_in.rs[0:32]) - with m.Else(): - comb += v.prtbl.eq(l_in.rs) - comb += v.pt3_valid.eq(0) - - comb += v.pt0_valid.eq(0) - comb += v.inval_all.eq(1) - comb += v.state.eq(State.DO_TLBIE) + self.radix_tree_idle(m, l_in, v) with m.Case(State.DO_TLBIE): comb += dcreq.eq(1) @@ -343,45 +430,14 @@ class MMU(Elaboratable): with m.Case(State.PROC_TBL_WAIT): with m.If(d_in.done): - with m.If(r.addr[63]): - comb += v.pgtbl3.eq(data) - comb += v.pt3_valid.eq(1) - with m.Else(): - comb += v.pgtbl0.eq(data) - comb += v.pt0_valid.eq(1) - # rts == radix tree size, # address bits being translated - comb += rts.eq(Cat(data[5:8], data[61:63])) - - # mbits == # address bits to index top level of tree - comb += mbits.eq(data[0:5]) - # set v.shift to rts so that we can use - # finalmask for the segment check - comb += v.shift.eq(rts) - comb += v.mask_size.eq(mbits[0:5]) - comb += v.pgbase.eq(Cat(C(0, 8), data[8:56])) - - with m.If(~mbits): - comb += v.state.eq(State.RADIX_FINISH) - comb += v.invalid.eq(1) - comb += v.state.eq(State.SEGMENT_CHECK) + self.proc_tbl_wait(m, v, r, data) with m.If(d_in.err): comb += v.state.eq(State.RADIX_FINISH) comb += v.badtree.eq(1) with m.Case(State.SEGMENT_CHECK): - comb += mbits.eq(r.mask_size) - comb += v.shift.eq(r.shift + (31 - 12) - mbits) - comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool()) - with m.If((r.addr[63] ^ r.addr[62]) | nonzero): - comb += v.state.eq(State.RADIX_FINISH) - comb += v.segerror.eq(1) - with m.Elif((mbits < 5) | (mbits > 16) - | (mbits > (r.shift + (31-12)))): - comb += v.state.eq(State.RADIX_FINISH) - comb += v.badtree.eq(1) - with m.Else(): - comb += v.state.eq(State.RADIX_LOOKUP) + self.segment_check(m, v, r, data) with m.Case(State.RADIX_LOOKUP): comb += dcreq.eq(1) @@ -389,47 +445,11 @@ class MMU(Elaboratable): with m.Case(State.RADIX_READ_WAIT): with m.If(d_in.done): - comb += v.pde.eq(data) - # test valid bit - with m.If(data[63]): - with m.If(data[62]): - # check permissions and RC bits - perm_ok = Signal() - comb += perm_ok.eq(0) - with m.If(r.priv | ~data[3]): - with m.If(~r.iside): - comb += perm_ok.eq((data[1] | data[2]) & - (~r.store)) - with m.Else(): - # no IAMR, so no KUEP support - # for now deny execute - # permission if cache inhibited - comb += perm_ok.eq(data[0] & ~data[5]) - - comb += rc_ok.eq(data[8] & (data[7] | (~r.store))) - with m.If(perm_ok & rc_ok): - comb += v.state.eq(State.RADIX_LOAD_TLB) - with m.Else(): - comb += v.state.eq(State.RADIX_ERROR) - comb += v.perm_err.eq(~perm_ok) - # permission error takes precedence - # over RC error - comb += v.rc_error.eq(perm_ok) - with m.Else(): - comb += mbits.eq(data[0:5]) - with m.If((mbits < 5) | (mbits > 16) | - (mbits > r.shift)): - comb += v.state.eq(State.RADIX_FINISH) - comb += v.badtree.eq(1) - with m.Else(): - comb += v.shift.eq(v.shif - mbits) - comb += v.mask_size.eq(mbits[0:5]) - comb += v.pgbase.eq(Cat(C(0, 8), data[8:56])) - comb += v.state.eq(State.RADIX_LOOKUP) - with m.Else(): - # non-present PTE, generate a DSI - comb += v.state.eq(State.RADIX_FINISH) - comb += v.invalid.eq(1) + self.radix_read_wait(m, v, r, d_in, data) + with m.Else(): + # non-present PTE, generate a DSI + comb += v.state.eq(State.RADIX_FINISH) + comb += v.invalid.eq(1) with m.If(d_in.err): comb += v.state.eq(State.RADIX_FINISH) @@ -447,8 +467,8 @@ class MMU(Elaboratable): with m.Case(State.RADIX_FINISH): comb += v.state.eq(State.IDLE) - with m.If((v.state == State.RADIX_FINISH) - | ((v.state == State.RADIX_LOAD_TLB) & r.iside)): + with m.If((v.state == State.RADIX_FINISH) | + ((v.state == State.RADIX_LOAD_TLB) & r.iside)): comb += v.err.eq(v.invalid | v.badtree | v.segerror | v.perm_err | v.rc_error) comb += v.done.eq(~v.err) @@ -502,7 +522,7 @@ class MMU(Elaboratable): comb += d_out.valid.eq(dcreq) comb += d_out.tlbie.eq(tlbie_req) comb += d_out.doall.eq(r.inval_all) - comb += d_out.tlbld.eeq(tlb_load) + comb += d_out.tlbld.eq(tlb_load) comb += d_out.addr.eq(addr) comb += d_out.pte.eq(tlb_data) @@ -549,9 +569,7 @@ def mmu_sim(): def test_mmu(): dut = MMU() - rp = dut.read_port() - wp = dut.write_port() - vl = rtlil.convert(dut, ports=dut.ports()) + vl = rtlil.convert(dut, ports=[])#dut.ports()) with open("test_mmu.il", "w") as f: f.write(vl) -- 2.30.2