7ae3fcc5778ebb96531c4939efa67549959c96c5
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
10 from nmutil.latch import SRLatch, latchregister
11 from nmigen.back.pysim import Simulator, Delay
12 from nmigen.cli import verilog, rtlil
13
14 from soc.scoreboard.addr_match import LenExpand
15 #from nmutil.queue import Queue
16
17
18 class LDData(Record):
19 def __init__(self, dwidth, name=None):
20 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
21
22
23 class LDLatch(Elaboratable):
24
25 def __init__(self, dwidth, awidth, mlen):
26 self.addr_i = Signal(awidth, reset_less=True)
27 self.mask_i = Signal(mlen, reset_less=True)
28 self.valid_i = Signal(reset_less=True)
29 self.ld_i = LDData(dwidth, "ld_i")
30 self.ld_o = LDData(dwidth, "ld_o")
31 self.valid_o = Signal(reset_less=True)
32
33 def elaborate(self, platform):
34 m = Module()
35 comb = m.d.comb
36 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
37
38 comb += in_l.s.eq(self.valid_i)
39 comb += self.valid_o.eq(in_l.q & self.valid_i)
40 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
41
42 return m
43
44 def __iter__(self):
45 yield self.addr_i
46 yield self.mask_i
47 yield self.ld_i.err
48 yield self.ld_i.data
49 yield self.ld_o.err
50 yield self.ld_o.data
51 yield self.valid_i
52 yield self.valid_o
53
54 def ports(self):
55 return list(self)
56
57
58 class LDSTSplitter(Elaboratable):
59
60 def __init__(self, dwidth, awidth, dlen):
61 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
62 #cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
63 cline_wid = dwidth # TODO: make this bytes not bits
64 self.addr_i = Signal(awidth, reset_less=True)
65 self.len_i = Signal(dlen, reset_less=True)
66 self.valid_i = Signal(reset_less=True)
67 self.valid_o = Signal(reset_less=True)
68
69 self.is_ld_i = Signal(reset_less=True)
70 self.is_st_i = Signal(reset_less=True)
71
72 self.ld_data_o = LDData(dwidth, "ld_data_o")
73 self.st_data_i = LDData(dwidth, "st_data_i")
74
75 self.exc = Signal(reset_less=True)
76
77 self.sld_valid_o = Signal(2, reset_less=True)
78 self.sld_valid_i = Signal(2, reset_less=True)
79 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
80 LDData(cline_wid, "ld_data_i2")))
81
82 self.sst_valid_o = Signal(2, reset_less=True)
83 self.sst_valid_i = Signal(2, reset_less=True)
84 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
85 LDData(cline_wid, "st_data_i2")))
86
87 def elaborate(self, platform):
88 m = Module()
89 comb = m.d.comb
90 dlen = self.dlen
91 mlen = 1 << dlen
92 mzero = Const(0, mlen)
93 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
94 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
95 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
96
97 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
98 comb += lenexp.addr_i.eq(self.addr_i)
99 comb += lenexp.len_i.eq(self.len_i)
100 mask1 = Signal(mlen, reset_less=True)
101 mask2 = Signal(mlen, reset_less=True)
102 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
103 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
104
105 # set up new address records: addr1 is "as-is", addr2 is +1
106 comb += ld1.addr_i.eq(self.addr_i[dlen:])
107 ld2_value = self.addr_i[dlen:] + 1
108 comb += ld2.addr_i.eq(ld2_value)
109 #exception if rolls
110 with m.If(ld2_value[self.awidth-dlen]):
111 comb += self.exc.eq(1)
112
113 # data needs recombining / splitting via shifting.
114 ashift1 = Signal(self.dlen, reset_less=True)
115 ashift2 = Signal(self.dlen, reset_less=True)
116 comb += ashift1.eq(self.addr_i[:self.dlen])
117 comb += ashift2.eq((1<<dlen)-ashift1)
118
119 with m.If(self.is_ld_i):
120 # set up connections to LD-split. note: not active if mask is zero
121 for i, (ld, mask) in enumerate(((ld1, mask1),
122 (ld2, mask2))):
123 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
124 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
125 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
126 comb += ld.ld_i.eq(self.sld_data_i[i])
127 comb += self.sld_valid_o[i].eq(ld.valid_o)
128
129 # sort out valid: mask2 zero we ignore 2nd LD
130 with m.If(mask2 == mzero):
131 comb += self.valid_o.eq(self.sld_valid_o[0])
132 with m.Else():
133 comb += self.valid_o.eq(self.sld_valid_o.all())
134
135 # all bits valid (including when data error occurs!) decode ld1/ld2
136 with m.If(self.valid_o):
137 # errors cause error condition
138 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
139
140 # note that data from LD1 will be in *cache-line* byte position
141 # likewise from LD2 but we *know* it is at the start of the line
142 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
143 (ld2.ld_o.data << ashift2))
144
145 with m.If(self.is_st_i):
146 for i, (ld, mask) in enumerate(((ld1, mask1),
147 (ld2, mask2))):
148 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
149 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
150 comb += ld.valid_i.eq(valid & (mask != mzero))
151 comb += self.sld_valid_o[i].eq(ld.valid_o)
152 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
153
154 comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
155 comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
156
157 # sort out valid: mask2 zero we ignore 2nd LD
158 with m.If(mask2 == mzero):
159 comb += self.valid_o.eq(self.sst_valid_o[0])
160 with m.Else():
161 comb += self.valid_o.eq(self.sst_valid_o.all())
162
163 # all bits valid (including when data error occurs!) decode ld1/ld2
164 with m.If(self.valid_o):
165 # errors cause error condition
166 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
167
168 return m
169
170 def __iter__(self):
171 yield self.addr_i
172 yield self.len_i
173 yield self.is_ld_i
174 yield self.ld_data_o.err
175 yield self.ld_data_o.data
176 yield self.valid_i
177 yield self.valid_o
178 yield self.sld_valid_i
179 for i in range(2):
180 yield self.sld_data_i[i].err
181 yield self.sld_data_i[i].data
182
183 def ports(self):
184 return list(self)
185
186 def sim(dut):
187
188 sim = Simulator(dut)
189 sim.add_clock(1e-6)
190 data = 0b11010011
191 dlen = 4 # 4 bits
192 addr = 0b1100
193 ld_len = 8
194 ldm = ((1<<ld_len)-1)
195 dlm = ((1<<dlen)-1)
196 data = data & ldm # truncate data to be tested, mask to within ld len
197 print ("ldm", ldm, bin(data&ldm))
198 print ("dlm", dlm, bin(addr&dlm))
199 dmask = ldm << (addr & dlm)
200 print ("dmask", bin(dmask))
201 dmask1 = dmask >> (1<<dlen)
202 print ("dmask1", bin(dmask1))
203 dmask = dmask & ((1<<(1<<dlen))-1)
204 print ("dmask", bin(dmask))
205
206 def send_ld():
207 print ("send_ld")
208 yield dut.is_ld_i.eq(1)
209 yield dut.len_i.eq(ld_len)
210 yield dut.addr_i.eq(addr)
211 yield dut.valid_i.eq(1)
212 print ("waiting")
213 while True:
214 valid_o = yield dut.valid_o
215 if valid_o:
216 break
217 yield
218 ld_data_o = yield dut.ld_data_o.data
219 yield dut.is_ld_i.eq(0)
220 yield
221
222 print (bin(ld_data_o), bin(data))
223 assert ld_data_o == data
224
225 def lds():
226 print ("lds")
227 while True:
228 valid_i = yield dut.valid_i
229 if valid_i:
230 break
231 yield
232
233 shf = addr & dlm
234 shfdata = (data << shf)
235 data1 = shfdata & dmask
236 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
237
238 data2 = (shfdata >> 16) & dmask1
239 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
240 yield dut.sld_data_i[0].data.eq(data1)
241 yield dut.sld_valid_i[0].eq(1)
242 yield
243 yield dut.sld_data_i[1].data.eq(data2)
244 yield dut.sld_valid_i[1].eq(1)
245 yield
246
247 sim.add_sync_process(lds)
248 sim.add_sync_process(send_ld)
249
250 prefix = "ldst_splitter"
251 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
252 sim.run()
253
254
255 if __name__ == '__main__':
256 dut = LDSTSplitter(32, 48, 4)
257 vl = rtlil.convert(dut, ports=dut.ports())
258 with open("ldst_splitter.il", "w") as f:
259 f.write(vl)
260
261 sim(dut)