Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / minerva / units / loadstore.py
1 from nmigen import Elaboratable, Module, Signal, Record, Cat, Const, Mux
2 from nmigen.utils import log2_int
3 from nmigen.lib.fifo import SyncFIFO
4
5 from soc.minerva.cache import L1Cache
6 from soc.minerva.wishbone import make_wb_layout, WishboneArbiter, Cycle
7 from soc.bus.wb_downconvert import WishboneDownConvert
8
9 from copy import deepcopy
10
11 __all__ = ["LoadStoreUnitInterface", "BareLoadStoreUnit",
12 "CachedLoadStoreUnit"]
13
14
15 class LoadStoreUnitInterface:
16 def __init__(self, pspec):
17 self.pspec = pspec
18 self.pspecslave = pspec
19 if (hasattr(pspec, "dmem_test_depth") and
20 isinstance(pspec.wb_data_wid, int) and
21 pspec.wb_data_wid != pspec.reg_wid):
22 self.dbus = Record(make_wb_layout(pspec), name="int_dbus")
23 pspecslave = deepcopy(pspec)
24 pspecslave.reg_wid = pspec.wb_data_wid
25 mask_ratio = (pspec.reg_wid // pspec.wb_data_wid)
26 pspecslave.mask_wid = pspec.mask_wid // mask_ratio
27 self.pspecslave = pspecslave
28 self.slavebus = Record(make_wb_layout(pspecslave), name="dbus")
29 self.needs_cvt = True
30 else:
31 self.needs_cvt = False
32 self.dbus = self.slavebus = Record(make_wb_layout(pspec))
33
34 # detect whether the wishbone bus is enabled / disabled
35 if (hasattr(pspec, "wb_dcache_en") and
36 isinstance(pspec.wb_dcache_en, Signal)):
37 self.jtag_en = pspec.wb_dcache_en
38 else:
39 self.jtag_en = Const(1, 1) # permanently on
40
41 print(self.dbus.sel.shape())
42 self.mask_wid = mask_wid = pspec.mask_wid
43 self.addr_wid = addr_wid = pspec.addr_wid
44 self.data_wid = data_wid = pspec.reg_wid
45 print("loadstoreunit addr mask data", addr_wid, mask_wid, data_wid)
46 self.adr_lsbs = log2_int(mask_wid) # LSBs of addr covered by mask
47 badwid = addr_wid-self.adr_lsbs # TODO: is this correct?
48
49 # INPUTS
50 self.x_addr_i = Signal(addr_wid) # address used for loads/stores
51 self.x_mask_i = Signal(mask_wid) # Mask of which bytes to write
52 self.x_ld_i = Signal() # set to do a memory load
53 self.x_st_i = Signal() # set to do a memory store
54 self.x_st_data_i = Signal(data_wid) # The data to write when storing
55
56 self.x_stall_i = Signal() # do nothing until low
57 self.x_i_valid = Signal() # Whether x pipeline stage is
58 # currently enabled (I
59 # think?). Set to 1 for #now
60 self.m_stall_i = Signal() # do nothing until low
61 self.m_i_valid = Signal() # Whether m pipeline stage is
62 # currently enabled. Set
63 # to 1 for now
64
65 # OUTPUTS
66 self.x_busy_o = Signal() # set when the memory is busy
67 self.m_busy_o = Signal() # set when the memory is busy
68
69 self.m_ld_data_o = Signal(data_wid) # Data returned from memory read
70 # Data validity is NOT indicated by m_i_valid or x_i_valid as
71 # those are inputs. I believe it is valid on the next cycle
72 # after raising m_load where busy is low
73
74 self.m_load_err_o = Signal() # if there was an error when loading
75 self.m_store_err_o = Signal() # if there was an error when storing
76 # The address of the load/store error
77 self.m_badaddr_o = Signal(badwid)
78
79 def __iter__(self):
80 yield self.x_addr_i
81 yield self.x_mask_i
82 yield self.x_ld_i
83 yield self.x_st_i
84 yield self.x_st_data_i
85
86 yield self.x_stall_i
87 yield self.x_i_valid
88 yield self.m_stall_i
89 yield self.m_i_valid
90 yield self.x_busy_o
91 yield self.m_busy_o
92 yield self.m_ld_data_o
93 yield self.m_load_err_o
94 yield self.m_store_err_o
95 yield self.m_badaddr_o
96 for sig in self.slavebus.fields.values():
97 yield sig
98
99 def ports(self):
100 return list(self)
101
102
103 class BareLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
104 def elaborate(self, platform):
105 m = Module()
106
107 if self.needs_cvt:
108 self.cvt = WishboneDownConvert(self.dbus, self.slavebus)
109 m.submodules.cvt = self.cvt
110
111 with m.If(self.jtag_en): # for safety, JTAG can completely disable WB
112
113 with m.If(self.dbus.cyc):
114 with m.If(self.dbus.ack | self.dbus.err | ~self.m_i_valid):
115 m.d.sync += [
116 self.dbus.cyc.eq(0),
117 self.dbus.stb.eq(0),
118 self.dbus.sel.eq(0),
119 self.m_ld_data_o.eq(self.dbus.dat_r)
120 ]
121 with m.Elif((self.x_ld_i | self.x_st_i) &
122 self.x_i_valid & ~self.x_stall_i):
123 m.d.sync += [
124 self.dbus.cyc.eq(1),
125 self.dbus.stb.eq(1),
126 self.dbus.adr.eq(self.x_addr_i[self.adr_lsbs:]),
127 self.dbus.sel.eq(self.x_mask_i),
128 self.dbus.we.eq(self.x_st_i),
129 self.dbus.dat_w.eq(self.x_st_data_i)
130 ]
131 with m.Else():
132 m.d.sync += [
133 self.dbus.adr.eq(0),
134 self.dbus.sel.eq(0),
135 self.dbus.we.eq(0),
136 self.dbus.sel.eq(0),
137 self.dbus.dat_w.eq(0),
138 ]
139
140 with m.If(self.dbus.cyc & self.dbus.err):
141 m.d.sync += [
142 self.m_load_err_o.eq(~self.dbus.we),
143 self.m_store_err_o.eq(self.dbus.we),
144 self.m_badaddr_o.eq(self.dbus.adr)
145 ]
146 with m.Elif(~self.m_stall_i):
147 m.d.sync += [
148 self.m_load_err_o.eq(0),
149 self.m_store_err_o.eq(0)
150 ]
151
152 m.d.comb += self.x_busy_o.eq(self.dbus.cyc)
153
154 with m.If(self.m_load_err_o | self.m_store_err_o):
155 m.d.comb += self.m_busy_o.eq(0)
156 with m.Else():
157 m.d.comb += self.m_busy_o.eq(self.dbus.cyc)
158
159 return m
160
161
162 class CachedLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
163 def __init__(self, pspec):
164 super().__init__(pspec)
165
166 self.dcache_args = psiec.dcache_args
167
168 self.x_fence_i = Signal()
169 self.x_flush = Signal()
170 self.m_load = Signal()
171 self.m_store = Signal()
172
173 def elaborate(self, platform):
174 m = Module()
175
176 dcache = m.submodules.dcache = L1Cache(*self.dcache_args)
177
178 x_dcache_select = Signal()
179 # Test whether the target address is inside the L1 cache region.
180 # We use bit masks in order to avoid carry chains from arithmetic
181 # comparisons. This restricts the region boundaries to powers of 2.
182 with m.Switch(self.x_addr_i[self.adr_lsbs:]):
183 def addr_below(limit):
184 assert limit in range(1, 2**30 + 1)
185 range_bits = log2_int(limit)
186 const_bits = 30 - range_bits
187 return "{}{}".format("0" * const_bits, "-" * range_bits)
188
189 if dcache.base >= (1 << self.adr_lsbs):
190 with m.Case(addr_below(dcache.base >> self.adr_lsbs)):
191 m.d.comb += x_dcache_select.eq(0)
192 with m.Case(addr_below(dcache.limit >> self.adr_lsbs)):
193 m.d.comb += x_dcache_select.eq(1)
194 with m.Default():
195 m.d.comb += x_dcache_select.eq(0)
196
197 m_dcache_select = Signal()
198 m_addr = Signal.like(self.x_addr_i)
199
200 with m.If(~self.x_stall_i):
201 m.d.sync += [
202 m_dcache_select.eq(x_dcache_select),
203 m_addr.eq(self.x_addr_i),
204 ]
205
206 m.d.comb += [
207 dcache.s1_addr.eq(self.x_addr_i[self.adr_lsbs:]),
208 dcache.s1_flush.eq(self.x_flush),
209 dcache.s1_stall.eq(self.x_stall_i),
210 dcache.s1_valid.eq(self.x_i_valid & x_dcache_select),
211 dcache.s2_addr.eq(m_addr[self.adr_lsbs:]),
212 dcache.s2_re.eq(self.m_load),
213 dcache.s2_evict.eq(self.m_store),
214 dcache.s2_valid.eq(self.m_i_valid & m_dcache_select)
215 ]
216
217 wrbuf_w_data = Record([("addr", self.addr_wid-self.adr_lsbs),
218 ("mask", self.mask_wid),
219 ("data", self.data_wid)])
220 wrbuf_r_data = Record.like(wrbuf_w_data)
221 wrbuf = m.submodules.wrbuf = SyncFIFO(width=len(wrbuf_w_data),
222 depth=dcache.nwords)
223 m.d.comb += [
224 wrbuf.w_data.eq(wrbuf_w_data),
225 wrbuf_w_data.addr.eq(self.x_addr_i[self.adr_lsbs:]),
226 wrbuf_w_data.mask.eq(self.x_mask_i),
227 wrbuf_w_data.data.eq(self.x_st_data_i),
228 wrbuf.w_en.eq(self.x_st_i & self.x_i_valid &
229 x_dcache_select & ~self.x_stall_i),
230 wrbuf_r_data.eq(wrbuf.r_data),
231 ]
232
233 dba = WishboneArbiter(self.pspec)
234 m.submodules.dbus_arbiter = dba
235 m.d.comb += dba.bus.connect(self.dbus)
236
237 wrbuf_port = dbus_arbiter.port(priority=0)
238 m.d.comb += [
239 wrbuf_port.cyc.eq(wrbuf.r_rdy),
240 wrbuf_port.we.eq(Const(1)),
241 ]
242 with m.If(wrbuf_port.stb):
243 with m.If(wrbuf_port.ack | wrbuf_port.err):
244 m.d.sync += wrbuf_port.stb.eq(0)
245 m.d.comb += wrbuf.r_en.eq(1)
246 with m.Elif(wrbuf.r_rdy):
247 m.d.sync += [
248 wrbuf_port.stb.eq(1),
249 wrbuf_port.adr.eq(wrbuf_r_data.addr),
250 wrbuf_port.sel.eq(wrbuf_r_data.mask),
251 wrbuf_port.dat_w.eq(wrbuf_r_data.data)
252 ]
253
254 dcache_port = dba.port(priority=1)
255 cti = Mux(dcache.bus_last, Cycle.END, Cycle.INCREMENT)
256 m.d.comb += [
257 dcache_port.cyc.eq(dcache.bus_re),
258 dcache_port.stb.eq(dcache.bus_re),
259 dcache_port.adr.eq(dcache.bus_addr),
260 dcache_port.cti.eq(cti),
261 dcache_port.bte.eq(Const(log2_int(dcache.nwords) - 1)),
262 dcache.bus_valid.eq(dcache_port.ack),
263 dcache.bus_error.eq(dcache_port.err),
264 dcache.bus_rdata.eq(dcache_port.dat_r)
265 ]
266
267 bare_port = dba.port(priority=2)
268 bare_rdata = Signal.like(bare_port.dat_r)
269 with m.If(bare_port.cyc):
270 with m.If(bare_port.ack | bare_port.err | ~self.m_i_valid):
271 m.d.sync += [
272 bare_port.cyc.eq(0),
273 bare_port.stb.eq(0),
274 bare_rdata.eq(bare_port.dat_r)
275 ]
276 with m.Elif((self.x_ld_i | self.x_st_i) &
277 ~x_dcache_select & self.x_i_valid & ~self.x_stall_i):
278 m.d.sync += [
279 bare_port.cyc.eq(1),
280 bare_port.stb.eq(1),
281 bare_port.adr.eq(self.x_addr_i[self.adr_lsbs:]),
282 bare_port.sel.eq(self.x_mask_i),
283 bare_port.we.eq(self.x_st_i),
284 bare_port.dat_w.eq(self.x_st_data_i)
285 ]
286
287 with m.If(self.dbus.cyc & self.dbus.err):
288 m.d.sync += [
289 self.m_load_err_o.eq(~self.dbus.we),
290 self.m_store_err_o.eq(self.dbus.we),
291 self.m_badaddr_o.eq(self.dbus.adr)
292 ]
293 with m.Elif(~self.m_stall_i):
294 m.d.sync += [
295 self.m_load_err_o.eq(0),
296 self.m_store_err_o.eq(0)
297 ]
298
299 with m.If(self.x_fence_i):
300 m.d.comb += self.x_busy_o.eq(wrbuf.r_rdy)
301 with m.Elif(x_dcache_select):
302 m.d.comb += self.x_busy_o.eq(self.x_st_i & ~wrbuf.w_rdy)
303 with m.Else():
304 m.d.comb += self.x_busy_o.eq(bare_port.cyc)
305
306 with m.If(self.m_flush):
307 m.d.comb += self.m_busy_o.eq(~dcache.s2_flush_ack)
308 with m.If(self.m_load_err_o | self.m_store_err_o):
309 m.d.comb += self.m_busy_o.eq(0)
310 with m.Elif(m_dcache_select):
311 m.d.comb += [
312 self.m_busy_o.eq(dcache.s2_miss),
313 self.m_ld_data_o.eq(dcache.s2_rdata)
314 ]
315 with m.Else():
316 m.d.comb += [
317 self.m_busy_o.eq(bare_port.cyc),
318 self.m_ld_data_o.eq(bare_rdata)
319 ]
320
321 return m