get rid of rst
[soc.git] / src / soc / experiment / dcache.py
index 0bb648dc9bdfffa6adce193051936d4e90848da3..b23c46e3aad328b00e3a0868f22dfd7c64220549 100644 (file)
@@ -7,17 +7,12 @@ based on Anton Blanchard microwatt dcache.vhdl
 from enum import Enum, unique
 
 from nmigen import Module, Signal, Elaboratable, Cat, Repl, Array, Const
 from enum import Enum, unique
 
 from nmigen import Module, Signal, Elaboratable, Cat, Repl, Array, Const
-try:
-    from nmigen.hdl.ast import Display
-except ImportError:
-    def Display(*args):
-        return []
+from nmutil.util import Display
 
 from random import randint
 
 from nmigen.cli import main
 from nmutil.iocontrol import RecordObject
 
 from random import randint
 
 from nmigen.cli import main
 from nmutil.iocontrol import RecordObject
-from nmutil.util import wrap
 from nmigen.utils import log2_int
 from soc.experiment.mem_types import (LoadStore1ToDCacheType,
                                      DCacheToLoadStore1Type,
 from nmigen.utils import log2_int
 from soc.experiment.mem_types import (LoadStore1ToDCacheType,
                                      DCacheToLoadStore1Type,
@@ -41,6 +36,7 @@ if True:
     from nmigen.back.pysim import Simulator, Delay, Settle
 else:
     from nmigen.sim.cxxsim import Simulator, Delay, Settle
     from nmigen.back.pysim import Simulator, Delay, Settle
 else:
     from nmigen.sim.cxxsim import Simulator, Delay, Settle
+from nmutil.util import wrap
 
 
 # TODO: make these parameters of DCache at some point
 
 
 # TODO: make these parameters of DCache at some point
@@ -72,6 +68,7 @@ BRAM_ROWS = NUM_LINES * ROW_PER_LINE
 print ("ROW_SIZE", ROW_SIZE)
 print ("ROW_PER_LINE", ROW_PER_LINE)
 print ("BRAM_ROWS", BRAM_ROWS)
 print ("ROW_SIZE", ROW_SIZE)
 print ("ROW_PER_LINE", ROW_PER_LINE)
 print ("BRAM_ROWS", BRAM_ROWS)
+print ("NUM_WAYS", NUM_WAYS)
 
 # Bit fields counts in the address
 
 
 # Bit fields counts in the address
 
@@ -376,6 +373,7 @@ class RegStage1(RecordObject):
         self.write_bram       = Signal()
         self.write_tag        = Signal()
         self.slow_valid       = Signal()
         self.write_bram       = Signal()
         self.write_tag        = Signal()
         self.slow_valid       = Signal()
+        self.real_adr         = Signal(REAL_ADDR_BITS)
         self.wb               = WBMasterOut("wb")
         self.reload_tag       = Signal(TAG_BITS)
         self.store_way        = Signal(WAY_BITS)
         self.wb               = WBMasterOut("wb")
         self.reload_tag       = Signal(TAG_BITS)
         self.store_way        = Signal(WAY_BITS)
@@ -798,7 +796,7 @@ class DCache(Elaboratable):
         sync += cache_tag_set.eq(cache_tags[index])
 
     def dcache_request(self, m, r0, ra, req_index, req_row, req_tag,
         sync += cache_tag_set.eq(cache_tags[index])
 
     def dcache_request(self, m, r0, ra, req_index, req_row, req_tag,
-                       r0_valid, r1, cache_valid_bits, replace_way,
+                       r0_valid, r1, cache_valids, replace_way,
                        use_forward1_next, use_forward2_next,
                        req_hit_way, plru_victim, rc_ok, perm_attr,
                        valid_ra, perm_ok, access_ok, req_op, req_go,
                        use_forward1_next, use_forward2_next,
                        req_hit_way, plru_victim, rc_ok, perm_attr,
                        valid_ra, perm_ok, access_ok, req_op, req_go,
@@ -820,18 +818,19 @@ class DCache(Elaboratable):
         nc          = Signal()
         hit_set     = Array(Signal(name="hit_set_%d" % i) \
                                   for i in range(TLB_NUM_WAYS))
         nc          = Signal()
         hit_set     = Array(Signal(name="hit_set_%d" % i) \
                                   for i in range(TLB_NUM_WAYS))
-        cache_valid_idx = Signal(INDEX_BITS)
+        cache_valid_idx = Signal(NUM_WAYS)
 
         # Extract line, row and tag from request
         comb += req_index.eq(get_index(r0.req.addr))
         comb += req_row.eq(get_row(r0.req.addr))
         comb += req_tag.eq(get_tag(ra))
 
 
         # Extract line, row and tag from request
         comb += req_index.eq(get_index(r0.req.addr))
         comb += req_row.eq(get_row(r0.req.addr))
         comb += req_tag.eq(get_tag(ra))
 
-        comb += Display("dcache_req addr:%x ra: %x idx: %x tag: %x row: %x",
-                r0.req.addr, ra, req_index, req_tag, req_row)
+        if False: # display on comb is a bit... busy.
+            comb += Display("dcache_req addr:%x ra: %x idx: %x tag: %x row: %x",
+                    r0.req.addr, ra, req_index, req_tag, req_row)
 
         comb += go.eq(r0_valid & ~(r0.tlbie | r0.tlbld) & ~r1.ls_error)
 
         comb += go.eq(r0_valid & ~(r0.tlbie | r0.tlbld) & ~r1.ls_error)
-        comb += cache_valid_idx.eq(cache_valid_bits[req_index])
+        comb += cache_valid_idx.eq(cache_valids[req_index])
 
         m.submodules.dcache_pend = dc = DCachePendingHit(tlb_pte_way,
                                 tlb_valid_way, tlb_hit_way,
 
         m.submodules.dcache_pend = dc = DCachePendingHit(tlb_pte_way,
                                 tlb_valid_way, tlb_hit_way,
@@ -854,7 +853,9 @@ class DCache(Elaboratable):
             # For a store, consider this a hit even if the row isn't
             # valid since it will be by the time we perform the store.
             # For a load, check the appropriate row valid bit.
             # For a store, consider this a hit even if the row isn't
             # valid since it will be by the time we perform the store.
             # For a load, check the appropriate row valid bit.
-            valid = r1.rows_valid[req_row[:ROW_LINE_BITS]]
+            rrow = Signal(ROW_LINE_BITS)
+            comb += rrow.eq(req_row)
+            valid = r1.rows_valid[rrow]
             comb += is_hit.eq(~r0.req.load | valid)
             comb += hit_way.eq(replace_way)
 
             comb += is_hit.eq(~r0.req.load | valid)
             comb += hit_way.eq(replace_way)
 
@@ -1179,7 +1180,7 @@ class DCache(Elaboratable):
     # All wishbone requests generation is done here.
     # This machine operates at stage 1.
     def dcache_slow(self, m, r1, use_forward1_next, use_forward2_next,
     # All wishbone requests generation is done here.
     # This machine operates at stage 1.
     def dcache_slow(self, m, r1, use_forward1_next, use_forward2_next,
-                    cache_valid_bits, r0, replace_way,
+                    cache_valids, r0, replace_way,
                     req_hit_way, req_same_tag,
                     r0_valid, req_op, cache_tags, req_go, ra):
 
                     req_hit_way, req_same_tag,
                     r0_valid, req_op, cache_tags, req_go, ra):
 
@@ -1290,7 +1291,7 @@ class DCache(Elaboratable):
         with m.Switch(r1.state):
 
             with m.Case(State.IDLE):
         with m.Switch(r1.state):
 
             with m.Case(State.IDLE):
-                sync += r1.wb.adr.eq(req.real_addr)
+                sync += r1.real_adr.eq(req.real_addr)
                 sync += r1.wb.sel.eq(req.byte_sel)
                 sync += r1.wb.dat.eq(req.data)
                 sync += r1.dcbz.eq(req.dcbz)
                 sync += r1.wb.sel.eq(req.byte_sel)
                 sync += r1.wb.dat.eq(req.data)
                 sync += r1.dcbz.eq(req.dcbz)
@@ -1386,19 +1387,21 @@ class DCache(Elaboratable):
                     # Clear stb and set ld_stbs_done
                     # so we can handle an eventual
                     # last ack on the same cycle.
                     # Clear stb and set ld_stbs_done
                     # so we can handle an eventual
                     # last ack on the same cycle.
-                    with m.If(is_last_row_addr(r1.wb.adr, r1.end_row_ix)):
+                    with m.If(is_last_row_addr(r1.real_adr, r1.end_row_ix)):
                         sync += r1.wb.stb.eq(0)
                         comb += ld_stbs_done.eq(1)
 
                     # Calculate the next row address in the current cache line
                         sync += r1.wb.stb.eq(0)
                         comb += ld_stbs_done.eq(1)
 
                     # Calculate the next row address in the current cache line
-                    rarange = Signal(LINE_OFF_BITS-ROW_OFF_BITS)
-                    comb += rarange.eq(r1.wb.adr[ROW_OFF_BITS:LINE_OFF_BITS]+1)
-                    sync += r1.wb.adr[ROW_OFF_BITS:LINE_OFF_BITS].eq(rarange)
+                    row = Signal(LINE_OFF_BITS-ROW_OFF_BITS)
+                    comb += row.eq(r1.real_adr[ROW_OFF_BITS:])
+                    sync += r1.real_adr[ROW_OFF_BITS:LINE_OFF_BITS].eq(row+1)
 
                 # Incoming acks processing
                 sync += r1.forward_valid1.eq(wb_in.ack)
                 with m.If(wb_in.ack):
 
                 # Incoming acks processing
                 sync += r1.forward_valid1.eq(wb_in.ack)
                 with m.If(wb_in.ack):
-                    sync += r1.rows_valid[r1.store_row[:ROW_LINE_BITS]].eq(1)
+                    srow = Signal(ROW_LINE_BITS)
+                    comb += srow.eq(r1.store_row)
+                    sync += r1.rows_valid[srow].eq(1)
 
                     # If this is the data we were looking for,
                     # we can complete the request next cycle.
 
                     # If this is the data we were looking for,
                     # we can complete the request next cycle.
@@ -1426,9 +1429,9 @@ class DCache(Elaboratable):
 
                         # Cache line is now valid
                         cv = Signal(INDEX_BITS)
 
                         # Cache line is now valid
                         cv = Signal(INDEX_BITS)
-                        comb += cv.eq(cache_valid_bits[r1.store_index])
+                        comb += cv.eq(cache_valids[r1.store_index])
                         comb += cv.bit_select(r1.store_way, 1).eq(1)
                         comb += cv.bit_select(r1.store_way, 1).eq(1)
-                        sync += cache_valid_bits[r1.store_index].eq(cv)
+                        sync += cache_valids[r1.store_index].eq(cv)
                         sync += r1.state.eq(State.IDLE)
 
                     # Increment store row counter
                         sync += r1.state.eq(State.IDLE)
 
                     # Increment store row counter
@@ -1455,7 +1458,7 @@ class DCache(Elaboratable):
                     # to be done which is in the same real page.
                     with m.If(req.valid):
                         ra = req.real_addr[0:SET_SIZE_BITS]
                     # to be done which is in the same real page.
                     with m.If(req.valid):
                         ra = req.real_addr[0:SET_SIZE_BITS]
-                        sync += r1.wb.adr[0:SET_SIZE_BITS].eq(ra)
+                        sync += r1.real_adr[0:SET_SIZE_BITS].eq(ra)
                         sync += r1.wb.dat.eq(req.data)
                         sync += r1.wb.sel.eq(req.byte_sel)
 
                         sync += r1.wb.dat.eq(req.data)
                         sync += r1.wb.sel.eq(req.byte_sel)
 
@@ -1515,7 +1518,7 @@ class DCache(Elaboratable):
         sync += log_out.eq(Cat(r1.state[:3], valid_ra, tlb_hit_way[:3],
                                stall_out, req_op[:3], d_out.valid, d_out.error,
                                r1.wb.cyc, r1.wb.stb, wb_in.ack, wb_in.stall,
         sync += log_out.eq(Cat(r1.state[:3], valid_ra, tlb_hit_way[:3],
                                stall_out, req_op[:3], d_out.valid, d_out.error,
                                r1.wb.cyc, r1.wb.stb, wb_in.ack, wb_in.stall,
-                               r1.wb.adr[3:6]))
+                               r1.real_adr[3:6]))
 
     def elaborate(self, platform):
 
 
     def elaborate(self, platform):
 
@@ -1525,7 +1528,7 @@ class DCache(Elaboratable):
         # Storage. Hopefully "cache_rows" is a BRAM, the rest is LUTs
         cache_tags       = CacheTagArray()
         cache_tag_set    = Signal(TAG_RAM_WIDTH)
         # Storage. Hopefully "cache_rows" is a BRAM, the rest is LUTs
         cache_tags       = CacheTagArray()
         cache_tag_set    = Signal(TAG_RAM_WIDTH)
-        cache_valid_bits = CacheValidBitsArray()
+        cache_valids = CacheValidBitsArray()
 
         # TODO attribute ram_style : string;
         # TODO attribute ram_style of cache_tags : signal is "distributed";
 
         # TODO attribute ram_style : string;
         # TODO attribute ram_style of cache_tags : signal is "distributed";
@@ -1605,6 +1608,7 @@ class DCache(Elaboratable):
         comb += self.stall_out.eq(r0_stall)
 
         # Wire up wishbone request latch out of stage 1
         comb += self.stall_out.eq(r0_stall)
 
         # Wire up wishbone request latch out of stage 1
+        comb += r1.wb.adr.eq(r1.real_adr[ROW_OFF_BITS:]) # truncate LSBs
         comb += self.wb_out.eq(r1.wb)
 
         # call sub-functions putting everything together, using shared
         comb += self.wb_out.eq(r1.wb)
 
         # call sub-functions putting everything together, using shared
@@ -1623,7 +1627,7 @@ class DCache(Elaboratable):
         self.maybe_tlb_plrus(m, r1, tlb_plru_victim)
         self.cache_tag_read(m, r0_stall, req_index, cache_tag_set, cache_tags)
         self.dcache_request(m, r0, ra, req_index, req_row, req_tag,
         self.maybe_tlb_plrus(m, r1, tlb_plru_victim)
         self.cache_tag_read(m, r0_stall, req_index, cache_tag_set, cache_tags)
         self.dcache_request(m, r0, ra, req_index, req_row, req_tag,
-                           r0_valid, r1, cache_valid_bits, replace_way,
+                           r0_valid, r1, cache_valids, replace_way,
                            use_forward1_next, use_forward2_next,
                            req_hit_way, plru_victim, rc_ok, perm_attr,
                            valid_ra, perm_ok, access_ok, req_op, req_go,
                            use_forward1_next, use_forward2_next,
                            req_hit_way, plru_victim, rc_ok, perm_attr,
                            valid_ra, perm_ok, access_ok, req_op, req_go,
@@ -1640,7 +1644,7 @@ class DCache(Elaboratable):
                         req_hit_way, req_index, req_tag, access_ok,
                         tlb_hit, tlb_hit_way, tlb_req_index)
         self.dcache_slow(m, r1, use_forward1_next, use_forward2_next,
                         req_hit_way, req_index, req_tag, access_ok,
                         tlb_hit, tlb_hit_way, tlb_req_index)
         self.dcache_slow(m, r1, use_forward1_next, use_forward2_next,
-                    cache_valid_bits, r0, replace_way,
+                    cache_valids, r0, replace_way,
                     req_hit_way, req_same_tag,
                          r0_valid, req_op, cache_tags, req_go, ra)
         #self.dcache_log(m, r1, valid_ra, tlb_hit_way, stall_out)
                     req_hit_way, req_same_tag,
                          r0_valid, req_op, cache_tags, req_go, ra)
         #self.dcache_log(m, r1, valid_ra, tlb_hit_way, stall_out)
@@ -1720,7 +1724,7 @@ def dcache_random_sim(dut):
         assert data == sim_data, \
             "check %x data %x != %x" % (addr, data, sim_data)
 
         assert data == sim_data, \
             "check %x data %x != %x" % (addr, data, sim_data)
 
-    for addr in range(8):
+    for addr in range(256):
         data = yield from dcache_load(dut, addr*8)
         assert data == sim_mem[addr], \
             "final check %x data %x != %x" % (addr*8, data, sim_mem[addr])
         data = yield from dcache_load(dut, addr*8)
         assert data == sim_mem[addr], \
             "final check %x data %x != %x" % (addr*8, data, sim_mem[addr])
@@ -1810,7 +1814,7 @@ def test_dcache(mem, test_fn, test_name):
     m.d.comb += sram.bus.stb.eq(dut.wb_out.stb)
     m.d.comb += sram.bus.we.eq(dut.wb_out.we)
     m.d.comb += sram.bus.sel.eq(dut.wb_out.sel)
     m.d.comb += sram.bus.stb.eq(dut.wb_out.stb)
     m.d.comb += sram.bus.we.eq(dut.wb_out.we)
     m.d.comb += sram.bus.sel.eq(dut.wb_out.sel)
-    m.d.comb += sram.bus.adr.eq(dut.wb_out.adr[3:])
+    m.d.comb += sram.bus.adr.eq(dut.wb_out.adr)
     m.d.comb += sram.bus.dat_w.eq(dut.wb_out.dat)
 
     m.d.comb += dut.wb_in.ack.eq(sram.bus.ack)
     m.d.comb += sram.bus.dat_w.eq(dut.wb_out.dat)
 
     m.d.comb += dut.wb_in.ack.eq(sram.bus.ack)
@@ -1835,5 +1839,5 @@ if __name__ == '__main__':
         mem.append((i*2)| ((i*2+1)<<32))
 
     test_dcache(mem, dcache_sim, "")
         mem.append((i*2)| ((i*2+1)<<32))
 
     test_dcache(mem, dcache_sim, "")
-    #test_dcache(None, dcache_random_sim, "random")
+    test_dcache(None, dcache_random_sim, "random")