Merge remote-tracking branch 'origin/master'
[soc.git] / src / soc / minerva / units / loadstore.py
1 from nmigen import Elaboratable, Module, Signal, Record, Cat, Const, Mux
2 from nmigen.utils import log2_int
3 from nmigen.lib.fifo import SyncFIFO
4
5 from soc.minerva.cache import L1Cache
6 from soc.minerva.wishbone import make_wb_layout, WishboneArbiter, Cycle
7
8
9 __all__ = ["LoadStoreUnitInterface", "BareLoadStoreUnit",
10 "CachedLoadStoreUnit"]
11
12
13 class LoadStoreUnitInterface:
14 def __init__(self, pspec):
15 self.pspec = pspec
16 self.dbus = Record(make_wb_layout(pspec))
17 print(self.dbus.sel.shape())
18 self.mask_wid = mask_wid = pspec.mask_wid
19 self.addr_wid = addr_wid = pspec.addr_wid
20 self.data_wid = data_wid = pspec.reg_wid
21 print("loadstoreunit addr mask data", addr_wid, mask_wid, data_wid)
22 self.adr_lsbs = log2_int(mask_wid) # LSBs of addr covered by mask
23 badwid = addr_wid-self.adr_lsbs # TODO: is this correct?
24
25 # INPUTS
26 self.x_addr_i = Signal(addr_wid) # address used for loads/stores
27 self.x_mask_i = Signal(mask_wid) # Mask of which bytes to write
28 self.x_ld_i = Signal() # set to do a memory load
29 self.x_st_i = Signal() # set to do a memory store
30 self.x_st_data_i = Signal(data_wid) # The data to write when storing
31
32 self.x_stall_i = Signal() # do nothing until low
33 self.x_valid_i = Signal() # Whether x pipeline stage is
34 # currently enabled (I
35 # think?). Set to 1 for #now
36 self.m_stall_i = Signal() # do nothing until low
37 self.m_valid_i = Signal() # Whether m pipeline stage is
38 # currently enabled. Set
39 # to 1 for now
40
41 # OUTPUTS
42 self.x_busy_o = Signal() # set when the memory is busy
43 self.m_busy_o = Signal() # set when the memory is busy
44
45 self.m_ld_data_o = Signal(data_wid) # Data returned from memory read
46 # Data validity is NOT indicated by m_valid_i or x_valid_i as
47 # those are inputs. I believe it is valid on the next cycle
48 # after raising m_load where busy is low
49
50 self.m_load_err_o = Signal() # if there was an error when loading
51 self.m_store_err_o = Signal() # if there was an error when storing
52 # The address of the load/store error
53 self.m_badaddr_o = Signal(badwid)
54
55 def __iter__(self):
56 yield self.x_addr_i
57 yield self.x_mask_i
58 yield self.x_ld_i
59 yield self.x_st_i
60 yield self.x_st_data_i
61
62 yield self.x_stall_i
63 yield self.x_valid_i
64 yield self.m_stall_i
65 yield self.m_valid_i
66 yield self.x_busy_o
67 yield self.m_busy_o
68 yield self.m_ld_data_o
69 yield self.m_load_err_o
70 yield self.m_store_err_o
71 yield self.m_badaddr_o
72 for sig in self.dbus.fields.values():
73 yield sig
74
75 def ports(self):
76 return list(self)
77
78
79 class BareLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
80 def elaborate(self, platform):
81 m = Module()
82
83 with m.If(self.dbus.cyc):
84 with m.If(self.dbus.ack | self.dbus.err | ~self.m_valid_i):
85 m.d.sync += [
86 self.dbus.cyc.eq(0),
87 self.dbus.stb.eq(0),
88 self.m_ld_data_o.eq(self.dbus.dat_r)
89 ]
90 with m.Elif((self.x_ld_i | self.x_st_i) &
91 self.x_valid_i & ~self.x_stall_i):
92 m.d.sync += [
93 self.dbus.cyc.eq(1),
94 self.dbus.stb.eq(1),
95 self.dbus.adr.eq(self.x_addr_i[self.adr_lsbs:]),
96 self.dbus.sel.eq(self.x_mask_i),
97 self.dbus.we.eq(self.x_st_i),
98 self.dbus.dat_w.eq(self.x_st_data_i)
99 ]
100 with m.Else():
101 m.d.sync += [
102 self.dbus.adr.eq(0),
103 self.dbus.sel.eq(0),
104 self.dbus.we.eq(0),
105 self.dbus.sel.eq(0),
106 self.dbus.dat_w.eq(0),
107 ]
108
109 with m.If(self.dbus.cyc & self.dbus.err):
110 m.d.sync += [
111 self.m_load_err_o.eq(~self.dbus.we),
112 self.m_store_err_o.eq(self.dbus.we),
113 self.m_badaddr_o.eq(self.dbus.adr)
114 ]
115 with m.Elif(~self.m_stall_i):
116 m.d.sync += [
117 self.m_load_err_o.eq(0),
118 self.m_store_err_o.eq(0)
119 ]
120
121 m.d.comb += self.x_busy_o.eq(self.dbus.cyc)
122
123 with m.If(self.m_load_err_o | self.m_store_err_o):
124 m.d.comb += self.m_busy_o.eq(0)
125 with m.Else():
126 m.d.comb += self.m_busy_o.eq(self.dbus.cyc)
127
128 return m
129
130
131 class CachedLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
132 def __init__(self, pspec):
133 super().__init__(pspec)
134
135 self.dcache_args = psiec.dcache_args
136
137 self.x_fence_i = Signal()
138 self.x_flush = Signal()
139 self.m_load = Signal()
140 self.m_store = Signal()
141
142 def elaborate(self, platform):
143 m = Module()
144
145 dcache = m.submodules.dcache = L1Cache(*self.dcache_args)
146
147 x_dcache_select = Signal()
148 # Test whether the target address is inside the L1 cache region.
149 # We use bit masks in order to avoid carry chains from arithmetic
150 # comparisons. This restricts the region boundaries to powers of 2.
151 with m.Switch(self.x_addr_i[self.adr_lsbs:]):
152 def addr_below(limit):
153 assert limit in range(1, 2**30 + 1)
154 range_bits = log2_int(limit)
155 const_bits = 30 - range_bits
156 return "{}{}".format("0" * const_bits, "-" * range_bits)
157
158 if dcache.base >= (1 << self.adr_lsbs):
159 with m.Case(addr_below(dcache.base >> self.adr_lsbs)):
160 m.d.comb += x_dcache_select.eq(0)
161 with m.Case(addr_below(dcache.limit >> self.adr_lsbs)):
162 m.d.comb += x_dcache_select.eq(1)
163 with m.Default():
164 m.d.comb += x_dcache_select.eq(0)
165
166 m_dcache_select = Signal()
167 m_addr = Signal.like(self.x_addr_i)
168
169 with m.If(~self.x_stall_i):
170 m.d.sync += [
171 m_dcache_select.eq(x_dcache_select),
172 m_addr.eq(self.x_addr_i),
173 ]
174
175 m.d.comb += [
176 dcache.s1_addr.eq(self.x_addr_i[self.adr_lsbs:]),
177 dcache.s1_flush.eq(self.x_flush),
178 dcache.s1_stall.eq(self.x_stall_i),
179 dcache.s1_valid.eq(self.x_valid_i & x_dcache_select),
180 dcache.s2_addr.eq(m_addr[self.adr_lsbs:]),
181 dcache.s2_re.eq(self.m_load),
182 dcache.s2_evict.eq(self.m_store),
183 dcache.s2_valid.eq(self.m_valid_i & m_dcache_select)
184 ]
185
186 wrbuf_w_data = Record([("addr", self.addr_wid-self.adr_lsbs),
187 ("mask", self.mask_wid),
188 ("data", self.data_wid)])
189 wrbuf_r_data = Record.like(wrbuf_w_data)
190 wrbuf = m.submodules.wrbuf = SyncFIFO(width=len(wrbuf_w_data),
191 depth=dcache.nwords)
192 m.d.comb += [
193 wrbuf.w_data.eq(wrbuf_w_data),
194 wrbuf_w_data.addr.eq(self.x_addr_i[self.adr_lsbs:]),
195 wrbuf_w_data.mask.eq(self.x_mask_i),
196 wrbuf_w_data.data.eq(self.x_st_data_i),
197 wrbuf.w_en.eq(self.x_st_i & self.x_valid_i &
198 x_dcache_select & ~self.x_stall_i),
199 wrbuf_r_data.eq(wrbuf.r_data),
200 ]
201
202 dba = WishboneArbiter(self.pspec)
203 m.submodules.dbus_arbiter = dba
204 m.d.comb += dba.bus.connect(self.dbus)
205
206 wrbuf_port = dbus_arbiter.port(priority=0)
207 m.d.comb += [
208 wrbuf_port.cyc.eq(wrbuf.r_rdy),
209 wrbuf_port.we.eq(Const(1)),
210 ]
211 with m.If(wrbuf_port.stb):
212 with m.If(wrbuf_port.ack | wrbuf_port.err):
213 m.d.sync += wrbuf_port.stb.eq(0)
214 m.d.comb += wrbuf.r_en.eq(1)
215 with m.Elif(wrbuf.r_rdy):
216 m.d.sync += [
217 wrbuf_port.stb.eq(1),
218 wrbuf_port.adr.eq(wrbuf_r_data.addr),
219 wrbuf_port.sel.eq(wrbuf_r_data.mask),
220 wrbuf_port.dat_w.eq(wrbuf_r_data.data)
221 ]
222
223 dcache_port = dba.port(priority=1)
224 cti = Mux(dcache.bus_last, Cycle.END, Cycle.INCREMENT)
225 m.d.comb += [
226 dcache_port.cyc.eq(dcache.bus_re),
227 dcache_port.stb.eq(dcache.bus_re),
228 dcache_port.adr.eq(dcache.bus_addr),
229 dcache_port.cti.eq(cti),
230 dcache_port.bte.eq(Const(log2_int(dcache.nwords) - 1)),
231 dcache.bus_valid.eq(dcache_port.ack),
232 dcache.bus_error.eq(dcache_port.err),
233 dcache.bus_rdata.eq(dcache_port.dat_r)
234 ]
235
236 bare_port = dba.port(priority=2)
237 bare_rdata = Signal.like(bare_port.dat_r)
238 with m.If(bare_port.cyc):
239 with m.If(bare_port.ack | bare_port.err | ~self.m_valid_i):
240 m.d.sync += [
241 bare_port.cyc.eq(0),
242 bare_port.stb.eq(0),
243 bare_rdata.eq(bare_port.dat_r)
244 ]
245 with m.Elif((self.x_ld_i | self.x_st_i) &
246 ~x_dcache_select & self.x_valid_i & ~self.x_stall_i):
247 m.d.sync += [
248 bare_port.cyc.eq(1),
249 bare_port.stb.eq(1),
250 bare_port.adr.eq(self.x_addr_i[self.adr_lsbs:]),
251 bare_port.sel.eq(self.x_mask_i),
252 bare_port.we.eq(self.x_st_i),
253 bare_port.dat_w.eq(self.x_st_data_i)
254 ]
255
256 with m.If(self.dbus.cyc & self.dbus.err):
257 m.d.sync += [
258 self.m_load_err_o.eq(~self.dbus.we),
259 self.m_store_err_o.eq(self.dbus.we),
260 self.m_badaddr_o.eq(self.dbus.adr)
261 ]
262 with m.Elif(~self.m_stall_i):
263 m.d.sync += [
264 self.m_load_err_o.eq(0),
265 self.m_store_err_o.eq(0)
266 ]
267
268 with m.If(self.x_fence_i):
269 m.d.comb += self.x_busy_o.eq(wrbuf.r_rdy)
270 with m.Elif(x_dcache_select):
271 m.d.comb += self.x_busy_o.eq(self.x_st_i & ~wrbuf.w_rdy)
272 with m.Else():
273 m.d.comb += self.x_busy_o.eq(bare_port.cyc)
274
275 with m.If(self.m_flush):
276 m.d.comb += self.m_busy_o.eq(~dcache.s2_flush_ack)
277 with m.If(self.m_load_err_o | self.m_store_err_o):
278 m.d.comb += self.m_busy_o.eq(0)
279 with m.Elif(m_dcache_select):
280 m.d.comb += [
281 self.m_busy_o.eq(dcache.s2_miss),
282 self.m_ld_data_o.eq(dcache.s2_rdata)
283 ]
284 with m.Else():
285 m.d.comb += [
286 self.m_busy_o.eq(bare_port.cyc),
287 self.m_ld_data_o.eq(bare_rdata)
288 ]
289
290 return m