add detection and disable of LoadStore Wishbone based on JTAG command
[soc.git] / src / soc / minerva / units / loadstore.py
index 3469851446d7b42ae7af74520bcf40559d3e734c..e9830c7a643e2364fbea48a6842bc1e1e28059c2 100644 (file)
@@ -2,150 +2,171 @@ from nmigen import Elaboratable, Module, Signal, Record, Cat, Const, Mux
 from nmigen.utils import log2_int
 from nmigen.lib.fifo import SyncFIFO
 
-from ..cache import L1Cache
-from ..isa import Funct3
-from ..wishbone import wishbone_layout, WishboneArbiter, Cycle
+from soc.minerva.cache import L1Cache
+from soc.minerva.wishbone import make_wb_layout, WishboneArbiter, Cycle
+from soc.bus.wb_downconvert import WishboneDownConvert
 
+from copy import deepcopy
 
-__all__ = ["DataSelector", "LoadStoreUnitInterface", "BareLoadStoreUnit", "CachedLoadStoreUnit"]
-
-
-class DataSelector(Elaboratable):
-    def __init__(self):
-        self.x_offset = Signal(2)
-        self.x_funct3 = Signal(3)
-        self.x_store_operand = Signal(32)
-        self.w_offset = Signal(2)
-        self.w_funct3 = Signal(3)
-        self.w_load_data = Signal(32)
-
-        self.x_misaligned = Signal()
-        self.x_mask = Signal(4)
-        self.x_store_data = Signal(32)
-        self.w_load_result = Signal((32, True))
-
-    def elaborate(self, platform):
-        m = Module()
-
-        with m.Switch(self.x_funct3):
-            with m.Case(Funct3.H, Funct3.HU):
-                m.d.comb += self.x_misaligned.eq(self.x_offset[0])
-            with m.Case(Funct3.W):
-                m.d.comb += self.x_misaligned.eq(self.x_offset.bool())
-
-        with m.Switch(self.x_funct3):
-            with m.Case(Funct3.B, Funct3.BU):
-                m.d.comb += self.x_mask.eq(0b1 << self.x_offset)
-            with m.Case(Funct3.H, Funct3.HU):
-                m.d.comb += self.x_mask.eq(0b11 << self.x_offset)
-            with m.Case(Funct3.W):
-                m.d.comb += self.x_mask.eq(0b1111)
-
-        with m.Switch(self.x_funct3):
-            with m.Case(Funct3.B):
-                m.d.comb += self.x_store_data.eq(self.x_store_operand[:8] << self.x_offset*8)
-            with m.Case(Funct3.H):
-                m.d.comb += self.x_store_data.eq(self.x_store_operand[:16] << self.x_offset[1]*16)
-            with m.Case(Funct3.W):
-                m.d.comb += self.x_store_data.eq(self.x_store_operand)
-
-        w_byte = Signal((8, True))
-        w_half = Signal((16, True))
-
-        m.d.comb += [
-            w_byte.eq(self.w_load_data.word_select(self.w_offset, 8)),
-            w_half.eq(self.w_load_data.word_select(self.w_offset[1], 16))
-        ]
-
-        with m.Switch(self.w_funct3):
-            with m.Case(Funct3.B):
-                m.d.comb += self.w_load_result.eq(w_byte)
-            with m.Case(Funct3.BU):
-                m.d.comb += self.w_load_result.eq(Cat(w_byte, 0))
-            with m.Case(Funct3.H):
-                m.d.comb += self.w_load_result.eq(w_half)
-            with m.Case(Funct3.HU):
-                m.d.comb += self.w_load_result.eq(Cat(w_half, 0))
-            with m.Case(Funct3.W):
-                m.d.comb += self.w_load_result.eq(self.w_load_data)
-
-        return m
+__all__ = ["LoadStoreUnitInterface", "BareLoadStoreUnit",
+           "CachedLoadStoreUnit"]
 
 
 class LoadStoreUnitInterface:
