format code
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
10 from nmutil.latch import SRLatch, latchregister
11 from nmigen.back.pysim import Simulator, Delay
12 from nmigen.cli import verilog, rtlil
13
14 from soc.scoreboard.addr_match import LenExpand
15 #from nmutil.queue import Queue
16
17
18 class LDData(Record):
19 def __init__(self, dwidth, name=None):
20 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
21
22
23 class LDLatch(Elaboratable):
24
25 def __init__(self, dwidth, awidth, mlen):
26 self.addr_i = Signal(awidth, reset_less=True)
27 self.mask_i = Signal(mlen, reset_less=True)
28 self.valid_i = Signal(reset_less=True)
29 self.ld_i = LDData(dwidth, "ld_i")
30 self.ld_o = LDData(dwidth, "ld_o")
31 self.valid_o = Signal(reset_less=True)
32
33 def elaborate(self, platform):
34 m = Module()
35 comb = m.d.comb
36 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
37
38 comb += in_l.s.eq(self.valid_i)
39 comb += self.valid_o.eq(in_l.q & self.valid_i)
40 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
41
42 return m
43
44 def __iter__(self):
45 yield self.addr_i
46 yield self.mask_i
47 yield self.ld_i.err
48 yield self.ld_i.data
49 yield self.ld_o.err
50 yield self.ld_o.data
51 yield self.valid_i
52 yield self.valid_o
53
54 def ports(self):
55 return list(self)
56
57
58 class LDSTSplitter(Elaboratable):
59
60 def __init__(self, dwidth, awidth, dlen):
61 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
62 # cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
63 cline_wid = dwidth # TODO: make this bytes not bits
64 self.addr_i = Signal(awidth, reset_less=True)
65 self.len_i = Signal(dlen, reset_less=True)
66 self.valid_i = Signal(reset_less=True)
67 self.valid_o = Signal(reset_less=True)
68
69 self.is_ld_i = Signal(reset_less=True)
70 self.is_st_i = Signal(reset_less=True)
71
72 self.ld_data_o = LDData(dwidth, "ld_data_o")
73 self.st_data_i = LDData(dwidth, "st_data_i")
74
75 self.exc = Signal(reset_less=True)
76
77 self.sld_valid_o = Signal(2, reset_less=True)
78 self.sld_valid_i = Signal(2, reset_less=True)
79 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
80 LDData(cline_wid, "ld_data_i2")))
81
82 self.sst_valid_o = Signal(2, reset_less=True)
83 self.sst_valid_i = Signal(2, reset_less=True)
84 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
85 LDData(cline_wid, "st_data_i2")))
86
87 def elaborate(self, platform):
88 m = Module()
89 comb = m.d.comb
90 dlen = self.dlen
91 mlen = 1 << dlen
92 mzero = Const(0, mlen)
93 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
94 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
95 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
96
97 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
98 comb += lenexp.addr_i.eq(self.addr_i)
99 comb += lenexp.len_i.eq(self.len_i)
100 mask1 = Signal(mlen, reset_less=True)
101 mask2 = Signal(mlen, reset_less=True)
102 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
103 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
104
105 # set up new address records: addr1 is "as-is", addr2 is +1
106 comb += ld1.addr_i.eq(self.addr_i[dlen:])
107 ld2_value = self.addr_i[dlen:] + 1
108 comb += ld2.addr_i.eq(ld2_value)
109 # exception if rolls
110 with m.If(ld2_value[self.awidth-dlen]):
111 comb += self.exc.eq(1)
112
113 # data needs recombining / splitting via shifting.
114 ashift1 = Signal(self.dlen, reset_less=True)
115 ashift2 = Signal(self.dlen, reset_less=True)
116 comb += ashift1.eq(self.addr_i[:self.dlen])
117 comb += ashift2.eq((1 << dlen)-ashift1)
118
119 with m.If(self.is_ld_i):
120 # set up connections to LD-split. note: not active if mask is zero
121 for i, (ld, mask) in enumerate(((ld1, mask1),
122 (ld2, mask2))):
123 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
124 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
125 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
126 comb += ld.ld_i.eq(self.sld_data_i[i])
127 comb += self.sld_valid_o[i].eq(ld.valid_o)
128
129 # sort out valid: mask2 zero we ignore 2nd LD
130 with m.If(mask2 == mzero):
131 comb += self.valid_o.eq(self.sld_valid_o[0])
132 with m.Else():
133 comb += self.valid_o.eq(self.sld_valid_o.all())
134
135 # all bits valid (including when data error occurs!) decode ld1/ld2
136 with m.If(self.valid_o):
137 # errors cause error condition
138 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
139
140 # note that data from LD1 will be in *cache-line* byte position
141 # likewise from LD2 but we *know* it is at the start of the line
142 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
143 (ld2.ld_o.data << ashift2))
144
145 with m.If(self.is_st_i):
146 for i, (ld, mask) in enumerate(((ld1, mask1),
147 (ld2, mask2))):
148 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
149 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
150 comb += ld.valid_i.eq(valid & (mask != mzero))
151 comb += self.sld_valid_o[i].eq(ld.valid_o)
152 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
153
154 comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
155 comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
156
157 # sort out valid: mask2 zero we ignore 2nd LD
158 with m.If(mask2 == mzero):
159 comb += self.valid_o.eq(self.sst_valid_o[0])
160 with m.Else():
161 comb += self.valid_o.eq(self.sst_valid_o.all())
162
163 # all bits valid (including when data error occurs!) decode ld1/ld2
164 with m.If(self.valid_o):
165 # errors cause error condition
166 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
167
168 return m
169
170 def __iter__(self):
171 yield self.addr_i
172 yield self.len_i
173 yield self.is_ld_i
174 yield self.ld_data_o.err
175 yield self.ld_data_o.data
176 yield self.valid_i
177 yield self.valid_o
178 yield self.sld_valid_i
179 for i in range(2):
180 yield self.sld_data_i[i].err
181 yield self.sld_data_i[i].data
182
183 def ports(self):
184 return list(self)
185
186
187 def sim(dut):
188
189 sim = Simulator(dut)
190 sim.add_clock(1e-6)
191 data = 0b11010011
192 dlen = 4 # 4 bits
193 addr = 0b1100
194 ld_len = 8
195 ldm = ((1 << ld_len)-1)
196 dlm = ((1 << dlen)-1)
197 data = data & ldm # truncate data to be tested, mask to within ld len
198 print("ldm", ldm, bin(data & ldm))
199 print("dlm", dlm, bin(addr & dlm))
200 dmask = ldm << (addr & dlm)
201 print("dmask", bin(dmask))
202 dmask1 = dmask >> (1 << dlen)
203 print("dmask1", bin(dmask1))
204 dmask = dmask & ((1 << (1 << dlen))-1)
205 print("dmask", bin(dmask))
206
207 def send_ld():
208 print("send_ld")
209 yield dut.is_ld_i.eq(1)
210 yield dut.len_i.eq(ld_len)
211 yield dut.addr_i.eq(addr)
212 yield dut.valid_i.eq(1)
213 print("waiting")
214 while True:
215 valid_o = yield dut.valid_o
216 if valid_o:
217 break
218 yield
219 ld_data_o = yield dut.ld_data_o.data
220 yield dut.is_ld_i.eq(0)
221 yield
222
223 print(bin(ld_data_o), bin(data))
224 assert ld_data_o == data
225
226 def lds():
227 print("lds")
228 while True:
229 valid_i = yield dut.valid_i
230 if valid_i:
231 break
232 yield
233
234 shf = addr & dlm
235 shfdata = (data << shf)
236 data1 = shfdata & dmask
237 print("ld data1", bin(data), bin(data1), shf, bin(dmask))
238
239 data2 = (shfdata >> 16) & dmask1
240 print("ld data2", 1 << dlen, bin(data >> (1 << dlen)), bin(data2))
241 yield dut.sld_data_i[0].data.eq(data1)
242 yield dut.sld_valid_i[0].eq(1)
243 yield
244 yield dut.sld_data_i[1].data.eq(data2)
245 yield dut.sld_valid_i[1].eq(1)
246 yield
247
248 sim.add_sync_process(lds)
249 sim.add_sync_process(send_ld)
250
251 prefix = "ldst_splitter"
252 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
253 sim.run()
254
255
256 if __name__ == '__main__':
257 dut = LDSTSplitter(32, 48, 4)
258 vl = rtlil.convert(dut, ports=dut.ports())
259 with open("ldst_splitter.il", "w") as f:
260 f.write(vl)
261
262 sim(dut)