3241e20e1c19aa2a457e572333120f103944f5f5
[soc.git] / src / soc / minerva / units / loadstore.py
1 from nmigen import Elaboratable, Module, Signal, Record, Cat, Const, Mux
2 from nmigen.utils import log2_int
3 from nmigen.lib.fifo import SyncFIFO
4
5 from soc.minerva.cache import L1Cache
6 from soc.minerva.wishbone import make_wb_layout, WishboneArbiter, Cycle
7 from soc.bus.wb_downconvert import WishboneDownConvert
8
9 from copy import deepcopy
10
11 __all__ = ["LoadStoreUnitInterface", "BareLoadStoreUnit",
12 "CachedLoadStoreUnit"]
13
14
15 class LoadStoreUnitInterface:
16 def __init__(self, pspec):
17 self.pspec = pspec
18 self.pspecslave = pspec
19 self.dbus = self.slavebus = Record(make_wb_layout(pspec))
20 print(self.dbus.sel.shape())
21 self.needs_cvt = False
22 if (hasattr(pspec, "dmem_test_depth") and
23 isinstance(pspec.wb_data_wid, int) and
24 pspec.wb_data_wid != pspec.reg_wid):
25 pspecslave = deepcopy(pspec)
26 pspecslave.reg_wid = pspec.wb_data_wid
27 mask_ratio = (pspec.reg_wid // pspec.wb_data_wid)
28 pspecslave.mask_wid = pspec.mask_wid // mask_ratio
29 self.pspecslave = pspecslave
30 self.slavebus = Record(make_wb_layout(pspecslave))
31 self.needs_cvt = True
32 self.mask_wid = mask_wid = pspec.mask_wid
33 self.addr_wid = addr_wid = pspec.addr_wid
34 self.data_wid = data_wid = pspec.reg_wid
35 print("loadstoreunit addr mask data", addr_wid, mask_wid, data_wid)
36 self.adr_lsbs = log2_int(mask_wid) # LSBs of addr covered by mask
37 badwid = addr_wid-self.adr_lsbs # TODO: is this correct?
38
39 # INPUTS
40 self.x_addr_i = Signal(addr_wid) # address used for loads/stores
41 self.x_mask_i = Signal(mask_wid) # Mask of which bytes to write
42 self.x_ld_i = Signal() # set to do a memory load
43 self.x_st_i = Signal() # set to do a memory store
44 self.x_st_data_i = Signal(data_wid) # The data to write when storing
45
46 self.x_stall_i = Signal() # do nothing until low
47 self.x_valid_i = Signal() # Whether x pipeline stage is
48 # currently enabled (I
49 # think?). Set to 1 for #now
50 self.m_stall_i = Signal() # do nothing until low
51 self.m_valid_i = Signal() # Whether m pipeline stage is
52 # currently enabled. Set
53 # to 1 for now
54
55 # OUTPUTS
56 self.x_busy_o = Signal() # set when the memory is busy
57 self.m_busy_o = Signal() # set when the memory is busy
58
59 self.m_ld_data_o = Signal(data_wid) # Data returned from memory read
60 # Data validity is NOT indicated by m_valid_i or x_valid_i as
61 # those are inputs. I believe it is valid on the next cycle
62 # after raising m_load where busy is low
63
64 self.m_load_err_o = Signal() # if there was an error when loading
65 self.m_store_err_o = Signal() # if there was an error when storing
66 # The address of the load/store error
67 self.m_badaddr_o = Signal(badwid)
68
69 def __iter__(self):
70 yield self.x_addr_i
71 yield self.x_mask_i
72 yield self.x_ld_i
73 yield self.x_st_i
74 yield self.x_st_data_i
75
76 yield self.x_stall_i
77 yield self.x_valid_i
78 yield self.m_stall_i
79 yield self.m_valid_i
80 yield self.x_busy_o
81 yield self.m_busy_o
82 yield self.m_ld_data_o
83 yield self.m_load_err_o
84 yield self.m_store_err_o
85 yield self.m_badaddr_o
86 for sig in self.dbus.fields.values():
87 yield sig
88
89 def ports(self):
90 return list(self)
91
92
93 class BareLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
94 def elaborate(self, platform):
95 m = Module()
96
97 if self.needs_cvt:
98 self.cvt = WishboneDownConvert(self.dbus, self.slavebus)
99 m.submodules.cvt = self.cvt
100
101 with m.If(self.dbus.cyc):
102 with m.If(self.dbus.ack | self.dbus.err | ~self.m_valid_i):
103 m.d.sync += [
104 self.dbus.cyc.eq(0),
105 self.dbus.stb.eq(0),
106 self.dbus.sel.eq(0),
107 self.m_ld_data_o.eq(self.dbus.dat_r)
108 ]
109 with m.Elif((self.x_ld_i | self.x_st_i) &
110 self.x_valid_i & ~self.x_stall_i):
111 m.d.sync += [
112 self.dbus.cyc.eq(1),
113 self.dbus.stb.eq(1),
114 self.dbus.adr.eq(self.x_addr_i[self.adr_lsbs:]),
115 self.dbus.sel.eq(self.x_mask_i),
116 self.dbus.we.eq(self.x_st_i),
117 self.dbus.dat_w.eq(self.x_st_data_i)
118 ]
119 with m.Else():
120 m.d.sync += [
121 self.dbus.adr.eq(0),
122 self.dbus.sel.eq(0),
123 self.dbus.we.eq(0),
124 self.dbus.sel.eq(0),
125 self.dbus.dat_w.eq(0),
126 ]
127
128 with m.If(self.dbus.cyc & self.dbus.err):
129 m.d.sync += [
130 self.m_load_err_o.eq(~self.dbus.we),
131 self.m_store_err_o.eq(self.dbus.we),
132 self.m_badaddr_o.eq(self.dbus.adr)
133 ]
134 with m.Elif(~self.m_stall_i):
135 m.d.sync += [
136 self.m_load_err_o.eq(0),
137 self.m_store_err_o.eq(0)
138 ]
139
140 m.d.comb += self.x_busy_o.eq(self.dbus.cyc)
141
142 with m.If(self.m_load_err_o | self.m_store_err_o):
143 m.d.comb += self.m_busy_o.eq(0)
144 with m.Else():
145 m.d.comb += self.m_busy_o.eq(self.dbus.cyc)
146
147 return m
148
149
150 class CachedLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
151 def __init__(self, pspec):
152 super().__init__(pspec)
153
154 self.dcache_args = psiec.dcache_args
155
156 self.x_fence_i = Signal()
157 self.x_flush = Signal()
158 self.m_load = Signal()
159 self.m_store = Signal()
160
161 def elaborate(self, platform):
162 m = Module()
163
164 dcache = m.submodules.dcache = L1Cache(*self.dcache_args)
165
166 x_dcache_select = Signal()
167 # Test whether the target address is inside the L1 cache region.
168 # We use bit masks in order to avoid carry chains from arithmetic
169 # comparisons. This restricts the region boundaries to powers of 2.
170 with m.Switch(self.x_addr_i[self.adr_lsbs:]):
171 def addr_below(limit):
172 assert limit in range(1, 2**30 + 1)
173 range_bits = log2_int(limit)
174 const_bits = 30 - range_bits
175 return "{}{}".format("0" * const_bits, "-" * range_bits)
176
177 if dcache.base >= (1 << self.adr_lsbs):
178 with m.Case(addr_below(dcache.base >> self.adr_lsbs)):
179 m.d.comb += x_dcache_select.eq(0)
180 with m.Case(addr_below(dcache.limit >> self.adr_lsbs)):
181 m.d.comb += x_dcache_select.eq(1)
182 with m.Default():
183 m.d.comb += x_dcache_select.eq(0)
184
185 m_dcache_select = Signal()
186 m_addr = Signal.like(self.x_addr_i)
187
188 with m.If(~self.x_stall_i):
189 m.d.sync += [
190 m_dcache_select.eq(x_dcache_select),
191 m_addr.eq(self.x_addr_i),
192 ]
193
194 m.d.comb += [
195 dcache.s1_addr.eq(self.x_addr_i[self.adr_lsbs:]),
196 dcache.s1_flush.eq(self.x_flush),
197 dcache.s1_stall.eq(self.x_stall_i),
198 dcache.s1_valid.eq(self.x_valid_i & x_dcache_select),
199 dcache.s2_addr.eq(m_addr[self.adr_lsbs:]),
200 dcache.s2_re.eq(self.m_load),
201 dcache.s2_evict.eq(self.m_store),
202 dcache.s2_valid.eq(self.m_valid_i & m_dcache_select)
203 ]
204
205 wrbuf_w_data = Record([("addr", self.addr_wid-self.adr_lsbs),
206 ("mask", self.mask_wid),
207 ("data", self.data_wid)])
208 wrbuf_r_data = Record.like(wrbuf_w_data)
209 wrbuf = m.submodules.wrbuf = SyncFIFO(width=len(wrbuf_w_data),
210 depth=dcache.nwords)
211 m.d.comb += [
212 wrbuf.w_data.eq(wrbuf_w_data),
213 wrbuf_w_data.addr.eq(self.x_addr_i[self.adr_lsbs:]),
214 wrbuf_w_data.mask.eq(self.x_mask_i),
215 wrbuf_w_data.data.eq(self.x_st_data_i),
216 wrbuf.w_en.eq(self.x_st_i & self.x_valid_i &
217 x_dcache_select & ~self.x_stall_i),
218 wrbuf_r_data.eq(wrbuf.r_data),
219 ]
220
221 dba = WishboneArbiter(self.pspec)
222 m.submodules.dbus_arbiter = dba
223 m.d.comb += dba.bus.connect(self.dbus)
224
225 wrbuf_port = dbus_arbiter.port(priority=0)
226 m.d.comb += [
227 wrbuf_port.cyc.eq(wrbuf.r_rdy),
228 wrbuf_port.we.eq(Const(1)),
229 ]
230 with m.If(wrbuf_port.stb):
231 with m.If(wrbuf_port.ack | wrbuf_port.err):
232 m.d.sync += wrbuf_port.stb.eq(0)
233 m.d.comb += wrbuf.r_en.eq(1)
234 with m.Elif(wrbuf.r_rdy):
235 m.d.sync += [
236 wrbuf_port.stb.eq(1),
237 wrbuf_port.adr.eq(wrbuf_r_data.addr),
238 wrbuf_port.sel.eq(wrbuf_r_data.mask),
239 wrbuf_port.dat_w.eq(wrbuf_r_data.data)
240 ]
241
242 dcache_port = dba.port(priority=1)
243 cti = Mux(dcache.bus_last, Cycle.END, Cycle.INCREMENT)
244 m.d.comb += [
245 dcache_port.cyc.eq(dcache.bus_re),
246 dcache_port.stb.eq(dcache.bus_re),
247 dcache_port.adr.eq(dcache.bus_addr),
248 dcache_port.cti.eq(cti),
249 dcache_port.bte.eq(Const(log2_int(dcache.nwords) - 1)),
250 dcache.bus_valid.eq(dcache_port.ack),
251 dcache.bus_error.eq(dcache_port.err),
252 dcache.bus_rdata.eq(dcache_port.dat_r)
253 ]
254
255 bare_port = dba.port(priority=2)
256 bare_rdata = Signal.like(bare_port.dat_r)
257 with m.If(bare_port.cyc):
258 with m.If(bare_port.ack | bare_port.err | ~self.m_valid_i):
259 m.d.sync += [
260 bare_port.cyc.eq(0),
261 bare_port.stb.eq(0),
262 bare_rdata.eq(bare_port.dat_r)
263 ]
264 with m.Elif((self.x_ld_i | self.x_st_i) &
265 ~x_dcache_select & self.x_valid_i & ~self.x_stall_i):
266 m.d.sync += [
267 bare_port.cyc.eq(1),
268 bare_port.stb.eq(1),
269 bare_port.adr.eq(self.x_addr_i[self.adr_lsbs:]),
270 bare_port.sel.eq(self.x_mask_i),
271 bare_port.we.eq(self.x_st_i),
272 bare_port.dat_w.eq(self.x_st_data_i)
273 ]
274
275 with m.If(self.dbus.cyc & self.dbus.err):
276 m.d.sync += [
277 self.m_load_err_o.eq(~self.dbus.we),
278 self.m_store_err_o.eq(self.dbus.we),
279 self.m_badaddr_o.eq(self.dbus.adr)
280 ]
281 with m.Elif(~self.m_stall_i):
282 m.d.sync += [
283 self.m_load_err_o.eq(0),
284 self.m_store_err_o.eq(0)
285 ]
286
287 with m.If(self.x_fence_i):
288 m.d.comb += self.x_busy_o.eq(wrbuf.r_rdy)
289 with m.Elif(x_dcache_select):
290 m.d.comb += self.x_busy_o.eq(self.x_st_i & ~wrbuf.w_rdy)
291 with m.Else():
292 m.d.comb += self.x_busy_o.eq(bare_port.cyc)
293
294 with m.If(self.m_flush):
295 m.d.comb += self.m_busy_o.eq(~dcache.s2_flush_ack)
296 with m.If(self.m_load_err_o | self.m_store_err_o):
297 m.d.comb += self.m_busy_o.eq(0)
298 with m.Elif(m_dcache_select):
299 m.d.comb += [
300 self.m_busy_o.eq(dcache.s2_miss),
301 self.m_ld_data_o.eq(dcache.s2_rdata)
302 ]
303 with m.Else():
304 m.d.comb += [
305 self.m_busy_o.eq(bare_port.cyc),
306 self.m_ld_data_o.eq(bare_rdata)
307 ]
308
309 return m