add detection and disable of Instruction Wishbone based on JTAG command
[soc.git] / src / soc / minerva / units / fetch.py
index b140aa20d300b4b837552a8e124928b2d8efd2f8..a1f14b3d5dfbeafcee830b86a0fc4192920b0b41 100644 (file)
-from nmigen import Elaboratable, Module, Signal, Record
+from nmigen import Elaboratable, Module, Signal, Record, Const, Mux
 from nmigen.utils import log2_int
 
-from ..cache import L1Cache
-from ..wishbone import wishbone_layout
+from soc.minerva.cache import L1Cache
+from soc.minerva.wishbone import make_wb_layout, WishboneArbiter, Cycle
 
 
-__all__ = ["PCSelector", "FetchUnitInterface", "BareFetchUnit", "CachedFetchUnit"]
-
-
-class PCSelector(Elaboratable):
-    def __init__(self):
-        self.f_pc = Signal(32)
-        self.d_pc = Signal(32)
-        self.d_branch_predict_taken = Signal()
-        self.d_branch_target = Signal(32)
-        self.d_valid = Signal()
-        self.x_pc = Signal(32)
-        self.x_fence_i = Signal()
-        self.x_valid = Signal()
-        self.m_branch_predict_taken = Signal()
-        self.m_branch_taken = Signal()
-        self.m_branch_target = Signal(32)
-        self.m_exception = Signal()
-        self.m_mret = Signal()
-        self.m_valid = Signal()
-        self.mtvec_r_base = Signal(30)
-        self.mepc_r_base = Signal(30)
-
-        self.a_pc = Signal(32)
-
-    def elaborate(self, platform):
-        m = Module()
-
-        with m.If(self.m_exception & self.m_valid):
-            m.d.comb += self.a_pc.eq(self.mtvec_r_base << 2)
-        with m.Elif(self.m_mret & self.m_valid):
-            m.d.comb += self.a_pc.eq(self.mepc_r_base << 2)
-        with m.Elif(self.m_branch_predict_taken & ~self.m_branch_taken & self.m_valid):
-            m.d.comb += self.a_pc.eq(self.x_pc)
-        with m.Elif(~self.m_branch_predict_taken & self.m_branch_taken & self.m_valid):
-            m.d.comb += self.a_pc.eq(self.m_branch_target),
-        with m.Elif(self.x_fence_i & self.x_valid):
-            m.d.comb += self.a_pc.eq(self.d_pc)
-        with m.Elif(self.d_branch_predict_taken & self.d_valid):
-            m.d.comb += self.a_pc.eq(self.d_branch_target),
-        with m.Else():
-            m.d.comb += self.a_pc.eq(self.f_pc + 4)
-
-        return m
+__all__ = ["FetchUnitInterface", "BareFetchUnit", "CachedFetchUnit"]
 
 
 class FetchUnitInterface:
