Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / scoreboard / addr_match.py
1 """ Load / Store partial address matcher
2
3 Related bugreports:
4 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
5
6 Loads and Stores do not need a full match (CAM), they need "good enough"
7 avoidance. Around 11 bits on a 64-bit address is "good enough".
8
9 The simplest way to use this module is to ignore not only the top bits,
10 but also the bottom bits as well: in this case (this RV64 processor),
11 enough to cover a DWORD (64-bit). that means ignore the bottom 4 bits,
12 due to the possibility of 64-bit LD/ST being misaligned.
13
14 To reiterate: the use of this module is an *optimisation*. All it has
15 to do is cover the cases that are *definitely* matches (by checking 11
16 bits or so), and if a few opportunities for parallel LD/STs are missed
17 because the top (or bottom) bits weren't checked, so what: all that
18 happens is: the mis-matched addresses are LD/STd on single-cycles. Big Deal.
19
20 However, if we wanted to enhance this algorithm (without using a CAM and
21 without using expensive comparators) probably the best way to do so would
22 be to turn the last 16 bits into a byte-level bitmap. LD/ST on a byte
23 would have 1 of the 16 bits set. LD/ST on a DWORD would have 8 of the 16
24 bits set (offset if the LD/ST was misaligned). TODO.
25
26 Notes:
27
28 > I have used bits <11:6> as they are not translated (4KB pages)
29 > and larger than a cache line (64 bytes).
30 > I have used bits <11:4> when the L1 cache was QuadW sized and
31 > the L2 cache was Line sized.
32 """
33
34 from nmigen.compat.sim import run_simulation, Settle
35 from nmigen.cli import verilog, rtlil
36 from nmigen import Module, Signal, Const, Cat, Elaboratable, Repl
37 from nmigen.lib.coding import Decoder
38 from nmigen.utils import log2_int
39
40 from nmutil.latch import latchregister, SRLatch
41
42
43 class PartialAddrMatch(Elaboratable):
44 """A partial address matcher
45 """
46
47 def __init__(self, n_adr, bitwid):
48 self.n_adr = n_adr
49 self.bitwid = bitwid
50 # inputs
51 self.addrs_i = tuple(Signal(bitwid, name="addr") for i in range(n_adr))
52 # self.addr_we_i = Signal(n_adr, reset_less=True) # write-enable
53 self.addr_en_i = Signal(n_adr, reset_less=True) # address latched in
54 self.addr_rs_i = Signal(n_adr, reset_less=True) # address deactivated
55
56 # output: a nomatch for each address plus individual nomatch signals
57 self.addr_nomatch_o = Signal(n_adr, name="nomatch_o", reset_less=True)
58 self.addr_nomatch_a_o = tuple(Signal(n_adr, reset_less=True,
59 name="nomatch_array_o")
60 for i in range(n_adr))
61
62 def elaborate(self, platform):
63 m = Module()
64 return self._elaborate(m, platform)
65
66 def _elaborate(self, m, platform):
67 comb = m.d.comb
68 sync = m.d.sync
69
70 # array of address-latches
71 m.submodules.l = self.l = l = SRLatch(llen=self.n_adr, sync=False)
72 self.adrs_r = adrs_r = tuple(Signal(self.bitwid, reset_less=True,
73 name="a_r")
74 for i in range(self.n_adr))
75
76 # latch set/reset
77 comb += l.s.eq(self.addr_en_i)
78 comb += l.r.eq(self.addr_rs_i)
79
80 # copy in addresses (and "enable" signals)
81 for i in range(self.n_adr):
82 latchregister(m, self.addrs_i[i], adrs_r[i], l.q[i])
83
84 # is there a clash, yes/no
85 matchgrp = []
86 for i in range(self.n_adr):
87 match = []
88 for j in range(self.n_adr):
89 match.append(self.is_match(i, j))
90 comb += self.addr_nomatch_a_o[i].eq(~Cat(*match))
91 matchgrp.append((self.addr_nomatch_a_o[i] & l.q) == l.q)
92 comb += self.addr_nomatch_o.eq(Cat(*matchgrp) & l.q)
93
94 return m
95
96 def is_match(self, i, j):
97 if i == j:
98 return Const(0) # don't match against self!
99 return self.adrs_r[i] == self.adrs_r[j]
100
101 def __iter__(self):
102 yield from self.addrs_i
103 # yield self.addr_we_i
104 yield self.addr_en_i
105 yield from self.addr_nomatch_a_o
106 yield self.addr_nomatch_o
107
108 def ports(self):
109 return list(self)
110
111
112 class LenExpand(Elaboratable):
113 """LenExpand: expands binary length (and LSBs of an address) into unary
114
115 this basically produces a bitmap of which *bytes* are to be read (written)
116 in memory. examples:
117
118 (bit_len=4) len=4, addr=0b0011 => 0b1111 << addr
119 => 0b1111000
120 (bit_len=4) len=8, addr=0b0101 => 0b11111111 << addr
121 => 0b1111111100000
122
123 note: by setting cover=8 this can also be used as a shift-mask. the
124 bit-mask is replicated (expanded out), each bit expanded to "cover" bits.
125 """
126
127 def __init__(self, bit_len, cover=1):
128 self.bit_len = bit_len
129 self.cover = cover
130 self.len_i = Signal(bit_len, reset_less=True)
131 self.addr_i = Signal(bit_len, reset_less=True)
132 self.lexp_o = Signal(self.llen(1), reset_less=True)
133 if cover > 1:
134 self.rexp_o = Signal(self.llen(cover), reset_less=True)
135 print("LenExpand", bit_len, cover, self.lexp_o.shape())
136
137 def llen(self, cover):
138 cl = log2_int(self.cover)
139 return (cover << (self.bit_len))+(cl << self.bit_len)
140
141 def elaborate(self, platform):
142 m = Module()
143 comb = m.d.comb
144
145 # covers N bits
146 llen = self.llen(1)
147 # temp
148 binlen = Signal((1 << self.bit_len)+1, reset_less=True)
149 lexp_o = Signal(llen, reset_less=True)
150 comb += binlen.eq((Const(1, self.bit_len+1) << (self.len_i)) - 1)
151 comb += self.lexp_o.eq(binlen << self.addr_i)
152 if self.cover == 1:
153 return m
154 l = []
155 print("llen", llen)
156 for i in range(llen):
157 l.append(Repl(self.lexp_o[i], self.cover))
158 comb += self.rexp_o.eq(Cat(*l))
159 return m
160
161 def ports(self):
162 return [self.len_i, self.addr_i, self.lexp_o, ]
163
164
165 class TwinPartialAddrBitmap(PartialAddrMatch):
166 """TwinPartialAddrBitMap
167
168 designed to be connected to via LDSTSplitter, which generates
169 *pairs* of addresses and covers the misalignment across cache
170 line boundaries *in the splitter*. Also LDSTSplitter takes
171 care of expanding the LSBs of each address into a bitmap, itself.
172
173 the key difference between this and PartialAddrMap is that the
174 knowledge (fact) that pairs of addresses from the same LDSTSplitter
175 are 1 apart is *guaranteed* to be a miss for those two addresses.
176 therefore is_match specially takes that into account.
177 """
178
179 def __init__(self, n_adr, lsbwid, bitlen):
180 self.lsbwid = lsbwid # number of bits to turn into unary
181 self.midlen = bitlen-lsbwid
182 PartialAddrMatch.__init__(self, n_adr, self.midlen)
183
184 # input: length of the LOAD/STORE
185 expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
186 self.lexp_i = tuple(Signal(1 << expwid, reset_less=True,
187 name="len") for i in range(n_adr))
188 # input: full address
189 self.faddrs_i = tuple(Signal(bitlen, reset_less=True,
190 name="fadr") for i in range(n_adr))
191
192 # registers for expanded len
193 self.len_r = tuple(Signal(expwid, reset_less=True, name="l_r")
194 for i in range(self.n_adr))
195
196 def elaborate(self, platform):
197 m = PartialAddrMatch.elaborate(self, platform)
198 comb = m.d.comb
199
200 # intermediaries
201 adrs_r, l = self.adrs_r, self.l
202 expwid = 1+self.lsbwid
203
204 for i in range(self.n_adr):
205 # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
206 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
207
208 # copy in expanded-lengths and latch them
209 latchregister(m, self.lexp_i[i], self.len_r[i], l.q[i])
210
211 return m
212
213 # TODO make this a module. too much.
214 def is_match(self, i, j):
215 if i == j:
216 return Const(0) # don't match against self!
217 # we know that pairs have addr and addr+1 therefore it is
218 # guaranteed that they will not match.
219 if (i // 2) == (j // 2):
220 return Const(0) # don't match against twin, either.
221
222 # the bitmask contains data for *two* cache lines (16 bytes).
223 # however len==8 only covers *half* a cache line so we only
224 # need to compare half the bits
225 expwid = 1 << self.lsbwid
226 # if i % 2 == 1 or j % 2 == 1: # XXX hmmm...
227 # expwid >>= 1
228
229 # straight compare: binary top bits of addr, *unary* compare on bottom
230 straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
231 (self.len_r[i][:expwid] & self.len_r[j][:expwid]).bool()
232 return straight_eq
233
234 def __iter__(self):
235 yield from self.faddrs_i
236 yield from self.lexp_i
237 yield self.addr_en_i
238 yield from self.addr_nomatch_a_o
239 yield self.addr_nomatch_o
240
241 def ports(self):
242 return list(self)
243
244
245 class PartialAddrBitmap(PartialAddrMatch):
246 """PartialAddrBitMap
247
248 makes two comparisons for each address, with each (addr,len)
249 being extended to an unary byte-map.
250
251 two comparisons are needed because when an address is misaligned,
252 the byte-map is split into two halves. example:
253
254 address = 0b1011011, len=8 => 0b101 and shift of 11 (0b1011)
255 len in unary is 0b0000 0000 1111 1111
256 when shifted becomes TWO addresses:
257
258 * 0b101 and a byte-map of 0b1111 1000 0000 0000 (len-mask shifted by 11)
259 * 0b101+1 and a byte-map of 0b0000 0000 0000 0111 (overlaps onto next 16)
260
261 therefore, because this now covers two addresses, we need *two*
262 comparisons per address *not* one.
263 """
264
265 def __init__(self, n_adr, lsbwid, bitlen):
266 self.lsbwid = lsbwid # number of bits to turn into unary
267 self.midlen = bitlen-lsbwid
268 PartialAddrMatch.__init__(self, n_adr, self.midlen)
269
270 # input: length of the LOAD/STORE
271 self.len_i = tuple(Signal(lsbwid, reset_less=True,
272 name="len") for i in range(n_adr))
273 # input: full address
274 self.faddrs_i = tuple(Signal(bitlen, reset_less=True,
275 name="fadr") for i in range(n_adr))
276
277 # intermediary: address + 1
278 self.addr1s = tuple(Signal(self.midlen, reset_less=True,
279 name="adr1")
280 for i in range(n_adr))
281
282 # expanded lengths, needed in match
283 expwid = 1+self.lsbwid # XXX assume LD/ST no greater than 8
284 self.lexp = tuple(Signal(1 << expwid, reset_less=True,
285 name="a_l")
286 for i in range(self.n_adr))
287
288 def elaborate(self, platform):
289 m = PartialAddrMatch.elaborate(self, platform)
290 comb = m.d.comb
291
292 # intermediaries
293 adrs_r, l = self.adrs_r, self.l
294 len_r = tuple(Signal(self.lsbwid, reset_less=True,
295 name="l_r")
296 for i in range(self.n_adr))
297
298 for i in range(self.n_adr):
299 # create a bit-expander for each address
300 be = LenExpand(self.lsbwid)
301 setattr(m.submodules, "le%d" % i, be)
302 # copy the top lsbwid..(lsbwid-bit_len) of addresses to compare
303 comb += self.addrs_i[i].eq(self.faddrs_i[i][self.lsbwid:])
304
305 # copy in lengths and latch them
306 latchregister(m, self.len_i[i], len_r[i], l.q[i])
307
308 # add one to intermediate addresses
309 comb += self.addr1s[i].eq(self.adrs_r[i]+1)
310
311 # put the bottom bits of each address into each LenExpander.
312 comb += be.len_i.eq(len_r[i])
313 comb += be.addr_i.eq(self.faddrs_i[i][:self.lsbwid])
314 # connect expander output
315 comb += self.lexp[i].eq(be.lexp_o)
316
317 return m
318
319 # TODO make this a module. too much.
320 def is_match(self, i, j):
321 if i == j:
322 return Const(0) # don't match against self!
323 # the bitmask contains data for *two* cache lines (16 bytes).
324 # however len==8 only covers *half* a cache line so we only
325 # need to compare half the bits
326 expwid = 1 << self.lsbwid
327 hexp = expwid >> 1
328 expwid2 = expwid + hexp
329 print(self.lsbwid, expwid)
330 # straight compare: binary top bits of addr, *unary* compare on bottom
331 straight_eq = (self.adrs_r[i] == self.adrs_r[j]) & \
332 (self.lexp[i][:expwid] & self.lexp[j][:expwid]).bool()
333 # compare i (addr+1) to j (addr), but top unary against bottom unary
334 i1_eq_j = (self.addr1s[i] == self.adrs_r[j]) & \
335 (self.lexp[i][expwid:expwid2] & self.lexp[j][:hexp]).bool()
336 # compare i (addr) to j (addr+1), but bottom unary against top unary
337 i_eq_j1 = (self.adrs_r[i] == self.addr1s[j]) & \
338 (self.lexp[i][:hexp] & self.lexp[j][expwid:expwid2]).bool()
339 return straight_eq | i1_eq_j | i_eq_j1
340
341 def __iter__(self):
342 yield from self.faddrs_i
343 yield from self.len_i
344 # yield self.addr_we_i
345 yield self.addr_en_i
346 yield from self.addr_nomatch_a_o
347 yield self.addr_nomatch_o
348
349 def ports(self):
350 return list(self)
351
352
353 def part_addr_sim(dut):
354 return
355 yield dut.dest_i.eq(1)
356 yield dut.issue_i.eq(1)
357 yield
358 yield dut.issue_i.eq(0)
359 yield
360 yield dut.src1_i.eq(1)
361 yield dut.issue_i.eq(1)
362 yield
363 yield dut.issue_i.eq(0)
364 yield
365 yield dut.go_rd_i.eq(1)
366 yield
367 yield dut.go_rd_i.eq(0)
368 yield
369 yield dut.go_wr_i.eq(1)
370 yield
371 yield dut.go_wr_i.eq(0)
372 yield
373
374
375 def part_addr_bit(dut):
376 # 0b110 | 0b101 |
377 # 0b101 1011 / 8 ==> 0b0000 0000 0000 0111 | 1111 1000 0000 0000 |
378 yield dut.len_i[0].eq(8)
379 yield dut.faddrs_i[0].eq(0b1011011)
380 yield dut.addr_en_i[0].eq(1)
381 yield
382 yield dut.addr_en_i[0].eq(0)
383 yield
384 # 0b110 | 0b101 |
385 # 0b110 0010 / 2 ==> 0b0000 0000 0000 1100 | 0000 0000 0000 0000 |
386 yield dut.len_i[1].eq(2)
387 yield dut.faddrs_i[1].eq(0b1100010)
388 yield dut.addr_en_i[1].eq(1)
389 yield
390 yield dut.addr_en_i[1].eq(0)
391 yield
392 # 0b110 | 0b101 |
393 # 0b101 1010 / 2 ==> 0b0000 0000 0000 0000 | 0000 1100 0000 0000 |
394 yield dut.len_i[2].eq(2)
395 yield dut.faddrs_i[2].eq(0b1011010)
396 yield dut.addr_en_i[2].eq(1)
397 yield
398 yield dut.addr_en_i[2].eq(0)
399 yield
400 # 0b110 | 0b101 |
401 # 0b101 1001 / 2 ==> 0b0000 0000 0000 0000 | 0000 0110 0000 0000 |
402 yield dut.len_i[2].eq(2)
403 yield dut.faddrs_i[2].eq(0b1011001)
404 yield dut.addr_en_i[2].eq(1)
405 yield
406 yield dut.addr_en_i[2].eq(0)
407 yield
408 yield dut.addr_rs_i[1].eq(1)
409 yield
410 yield dut.addr_rs_i[1].eq(0)
411 yield
412
413
414 def part_addr_byte(dut):
415 for l in range(8):
416 for a in range(1 << dut.bit_len):
417 maskbit = (1 << (l))-1
418 mask = (1 << (l*8))-1
419 yield dut.len_i.eq(l)
420 yield dut.addr_i.eq(a)
421 yield Settle()
422 lexp = yield dut.lexp_o
423 exp = yield dut.rexp_o
424 print("pa", l, a, bin(lexp), hex(exp))
425 assert exp == (mask << (a*8))
426 assert lexp == (maskbit << (a))
427
428
429 def test_lenexpand_byte():
430 dut = LenExpand(4, 8)
431 vl = rtlil.convert(dut, ports=dut.ports())
432 with open("test_len_expand_byte.il", "w") as f:
433 f.write(vl)
434 run_simulation(dut, part_addr_byte(dut), vcd_name='test_part_byte.vcd')
435
436
437 def test_part_addr():
438 dut = LenExpand(4)
439 vl = rtlil.convert(dut, ports=dut.ports())
440 with open("test_len_expand.il", "w") as f:
441 f.write(vl)
442
443 dut = TwinPartialAddrBitmap(3, 4, 10)
444 vl = rtlil.convert(dut, ports=dut.ports())
445 with open("test_twin_part_bit.il", "w") as f:
446 f.write(vl)
447
448 dut = PartialAddrBitmap(3, 4, 10)
449 vl = rtlil.convert(dut, ports=dut.ports())
450 with open("test_part_bit.il", "w") as f:
451 f.write(vl)
452
453 run_simulation(dut, part_addr_bit(dut), vcd_name='test_part_bit.vcd')
454
455 dut = PartialAddrMatch(3, 10)
456 vl = rtlil.convert(dut, ports=dut.ports())
457 with open("test_part_addr.il", "w") as f:
458 f.write(vl)
459
460 run_simulation(dut, part_addr_sim(dut), vcd_name='test_part_addr.vcd')
461
462
463 if __name__ == '__main__':
464 test_part_addr()
465 test_lenexpand_byte()