first test of down-converted load/store from 64 to 32 bit
[soc.git] / src / soc / minerva / units / loadstore.py
1 from nmigen import Elaboratable, Module, Signal, Record, Cat, Const, Mux
2 from nmigen.utils import log2_int
3 from nmigen.lib.fifo import SyncFIFO
4
5 from soc.minerva.cache import L1Cache
6 from soc.minerva.wishbone import make_wb_layout, WishboneArbiter, Cycle
7 from soc.bus.wb_downconvert import WishboneDownConvert
8
9 from copy import deepcopy
10
11 __all__ = ["LoadStoreUnitInterface", "BareLoadStoreUnit",
12 "CachedLoadStoreUnit"]
13
14
15 class LoadStoreUnitInterface:
16 def __init__(self, pspec):
17 self.pspec = pspec
18 self.dbus = self.slavebus = Record(make_wb_layout(pspec))
19 print(self.dbus.sel.shape())
20 if isinstance(pspec.wb_data_wid, int):
21 pspecslave = deepcopy(pspec)
22 pspecslave.reg_wid = pspec.wb_data_wid
23 self.slavebus = Record(make_wb_layout(pspecslave))
24 self.cvt = WishboneDownConvert(self.dbus, self.slavebus)
25 self.mask_wid = mask_wid = pspec.mask_wid
26 self.addr_wid = addr_wid = pspec.addr_wid
27 self.data_wid = data_wid = pspec.reg_wid
28 print("loadstoreunit addr mask data", addr_wid, mask_wid, data_wid)
29 self.adr_lsbs = log2_int(mask_wid) # LSBs of addr covered by mask
30 badwid = addr_wid-self.adr_lsbs # TODO: is this correct?
31
32 # INPUTS
33 self.x_addr_i = Signal(addr_wid) # address used for loads/stores
34 self.x_mask_i = Signal(mask_wid) # Mask of which bytes to write
35 self.x_ld_i = Signal() # set to do a memory load
36 self.x_st_i = Signal() # set to do a memory store
37 self.x_st_data_i = Signal(data_wid) # The data to write when storing
38
39 self.x_stall_i = Signal() # do nothing until low
40 self.x_valid_i = Signal() # Whether x pipeline stage is
41 # currently enabled (I
42 # think?). Set to 1 for #now
43 self.m_stall_i = Signal() # do nothing until low
44 self.m_valid_i = Signal() # Whether m pipeline stage is
45 # currently enabled. Set
46 # to 1 for now
47
48 # OUTPUTS
49 self.x_busy_o = Signal() # set when the memory is busy
50 self.m_busy_o = Signal() # set when the memory is busy
51
52 self.m_ld_data_o = Signal(data_wid) # Data returned from memory read
53 # Data validity is NOT indicated by m_valid_i or x_valid_i as
54 # those are inputs. I believe it is valid on the next cycle
55 # after raising m_load where busy is low
56
57 self.m_load_err_o = Signal() # if there was an error when loading
58 self.m_store_err_o = Signal() # if there was an error when storing
59 # The address of the load/store error
60 self.m_badaddr_o = Signal(badwid)
61
62 def __iter__(self):
63 yield self.x_addr_i
64 yield self.x_mask_i
65 yield self.x_ld_i
66 yield self.x_st_i
67 yield self.x_st_data_i
68
69 yield self.x_stall_i
70 yield self.x_valid_i
71 yield self.m_stall_i
72 yield self.m_valid_i
73 yield self.x_busy_o
74 yield self.m_busy_o
75 yield self.m_ld_data_o
76 yield self.m_load_err_o
77 yield self.m_store_err_o
78 yield self.m_badaddr_o
79 for sig in self.dbus.fields.values():
80 yield sig
81
82 def ports(self):
83 return list(self)
84
85
86 class BareLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
87 def elaborate(self, platform):
88 m = Module()
89
90 if hasattr(self, "cvt"):
91 m.submodules.cvt = self.cvt
92
93 with m.If(self.dbus.cyc):
94 with m.If(self.dbus.ack | self.dbus.err | ~self.m_valid_i):
95 m.d.sync += [
96 self.dbus.cyc.eq(0),
97 self.dbus.stb.eq(0),
98 self.dbus.sel.eq(0),
99 self.m_ld_data_o.eq(self.dbus.dat_r)
100 ]
101 with m.Elif((self.x_ld_i | self.x_st_i) &
102 self.x_valid_i & ~self.x_stall_i):
103 m.d.sync += [
104 self.dbus.cyc.eq(1),
105 self.dbus.stb.eq(1),
106 self.dbus.adr.eq(self.x_addr_i[self.adr_lsbs:]),
107 self.dbus.sel.eq(self.x_mask_i),
108 self.dbus.we.eq(self.x_st_i),
109 self.dbus.dat_w.eq(self.x_st_data_i)
110 ]
111 with m.Else():
112 m.d.sync += [
113 self.dbus.adr.eq(0),
114 self.dbus.sel.eq(0),
115 self.dbus.we.eq(0),
116 self.dbus.sel.eq(0),
117 self.dbus.dat_w.eq(0),
118 ]
119
120 with m.If(self.dbus.cyc & self.dbus.err):
121 m.d.sync += [
122 self.m_load_err_o.eq(~self.dbus.we),
123 self.m_store_err_o.eq(self.dbus.we),
124 self.m_badaddr_o.eq(self.dbus.adr)
125 ]
126 with m.Elif(~self.m_stall_i):
127 m.d.sync += [
128 self.m_load_err_o.eq(0),
129 self.m_store_err_o.eq(0)
130 ]
131
132 m.d.comb += self.x_busy_o.eq(self.dbus.cyc)
133
134 with m.If(self.m_load_err_o | self.m_store_err_o):
135 m.d.comb += self.m_busy_o.eq(0)
136 with m.Else():
137 m.d.comb += self.m_busy_o.eq(self.dbus.cyc)
138
139 return m
140
141
142 class CachedLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
143 def __init__(self, pspec):
144 super().__init__(pspec)
145
146 self.dcache_args = psiec.dcache_args
147
148 self.x_fence_i = Signal()
149 self.x_flush = Signal()
150 self.m_load = Signal()
151 self.m_store = Signal()
152
153 def elaborate(self, platform):
154 m = Module()
155
156 dcache = m.submodules.dcache = L1Cache(*self.dcache_args)
157
158 x_dcache_select = Signal()
159 # Test whether the target address is inside the L1 cache region.
160 # We use bit masks in order to avoid carry chains from arithmetic
161 # comparisons. This restricts the region boundaries to powers of 2.
162 with m.Switch(self.x_addr_i[self.adr_lsbs:]):
163 def addr_below(limit):
164 assert limit in range(1, 2**30 + 1)
165 range_bits = log2_int(limit)
166 const_bits = 30 - range_bits
167 return "{}{}".format("0" * const_bits, "-" * range_bits)
168
169 if dcache.base >= (1 << self.adr_lsbs):
170 with m.Case(addr_below(dcache.base >> self.adr_lsbs)):
171 m.d.comb += x_dcache_select.eq(0)
172 with m.Case(addr_below(dcache.limit >> self.adr_lsbs)):
173 m.d.comb += x_dcache_select.eq(1)
174 with m.Default():
175 m.d.comb += x_dcache_select.eq(0)
176
177 m_dcache_select = Signal()
178 m_addr = Signal.like(self.x_addr_i)
179
180 with m.If(~self.x_stall_i):
181 m.d.sync += [
182 m_dcache_select.eq(x_dcache_select),
183 m_addr.eq(self.x_addr_i),
184 ]
185
186 m.d.comb += [
187 dcache.s1_addr.eq(self.x_addr_i[self.adr_lsbs:]),
188 dcache.s1_flush.eq(self.x_flush),
189 dcache.s1_stall.eq(self.x_stall_i),
190 dcache.s1_valid.eq(self.x_valid_i & x_dcache_select),
191 dcache.s2_addr.eq(m_addr[self.adr_lsbs:]),
192 dcache.s2_re.eq(self.m_load),
193 dcache.s2_evict.eq(self.m_store),
194 dcache.s2_valid.eq(self.m_valid_i & m_dcache_select)
195 ]
196
197 wrbuf_w_data = Record([("addr", self.addr_wid-self.adr_lsbs),
198 ("mask", self.mask_wid),
199 ("data", self.data_wid)])
200 wrbuf_r_data = Record.like(wrbuf_w_data)
201 wrbuf = m.submodules.wrbuf = SyncFIFO(width=len(wrbuf_w_data),
202 depth=dcache.nwords)
203 m.d.comb += [
204 wrbuf.w_data.eq(wrbuf_w_data),
205 wrbuf_w_data.addr.eq(self.x_addr_i[self.adr_lsbs:]),
206 wrbuf_w_data.mask.eq(self.x_mask_i),
207 wrbuf_w_data.data.eq(self.x_st_data_i),
208 wrbuf.w_en.eq(self.x_st_i & self.x_valid_i &
209 x_dcache_select & ~self.x_stall_i),
210 wrbuf_r_data.eq(wrbuf.r_data),
211 ]
212
213 dba = WishboneArbiter(self.pspec)
214 m.submodules.dbus_arbiter = dba
215 m.d.comb += dba.bus.connect(self.dbus)
216
217 wrbuf_port = dbus_arbiter.port(priority=0)
218 m.d.comb += [
219 wrbuf_port.cyc.eq(wrbuf.r_rdy),
220 wrbuf_port.we.eq(Const(1)),
221 ]
222 with m.If(wrbuf_port.stb):
223 with m.If(wrbuf_port.ack | wrbuf_port.err):
224 m.d.sync += wrbuf_port.stb.eq(0)
225 m.d.comb += wrbuf.r_en.eq(1)
226 with m.Elif(wrbuf.r_rdy):
227 m.d.sync += [
228 wrbuf_port.stb.eq(1),
229 wrbuf_port.adr.eq(wrbuf_r_data.addr),
230 wrbuf_port.sel.eq(wrbuf_r_data.mask),
231 wrbuf_port.dat_w.eq(wrbuf_r_data.data)
232 ]
233
234 dcache_port = dba.port(priority=1)
235 cti = Mux(dcache.bus_last, Cycle.END, Cycle.INCREMENT)
236 m.d.comb += [
237 dcache_port.cyc.eq(dcache.bus_re),
238 dcache_port.stb.eq(dcache.bus_re),
239 dcache_port.adr.eq(dcache.bus_addr),
240 dcache_port.cti.eq(cti),
241 dcache_port.bte.eq(Const(log2_int(dcache.nwords) - 1)),
242 dcache.bus_valid.eq(dcache_port.ack),
243 dcache.bus_error.eq(dcache_port.err),
244 dcache.bus_rdata.eq(dcache_port.dat_r)
245 ]
246
247 bare_port = dba.port(priority=2)
248 bare_rdata = Signal.like(bare_port.dat_r)
249 with m.If(bare_port.cyc):
250 with m.If(bare_port.ack | bare_port.err | ~self.m_valid_i):
251 m.d.sync += [
252 bare_port.cyc.eq(0),
253 bare_port.stb.eq(0),
254 bare_rdata.eq(bare_port.dat_r)
255 ]
256 with m.Elif((self.x_ld_i | self.x_st_i) &
257 ~x_dcache_select & self.x_valid_i & ~self.x_stall_i):
258 m.d.sync += [
259 bare_port.cyc.eq(1),
260 bare_port.stb.eq(1),
261 bare_port.adr.eq(self.x_addr_i[self.adr_lsbs:]),
262 bare_port.sel.eq(self.x_mask_i),
263 bare_port.we.eq(self.x_st_i),
264 bare_port.dat_w.eq(self.x_st_data_i)
265 ]
266
267 with m.If(self.dbus.cyc & self.dbus.err):
268 m.d.sync += [
269 self.m_load_err_o.eq(~self.dbus.we),
270 self.m_store_err_o.eq(self.dbus.we),
271 self.m_badaddr_o.eq(self.dbus.adr)
272 ]
273 with m.Elif(~self.m_stall_i):
274 m.d.sync += [
275 self.m_load_err_o.eq(0),
276 self.m_store_err_o.eq(0)
277 ]
278
279 with m.If(self.x_fence_i):
280 m.d.comb += self.x_busy_o.eq(wrbuf.r_rdy)
281 with m.Elif(x_dcache_select):
282 m.d.comb += self.x_busy_o.eq(self.x_st_i & ~wrbuf.w_rdy)
283 with m.Else():
284 m.d.comb += self.x_busy_o.eq(bare_port.cyc)
285
286 with m.If(self.m_flush):
287 m.d.comb += self.m_busy_o.eq(~dcache.s2_flush_ack)
288 with m.If(self.m_load_err_o | self.m_store_err_o):
289 m.d.comb += self.m_busy_o.eq(0)
290 with m.Elif(m_dcache_select):
291 m.d.comb += [
292 self.m_busy_o.eq(dcache.s2_miss),
293 self.m_ld_data_o.eq(dcache.s2_rdata)
294 ]
295 with m.Else():
296 m.d.comb += [
297 self.m_busy_o.eq(bare_port.cyc),
298 self.m_ld_data_o.eq(bare_rdata)
299 ]
300
301 return m