-    def __init__(self):
-        self.dbus = Record(wishbone_layout)
-
-        self.x_addr = Signal(32)
-        self.x_mask = Signal(4)
-        self.x_load = Signal()
-        self.x_store = Signal()
-        self.x_store_data = Signal(32)
-        self.x_stall = Signal()
-        self.x_valid = Signal()
-        self.m_stall = Signal()
-        self.m_valid = Signal()
-
-        self.x_busy = Signal()
-        self.m_busy = Signal()
-        self.m_load_data = Signal(32)
-        self.m_load_error = Signal()
-        self.m_store_error = Signal()
-        self.m_badaddr = Signal(30)
+    def __init__(self, pspec):
+        self.pspec = pspec
+        self.pspecslave = pspec
+        if (hasattr(pspec, "dmem_test_depth") and
+                     isinstance(pspec.wb_data_wid, int) and
+                    pspec.wb_data_wid != pspec.reg_wid):
+            self.dbus = Record(make_wb_layout(pspec), name="int_dbus")
+            pspecslave = deepcopy(pspec)
+            pspecslave.reg_wid = pspec.wb_data_wid
+            mask_ratio = (pspec.reg_wid // pspec.wb_data_wid)
+            pspecslave.mask_wid = pspec.mask_wid // mask_ratio
+            self.pspecslave = pspecslave
+            self.slavebus = Record(make_wb_layout(pspecslave), name="dbus")
+            self.needs_cvt = True
+        else:
+            self.needs_cvt = False
+            self.dbus = self.slavebus = Record(make_wb_layout(pspec))
+
+        # detect whether the wishbone bus is enabled / disabled
+        if (hasattr(pspec, "wb_dcache_en") and
+                     isinstance(pspec.wb_dcache_en, Signal)):
+            self.jtag_en = pspec.wb_dcache_en
+        else:
+            self.jtag_en = Const(1, 1) # permanently on
+
+        print(self.dbus.sel.shape())
+        self.mask_wid = mask_wid = pspec.mask_wid
+        self.addr_wid = addr_wid = pspec.addr_wid
+        self.data_wid = data_wid = pspec.reg_wid
+        print("loadstoreunit addr mask data", addr_wid, mask_wid, data_wid)
+        self.adr_lsbs = log2_int(mask_wid)  # LSBs of addr covered by mask
+        badwid = addr_wid-self.adr_lsbs    # TODO: is this correct?
+
+        # INPUTS
+        self.x_addr_i = Signal(addr_wid)    # address used for loads/stores
+        self.x_mask_i = Signal(mask_wid)    # Mask of which bytes to write
+        self.x_ld_i = Signal()              # set to do a memory load
+        self.x_st_i = Signal()              # set to do a memory store
+        self.x_st_data_i = Signal(data_wid)  # The data to write when storing
+
+        self.x_stall_i = Signal()           # do nothing until low
+        self.x_valid_i = Signal()           # Whether x pipeline stage is
+        # currently enabled (I
+        # think?). Set to 1 for #now
+        self.m_stall_i = Signal()           # do nothing until low
+        self.m_valid_i = Signal()           # Whether m pipeline stage is
+        # currently enabled. Set
+        # to 1 for now
+
+        # OUTPUTS
+        self.x_busy_o = Signal()            # set when the memory is busy
+        self.m_busy_o = Signal()            # set when the memory is busy
+
+        self.m_ld_data_o = Signal(data_wid)  # Data returned from memory read
+        # Data validity is NOT indicated by m_valid_i or x_valid_i as
+        # those are inputs. I believe it is valid on the next cycle
+        # after raising m_load where busy is low
+
+        self.m_load_err_o = Signal()      # if there was an error when loading
+        self.m_store_err_o = Signal()     # if there was an error when storing
+        # The address of the load/store error
+        self.m_badaddr_o = Signal(badwid)
+
+    def __iter__(self):
+        yield self.x_addr_i
+        yield self.x_mask_i
+        yield self.x_ld_i
+        yield self.x_st_i
+        yield self.x_st_data_i
+
+        yield self.x_stall_i
+        yield self.x_valid_i
+        yield self.m_stall_i
+        yield self.m_valid_i
+        yield self.x_busy_o
+        yield self.m_busy_o
+        yield self.m_ld_data_o
+        yield self.m_load_err_o
+        yield self.m_store_err_o
+        yield self.m_badaddr_o
+        for sig in self.slavebus.fields.values():
+            yield sig
+
+    def ports(self):
+        return list(self)
 
 
 class BareLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        with m.If(self.dbus.cyc):
-            with m.If(self.dbus.ack | self.dbus.err | ~self.m_valid):
+        if self.needs_cvt:
+            self.cvt = WishboneDownConvert(self.dbus, self.slavebus)
+            m.submodules.cvt = self.cvt
+
+        with m.If(self.jtag_en): # for safety, JTAG can completely disable WB
+
+            with m.If(self.dbus.cyc):
+                with m.If(self.dbus.ack | self.dbus.err | ~self.m_valid_i):
+                    m.d.sync += [
+                        self.dbus.cyc.eq(0),
+                        self.dbus.stb.eq(0),
+                        self.dbus.sel.eq(0),
+                        self.m_ld_data_o.eq(self.dbus.dat_r)
+                    ]
+            with m.Elif((self.x_ld_i | self.x_st_i) &
+                        self.x_valid_i & ~self.x_stall_i):
                 m.d.sync += [
-                    self.dbus.cyc.eq(0),
-                    self.dbus.stb.eq(0),
-                    self.m_load_data.eq(self.dbus.dat_r)
+                    self.dbus.cyc.eq(1),
+                    self.dbus.stb.eq(1),
+                    self.dbus.adr.eq(self.x_addr_i[self.adr_lsbs:]),
+                    self.dbus.sel.eq(self.x_mask_i),
+                    self.dbus.we.eq(self.x_st_i),
+                    self.dbus.dat_w.eq(self.x_st_data_i)
+                ]
+            with m.Else():
+                m.d.sync += [
+                    self.dbus.adr.eq(0),
+                    self.dbus.sel.eq(0),
+                    self.dbus.we.eq(0),
+                    self.dbus.sel.eq(0),
+                    self.dbus.dat_w.eq(0),
                 ]
-        with m.Elif((self.x_load | self.x_store) & self.x_valid & ~self.x_stall):
-            m.d.sync += [
-                self.dbus.cyc.eq(1),
-                self.dbus.stb.eq(1),
-                self.dbus.adr.eq(self.x_addr[2:]),
-                self.dbus.sel.eq(self.x_mask),
-                self.dbus.we.eq(self.x_store),
-                self.dbus.dat_w.eq(self.x_store_data)
-            ]
 
-        with m.If(self.dbus.cyc & self.dbus.err):
-            m.d.sync += [
-                self.m_load_error.eq(~self.dbus.we),
-                self.m_store_error.eq(self.dbus.we),
-                self.m_badaddr.eq(self.dbus.adr)
-            ]
-        with m.Elif(~self.m_stall):
-            m.d.sync += [
-                self.m_load_error.eq(0),
-                self.m_store_error.eq(0)
-            ]
+            with m.If(self.dbus.cyc & self.dbus.err):
+                m.d.sync += [
+                    self.m_load_err_o.eq(~self.dbus.we),
+                    self.m_store_err_o.eq(self.dbus.we),
+                    self.m_badaddr_o.eq(self.dbus.adr)
+                ]
+            with m.Elif(~self.m_stall_i):
+                m.d.sync += [
+                    self.m_load_err_o.eq(0),
+                    self.m_store_err_o.eq(0)
+                ]
 
-        m.d.comb += self.x_busy.eq(self.dbus.cyc)
+            m.d.comb += self.x_busy_o.eq(self.dbus.cyc)
 
-        with m.If(self.m_load_error | self.m_store_error):
-            m.d.comb += self.m_busy.eq(0)
-        with m.Else():
-            m.d.comb += self.m_busy.eq(self.dbus.cyc)
+            with m.If(self.m_load_err_o | self.m_store_err_o):
+                m.d.comb += self.m_busy_o.eq(0)
+            with m.Else():
+                m.d.comb += self.m_busy_o.eq(self.dbus.cyc)
 
         return m
 
 
 class CachedLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
-    def __init__(self, *dcache_args):
-        super().__init__()
+    def __init__(self, pspec):
+        super().__init__(pspec)
 
-        self.dcache_args = dcache_args
+        self.dcache_args = psiec.dcache_args
 
         self.x_fence_i = Signal()
         self.x_flush = Signal()
-        self.m_addr = Signal(32)
         self.m_load = Signal()
         self.m_store = Signal()
 
@@ -155,120 +176,146 @@ class CachedLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
         dcache = m.submodules.dcache = L1Cache(*self.dcache_args)
 
         x_dcache_select = Signal()
+        # Test whether the target address is inside the L1 cache region.
+        # We use bit masks in order to avoid carry chains from arithmetic
+        # comparisons. This restricts the region boundaries to powers of 2.
+        with m.Switch(self.x_addr_i[self.adr_lsbs:]):
+            def addr_below(limit):
+                assert limit in range(1, 2**30 + 1)
+                range_bits = log2_int(limit)
+                const_bits = 30 - range_bits
+                return "{}{}".format("0" * const_bits, "-" * range_bits)
+
+            if dcache.base >= (1 << self.adr_lsbs):
+                with m.Case(addr_below(dcache.base >> self.adr_lsbs)):
+                    m.d.comb += x_dcache_select.eq(0)
+            with m.Case(addr_below(dcache.limit >> self.adr_lsbs)):
+                m.d.comb += x_dcache_select.eq(1)
+            with m.Default():
+                m.d.comb += x_dcache_select.eq(0)
+
         m_dcache_select = Signal()
+        m_addr = Signal.like(self.x_addr_i)
 
-        m.d.comb += x_dcache_select.eq((self.x_addr >= dcache.base) & (self.x_addr < dcache.limit))
-        with m.If(~self.x_stall):
-            m.d.sync += m_dcache_select.eq(x_dcache_select)
+        with m.If(~self.x_stall_i):
+            m.d.sync += [
+                m_dcache_select.eq(x_dcache_select),
+                m_addr.eq(self.x_addr_i),
+            ]
 
         m.d.comb += [
-            dcache.s1_addr.eq(self.x_addr[2:]),
+            dcache.s1_addr.eq(self.x_addr_i[self.adr_lsbs:]),
             dcache.s1_flush.eq(self.x_flush),
-            dcache.s1_stall.eq(self.x_stall),
-            dcache.s1_valid.eq(self.x_valid & x_dcache_select),
-            dcache.s2_addr.eq(self.m_addr[2:]),
+            dcache.s1_stall.eq(self.x_stall_i),
+            dcache.s1_valid.eq(self.x_valid_i & x_dcache_select),
+            dcache.s2_addr.eq(m_addr[self.adr_lsbs:]),
             dcache.s2_re.eq(self.m_load),
             dcache.s2_evict.eq(self.m_store),
-            dcache.s2_valid.eq(self.m_valid & m_dcache_select)
+            dcache.s2_valid.eq(self.m_valid_i & m_dcache_select)
         ]
 
-        wrbuf_w_data = Record([("addr", 30), ("mask", 4), ("data", 32)])
+        wrbuf_w_data = Record([("addr", self.addr_wid-self.adr_lsbs),
+                               ("mask", self.mask_wid),
+                               ("data", self.data_wid)])
         wrbuf_r_data = Record.like(wrbuf_w_data)
-        wrbuf = m.submodules.wrbuf = SyncFIFO(width=len(wrbuf_w_data), depth=dcache.nwords)
+        wrbuf = m.submodules.wrbuf = SyncFIFO(width=len(wrbuf_w_data),
+                                              depth=dcache.nwords)
         m.d.comb += [
             wrbuf.w_data.eq(wrbuf_w_data),
-            wrbuf_w_data.addr.eq(self.x_addr[2:]),
-            wrbuf_w_data.mask.eq(self.x_mask),
-            wrbuf_w_data.data.eq(self.x_store_data),
-            wrbuf.w_en.eq(self.x_store & self.x_valid & x_dcache_select & ~self.x_stall),
+            wrbuf_w_data.addr.eq(self.x_addr_i[self.adr_lsbs:]),
+            wrbuf_w_data.mask.eq(self.x_mask_i),
+            wrbuf_w_data.data.eq(self.x_st_data_i),
+            wrbuf.w_en.eq(self.x_st_i & self.x_valid_i &
+                          x_dcache_select & ~self.x_stall_i),
             wrbuf_r_data.eq(wrbuf.r_data),
         ]
 
-        dbus_arbiter = m.submodules.dbus_arbiter = WishboneArbiter()
-        m.d.comb += dbus_arbiter.bus.connect(self.dbus)
+        dba = WishboneArbiter(self.pspec)
+        m.submodules.dbus_arbiter = dba
+        m.d.comb += dba.bus.connect(self.dbus)
 
         wrbuf_port = dbus_arbiter.port(priority=0)
-        with m.If(wrbuf_port.cyc):
+        m.d.comb += [
+            wrbuf_port.cyc.eq(wrbuf.r_rdy),
+            wrbuf_port.we.eq(Const(1)),
+        ]
+        with m.If(wrbuf_port.stb):
             with m.If(wrbuf_port.ack | wrbuf_port.err):
-                m.d.sync += [
-                    wrbuf_port.cyc.eq(0),
-                    wrbuf_port.stb.eq(0)
-                ]
+                m.d.sync += wrbuf_port.stb.eq(0)
                 m.d.comb += wrbuf.r_en.eq(1)
         with m.Elif(wrbuf.r_rdy):
             m.d.sync += [
-                wrbuf_port.cyc.eq(1),
                 wrbuf_port.stb.eq(1),
                 wrbuf_port.adr.eq(wrbuf_r_data.addr),
                 wrbuf_port.sel.eq(wrbuf_r_data.mask),
                 wrbuf_port.dat_w.eq(wrbuf_r_data.data)
             ]
-        m.d.comb += wrbuf_port.we.eq(Const(1))
 
-        dcache_port = dbus_arbiter.port(priority=1)
+        dcache_port = dba.port(priority=1)
+        cti = Mux(dcache.bus_last, Cycle.END, Cycle.INCREMENT)
         m.d.comb += [
             dcache_port.cyc.eq(dcache.bus_re),
             dcache_port.stb.eq(dcache.bus_re),
             dcache_port.adr.eq(dcache.bus_addr),
-            dcache_port.cti.eq(Mux(dcache.bus_last, Cycle.END, Cycle.INCREMENT)),
+            dcache_port.cti.eq(cti),
             dcache_port.bte.eq(Const(log2_int(dcache.nwords) - 1)),
             dcache.bus_valid.eq(dcache_port.ack),
             dcache.bus_error.eq(dcache_port.err),
             dcache.bus_rdata.eq(dcache_port.dat_r)
         ]
 
-        bare_port = dbus_arbiter.port(priority=2)
+        bare_port = dba.port(priority=2)
         bare_rdata = Signal.like(bare_port.dat_r)
         with m.If(bare_port.cyc):
-            with m.If(bare_port.ack | bare_port.err | ~self.m_valid):
+            with m.If(bare_port.ack | bare_port.err | ~self.m_valid_i):
                 m.d.sync += [
                     bare_port.cyc.eq(0),
                     bare_port.stb.eq(0),
                     bare_rdata.eq(bare_port.dat_r)
                 ]
-        with m.Elif((self.x_load | self.x_store) & ~x_dcache_select & self.x_valid & ~self.x_stall):
+        with m.Elif((self.x_ld_i | self.x_st_i) &
+                    ~x_dcache_select & self.x_valid_i & ~self.x_stall_i):
             m.d.sync += [
                 bare_port.cyc.eq(1),
                 bare_port.stb.eq(1),
-                bare_port.adr.eq(self.x_addr[2:]),
-                bare_port.sel.eq(self.x_mask),
-                bare_port.we.eq(self.x_store),
-                bare_port.dat_w.eq(self.x_store_data)
+                bare_port.adr.eq(self.x_addr_i[self.adr_lsbs:]),
+                bare_port.sel.eq(self.x_mask_i),
+                bare_port.we.eq(self.x_st_i),
+                bare_port.dat_w.eq(self.x_st_data_i)
             ]
 
         with m.If(self.dbus.cyc & self.dbus.err):
             m.d.sync += [
-                self.m_load_error.eq(~self.dbus.we),
-                self.m_store_error.eq(self.dbus.we),
-                self.m_badaddr.eq(self.dbus.adr)
+                self.m_load_err_o.eq(~self.dbus.we),
+                self.m_store_err_o.eq(self.dbus.we),
+                self.m_badaddr_o.eq(self.dbus.adr)
             ]
-        with m.Elif(~self.m_stall):
+        with m.Elif(~self.m_stall_i):
             m.d.sync += [
-                self.m_load_error.eq(0),
-                self.m_store_error.eq(0)
+                self.m_load_err_o.eq(0),
+                self.m_store_err_o.eq(0)
             ]
 
         with m.If(self.x_fence_i):
-            m.d.comb += self.x_busy.eq(wrbuf.r_rdy)
+            m.d.comb += self.x_busy_o.eq(wrbuf.r_rdy)
         with m.Elif(x_dcache_select):
-            m.d.comb += self.x_busy.eq(self.x_store & ~wrbuf.w_rdy)
+            m.d.comb += self.x_busy_o.eq(self.x_st_i & ~wrbuf.w_rdy)
         with m.Else():
-            m.d.comb += self.x_busy.eq(bare_port.cyc)
+            m.d.comb += self.x_busy_o.eq(bare_port.cyc)
 
-        with m.If(self.m_load_error | self.m_store_error):
-            m.d.comb += [
-                self.m_busy.eq(0),
-                self.m_load_data.eq(0)
-            ]
+        with m.If(self.m_flush):
+            m.d.comb += self.m_busy_o.eq(~dcache.s2_flush_ack)
+        with m.If(self.m_load_err_o | self.m_store_err_o):
+            m.d.comb += self.m_busy_o.eq(0)
         with m.Elif(m_dcache_select):
             m.d.comb += [
-                self.m_busy.eq(dcache.s2_re & dcache.s2_miss),
-                self.m_load_data.eq(dcache.s2_rdata)
+                self.m_busy_o.eq(dcache.s2_miss),
+                self.m_ld_data_o.eq(dcache.s2_rdata)
             ]
         with m.Else():
             m.d.comb += [
-                self.m_busy.eq(bare_port.cyc),
-                self.m_load_data.eq(bare_rdata)
+                self.m_busy_o.eq(bare_port.cyc),
+                self.m_ld_data_o.eq(bare_rdata)
             ]
 
         return m