c9808bf1a9ec38d5eab7a2b59d708edd22f4393b
[soc.git] / src / soc / minerva / units / loadstore.py
1 from nmigen import Elaboratable, Module, Signal, Record, Cat, Const, Mux
2 from nmigen.utils import log2_int
3 from nmigen.lib.fifo import SyncFIFO
4
5 from soc.minerva.cache import L1Cache
6 from soc.minerva.wishbone import make_wb_layout, WishboneArbiter, Cycle
7
8
9 __all__ = ["LoadStoreUnitInterface", "BareLoadStoreUnit",
10 "CachedLoadStoreUnit"]
11
12
13 class LoadStoreUnitInterface:
14 def __init__(self, pspec):
15 self.pspec = pspec
16 self.dbus = Record(make_wb_layout(pspec))
17 print (self.dbus.sel.shape())
18 self.mask_wid = mask_wid = pspec.mask_wid
19 self.addr_wid = addr_wid = pspec.addr_wid
20 self.data_wid = data_wid = pspec.reg_wid
21 print ("loadstoreunit addr mask data", addr_wid, mask_wid, data_wid)
22 self.adr_lsbs = log2_int(mask_wid) # LSBs of addr covered by mask
23 badwid = addr_wid-self.adr_lsbs # TODO: is this correct?
24
25 # INPUTS
26 self.x_addr_i = Signal(addr_wid) # address used for loads/stores
27 self.x_mask_i = Signal(mask_wid) # Mask of which bytes to write
28 self.x_ld_i = Signal() # set to do a memory load
29 self.x_st_i = Signal() # set to do a memory store
30 self.x_st_data_i = Signal(data_wid) # The data to write when storing
31
32 self.x_stall_i = Signal() # do nothing until low
33 self.x_valid_i = Signal() # Whether x pipeline stage is
34 # currently enabled (I
35 # think?). Set to 1 for #now
36 self.m_stall_i = Signal() # do nothing until low
37 self.m_valid_i = Signal() # Whether m pipeline stage is
38 # currently enabled. Set
39 # to 1 for now
40
41 # OUTPUTS
42 self.x_busy_o = Signal() # set when the memory is busy
43 self.m_busy_o = Signal() # set when the memory is busy
44
45 self.m_ld_data_o = Signal(data_wid) # Data returned from memory read
46 # Data validity is NOT indicated by m_valid_i or x_valid_i as
47 # those are inputs. I believe it is valid on the next cycle
48 # after raising m_load where busy is low
49
50 self.m_load_err_o = Signal() # if there was an error when loading
51 self.m_store_err_o = Signal() # if there was an error when storing
52 self.m_badaddr_o = Signal(badwid) # The address of the load/store error
53
54 def __iter__(self):
55 yield self.x_addr_i
56 yield self.x_mask_i
57 yield self.x_ld_i
58 yield self.x_st_i
59 yield self.x_st_data_i
60
61 yield self.x_stall_i
62 yield self.x_valid_i
63 yield self.m_stall_i
64 yield self.m_valid_i
65 yield self.x_busy_o
66 yield self.m_busy_o
67 yield self.m_ld_data_o
68 yield self.m_load_err_o
69 yield self.m_store_err_o
70 yield self.m_badaddr_o
71 for sig in self.dbus.fields.values():
72 yield sig
73
74 def ports(self):
75 return list(self)
76
77
78 class BareLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
79 def elaborate(self, platform):
80 m = Module()
81
82 with m.If(self.dbus.cyc):
83 with m.If(self.dbus.ack | self.dbus.err | ~self.m_valid_i):
84 m.d.sync += [
85 self.dbus.cyc.eq(0),
86 self.dbus.stb.eq(0),
87 self.m_ld_data_o.eq(self.dbus.dat_r)
88 ]
89 with m.Elif((self.x_ld_i | self.x_st_i) &
90 self.x_valid_i & ~self.x_stall_i):
91 m.d.sync += [
92 self.dbus.cyc.eq(1),
93 self.dbus.stb.eq(1),
94 self.dbus.adr.eq(self.x_addr_i[self.adr_lsbs:]),
95 self.dbus.sel.eq(self.x_mask_i),
96 self.dbus.we.eq(self.x_st_i),
97 self.dbus.dat_w.eq(self.x_st_data_i)
98 ]
99 with m.Else():
100 m.d.sync += [
101 self.dbus.adr.eq(0),
102 self.dbus.sel.eq(0),
103 self.dbus.we.eq(0),
104 self.dbus.sel.eq(0),
105 self.dbus.dat_w.eq(0),
106 ]
107
108 with m.If(self.dbus.cyc & self.dbus.err):
109 m.d.sync += [
110 self.m_load_err_o.eq(~self.dbus.we),
111 self.m_store_err_o.eq(self.dbus.we),
112 self.m_badaddr_o.eq(self.dbus.adr)
113 ]
114 with m.Elif(~self.m_stall_i):
115 m.d.sync += [
116 self.m_load_err_o.eq(0),
117 self.m_store_err_o.eq(0)
118 ]
119
120 m.d.comb += self.x_busy_o.eq(self.dbus.cyc)
121
122 with m.If(self.m_load_err_o | self.m_store_err_o):
123 m.d.comb += self.m_busy_o.eq(0)
124 with m.Else():
125 m.d.comb += self.m_busy_o.eq(self.dbus.cyc)
126
127 return m
128
129
130 class CachedLoadStoreUnit(LoadStoreUnitInterface, Elaboratable):
131 def __init__(self, pspec):
132 super().__init__(pspec)
133
134 self.dcache_args = psiec.dcache_args
135
136 self.x_fence_i = Signal()
137 self.x_flush = Signal()
138 self.m_load = Signal()
139 self.m_store = Signal()
140
141 def elaborate(self, platform):
142 m = Module()
143
144 dcache = m.submodules.dcache = L1Cache(*self.dcache_args)
145
146 x_dcache_select = Signal()
147 # Test whether the target address is inside the L1 cache region.
148 # We use bit masks in order to avoid carry chains from arithmetic
149 # comparisons. This restricts the region boundaries to powers of 2.
150 with m.Switch(self.x_addr_i[self.adr_lsbs:]):
151 def addr_below(limit):
152 assert limit in range(1, 2**30 + 1)
153 range_bits = log2_int(limit)
154 const_bits = 30 - range_bits
155 return "{}{}".format("0" * const_bits, "-" * range_bits)
156
157 if dcache.base >= (1<<self.adr_lsbs):
158 with m.Case(addr_below(dcache.base >> self.adr_lsbs)):
159 m.d.comb += x_dcache_select.eq(0)
160 with m.Case(addr_below(dcache.limit >> self.adr_lsbs)):
161 m.d.comb += x_dcache_select.eq(1)
162 with m.Default():
163 m.d.comb += x_dcache_select.eq(0)
164
165 m_dcache_select = Signal()
166 m_addr = Signal.like(self.x_addr_i)
167
168 with m.If(~self.x_stall_i):
169 m.d.sync += [
170 m_dcache_select.eq(x_dcache_select),
171 m_addr.eq(self.x_addr_i),
172 ]
173
174 m.d.comb += [
175 dcache.s1_addr.eq(self.x_addr_i[self.adr_lsbs:]),
176 dcache.s1_flush.eq(self.x_flush),
177 dcache.s1_stall.eq(self.x_stall_i),
178 dcache.s1_valid.eq(self.x_valid_i & x_dcache_select),
179 dcache.s2_addr.eq(m_addr[self.adr_lsbs:]),
180 dcache.s2_re.eq(self.m_load),
181 dcache.s2_evict.eq(self.m_store),
182 dcache.s2_valid.eq(self.m_valid_i & m_dcache_select)
183 ]
184
185 wrbuf_w_data = Record([("addr", self.addr_wid-self.adr_lsbs),
186 ("mask", self.mask_wid),
187 ("data", self.data_wid)])
188 wrbuf_r_data = Record.like(wrbuf_w_data)
189 wrbuf = m.submodules.wrbuf = SyncFIFO(width=len(wrbuf_w_data),
190 depth=dcache.nwords)
191 m.d.comb += [
192 wrbuf.w_data.eq(wrbuf_w_data),
193 wrbuf_w_data.addr.eq(self.x_addr_i[self.adr_lsbs:]),
194 wrbuf_w_data.mask.eq(self.x_mask_i),
195 wrbuf_w_data.data.eq(self.x_st_data_i),
196 wrbuf.w_en.eq(self.x_st_i & self.x_valid_i &
197 x_dcache_select & ~self.x_stall_i),
198 wrbuf_r_data.eq(wrbuf.r_data),
199 ]
200
201 dba = WishboneArbiter(self.pspec)
202 m.submodules.dbus_arbiter = dba
203 m.d.comb += dba.bus.connect(self.dbus)
204
205 wrbuf_port = dbus_arbiter.port(priority=0)
206 m.d.comb += [
207 wrbuf_port.cyc.eq(wrbuf.r_rdy),
208 wrbuf_port.we.eq(Const(1)),
209 ]
210 with m.If(wrbuf_port.stb):
211 with m.If(wrbuf_port.ack | wrbuf_port.err):
212 m.d.sync += wrbuf_port.stb.eq(0)
213 m.d.comb += wrbuf.r_en.eq(1)
214 with m.Elif(wrbuf.r_rdy):
215 m.d.sync += [
216 wrbuf_port.stb.eq(1),
217 wrbuf_port.adr.eq(wrbuf_r_data.addr),
218 wrbuf_port.sel.eq(wrbuf_r_data.mask),
219 wrbuf_port.dat_w.eq(wrbuf_r_data.data)
220 ]
221
222 dcache_port = dba.port(priority=1)
223 cti = Mux(dcache.bus_last, Cycle.END, Cycle.INCREMENT)
224 m.d.comb += [
225 dcache_port.cyc.eq(dcache.bus_re),
226 dcache_port.stb.eq(dcache.bus_re),
227 dcache_port.adr.eq(dcache.bus_addr),
228 dcache_port.cti.eq(cti),
229 dcache_port.bte.eq(Const(log2_int(dcache.nwords) - 1)),
230 dcache.bus_valid.eq(dcache_port.ack),
231 dcache.bus_error.eq(dcache_port.err),
232 dcache.bus_rdata.eq(dcache_port.dat_r)
233 ]
234
235 bare_port = dba.port(priority=2)
236 bare_rdata = Signal.like(bare_port.dat_r)
237 with m.If(bare_port.cyc):
238 with m.If(bare_port.ack | bare_port.err | ~self.m_valid_i):
239 m.d.sync += [
240 bare_port.cyc.eq(0),
241 bare_port.stb.eq(0),
242 bare_rdata.eq(bare_port.dat_r)
243 ]
244 with m.Elif((self.x_ld_i | self.x_st_i) &
245 ~x_dcache_select & self.x_valid_i & ~self.x_stall_i):
246 m.d.sync += [
247 bare_port.cyc.eq(1),
248 bare_port.stb.eq(1),
249 bare_port.adr.eq(self.x_addr_i[self.adr_lsbs:]),
250 bare_port.sel.eq(self.x_mask_i),
251 bare_port.we.eq(self.x_st_i),
252 bare_port.dat_w.eq(self.x_st_data_i)
253 ]
254
255 with m.If(self.dbus.cyc & self.dbus.err):
256 m.d.sync += [
257 self.m_load_err_o.eq(~self.dbus.we),
258 self.m_store_err_o.eq(self.dbus.we),
259 self.m_badaddr_o.eq(self.dbus.adr)
260 ]
261 with m.Elif(~self.m_stall_i):
262 m.d.sync += [
263 self.m_load_err_o.eq(0),
264 self.m_store_err_o.eq(0)
265 ]
266
267 with m.If(self.x_fence_i):
268 m.d.comb += self.x_busy_o.eq(wrbuf.r_rdy)
269 with m.Elif(x_dcache_select):
270 m.d.comb += self.x_busy_o.eq(self.x_st_i & ~wrbuf.w_rdy)
271 with m.Else():
272 m.d.comb += self.x_busy_o.eq(bare_port.cyc)
273
274 with m.If(self.m_flush):
275 m.d.comb += self.m_busy_o.eq(~dcache.s2_flush_ack)
276 with m.If(self.m_load_err_o | self.m_store_err_o):
277 m.d.comb += self.m_busy_o.eq(0)
278 with m.Elif(m_dcache_select):
279 m.d.comb += [
280 self.m_busy_o.eq(dcache.s2_miss),
281 self.m_ld_data_o.eq(dcache.s2_rdata)
282 ]
283 with m.Else():
284 m.d.comb += [
285 self.m_busy_o.eq(bare_port.cyc),
286 self.m_ld_data_o.eq(bare_rdata)
287 ]
288
289 return m