-    def __init__(self):
-        self.ibus = Record(wishbone_layout)
-
-        self.a_pc = Signal(32)
-        self.a_stall = Signal()
-        self.a_valid = Signal()
-        self.f_stall = Signal()
-        self.f_valid = Signal()
-
-        self.a_busy = Signal()
-        self.f_busy = Signal()
-        self.f_instruction = Signal(32)
-        self.f_fetch_error = Signal()
-        self.f_badaddr = Signal(30)
+    def __init__(self, pspec):
+        self.pspec = pspec
+        self.addr_wid = pspec.addr_wid
+        if isinstance(pspec.imem_reg_wid, int):
+            self.data_wid = pspec.imem_reg_wid
+        else:
+            self.data_wid = pspec.reg_wid
+        self.adr_lsbs = log2_int(self.data_wid//8)
+        self.ibus = Record(make_wb_layout(pspec))
+        bad_wid = pspec.addr_wid - self.adr_lsbs # TODO: is this correct?
+
+        # inputs: address to fetch PC, and valid/stall signalling
+        self.a_pc_i = Signal(self.addr_wid)
+        self.a_stall_i = Signal()
+        self.a_valid_i = Signal()
+        self.f_stall_i = Signal()
+        self.f_valid_i = Signal()
+
+        # outputs: instruction (or error), and busy indicators
+        self.a_busy_o = Signal()
+        self.f_busy_o = Signal()
+        self.f_instr_o = Signal(self.data_wid)
+        self.f_fetch_err_o = Signal()
+        self.f_badaddr_o = Signal(bad_wid)
+
+        # detect whether the wishbone bus is enabled / disabled
+        if (hasattr(pspec, "wb_icache_en") and
+                     isinstance(pspec.wb_icache_en, Signal)):
+            self.jtag_en = pspec.wb_icache_en
+        else:
+            self.jtag_en = Const(1, 1) # permanently on
+
+
+    def __iter__(self):
+        yield self.a_pc_i
+        yield self.a_stall_i
+        yield self.a_valid_i
+        yield self.f_stall_i
+        yield self.f_valid_i
+        yield self.a_busy_o
+        yield self.f_busy_o
+        yield self.f_instr_o
+        yield self.f_fetch_err_o
+        yield self.f_badaddr_o
+        for sig in self.ibus.fields.values():
+            yield sig
+
+    def ports(self):
+        return list(self)
 
 
 class BareFetchUnit(FetchUnitInterface, Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        ibus_rdata = Signal.like(self.ibus.dat_r)
-        with m.If(self.ibus.cyc):
-            with m.If(self.ibus.ack | self.ibus.err | ~self.f_valid):
+        with m.If(self.jtag_en): # for safety, JTAG can completely disable WB
+
+            ibus_rdata = Signal.like(self.ibus.dat_r)
+            with m.If(self.ibus.cyc):
+                with m.If(self.ibus.ack | self.ibus.err | ~self.f_valid_i):
+                    m.d.sync += [
+                        self.ibus.cyc.eq(0),
+                        self.ibus.stb.eq(0),
+                        self.ibus.sel.eq(0),
+                        ibus_rdata.eq(self.ibus.dat_r)
+                    ]
+            with m.Elif(self.a_valid_i & ~self.a_stall_i):
                 m.d.sync += [
-                    self.ibus.cyc.eq(0),
-                    self.ibus.stb.eq(0),
-                    ibus_rdata.eq(self.ibus.dat_r)
+                    self.ibus.adr.eq(self.a_pc_i[self.adr_lsbs:]),
+                    self.ibus.cyc.eq(1),
+                    self.ibus.stb.eq(1),
+                    self.ibus.sel.eq((1<<(1<<self.adr_lsbs))-1),
                 ]
-        with m.Elif(self.a_valid & ~self.a_stall):
-            m.d.sync += [
-                self.ibus.adr.eq(self.a_pc[2:]),
-                self.ibus.cyc.eq(1),
-                self.ibus.stb.eq(1)
-            ]
 
-        with m.If(self.ibus.cyc & self.ibus.err):
-            m.d.sync += [
-                self.f_fetch_error.eq(1),
-                self.f_badaddr.eq(self.ibus.adr)
-            ]
-        with m.Elif(~self.f_stall):
-            m.d.sync += self.f_fetch_error.eq(0)
+            with m.If(self.ibus.cyc & self.ibus.err):
+                m.d.sync += [
+                    self.f_fetch_err_o.eq(1),
+                    self.f_badaddr_o.eq(self.ibus.adr)
+                ]
+            with m.Elif(~self.f_stall_i):
+                m.d.sync += self.f_fetch_err_o.eq(0)
 
-        m.d.comb += self.a_busy.eq(self.ibus.cyc)
+            m.d.comb += self.a_busy_o.eq(self.ibus.cyc)
 
-        with m.If(self.f_fetch_error):
-            m.d.comb += [
-                self.f_busy.eq(0),
-                self.f_instruction.eq(0x00000013) # nop (addi x0, x0, 0)
-            ]
-        with m.Else():
-            m.d.comb += [
-                self.f_busy.eq(self.ibus.cyc),
-                self.f_instruction.eq(ibus_rdata)
-            ]
+            with m.If(self.f_fetch_err_o):
+                m.d.comb += self.f_busy_o.eq(0)
+            with m.Else():
+                m.d.comb += [
+                    self.f_busy_o.eq(self.ibus.cyc),
+                    self.f_instr_o.eq(ibus_rdata)
+                ]
 
         return m
 
 
 class CachedFetchUnit(FetchUnitInterface, Elaboratable):
-    def __init__(self, *icache_args):
-        super().__init__()
+    def __init__(self, pspec):
+        super().__init__(pspec)
 
-        self.icache_args = icache_args
+        self.icache_args = pspec.icache_args
 
         self.a_flush = Signal()
-        self.f_pc = Signal(32)
+        self.f_pc = Signal(addr_wid)
 
     def elaborate(self, platform):
         m = Module()
@@ -125,81 +119,105 @@ class CachedFetchUnit(FetchUnitInterface, Elaboratable):
         icache = m.submodules.icache = L1Cache(*self.icache_args)
 
         a_icache_select = Signal()
+
+        # Test whether the target address is inside the L1 cache region.
+        # We use bit masks in order to avoid carry chains from arithmetic
+        # comparisons. This restricts the region boundaries to powers of 2.
+
+        # TODO: minerva defaults adr_lsbs to 2.  check this code
+        with m.Switch(self.a_pc_i[self.adr_lsbs:]):
+            def addr_below(limit):
+                assert limit in range(1, 2**30 + 1)
+                range_bits = log2_int(limit)
+                const_bits = 30 - range_bits
+                return "{}{}".format("0" * const_bits, "-" * range_bits)
+
+            if icache.base >= 4: # XX (1<<self.adr_lsbs?)
+                with m.Case(addr_below(icache.base >> self.adr_lsbs)):
+                    m.d.comb += a_icache_select.eq(0)
+            with m.Case(addr_below(icache.limit >> self.adr_lsbs)):
+                m.d.comb += a_icache_select.eq(1)
+            with m.Default():
+                m.d.comb += a_icache_select.eq(0)
+
         f_icache_select = Signal()
+        f_flush = Signal()
 
-        m.d.comb += a_icache_select.eq((self.a_pc >= icache.base) & (self.a_pc < icache.limit))
-        with m.If(~self.a_stall):
-            m.d.sync += f_icache_select.eq(a_icache_select)
+        with m.If(~self.a_stall_i):
+            m.d.sync += [
+                f_icache_select.eq(a_icache_select),
+                f_flush.eq(self.a_flush),
+            ]
 
         m.d.comb += [
-            icache.s1_addr.eq(self.a_pc[2:]),
+            icache.s1_addr.eq(self.a_pc_i[self.adr_lsbs:]),
             icache.s1_flush.eq(self.a_flush),
-            icache.s1_stall.eq(self.a_stall),
-            icache.s1_valid.eq(self.a_valid & a_icache_select),
-            icache.s2_addr.eq(self.f_pc[2:]),
+            icache.s1_stall.eq(self.a_stall_i),
+            icache.s1_valid.eq(self.a_valid_i & a_icache_select),
+            icache.s2_addr.eq(self.f_pc[self.adr_lsbs:]),
             icache.s2_re.eq(Const(1)),
             icache.s2_evict.eq(Const(0)),
-            icache.s2_valid.eq(self.f_valid & f_icache_select)
+            icache.s2_valid.eq(self.f_valid_i & f_icache_select)
         ]
 
-        ibus_arbiter = m.submodules.ibus_arbiter = WishboneArbiter()
-        m.d.comb += ibus_arbiter.bus.connect(self.ibus)
+        iba = WishboneArbiter(self.pspec)
+        m.submodules.ibus_arbiter = iba
+        m.d.comb += iba.bus.connect(self.ibus)
 
-        icache_port = ibus_arbiter.port(priority=0)
+        icache_port = iba.port(priority=0)
+        cti = Mux(icache.bus_last, Cycle.END, Cycle.INCREMENT)
         m.d.comb += [
             icache_port.cyc.eq(icache.bus_re),
             icache_port.stb.eq(icache.bus_re),
             icache_port.adr.eq(icache.bus_addr),
-            icache_port.cti.eq(Mux(icache.bus_last, Cycle.END, Cycle.INCREMENT)),
+            icache_port.cti.eq(cti),
             icache_port.bte.eq(Const(log2_int(icache.nwords) - 1)),
             icache.bus_valid.eq(icache_port.ack),
             icache.bus_error.eq(icache_port.err),
             icache.bus_rdata.eq(icache_port.dat_r)
         ]
 
-        bare_port = ibus_arbiter.port(priority=1)
+        bare_port = iba.port(priority=1)
         bare_rdata = Signal.like(bare_port.dat_r)
         with m.If(bare_port.cyc):
-            with m.If(bare_port.ack | bare_port.err | ~self.f_valid):
+            with m.If(bare_port.ack | bare_port.err | ~self.f_valid_i):
                 m.d.sync += [
                     bare_port.cyc.eq(0),
                     bare_port.stb.eq(0),
+                    bare_port.sel.eq(0),
                     bare_rdata.eq(bare_port.dat_r)
                 ]
-        with m.Elif(~a_icache_select & self.a_valid & ~self.a_stall):
+        with m.Elif(~a_icache_select & self.a_valid_i & ~self.a_stall_i):
             m.d.sync += [
                 bare_port.cyc.eq(1),
                 bare_port.stb.eq(1),
-                bare_port.adr.eq(self.a_pc[2:])
+                bare_port.sel.eq((1<<(1<<self.adr_lsbs))-1),
+                bare_port.adr.eq(self.a_pc_i[self.adr_lsbs:])
             ]
 
+        m.d.comb += self.a_busy_o.eq(bare_port.cyc)
+
         with m.If(self.ibus.cyc & self.ibus.err):
             m.d.sync += [
-                self.f_fetch_error.eq(1),
-                self.f_badaddr.eq(self.ibus.adr)
+                self.f_fetch_err_o.eq(1),
+                self.f_badaddr_o.eq(self.ibus.adr)
             ]
-        with m.Elif(~self.f_stall):
-            m.d.sync += self.f_fetch_error.eq(0)
+        with m.Elif(~self.f_stall_i):
+            m.d.sync += self.f_fetch_err_o.eq(0)
 
-        with m.If(a_icache_select):
-            m.d.comb += self.a_busy.eq(0)
-        with m.Else():
-            m.d.comb += self.a_busy.eq(bare_port.cyc)
-
-        with m.If(self.f_fetch_error):
-            m.d.comb += [
-                self.f_busy.eq(0),
-                self.f_instruction.eq(0x00000013) # nop (addi x0, x0, 0)
-            ]
+        with m.If(f_flush):
+            m.d.comb += self.f_busy_o.eq(~icache.s2_flush_ack)
+        with m.Elif(self.f_fetch_err_o):
+            m.d.comb += self.f_busy_o.eq(0)
         with m.Elif(f_icache_select):
             m.d.comb += [
-                self.f_busy.eq(icache.s2_re & icache.s2_miss),
-                self.f_instruction.eq(icache.s2_rdata)
+                self.f_busy_o.eq(icache.s2_miss),
+                self.f_instr_o.eq(icache.s2_rdata)
             ]
         with m.Else():
             m.d.comb += [
-                self.f_busy.eq(bare_port.cyc),
-                self.f_instruction.eq(bare_rdata)
+                self.f_busy_o.eq(bare_port.cyc),
+                self.f_instr_o.eq(bare_rdata)
             ]
 
         return m