expand LenExpand (haha) to cover bytes, with an argument "cover"
[soc.git] / src / soc / scoreboard / addr_split.py
1 # LDST Address Splitter. For misaligned address crossing cache line boundary
2 """
3 Links:
4 * https://libre-riscv.org/3d_gpu/architecture/6600scoreboard/
5 * http://bugs.libre-riscv.org/show_bug.cgi?id=257
6 * http://bugs.libre-riscv.org/show_bug.cgi?id=216
7 """
8
9 from nmigen import Elaboratable, Module, Signal, Record, Array, Const
10 from nmutil.latch import SRLatch, latchregister
11 from nmigen.back.pysim import Simulator, Delay
12 from nmigen.cli import verilog, rtlil
13
14 from soc.scoreboard.addr_match import LenExpand
15 #from nmutil.queue import Queue
16
17
18 class LDData(Record):
19 def __init__(self, dwidth, name=None):
20 Record.__init__(self, (('err', 1), ('data', dwidth)), name=name)
21
22
23 class LDLatch(Elaboratable):
24
25 def __init__(self, dwidth, awidth, mlen):
26 self.addr_i = Signal(awidth, reset_less=True)
27 self.mask_i = Signal(mlen, reset_less=True)
28 self.valid_i = Signal(reset_less=True)
29 self.ld_i = LDData(dwidth, "ld_i")
30 self.ld_o = LDData(dwidth, "ld_o")
31 self.valid_o = Signal(reset_less=True)
32
33 def elaborate(self, platform):
34 m = Module()
35 comb = m.d.comb
36 m.submodules.in_l = in_l = SRLatch(sync=False, name="in_l")
37
38 comb += in_l.s.eq(self.valid_i)
39 comb += self.valid_o.eq(in_l.q & self.valid_i)
40 latchregister(m, self.ld_i, self.ld_o, in_l.q & self.valid_o, "ld_i_r")
41
42 return m
43
44 def __iter__(self):
45 yield self.addr_i
46 yield self.mask_i
47 yield self.ld_i.err
48 yield self.ld_i.data
49 yield self.ld_o.err
50 yield self.ld_o.data
51 yield self.valid_i
52 yield self.valid_o
53
54 def ports(self):
55 return list(self)
56
57
58 class LDSTSplitter(Elaboratable):
59
60 def __init__(self, dwidth, awidth, dlen):
61 self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
62 #cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
63 cline_wid = dwidth # TODO: make this bytes not bits
64 self.addr_i = Signal(awidth, reset_less=True)
65 self.len_i = Signal(dlen, reset_less=True)
66 self.valid_i = Signal(reset_less=True)
67 self.valid_o = Signal(reset_less=True)
68
69 self.is_ld_i = Signal(reset_less=True)
70 self.is_st_i = Signal(reset_less=True)
71
72 self.ld_data_o = LDData(dwidth, "ld_data_o")
73 self.st_data_i = LDData(dwidth, "st_data_i")
74
75 self.sld_valid_o = Signal(2, reset_less=True)
76 self.sld_valid_i = Signal(2, reset_less=True)
77 self.sld_data_i = Array((LDData(cline_wid, "ld_data_i1"),
78 LDData(cline_wid, "ld_data_i2")))
79
80 self.sst_valid_o = Signal(2, reset_less=True)
81 self.sst_valid_i = Signal(2, reset_less=True)
82 self.sst_data_o = Array((LDData(cline_wid, "st_data_i1"),
83 LDData(cline_wid, "st_data_i2")))
84
85 def elaborate(self, platform):
86 m = Module()
87 comb = m.d.comb
88 dlen = self.dlen
89 mlen = 1 << dlen
90 mzero = Const(0, mlen)
91 m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
92 m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
93 m.submodules.lenexp = lenexp = LenExpand(self.dlen)
94
95 # set up len-expander, len to mask. ld1 gets first bit, ld2 gets rest
96 comb += lenexp.addr_i.eq(self.addr_i)
97 comb += lenexp.len_i.eq(self.len_i)
98 mask1 = Signal(mlen, reset_less=True)
99 mask2 = Signal(mlen, reset_less=True)
100 comb += mask1.eq(lenexp.lexp_o[0:mlen]) # Lo bits of expanded len-mask
101 comb += mask2.eq(lenexp.lexp_o[mlen:]) # Hi bits of expanded len-mask
102
103 # set up new address records: addr1 is "as-is", addr2 is +1
104 comb += ld1.addr_i.eq(self.addr_i[dlen:])
105 comb += ld2.addr_i.eq(self.addr_i[dlen:] + 1) # TODO exception if rolls
106
107 # data needs recombining / splitting via shifting.
108 ashift1 = Signal(self.dlen, reset_less=True)
109 ashift2 = Signal(self.dlen, reset_less=True)
110 comb += ashift1.eq(self.addr_i[:self.dlen])
111 comb += ashift2.eq((1<<dlen)-ashift1)
112
113 with m.If(self.is_ld_i):
114 # set up connections to LD-split. note: not active if mask is zero
115 for i, (ld, mask) in enumerate(((ld1, mask1),
116 (ld2, mask2))):
117 ld_valid = Signal(name="ldvalid_i%d" % i, reset_less=True)
118 comb += ld_valid.eq(self.valid_i & self.sld_valid_i[i])
119 comb += ld.valid_i.eq(ld_valid & (mask != mzero))
120 comb += ld.ld_i.eq(self.sld_data_i[i])
121 comb += self.sld_valid_o[i].eq(ld.valid_o)
122
123 # sort out valid: mask2 zero we ignore 2nd LD
124 with m.If(mask2 == mzero):
125 comb += self.valid_o.eq(self.sld_valid_o[0])
126 with m.Else():
127 comb += self.valid_o.eq(self.sld_valid_o.all())
128
129 # all bits valid (including when data error occurs!) decode ld1/ld2
130 with m.If(self.valid_o):
131 # errors cause error condition
132 comb += self.ld_data_o.err.eq(ld1.ld_o.err | ld2.ld_o.err)
133
134 # note that data from LD1 will be in *cache-line* byte position
135 # likewise from LD2 but we *know* it is at the start of the line
136 comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
137 (ld2.ld_o.data << ashift2))
138
139 with m.If(self.is_st_i):
140 for i, (ld, mask) in enumerate(((ld1, mask1),
141 (ld2, mask2))):
142 valid = Signal(name="stvalid_i%d" % i, reset_less=True)
143 comb += valid.eq(self.valid_i & self.sst_valid_i[i])
144 comb += ld.valid_i.eq(valid & (mask != mzero))
145 comb += self.sld_valid_o[i].eq(ld.valid_o)
146 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
147
148 comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
149 comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
150
151 # sort out valid: mask2 zero we ignore 2nd LD
152 with m.If(mask2 == mzero):
153 comb += self.valid_o.eq(self.sst_valid_o[0])
154 with m.Else():
155 comb += self.valid_o.eq(self.sst_valid_o.all())
156
157 # all bits valid (including when data error occurs!) decode ld1/ld2
158 with m.If(self.valid_o):
159 # errors cause error condition
160 comb += self.st_data_i.err.eq(ld1.ld_o.err | ld2.ld_o.err)
161
162 return m
163
164 def __iter__(self):
165 yield self.addr_i
166 yield self.len_i
167 yield self.is_ld_i
168 yield self.ld_data_o.err
169 yield self.ld_data_o.data
170 yield self.valid_i
171 yield self.valid_o
172 yield self.sld_valid_i
173 for i in range(2):
174 yield self.sld_data_i[i].err
175 yield self.sld_data_i[i].data
176
177 def ports(self):
178 return list(self)
179
180 def sim(dut):
181
182 sim = Simulator(dut)
183 sim.add_clock(1e-6)
184 data = 0b11010011
185 dlen = 4 # 4 bits
186 addr = 0b1100
187 ld_len = 8
188 ldm = ((1<<ld_len)-1)
189 dlm = ((1<<dlen)-1)
190 data = data & ldm # truncate data to be tested, mask to within ld len
191 print ("ldm", ldm, bin(data&ldm))
192 print ("dlm", dlm, bin(addr&dlm))
193 dmask = ldm << (addr & dlm)
194 print ("dmask", bin(dmask))
195 dmask1 = dmask >> (1<<dlen)
196 print ("dmask1", bin(dmask1))
197 dmask = dmask & ((1<<(1<<dlen))-1)
198 print ("dmask", bin(dmask))
199
200 def send_ld():
201 print ("send_ld")
202 yield dut.is_ld_i.eq(1)
203 yield dut.len_i.eq(ld_len)
204 yield dut.addr_i.eq(addr)
205 yield dut.valid_i.eq(1)
206 print ("waiting")
207 while True:
208 valid_o = yield dut.valid_o
209 if valid_o:
210 break
211 yield
212 ld_data_o = yield dut.ld_data_o.data
213 yield dut.is_ld_i.eq(0)
214 yield
215
216 print (bin(ld_data_o), bin(data))
217 assert ld_data_o == data
218
219 def lds():
220 print ("lds")
221 while True:
222 valid_i = yield dut.valid_i
223 if valid_i:
224 break
225 yield
226
227 shf = addr & dlm
228 shfdata = (data << shf)
229 data1 = shfdata & dmask
230 print ("ld data1", bin(data), bin(data1), shf, bin(dmask))
231
232 data2 = (shfdata >> 16) & dmask1
233 print ("ld data2", 1<<dlen, bin(data >> (1<<dlen)), bin(data2))
234 yield dut.sld_data_i[0].data.eq(data1)
235 yield dut.sld_valid_i[0].eq(1)
236 yield
237 yield dut.sld_data_i[1].data.eq(data2)
238 yield dut.sld_valid_i[1].eq(1)
239 yield
240
241 sim.add_sync_process(lds)
242 sim.add_sync_process(send_ld)
243
244 prefix = "ldst_splitter"
245 with sim.write_vcd("%s.vcd" % prefix, traces=dut.ports()):
246 sim.run()
247
248
249 if __name__ == '__main__':
250 dut = LDSTSplitter(32, 48, 4)
251 vl = rtlil.convert(dut, ports=dut.ports())
252 with open("ldst_splitter.il", "w") as f:
253 f.write(vl)
254
255 sim(dut)