clarify is_match manually
[soc.git] / src / soc / scoreboard / addr_match.py
1 """ Load / Store partial address matcher
2
3 Related bugreports:
4 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
5
6 Loads and Stores do not need a full match (CAM), they need "good enough"
7 avoidance. Around 11 bits on a 64-bit address is "good enough".
8
9 The simplest way to use this module is to ignore not only the top bits,
10 but also the bottom bits as well: in this case (this RV64 processor),
11 enough to cover a DWORD (64-bit). that means ignore the bottom 4 bits,
12 due to the possibility of 64-bit LD/ST being misaligned.
13
14 To reiterate: the use of this module is an *optimisation*. All it has
15 to do is cover the cases that are *definitely* matches (by checking 11
16 bits or so), and if a few opportunities for parallel LD/STs are missed
17 because the top (or bottom) bits weren't checked, so what: all that
18 happens is: the mis-matched addresses are LD/STd on single-cycles. Big Deal.
19
20 However, if we wanted to enhance this algorithm (without using a CAM and
21 without using expensive comparators) probably the best way to do so would
22 be to turn the last 16 bits into a byte-level bitmap. LD/ST on a byte
23 would have 1 of the 16 bits set. LD/ST on a DWORD would have 8 of the 16
24 bits set (offset if the LD/ST was misaligned). TODO.
25
26 Notes:
27
28 > I have used bits <11:6> as they are not translated (4KB pages)
29 > and larger than a cache line (64 bytes).
30 > I have used bits <11:4> when the L1 cache was QuadW sized and
31 > the L2 cache was Line sized.
32 """
33
34 from nmigen.compat.sim import run_simulation
35 from nmigen.cli import verilog, rtlil
36 from nmigen import Module, Signal, Const, Array, Cat, Elaboratable
37 from nmigen.lib.coding import Decoder
38
39 from nmutil.latch import latchregister, SRLatch
40
41
42 class PartialAddrMatch(Elaboratable):
43 """A partial address matcher
44 """
45 def __init__(self, n_adr, bitwid):
46 self.n_adr = n_adr
47 self.bitwid = bitwid
48 # inputs
49 self.addrs_i = Array(Signal(bitwid, name="addr") for i in range(n_adr))
50 #self.addr_we_i = Signal(n_adr, reset_less=True) # write-enable
51 self.addr_en_i = Signal(n_adr, reset_less=True) # address latched in
52 self.addr_rs_i = Signal(n_adr, reset_less=True) # address deactivated
53
54 # output: a nomatch for each address plus individual nomatch signals
55 self.addr_nomatch_o = Signal(n_adr, name="nomatch_o", reset_less=True)
56 self.addr_nomatch_a_o = Array(Signal(n_adr, reset_less=True,
57 name="nomatch_array_o") \
58 for i in range(n_adr))
59
60 def elaborate(self, platform):
61 m = Module()
62 return self._elaborate(m, platform)
63
64 def _elaborate(self, m, platform):
65 comb = m.d.comb
66 sync = m.d.sync
67
68 # array of address-latches
69 m.submodules.l = self.l = l = SRLatch(llen=self.n_adr, sync=False)
70 self.adrs_r = adrs_r = Array(Signal(self.bitwid, reset_less=True,
71 name="a_r") \
72 for i in range(self.n_adr))
73
74 # latch set/reset
75 comb += l.s.eq(self.addr_en_i)
76 comb += l.r.eq(self.addr_rs_i)
77
78 # copy in addresses (and "enable" signals)
79 for i in range(self.n_adr):
80 latchregister(m, self.addrs_i[i], adrs_r[i], l.q[i])
81
82 # is there a clash, yes/no
83 matchgrp = []
84 for i in range(self.n_adr):
85 match = []
86 for j in range(self.n_adr):
87 match.append(self.is_match(i, j))
88 comb += self.addr_nomatch_a_o[i].eq(~Cat(*match))
89 matchgrp.append((self.addr_nomatch_a_o[i] & l.q) == l.q)
90 comb += self.addr_nomatch_o.eq(Cat(*matchgrp) & l.q)
91
92 return m
93
94 def is_match(self, i, j):
95 if i == j:
96 return Const(0) # don't match against self!
97 return self.adrs_r[i] == self.adrs_r[j]
98
99 def __iter__(self):
100 yield from self.addrs_i
101 #yield self.addr_we_i
102 yield self.addr_en_i
103 yield from self.addr_nomatch_a_o
104 yield self.addr_nomatch_o
105
106 def ports(self):
107 return list(self)
108
109
110 class LenExpand(Elaboratable):
111 """LenExpand: expands binary length (and LSBs of an address) into unary
112
113 this basically produces a bitmap of which *bytes* are to be read (written)
114 in memory. examples:
115
116 (bit_len=4) len=4, addr=0b0011 => 0b1111 << addr
117 => 0b1111000
118 (bit_len=4) len=8, addr=0b0101 => 0b11111111 << addr
119 => 0b1111111100000
120 """
121
122 def __init__(self, bit_len):
123 self.bit_len = bit_len
124 self.len_i = Signal(bit_len, reset_less=True)
125 self.addr_i = Signal(bit_len, reset_less=True)
126 self.lexp_o = Signal(1<<(bit_len+1), reset_less=True)
127
128 def elaborate(self, platform):
129 m = Module()
130 comb = m.d.comb
131
132 # temp
133 binlen = Signal((1<<self.bit_len)+1, reset_less=True)
134 comb += binlen.eq((Const(1, self.bit_len+1) << (self.len_i)) - 1)
135 comb += self.lexp_o.eq(binlen << self.addr_i)
136
137 return m
138
139 def ports(self):
140 return [self.len_i, self.addr_i, self.lexp_o,]
141
142
143 class PartialAddrBitmap(PartialAddrMatch):
144 """PartialAddrBitMap
145
146 makes two comparisons for each address, with each (addr,len)
147 being extended to an unary byte-map.
148
149 two comparisons are needed because when an address is misaligned,
150 the byte-map is split into two halves. example:
151
152 address = 0b1011011, len=8 => 0b101 and shift of 11 (0b1011)
153 len in unary is 0b0000 0000 1111 1111
154 when shifted becomes TWO addresses:
155
156 * 0b101 and a byte-map of 0b1111 1000 0000 0000 (len-mask shifted by 11)
157 * 0b101+1 and a byte-map of 0b0000 0000 0000 0111 (overlaps onto next 16)
158
159 therefore, because this now covers two addresses, we need *two*
160 comparisons per address *not* one.
161 """
162 def __init__(self, n_adr, lsbwid, bitlen):
163 self.lsbwid = lsbwid # number of bits to turn into unary
164 self.midlen = bitlen-lsbwid
165 PartialAddrMatch.__init__(self, n_adr, self.midlen)
166
167 # input: length of the LOAD/STORE
168 self.len_i = Array(Signal(lsbwid, reset_less=True,
169 name="len") for i in range(n_adr))
170 # input: full address
171 self.faddrs_i = Array(Signal(bitlen, reset_less=True,
172 name="fadr") for i in range(n_adr))
173
174 # intermediary: address + 1
175 self.addr1s = Array(Signal(self.midlen, reset_less=True,
176 name="adr1") \
177 for i in range(n_adr))
178
179 # expanded lengths, needed in match
180 expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
181 self.lexp = Array(Signal(1<<expwid, reset_less=True,
182 name="a_l") \
183 for i in range(self.n_adr))
184
185 def elaborate(self, platform):
186 m = PartialAddrMatch.elaborate(self, platform)
187 comb = m.d.comb
188
189 # intermediaries
190 adrs_r, l = self.adrs_r, self.l
191 len_r = Array(Signal(self.lsbwid, reset_less=True,
192 name="l_r") \
193 for i in range(self.n_adr))
194
195 for i in range(self.n_adr):
196 # create a bit-expander for each address
197 be = LenExpand(self.lsbwid)
198 setattr(m.submodules, "le%d" % i, be)
199 # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
200 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
201
202 # copy in lengths and latch them
203 latchregister(m, self.len_i[i], len_r[i], l.q[i])
204
205 # add one to intermediate addresses
206 comb += self.addr1s[i].eq(self.adrs_r[i]+1)
207
208 # put the bottom bits of each address into each LenExpander.
209 comb += be.len_i.eq(len_r[i])
210 comb += be.addr_i.eq(self.faddrs_i[i][:self.lsbwid])
211 # connect expander output
212 comb += self.lexp[i].eq(be.lexp_o)
213
214 return m
215
216 # TODO make this a module. too much.
217 def is_match(self, i, j):
218 if i == j:
219 return Const(0) # don't match against self!
220 # the bitmask contains data for *two* cache lines (16 bytes).
221 # however len==8 only covers *half* a cache line so we only
222 # need to compare half the bits
223 expwid = 1<<self.lsbwid
224 hexp = expwid >> 1
225 expwid2 = expwid + hexp
226 print (self.lsbwid, expwid)
227 # straight compare: binary top bits of addr, *unary* compare on bottom
228 straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
229 (self.lexp[i][:expwid] & self.lexp[j][:expwid]).bool()
230 # compare i (addr+1) to j (addr), but top unary against bottom unary
231 i1_eq_j = (self.addr1s[i] == self.adrs_r[j]) & \
232 (self.lexp[i][expwid:expwid2] & self.lexp[j][:hexp]).bool()
233 # compare i (addr) to j (addr+1), but bottom unary against top unary
234 i_eq_j1 = (self.adrs_r[i] == self.addr1s[j]) & \
235 (self.lexp[i][:hexp] & self.lexp[j][expwid:expwid2]).bool()
236 return straight_eq | i1_eq_j | i_eq_j1
237
238 def __iter__(self):
239 yield from self.faddrs_i
240 yield from self.len_i
241 #yield self.addr_we_i
242 yield self.addr_en_i
243 yield from self.addr_nomatch_a_o
244 yield self.addr_nomatch_o
245
246 def ports(self):
247 return list(self)
248
249 def part_addr_sim(dut):
250 yield dut.dest_i.eq(1)
251 yield dut.issue_i.eq(1)
252 yield
253 yield dut.issue_i.eq(0)
254 yield
255 yield dut.src1_i.eq(1)
256 yield dut.issue_i.eq(1)
257 yield
258 yield dut.issue_i.eq(0)
259 yield
260 yield dut.go_rd_i.eq(1)
261 yield
262 yield dut.go_rd_i.eq(0)
263 yield
264 yield dut.go_wr_i.eq(1)
265 yield
266 yield dut.go_wr_i.eq(0)
267 yield
268
269 def part_addr_bit(dut):
270 # 0b110 | 0b101 |
271 # 0b101 1011 / 8 ==> 0b0000 0000 0000 0111 | 1111 1000 0000 0000 |
272 yield dut.len_i[0].eq(8)
273 yield dut.faddrs_i[0].eq(0b1011011)
274 yield dut.addr_en_i[0].eq(1)
275 yield
276 yield dut.addr_en_i[0].eq(0)
277 yield
278 # 0b110 | 0b101 |
279 # 0b110 0010 / 2 ==> 0b0000 0000 0000 1100 | 0000 0000 0000 0000 |
280 yield dut.len_i[1].eq(2)
281 yield dut.faddrs_i[1].eq(0b1100010)
282 yield dut.addr_en_i[1].eq(1)
283 yield
284 yield dut.addr_en_i[1].eq(0)
285 yield
286 # 0b110 | 0b101 |
287 # 0b101 1010 / 2 ==> 0b0000 0000 0000 0000 | 0000 1100 0000 0000 |
288 yield dut.len_i[2].eq(2)
289 yield dut.faddrs_i[2].eq(0b1011010)
290 yield dut.addr_en_i[2].eq(1)
291 yield
292 yield dut.addr_en_i[2].eq(0)
293 yield
294 # 0b110 | 0b101 |
295 # 0b101 1001 / 2 ==> 0b0000 0000 0000 0000 | 0000 0110 0000 0000 |
296 yield dut.len_i[2].eq(2)
297 yield dut.faddrs_i[2].eq(0b1011001)
298 yield dut.addr_en_i[2].eq(1)
299 yield
300 yield dut.addr_en_i[2].eq(0)
301 yield
302 yield dut.addr_rs_i[1].eq(1)
303 yield
304 yield dut.addr_rs_i[1].eq(0)
305 yield
306
307 def test_part_addr():
308 dut = LenExpand(4)
309 vl = rtlil.convert(dut, ports=dut.ports())
310 with open("test_len_expand.il", "w") as f:
311 f.write(vl)
312
313 dut = PartialAddrBitmap(3, 4, 10)
314 vl = rtlil.convert(dut, ports=dut.ports())
315 with open("test_part_bit.il", "w") as f:
316 f.write(vl)
317
318 run_simulation(dut, part_addr_bit(dut), vcd_name='test_part_bit.vcd')
319
320 dut = PartialAddrMatch(3, 10)
321 vl = rtlil.convert(dut, ports=dut.ports())
322 with open("test_part_addr.il", "w") as f:
323 f.write(vl)
324
325 run_simulation(dut, part_addr_sim(dut), vcd_name='test_part_addr.vcd')
326
327 if __name__ == '__main__':
328 test_part_addr()