a1453403b09164701e912da23bb6b4d1d7a57813
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
10 from nmutil.latch import SRLatch, latchregister
11 from nmigen.back.pysim import Simulator, Delay
12 from nmigen.cli import verilog, rtlil
13
14 from soc.scoreboard.addr_match import LenExpand
15 #from nmutil.queue import Queue
16 from soc.experiment import l0_cache
17
18
19 class LDData(Record):
20 def __init__(self, dwidth, name=None):
21 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
22
23
24 class LDLatch(Elaboratable):
25
26 def __init__(self, dwidth, awidth, mlen):
27 self.addr_i = Signal(awidth, reset_less=True)
28 self.mask_i = Signal(mlen, reset_less=True)
29 self.valid_i = Signal(reset_less=True)
30 self.ld_i = LDData(dwidth, "ld_i")
31 self.ld_o = LDData(dwidth, "ld_o")
32 self.valid_o = Signal(reset_less=True)
33
34 def elaborate(self, platform):
35 m = Module()
36 comb = m.d.comb
37 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
38
39 comb += in_l.s.eq(self.valid_i)
40 comb += self.valid_o.eq(in_l.q & self.valid_i)
41 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
42
43 return m
44
45 def __iter__(self):
46 yield self.addr_i
47 yield self.mask_i
48 yield self.ld_i.err
49 yield self.ld_i.data
50 yield self.ld_o.err
51 yield self.ld_o.data
52 yield self.valid_i
53 yield self.valid_o
54
55 def ports(self):
56 return list(self)
57
58
59 class LDSTSplitter(Elaboratable):
60
61 def __init__(self, dwidth, awidth, dlen):
62 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
63 #cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
64 cline_wid = dwidth # TODO: make this bytes not bits
65
66 self.inport = l0_cache.PortInterface()
67 print(self.inport)
68
69 self.addr_i = Signal(awidth, reset_less=True)
70 self.len_i = Signal(dlen, reset_less=True)
71 self.valid_i = Signal(reset_less=True)
72 self.valid_o = Signal(reset_less=True)
73
74 self.is_ld_i = Signal(reset_less=True)
75 self.is_st_i = Signal(reset_less=True)
76
77 self.ld_data_o = LDData(dwidth, "ld_data_o")
78 self.st_data_i = LDData(dwidth, "st_data_i")
79
80 self.exc = Signal(reset_less=True)
81
82 #TODO out port interface
83 self.sld_valid_o = Signal(2, reset_less=True)
84 self.sld_valid_i = Signal(2, reset_less=True)
85 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
86 LDData(cline_wid, "ld_data_i2")))
87
88 self.sst_valid_o = Signal(2, reset_less=True)
89 self.sst_valid_i = Signal(2, reset_less=True)
90 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
91 LDData(cline_wid, "st_data_i2")))
92
93 def elaborate(self, platform):
94 m = Module()
95 comb = m.d.comb
96 dlen = self.dlen
97 mlen = 1 << dlen
98 mzero = Const(0, mlen)
99 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
100 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
101 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
102
103 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
104 comb += lenexp.addr_i.eq(self.addr_i)
105 comb += lenexp.len_i.eq(self.len_i)
106 mask1 = Signal(mlen, reset_less=True)
107 mask2 = Signal(mlen, reset_less=True)
108 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
109 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
110
111 # set up new address records: addr1 is "as-is", addr2 is +1
112 comb += ld1.addr_i.eq(self.addr_i[dlen:])
113 ld2_value = self.addr_i[dlen:] + 1
114 comb += ld2.addr_i.eq(ld2_value)
115 #exception if rolls
116 with m.If(ld2_value[self.awidth-dlen]):
117 comb += self.exc.eq(1)
118
119 # data needs recombining / splitting via shifting.
120 ashift1 = Signal(self.dlen, reset_less=True)
121 ashift2 = Signal(self.dlen, reset_less=True)
122 comb += ashift1.eq(self.addr_i[:self.dlen])
123 comb += ashift2.eq((1<<dlen)-ashift1)
124
125 with m.If(self.is_ld_i):
126 # set up connections to LD-split. note: not active if mask is zero
127 for i, (ld, mask) in enumerate(((ld1, mask1),
128 (ld2, mask2))):
129 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
130 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
131 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
132 comb += ld.ld_i.eq(self.sld_data_i[i])
133 comb += self.sld_valid_o[i].eq(ld.valid_o)
134
135 # sort out valid: mask2 zero we ignore 2nd LD
136 with m.If(mask2 == mzero):
137 comb += self.valid_o.eq(self.sld_valid_o[0])
138 with m.Else():
139 comb += self.valid_o.eq(self.sld_valid_o.all())
140
141 # all bits valid (including when data error occurs!) decode ld1/ld2
142 with m.If(self.valid_o):
143 # errors cause error condition
144 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
145
146 # note that data from LD1 will be in *cache-line* byte position
147 # likewise from LD2 but we *know* it is at the start of the line
148 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
149 (ld2.ld_o.data << ashift2))
150
151 with m.If(self.is_st_i):
152 for i, (ld, mask) in enumerate(((ld1, mask1),
153 (ld2, mask2))):
154 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
155 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
156 comb += ld.valid_i.eq(valid & (mask != mzero))
157 comb += self.sld_valid_o[i].eq(ld.valid_o)
158 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
159
160 comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
161 comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
162
163 # sort out valid: mask2 zero we ignore 2nd LD
164 with m.If(mask2 == mzero):
165 comb += self.valid_o.eq(self.sst_valid_o[0])
166 with m.Else():
167 comb += self.valid_o.eq(self.sst_valid_o.all())
168
169 # all bits valid (including when data error occurs!) decode ld1/ld2
170 with m.If(self.valid_o):
171 # errors cause error condition
172 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
173
174 return m
175
176 def __iter__(self):
177 yield self.addr_i
178 yield self.len_i
179 yield self.is_ld_i
180 yield self.ld_data_o.err
181 yield self.ld_data_o.data
182 yield self.valid_i
183 yield self.valid_o
184 yield self.sld_valid_i
185 for i in range(2):
186 yield self.sld_data_i[i].err
187 yield self.sld_data_i[i].data
188
189 def ports(self):
190 return list(self)
191
192 def sim(dut):
193
194 sim = Simulator(dut)
195 sim.add_clock(1e-6)
196 data = 0b11010011
197 dlen = 4 # 4 bits
198 addr = 0b1100
199 ld_len = 8
200 ldm = ((1<<ld_len)-1)
201 dlm = ((1<<dlen)-1)
202 data = data & ldm # truncate data to be tested, mask to within ld len
203 print ("ldm", ldm, bin(data&ldm))
204 print ("dlm", dlm, bin(addr&dlm))
205 dmask = ldm << (addr & dlm)
206 print ("dmask", bin(dmask))
207 dmask1 = dmask >> (1<<dlen)
208 print ("dmask1", bin(dmask1))
209 dmask = dmask & ((1<<(1<<dlen))-1)
210 print ("dmask", bin(dmask))
211
212 def send_ld():
213 print ("send_ld")
214 yield dut.is_ld_i.eq(1)
215 yield dut.len_i.eq(ld_len)
216 yield dut.addr_i.eq(addr)
217 yield dut.valid_i.eq(1)
218 print ("waiting")
219 while True:
220 valid_o = yield dut.valid_o
221 if valid_o:
222 break
223 yield
224 ld_data_o = yield dut.ld_data_o.data
225 yield dut.is_ld_i.eq(0)
226 yield
227
228 print (bin(ld_data_o), bin(data))
229 #FIXME: wrong result here
230 assert ld_data_o == data
231
232 def lds():
233 print ("lds")
234 while True:
235 valid_i = yield dut.valid_i
236 if valid_i:
237 break
238 yield
239
240 shf = addr & dlm
241 shfdata = (data << shf)
242 data1 = shfdata & dmask
243 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
244
245 data2 = (shfdata >> 16) & dmask1
246 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
247 yield dut.sld_data_i[0].data.eq(data1)
248 yield dut.sld_valid_i[0].eq(1)
249 yield
250 yield dut.sld_data_i[1].data.eq(data2)
251 yield dut.sld_valid_i[1].eq(1)
252 yield
253
254 sim.add_sync_process(lds)
255 sim.add_sync_process(send_ld)
256
257 prefix = "ldst_splitter"
258 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
259 sim.run()
260
261
262 if __name__ == '__main__':
263 dut = LDSTSplitter(32, 48, 4)
264 vl = rtlil.convert(dut, ports=dut.ports())
265 with open("ldst_splitter.il", "w") as f:
266 f.write(vl)
267
268 sim(dut)