Add memory loads and stores to simulator
[soc.git] / src / soc / scoreboard / addr_match.py
index 3f48008aa8d7a3a318b999f120336254f635da79..a47f635f4e9c56a7a13329810855576358110339 100644 (file)
@@ -47,7 +47,7 @@ class PartialAddrMatch(Elaboratable):
         self.bitwid = bitwid
         # inputs
         self.addrs_i = Array(Signal(bitwid, name="addr") for i in range(n_adr))
-        self.addr_we_i = Signal(n_adr, reset_less=True) # write-enable
+        #self.addr_we_i = Signal(n_adr, reset_less=True) # write-enable
         self.addr_en_i = Signal(n_adr, reset_less=True) # address latched in
         self.addr_rs_i = Signal(n_adr, reset_less=True) # address deactivated
 
@@ -67,7 +67,7 @@ class PartialAddrMatch(Elaboratable):
 
         # array of address-latches
         m.submodules.l = self.l = l = SRLatch(llen=self.n_adr, sync=False)
-        self.addrs_r = addrs_r = Array(Signal(self.bitwid, reset_less=True,
+        self.adrs_r = adrs_r = Array(Signal(self.bitwid, reset_less=True,
                                               name="a_r") \
                                        for i in range(self.n_adr))
 
@@ -77,7 +77,7 @@ class PartialAddrMatch(Elaboratable):
 
         # copy in addresses (and "enable" signals)
         for i in range(self.n_adr):
-            latchregister(m, self.addrs_i[i], addrs_r[i], l.q[i])
+            latchregister(m, self.addrs_i[i], adrs_r[i], l.q[i])
 
         # is there a clash, yes/no
         matchgrp = []
@@ -85,8 +85,8 @@ class PartialAddrMatch(Elaboratable):
             match = []
             for j in range(self.n_adr):
                 match.append(self.is_match(i, j))
-            comb += self.addr_nomatch_a_o[i].eq(~Cat(*match) & l.q)
-            matchgrp.append(self.addr_nomatch_a_o[i] == l.q)
+            comb += self.addr_nomatch_a_o[i].eq(~Cat(*match))
+            matchgrp.append((self.addr_nomatch_a_o[i] & l.q) == l.q)
         comb += self.addr_nomatch_o.eq(Cat(*matchgrp) & l.q)
 
         return m
@@ -94,11 +94,11 @@ class PartialAddrMatch(Elaboratable):
     def is_match(self, i, j):
         if i == j:
             return Const(0) # don't match against self!
-        return self.addrs_r[i] == self.addrs_r[j]
+        return self.adrs_r[i] == self.adrs_r[j]
 
     def __iter__(self):
         yield from self.addrs_i
-        yield self.addr_we_i
+        #yield self.addr_we_i
         yield self.addr_en_i
         yield from self.addr_nomatch_a_o
         yield self.addr_nomatch_o
@@ -123,7 +123,7 @@ class LenExpand(Elaboratable):
         self.bit_len = bit_len
         self.len_i = Signal(bit_len, reset_less=True)
         self.addr_i = Signal(bit_len, reset_less=True)
-        self.explen_o = Signal(1<<(bit_len+1), reset_less=True)
+        self.lexp_o = Signal(1<<(bit_len+1), reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
@@ -131,13 +131,92 @@ class LenExpand(Elaboratable):
 
         # temp
         binlen = Signal((1<<self.bit_len)+1, reset_less=True)
-        comb += binlen.eq((Const(1, self.bit_len+1) << (1+self.len_i)) - 1)
-        comb += self.explen_o.eq(binlen << self.addr_i)
+        comb += binlen.eq((Const(1, self.bit_len+1) << (self.len_i)) - 1)
+        comb += self.lexp_o.eq(binlen << self.addr_i)
 
         return m
 
     def ports(self):
-        return [self.len_i, self.addr_i, self.explen_o,]
+        return [self.len_i, self.addr_i, self.lexp_o,]
+
+
+class TwinPartialAddrBitmap(PartialAddrMatch):
+    """TwinPartialAddrBitMap
+
+    designed to be connected to via LDSTSplitter, which generates
+    *pairs* of addresses and covers the misalignment across cache
+    line boundaries *in the splitter*.  Also LDSTSplitter takes
+    care of expanding the LSBs of each address into a bitmap, itself.
+
+    the key difference between this and PartialAddrMap is that the
+    knowledge (fact) that pairs of addresses from the same LDSTSplitter
+    are 1 apart is *guaranteed* to be a miss for those two addresses.
+    therefore is_match specially takes that into account.
+    """
+    def __init__(self, n_adr, lsbwid, bitlen):
+        self.lsbwid = lsbwid # number of bits to turn into unary
+        self.midlen = bitlen-lsbwid
+        PartialAddrMatch.__init__(self, n_adr, self.midlen)
+
+        # input: length of the LOAD/STORE
+        expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
+        self.lexp_i = Array(Signal(1<<expwid, reset_less=True,
+                                  name="len") for i in range(n_adr))
+        # input: full address
+        self.faddrs_i = Array(Signal(bitlen, reset_less=True,
+                                      name="fadr") for i in range(n_adr))
+
+        # registers for expanded len
+        self.len_r = Array(Signal(expwid, reset_less=True, name="l_r") \
+                                       for i in range(self.n_adr))
+
+    def elaborate(self, platform):
+        m = PartialAddrMatch.elaborate(self, platform)
+        comb = m.d.comb
+
+        # intermediaries
+        adrs_r, l = self.adrs_r, self.l
+        expwid = 1+self.lsbwid
+
+        for i in range(self.n_adr):
+            # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
+            comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
+
+            # copy in expanded-lengths and latch them
+            latchregister(m, self.lexp_i[i], self.len_r[i], l.q[i])
+
+        return m
+
+    # TODO make this a module.  too much.
+    def is_match(self, i, j):
+        if i == j:
+            return Const(0) # don't match against self!
+        # we know that pairs have addr and addr+1 therefore it is
+        # guaranteed that they will not match.
+        if (i // 2) == (j // 2):
+            return Const(0) # don't match against twin, either.
+
+        # the bitmask contains data for *two* cache lines (16 bytes).
+        # however len==8 only covers *half* a cache line so we only
+        # need to compare half the bits
+        expwid = 1<<self.lsbwid
+        #if i % 2 == 1 or j % 2 == 1: # XXX hmmm...
+        #   expwid >>= 1
+
+        # straight compare: binary top bits of addr, *unary* compare on bottom
+        straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
+                      (self.len_r[i][:expwid] & self.len_r[j][:expwid]).bool()
+        return straight_eq
+
+    def __iter__(self):
+        yield from self.faddrs_i
+        yield from self.lexp_i
+        yield self.addr_en_i
+        yield from self.addr_nomatch_a_o
+        yield self.addr_nomatch_o
+
+    def ports(self):
+        return list(self)
 
 
 class PartialAddrBitmap(PartialAddrMatch):
@@ -178,7 +257,7 @@ class PartialAddrBitmap(PartialAddrMatch):
 
         # expanded lengths, needed in match
         expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
-        self.explen = Array(Signal(1<<expwid, reset_less=True,
+        self.lexp = Array(Signal(1<<expwid, reset_less=True,
                                 name="a_l") \
                                        for i in range(self.n_adr))
 
@@ -187,7 +266,7 @@ class PartialAddrBitmap(PartialAddrMatch):
         comb = m.d.comb
 
         # intermediaries
-        addrs_r, l = self.addrs_r, self.l
+        adrs_r, l = self.adrs_r, self.l
         len_r = Array(Signal(self.lsbwid, reset_less=True,
                                 name="l_r") \
                                        for i in range(self.n_adr))
@@ -203,32 +282,42 @@ class PartialAddrBitmap(PartialAddrMatch):
             latchregister(m, self.len_i[i], len_r[i], l.q[i])
 
             # add one to intermediate addresses
-            comb += self.addr1s[i].eq(self.addrs_r[i]+1)
+            comb += self.addr1s[i].eq(self.adrs_r[i]+1)
 
             # put the bottom bits of each address into each LenExpander.
             comb += be.len_i.eq(len_r[i])
             comb += be.addr_i.eq(self.faddrs_i[i][:self.lsbwid])
             # connect expander output
-            comb += self.explen[i].eq(be.explen_o)
+            comb += self.lexp[i].eq(be.lexp_o)
 
         return m
 
+    # TODO make this a module.  too much.
     def is_match(self, i, j):
         if i == j:
             return Const(0) # don't match against self!
+        # the bitmask contains data for *two* cache lines (16 bytes).
+        # however len==8 only covers *half* a cache line so we only
+        # need to compare half the bits
         expwid = 1<<self.lsbwid
         hexp = expwid >> 1
         expwid2 = expwid + hexp
         print (self.lsbwid, expwid)
-        return ((self.addrs_r[i] == self.addrs_r[j]) & \
-                (self.explen[i][:expwid] & self.explen[j][:expwid]).bool() |
-               (self.addr1s[i] == self.addrs_r[j]) & \
-                (self.explen[i][expwid:expwid2] & self.explen[j][:hexp]).bool())
+        # straight compare: binary top bits of addr, *unary* compare on bottom
+        straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
+                      (self.lexp[i][:expwid] & self.lexp[j][:expwid]).bool()
+        # compare i (addr+1) to j (addr), but top unary against bottom unary
+        i1_eq_j = (self.addr1s[i] == self.adrs_r[j]) & \
+                  (self.lexp[i][expwid:expwid2] & self.lexp[j][:hexp]).bool()
+        # compare i (addr) to j (addr+1), but bottom unary against top unary
+        i_eq_j1 = (self.adrs_r[i] == self.addr1s[j]) & \
+                  (self.lexp[i][:hexp] & self.lexp[j][expwid:expwid2]).bool()
+        return straight_eq | i1_eq_j | i_eq_j1
 
     def __iter__(self):
         yield from self.faddrs_i
         yield from self.len_i
-        yield self.addr_we_i
+        #yield self.addr_we_i
         yield self.addr_en_i
         yield from self.addr_nomatch_a_o
         yield self.addr_nomatch_o
@@ -236,6 +325,7 @@ class PartialAddrBitmap(PartialAddrMatch):
     def ports(self):
         return list(self)
 
+
 def part_addr_sim(dut):
     yield dut.dest_i.eq(1)
     yield dut.issue_i.eq(1)
@@ -256,17 +346,62 @@ def part_addr_sim(dut):
     yield dut.go_wr_i.eq(0)
     yield
 
+def part_addr_bit(dut):
+    #                                    0b110 |               0b101 |
+    # 0b101 1011 / 8 ==> 0b0000 0000 0000 0111 | 1111 1000 0000 0000 |
+    yield dut.len_i[0].eq(8)
+    yield dut.faddrs_i[0].eq(0b1011011)
+    yield dut.addr_en_i[0].eq(1)
+    yield
+    yield dut.addr_en_i[0].eq(0)
+    yield
+    #                                    0b110 |               0b101 |
+    # 0b110 0010 / 2 ==> 0b0000 0000 0000 1100 | 0000 0000 0000 0000 |
+    yield dut.len_i[1].eq(2)
+    yield dut.faddrs_i[1].eq(0b1100010)
+    yield dut.addr_en_i[1].eq(1)
+    yield
+    yield dut.addr_en_i[1].eq(0)
+    yield
+    #                                    0b110 |               0b101 |
+    # 0b101 1010 / 2 ==> 0b0000 0000 0000 0000 | 0000 1100 0000 0000 |
+    yield dut.len_i[2].eq(2)
+    yield dut.faddrs_i[2].eq(0b1011010)
+    yield dut.addr_en_i[2].eq(1)
+    yield
+    yield dut.addr_en_i[2].eq(0)
+    yield
+    #                                    0b110 |               0b101 |
+    # 0b101 1001 / 2 ==> 0b0000 0000 0000 0000 | 0000 0110 0000 0000 |
+    yield dut.len_i[2].eq(2)
+    yield dut.faddrs_i[2].eq(0b1011001)
+    yield dut.addr_en_i[2].eq(1)
+    yield
+    yield dut.addr_en_i[2].eq(0)
+    yield
+    yield dut.addr_rs_i[1].eq(1)
+    yield
+    yield dut.addr_rs_i[1].eq(0)
+    yield
+
 def test_part_addr():
     dut = LenExpand(4)
     vl = rtlil.convert(dut, ports=dut.ports())
     with open("test_len_expand.il", "w") as f:
         f.write(vl)
 
+    dut = TwinPartialAddrBitmap(3, 4, 10)
+    vl = rtlil.convert(dut, ports=dut.ports())
+    with open("test_twin_part_bit.il", "w") as f:
+        f.write(vl)
+
     dut = PartialAddrBitmap(3, 4, 10)
     vl = rtlil.convert(dut, ports=dut.ports())
     with open("test_part_bit.il", "w") as f:
         f.write(vl)
 
+    run_simulation(dut, part_addr_bit(dut), vcd_name='test_part_bit.vcd')
+
     dut = PartialAddrMatch(3, 10)
     vl = rtlil.convert(dut, ports=dut.ports())
     with open("test_part_addr.il", "w") as f: