add description of how PartialAddrBitmap works
[soc.git] / src / soc / scoreboard / addr_match.py
1 """ Load / Store partial address matcher
2
3 Loads and Stores do not need a full match (CAM), they need "good enough"
4 avoidance. Around 11 bits on a 64-bit address is "good enough".
5
6 The simplest way to use this module is to ignore not only the top bits,
7 but also the bottom bits as well: in this case (this RV64 processor),
8 enough to cover a DWORD (64-bit). that means ignore the bottom 4 bits,
9 due to the possibility of 64-bit LD/ST being misaligned.
10
11 To reiterate: the use of this module is an *optimisation*. All it has
12 to do is cover the cases that are *definitely* matches (by checking 11
13 bits or so), and if a few opportunities for parallel LD/STs are missed
14 because the top (or bottom) bits weren't checked, so what: all that
15 happens is: the mis-matched addresses are LD/STd on single-cycles. Big Deal.
16
17 However, if we wanted to enhance this algorithm (without using a CAM and
18 without using expensive comparators) probably the best way to do so would
19 be to turn the last 16 bits into a byte-level bitmap. LD/ST on a byte
20 would have 1 of the 16 bits set. LD/ST on a DWORD would have 8 of the 16
21 bits set (offset if the LD/ST was misaligned). TODO.
22
23 Notes:
24
25 > I have used bits <11:6> as they are not translated (4KB pages)
26 > and larger than a cache line (64 bytes).
27 > I have used bits <11:4> when the L1 cache was QuadW sized and
28 > the L2 cache was Line sized.
29 """
30
31 from nmigen.compat.sim import run_simulation
32 from nmigen.cli import verilog, rtlil
33 from nmigen import Module, Signal, Const, Array, Cat, Elaboratable
34 from nmigen.lib.coding import Decoder
35
36 from nmutil.latch import latchregister, SRLatch
37
38
39 class PartialAddrMatch(Elaboratable):
40 """A partial address matcher
41 """
42 def __init__(self, n_adr, bitwid):
43 self.n_adr = n_adr
44 self.bitwid = bitwid
45 # inputs
46 self.addrs_i = Array(Signal(bitwid, name="addr") for i in range(n_adr))
47 self.addr_we_i = Signal(n_adr, reset_less=True) # write-enable
48 self.addr_en_i = Signal(n_adr, reset_less=True) # address latched in
49 self.addr_rs_i = Signal(n_adr, reset_less=True) # address deactivated
50
51 # output: a nomatch for each address plus individual nomatch signals
52 self.addr_nomatch_o = Signal(n_adr, name="nomatch_o", reset_less=True)
53 self.addr_nomatch_a_o = Array(Signal(n_adr, reset_less=True,
54 name="nomatch_array_o") \
55 for i in range(n_adr))
56
57 def elaborate(self, platform):
58 m = Module()
59 return self._elaborate(m, platform)
60
61 def _elaborate(self, m, platform):
62 comb = m.d.comb
63 sync = m.d.sync
64
65 # array of address-latches
66 m.submodules.l = self.l = l = SRLatch(llen=self.n_adr, sync=False)
67 self.addrs_r = addrs_r = Array(Signal(self.bitwid, reset_less=True,
68 name="a_r") \
69 for i in range(self.n_adr))
70
71 # latch set/reset
72 comb += l.s.eq(self.addr_en_i)
73 comb += l.r.eq(self.addr_rs_i)
74
75 # copy in addresses (and "enable" signals)
76 for i in range(self.n_adr):
77 latchregister(m, self.addrs_i[i], addrs_r[i], l.q[i])
78
79 # is there a clash, yes/no
80 matchgrp = []
81 for i in range(self.n_adr):
82 match = []
83 for j in range(self.n_adr):
84 match.append(self.is_match(i, j))
85 comb += self.addr_nomatch_a_o[i].eq(~Cat(*match) & l.q)
86 matchgrp.append(self.addr_nomatch_a_o[i] == l.q)
87 comb += self.addr_nomatch_o.eq(Cat(*matchgrp) & l.q)
88
89 return m
90
91 def is_match(self, i, j):
92 if i == j:
93 return Const(0) # don't match against self!
94 return self.addrs_r[i] == self.addrs_r[j]
95
96 def __iter__(self):
97 yield from self.addrs_i
98 yield self.addr_we_i
99 yield self.addr_en_i
100 yield from self.addr_nomatch_a_o
101 yield self.addr_nomatch_o
102
103 def ports(self):
104 return list(self)
105
106
107 class LenExpand(Elaboratable):
108 """LenExpand: expands binary length (and LSBs of an address) into unary
109
110 this basically produces a bitmap of which *bytes* are to be read (written)
111 in memory. examples:
112
113 (bit_len=4) len=4, addr=0b0011 => 0b1111 << addr
114 => 0b1111000
115 (bit_len=4) len=8, addr=0b0101 => 0b11111111 << addr
116 => 0b1111111100000
117 """
118
119 def __init__(self, bit_len):
120 self.bit_len = bit_len
121 self.len_i = Signal(bit_len, reset_less=True)
122 self.addr_i = Signal(bit_len, reset_less=True)
123 self.explen_o = Signal(1<<(bit_len+1), reset_less=True)
124
125 def elaborate(self, platform):
126 m = Module()
127 comb = m.d.comb
128
129 # temp
130 binlen = Signal((1<<self.bit_len)+1, reset_less=True)
131 comb += binlen.eq((Const(1, self.bit_len+1) << (1+self.len_i)) - 1)
132 comb += self.explen_o.eq(binlen << self.addr_i)
133
134 return m
135
136 def ports(self):
137 return [self.len_i, self.addr_i, self.explen_o,]
138
139
140 class PartialAddrBitmap(PartialAddrMatch):
141 """PartialAddrBitMap
142
143 makes two comparisons for each address, with each (addr,len)
144 being extended to an unary byte-map.
145
146 two comparisons are needed because when an address is misaligned,
147 the byte-map is split into two halves. example:
148
149 address = 0b1011011, len=8 => 0b101 and shift of 11 (0b1011)
150 len in unary is 0b0000 0000 1111 1111
151 when shifted becomes TWO addresses:
152
153 * 0b101 and a byte-map of 0b1111 1000 0000 0000 (len-mask shifted by 11)
154 * 0b101+1 and a byte-map of 0b0000 0000 0000 0111 (overlaps onto next 16)
155
156 therefore, because this now covers two addresses, we need *two*
157 comparisons per address *not* one.
158 """
159 def __init__(self, n_adr, bitwid, bitlen):
160 self.bitwid = bitwid # number of bits to turn into unary
161 self.midlen = bitlen-bitwid
162 PartialAddrMatch.__init__(self, n_adr, self.midlen)
163
164 # input: length of the LOAD/STORE
165 self.len_i = Array(Signal(bitwid, reset_less=True,
166 name="len") for i in range(n_adr))
167 # input: full address
168 self.faddrs_i = Array(Signal(bitlen, reset_less=True,
169 name="fadr") for i in range(n_adr))
170
171 # intermediary: address + 1
172 self.addr1s = Array(Signal(self.bitwid, reset_less=True,
173 name="adr1") \
174 for i in range(n_adr))
175
176 def elaborate(self, platform):
177 m = PartialAddrMatch.elaborate(self, platform)
178 comb = m.d.comb
179
180 # intermediaries
181 addrs_r, l = self.addrs_r, self.l
182 expwid = 1+self.bitwid # XXX assume LD/ST no greater than 8
183 explen_i = Array(Signal(expwid, reset_less=True,
184 name="a_l") \
185 for i in range(self.n_adr))
186 lenexp_r = Array(Signal(expwid, reset_less=True,
187 name="a_l") \
188 for i in range(self.n_adr))
189
190 # copy the top bitlen..(bitwid-bit_len) of addresses to compare
191 for i in range(self.n_adr):
192 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.bitwid:])
193
194 # copy in lengths and latch them
195 for i in range(self.n_adr):
196 latchregister(m, explen_i[i], lenexp_r[i], l.q[i])
197
198 # add one to intermediate addresses
199 for i in range(self.n_adr):
200 comb += self.addr1s[i].eq(self.addrs_r[i]+1)
201
202 # put the bottom bits into the LenExpanders. One is for
203 # non-aligned stores.
204
205 return m
206
207 def is_match(self, i, j):
208 if i == j:
209 return Const(0) # don't match against self!
210 return self.addrs_r[i] == self.addrs_r[j]
211
212 def __iter__(self):
213 yield from self.faddrs_i
214 yield from self.len_i
215 yield self.addr_we_i
216 yield self.addr_en_i
217 yield from self.addr_nomatch_a_o
218 yield self.addr_nomatch_o
219
220 def ports(self):
221 return list(self)
222
223 def part_addr_sim(dut):
224 yield dut.dest_i.eq(1)
225 yield dut.issue_i.eq(1)
226 yield
227 yield dut.issue_i.eq(0)
228 yield
229 yield dut.src1_i.eq(1)
230 yield dut.issue_i.eq(1)
231 yield
232 yield dut.issue_i.eq(0)
233 yield
234 yield dut.go_rd_i.eq(1)
235 yield
236 yield dut.go_rd_i.eq(0)
237 yield
238 yield dut.go_wr_i.eq(1)
239 yield
240 yield dut.go_wr_i.eq(0)
241 yield
242
243 def test_part_addr():
244 dut = LenExpand(4)
245 vl = rtlil.convert(dut, ports=dut.ports())
246 with open("test_len_expand.il", "w") as f:
247 f.write(vl)
248
249 dut = PartialAddrBitmap(3, 4, 10)
250 vl = rtlil.convert(dut, ports=dut.ports())
251 with open("test_part_bit.il", "w") as f:
252 f.write(vl)
253
254 dut = PartialAddrMatch(3, 10)
255 vl = rtlil.convert(dut, ports=dut.ports())
256 with open("test_part_addr.il", "w") as f:
257 f.write(vl)
258
259 run_simulation(dut, part_addr_sim(dut), vcd_name='test_part_addr.vcd')
260
261 if __name__ == '__main__':
262 test_part_addr()