addr_split.py: shift bytes not bits
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 from soc.experiment.pimem import PortInterface
10
11 from nmigen import Elaboratable, Module, Signal, Record, Array, Const, Cat
12 from nmutil.latch import SRLatch, latchregister
13 from nmigen.back.pysim import Simulator, Delay
14 from nmigen.cli import verilog, rtlil
15
16 from soc.scoreboard.addr_match import LenExpand
17 #from nmutil.queue import Queue
18
19
20 class LDData(Record):
21 def __init__(self, dwidth, name=None):
22 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
23
24
25 class LDLatch(Elaboratable):
26
27 def __init__(self, dwidth, awidth, mlen):
28 self.addr_i = Signal(awidth, reset_less=True)
29 self.mask_i = Signal(mlen, reset_less=True)
30 self.valid_i = Signal(reset_less=True)
31 self.ld_i = LDData(dwidth, "ld_i")
32 self.ld_o = LDData(dwidth, "ld_o")
33 self.valid_o = Signal(reset_less=True)
34
35 def elaborate(self, platform):
36 m = Module()
37 comb = m.d.comb
38 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
39
40 comb += in_l.s.eq(self.valid_i)
41 comb += self.valid_o.eq(in_l.q & self.valid_i)
42 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
43
44 return m
45
46 def __iter__(self):
47 yield self.addr_i
48 yield self.mask_i
49 yield self.ld_i.err
50 yield self.ld_i.data
51 yield self.ld_o.err
52 yield self.ld_o.data
53 yield self.valid_i
54 yield self.valid_o
55
56 def ports(self):
57 return list(self)
58
59 def byteExpand(signal):
60 if(type(signal)==int):
61 ret = 0
62 shf = 0
63 while(signal>0):
64 bit = signal & 1
65 ret |= (0xFF * bit) << shf
66 signal = signal >> 1
67 shf += 8
68 return ret
69 lst = []
70 for i in range(len(signal)):
71 bit = signal[i]
72 for j in range(8):
73 lst += [bit]
74 return Cat(*lst)
75
76 class LDSTSplitter(Elaboratable):
77
78 def __init__(self, dwidth, awidth, dlen):
79 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
80 # cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
81 cline_wid = dwidth*8 # convert bytes to bits
82
83 self.pi = PortInterface()
84
85 self.addr_i = self.pi.addr.data #Signal(awidth, reset_less=True)
86 # no match in PortInterface
87 self.len_i = Signal(dlen, reset_less=True)
88 self.valid_i = Signal(reset_less=True)
89 self.valid_o = Signal(reset_less=True)
90
91 self.is_ld_i = self.pi.is_ld_i #Signal(reset_less=True)
92 self.is_st_i = self.pi.is_st_i #Signal(reset_less=True)
93
94 self.ld_data_o = LDData(dwidth*8, "ld_data_o") #port.ld
95 self.st_data_i = LDData(dwidth*8, "st_data_i") #port.st
96
97 self.exc = Signal(reset_less=True) # pi.exc TODO
98 # TODO : create/connect two outgoing port interfaces
99
100 self.sld_valid_o = Signal(2, reset_less=True)
101 self.sld_valid_i = Signal(2, reset_less=True)
102 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
103 LDData(cline_wid, "ld_data_i2")))
104
105 self.sst_valid_o = Signal(2, reset_less=True)
106 self.sst_valid_i = Signal(2, reset_less=True)
107 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
108 LDData(cline_wid, "st_data_i2")))
109
110 def elaborate(self, platform):
111 m = Module()
112 comb = m.d.comb
113 dlen = self.dlen
114 mlen = 1 << dlen
115 mzero = Const(0, mlen)
116 m.submodules.ld1 = ld1 = LDLatch(self.dwidth*8, self.awidth-dlen, mlen)
117 m.submodules.ld2 = ld2 = LDLatch(self.dwidth*8, self.awidth-dlen, mlen)
118 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
119
120 # FIXME bytes not bits
121 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
122 comb += lenexp.addr_i.eq(self.addr_i)
123 comb += lenexp.len_i.eq(self.len_i)
124 mask1 = Signal(mlen, reset_less=True)
125 mask2 = Signal(mlen, reset_less=True)
126 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
127 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
128
129 # set up new address records: addr1 is "as-is", addr2 is +1
130 comb += ld1.addr_i.eq(self.addr_i[dlen:])
131 ld2_value = self.addr_i[dlen:] + 1
132 comb += ld2.addr_i.eq(ld2_value)
133 # exception if rolls
134 with m.If(ld2_value[self.awidth-dlen]):
135 comb += self.exc.eq(1)
136
137 # data needs recombining / splitting via shifting.
138 ashift1 = Signal(self.dlen, reset_less=True)
139 ashift2 = Signal(self.dlen, reset_less=True)
140 comb += ashift1.eq(self.addr_i[:self.dlen])
141 comb += ashift2.eq((1 << dlen)-ashift1)
142
143 #expand masks
144 mask1 = byteExpand(mask1)
145 mask2 = byteExpand(mask2)
146 mzero = byteExpand(mzero)
147
148 with m.If(self.is_ld_i):
149 # set up connections to LD-split. note: not active if mask is zero
150 for i, (ld, mask) in enumerate(((ld1, mask1),
151 (ld2, mask2))):
152 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
153 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
154 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
155 comb += ld.ld_i.eq(self.sld_data_i[i])
156 comb += self.sld_valid_o[i].eq(ld.valid_o)
157
158 # sort out valid: mask2 zero we ignore 2nd LD
159 with m.If(mask2 == mzero):
160 comb += self.valid_o.eq(self.sld_valid_o[0])
161 with m.Else():
162 comb += self.valid_o.eq(self.sld_valid_o.all())
163 ## debug output -- output mask2 and mzero
164 ## guess second port is invalid
165
166 # all bits valid (including when data error occurs!) decode ld1/ld2
167 with m.If(self.valid_o):
168 # errors cause error condition
169 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
170
171 # note that data from LD1 will be in *cache-line* byte position
172 # likewise from LD2 but we *know* it is at the start of the line
173 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> (ashift1*8)) |
174 (ld2.ld_o.data << (ashift2*8)))
175
176 with m.If(self.is_st_i):
177 for i, (ld, mask) in enumerate(((ld1, mask1),
178 (ld2, mask2))):
179 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
180 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
181 comb += ld.valid_i.eq(valid & (mask != mzero))
182 comb += self.sld_valid_o[i].eq(ld.valid_o)
183 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
184
185 comb += ld1.ld_i.eq((self.st_data_i << (ashift1*8)) & mask1)
186 comb += ld2.ld_i.eq((self.st_data_i >> (ashift2*8)) & mask2)
187
188 # sort out valid: mask2 zero we ignore 2nd LD
189 with m.If(mask2 == mzero):
190 comb += self.valid_o.eq(self.sst_valid_o[0])
191 with m.Else():
192 comb += self.valid_o.eq(self.sst_valid_o.all())
193
194 # all bits valid (including when data error occurs!) decode ld1/ld2
195 with m.If(self.valid_o):
196 # errors cause error condition
197 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
198
199 return m
200
201 def __iter__(self):
202 yield self.addr_i
203 yield self.len_i
204 yield self.is_ld_i
205 yield self.ld_data_o.err
206 yield self.ld_data_o.data
207 yield self.valid_i
208 yield self.valid_o
209 yield self.sld_valid_i
210 for i in range(2):
211 yield self.sld_data_i[i].err
212 yield self.sld_data_i[i].data
213
214 def ports(self):
215 return list(self)
216
217
218 def sim(dut):
219
220 sim = Simulator(dut)
221 sim.add_clock(1e-6)
222 data = 0x0102030405060708A1A2A3A4A5A6A7A8
223 dlen = 16 # data length in bytes
224 addr = 0b1110
225 ld_len = 8
226 ldm = ((1 << ld_len)-1)
227 ldme = byteExpand(ldm)
228 dlm = ((1 << dlen)-1)
229 data = data & ldme # truncate data to be tested, mask to within ld len
230 print("ldm", ldm, hex(data & ldme))
231 print("dlm", dlm, bin(addr & dlm))
232
233 dmask = ldm << (addr & dlm)
234 print("dmask", bin(dmask))
235 dmask1 = dmask >> (1 << dlen)
236 print("dmask1", bin(dmask1))
237 dmask = dmask & ((1 << (1 << dlen))-1)
238 print("dmask", bin(dmask))
239 dmask1 = byteExpand(dmask1)
240 dmask = byteExpand(dmask)
241
242 def send_ld():
243 print("send_ld")
244 yield dut.is_ld_i.eq(1)
245 yield dut.len_i.eq(ld_len)
246 yield dut.addr_i.eq(addr)
247 yield dut.valid_i.eq(1)
248 print("waiting")
249 while True:
250 valid_o = yield dut.valid_o
251 if valid_o:
252 break
253 yield
254 exc = yield dut.exc
255 ld_data_o = yield dut.ld_data_o.data
256 yield dut.is_ld_i.eq(0)
257 yield
258
259 print(exc)
260 assert exc==0
261 print(hex(ld_data_o), hex(data))
262 assert ld_data_o == data
263
264 def lds():
265 print("lds")
266 while True:
267 valid_i = yield dut.valid_i
268 if valid_i:
269 break
270 yield
271
272 shf = (addr & dlm)*8 #shift bytes not bits
273 print("shf",shf/8.0)
274 shfdata = (data << shf)
275 data1 = shfdata & dmask
276 print("ld data1", hex(data), hex(data1), shf,shf/8.0, hex(dmask))
277
278 data2 = (shfdata >> 128) & dmask1
279 print("ld data2", 1 << dlen, hex(data >> (1 << dlen)), hex(data2))
280 yield dut.sld_data_i[0].data.eq(data1)
281 yield dut.sld_valid_i[0].eq(1)
282 yield
283 yield dut.sld_data_i[1].data.eq(data2)
284 yield dut.sld_valid_i[1].eq(1)
285 yield
286
287 sim.add_sync_process(lds)
288 sim.add_sync_process(send_ld)
289
290 prefix = "ldst_splitter"
291 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
292 sim.run()
293
294
295 if __name__ == '__main__':
296 dut = LDSTSplitter(32, 48, 4)
297 vl = rtlil.convert(dut, ports=dut.ports())
298 with open("ldst_splitter.il", "w") as f:
299 f.write(vl)
300
301 sim(dut)