add TODOs
[litex.git] / litex / soc / interconnect / stream_packet.py
1 from litex.gen import *
2 from litex.gen.genlib.roundrobin import *
3 from litex.gen.genlib.record import *
4 from litex.gen.genlib.fsm import FSM, NextState
5
6 from litex.soc.interconnect.stream import *
7
8 # TODO: clean up code below
9 # XXX
10
11 def reverse_bytes(signal):
12 n = (len(signal)+7)//8
13 r = []
14 for i in reversed(range(n)):
15 r.append(signal[i*8:min((i+1)*8, len(signal))])
16 return Cat(iter(r))
17
18
19 @ResetInserter()
20 @CEInserter()
21 class Counter(Module):
22 def __init__(self, *args, increment=1, **kwargs):
23 self.value = Signal(*args, **kwargs)
24 self.width = len(self.value)
25 self.sync += self.value.eq(self.value+increment)
26
27 class Status(Module):
28 def __init__(self, endpoint):
29 self.sop = sop = Signal()
30 self.eop = eop =Signal()
31 self.ongoing = Signal()
32
33 ongoing = Signal()
34 self.comb += \
35 If(endpoint.stb,
36 sop.eq(endpoint.sop),
37 eop.eq(endpoint.eop & endpoint.ack)
38 )
39 self.sync += ongoing.eq((sop | ongoing) & ~eop)
40 self.comb += self.ongoing.eq((sop | ongoing) & ~eop)
41
42
43 class Arbiter(Module):
44 def __init__(self, masters, slave):
45 if len(masters) == 0:
46 pass
47 elif len(masters) == 1:
48 self.grant = Signal()
49 self.comb += Record.connect(masters.pop(), slave)
50 else:
51 self.submodules.rr = RoundRobin(len(masters))
52 self.grant = self.rr.grant
53 cases = {}
54 for i, master in enumerate(masters):
55 status = Status(master)
56 self.submodules += status
57 self.comb += self.rr.request[i].eq(status.ongoing)
58 cases[i] = [Record.connect(master, slave)]
59 self.comb += Case(self.grant, cases)
60
61
62 class Dispatcher(Module):
63 def __init__(self, master, slaves, one_hot=False):
64 if len(slaves) == 0:
65 self.sel = Signal()
66 elif len(slaves) == 1:
67 self.comb += Record.connect(master, slaves.pop())
68 self.sel = Signal()
69 else:
70 if one_hot:
71 self.sel = Signal(len(slaves))
72 else:
73 self.sel = Signal(max=len(slaves))
74
75 # # #
76
77 status = Status(master)
78 self.submodules += status
79
80 sel = Signal.like(self.sel)
81 sel_ongoing = Signal.like(self.sel)
82 self.sync += \
83 If(status.sop,
84 sel_ongoing.eq(self.sel)
85 )
86 self.comb += \
87 If(status.sop,
88 sel.eq(self.sel)
89 ).Else(
90 sel.eq(sel_ongoing)
91 )
92 cases = {}
93 for i, slave in enumerate(slaves):
94 if one_hot:
95 idx = 2**i
96 else:
97 idx = i
98 cases[idx] = [Record.connect(master, slave)]
99 cases["default"] = [master.ack.eq(1)]
100 self.comb += Case(sel, cases)
101
102
103 class HeaderField:
104 def __init__(self, byte, offset, width):
105 self.byte = byte
106 self.offset = offset
107 self.width = width
108
109
110 class Header:
111 def __init__(self, fields, length, swap_field_bytes=True):
112 self.fields = fields
113 self.length = length
114 self.swap_field_bytes = swap_field_bytes
115
116 def get_layout(self):
117 layout = []
118 for k, v in sorted(self.fields.items()):
119 layout.append((k, v.width))
120 return layout
121
122 def get_field(self, obj, name, width):
123 if "_lsb" in name:
124 field = getattr(obj, name.replace("_lsb", ""))[:width]
125 elif "_msb" in name:
126 field = getattr(obj, name.replace("_msb", ""))[width:2*width]
127 else:
128 field = getattr(obj, name)
129 return field
130
131 def encode(self, obj, signal):
132 r = []
133 for k, v in sorted(self.fields.items()):
134 start = v.byte*8+v.offset
135 end = start+v.width
136 field = self.get_field(obj, k, v.width)
137 if self.swap_field_bytes:
138 field = reverse_bytes(field)
139 r.append(signal[start:end].eq(field))
140 return r
141
142 def decode(self, signal, obj):
143 r = []
144 for k, v in sorted(self.fields.items()):
145 start = v.byte*8+v.offset
146 end = start+v.width
147 field = self.get_field(obj, k, v.width)
148 if self.swap_field_bytes:
149 r.append(field.eq(reverse_bytes(signal[start:end])))
150 else:
151 r.append(field.eq(signal[start:end]))
152 return r
153
154
155 class Packetizer(Module):
156 def __init__(self, sink_description, source_description, header):
157 self.sink = sink = Sink(sink_description)
158 self.source = source = Source(source_description)
159 self.header = Signal(header.length*8)
160
161 # # #
162
163 dw = len(self.sink.data)
164
165 header_reg = Signal(header.length*8)
166 header_words = (header.length*8)//dw
167 load = Signal()
168 shift = Signal()
169 counter = Counter(max=max(header_words, 2))
170 self.submodules += counter
171
172 self.comb += header.encode(sink, self.header)
173 if header_words == 1:
174 self.sync += [
175 If(load,
176 header_reg.eq(self.header)
177 )
178 ]
179 else:
180 self.sync += [
181 If(load,
182 header_reg.eq(self.header)
183 ).Elif(shift,
184 header_reg.eq(Cat(header_reg[dw:], Signal(dw)))
185 )
186 ]
187
188 fsm = FSM(reset_state="IDLE")
189 self.submodules += fsm
190
191 if header_words == 1:
192 idle_next_state = "COPY"
193 else:
194 idle_next_state = "SEND_HEADER"
195
196 fsm.act("IDLE",
197 sink.ack.eq(1),
198 counter.reset.eq(1),
199 If(sink.stb & sink.sop,
200 sink.ack.eq(0),
201 source.stb.eq(1),
202 source.sop.eq(1),
203 source.eop.eq(0),
204 source.data.eq(self.header[:dw]),
205 If(source.stb & source.ack,
206 load.eq(1),
207 NextState(idle_next_state)
208 )
209 )
210 )
211 if header_words != 1:
212 fsm.act("SEND_HEADER",
213 source.stb.eq(1),
214 source.sop.eq(0),
215 source.eop.eq(0),
216 source.data.eq(header_reg[dw:2*dw]),
217 If(source.stb & source.ack,
218 shift.eq(1),
219 counter.ce.eq(1),
220 If(counter.value == header_words-2,
221 NextState("COPY")
222 )
223 )
224 )
225 fsm.act("COPY",
226 source.stb.eq(sink.stb),
227 source.sop.eq(0),
228 source.eop.eq(sink.eop),
229 source.data.eq(sink.data),
230 source.error.eq(sink.error),
231 If(source.stb & source.ack,
232 sink.ack.eq(1),
233 If(source.eop,
234 NextState("IDLE")
235 )
236 )
237 )
238
239
240 class Depacketizer(Module):
241 def __init__(self, sink_description, source_description, header):
242 self.sink = sink = Sink(sink_description)
243 self.source = source = Source(source_description)
244 self.header = Signal(header.length*8)
245
246 # # #
247
248 dw = len(sink.data)
249
250 header_words = (header.length*8)//dw
251
252 shift = Signal()
253 counter = Counter(max=max(header_words, 2))
254 self.submodules += counter
255
256 if header_words == 1:
257 self.sync += \
258 If(shift,
259 self.header.eq(sink.data)
260 )
261 else:
262 self.sync += \
263 If(shift,
264 self.header.eq(Cat(self.header[dw:], sink.data))
265 )
266
267 fsm = FSM(reset_state="IDLE")
268 self.submodules += fsm
269
270 if header_words == 1:
271 idle_next_state = "COPY"
272 else:
273 idle_next_state = "RECEIVE_HEADER"
274
275 fsm.act("IDLE",
276 sink.ack.eq(1),
277 counter.reset.eq(1),
278 If(sink.stb,
279 shift.eq(1),
280 NextState(idle_next_state)
281 )
282 )
283 if header_words != 1:
284 fsm.act("RECEIVE_HEADER",
285 sink.ack.eq(1),
286 If(sink.stb,
287 counter.ce.eq(1),
288 shift.eq(1),
289 If(counter.value == header_words-2,
290 NextState("COPY")
291 )
292 )
293 )
294 no_payload = Signal()
295 self.sync += \
296 If(fsm.before_entering("COPY"),
297 source.sop.eq(1),
298 no_payload.eq(sink.eop)
299 ).Elif(source.stb & source.ack,
300 source.sop.eq(0)
301 )
302
303 if hasattr(sink, "error"):
304 self.comb += source.error.eq(sink.error)
305 self.comb += [
306 source.eop.eq(sink.eop | no_payload),
307 source.data.eq(sink.data),
308 header.decode(self.header, source)
309 ]
310 fsm.act("COPY",
311 sink.ack.eq(source.ack),
312 source.stb.eq(sink.stb | no_payload),
313 If(source.stb & source.ack & source.eop,
314 NextState("IDLE")
315 )
316 )
317
318
319 class Buffer(Module):
320 def __init__(self, description, data_depth, cmd_depth=4, almost_full=None):
321 self.sink = sink = Sink(description)
322 self.source = source = Source(description)
323
324 # # #
325
326 sink_status = Status(self.sink)
327 source_status = Status(self.source)
328 self.submodules += sink_status, source_status
329
330 # store incoming packets
331 # cmds
332 def cmd_description():
333 layout = [("error", 1)]
334 return EndpointDescription(layout)
335 cmd_fifo = SyncFIFO(cmd_description(), cmd_depth)
336 self.submodules += cmd_fifo
337 self.comb += cmd_fifo.sink.stb.eq(sink_status.eop)
338 if hasattr(sink, "error"):
339 self.comb += cmd_fifo.sink.error.eq(sink.error)
340
341 # data
342 data_fifo = SyncFIFO(description, data_depth, buffered=True)
343 self.submodules += data_fifo
344 self.comb += [
345 Record.connect(self.sink, data_fifo.sink),
346 data_fifo.sink.stb.eq(self.sink.stb & cmd_fifo.sink.ack),
347 self.sink.ack.eq(data_fifo.sink.ack & cmd_fifo.sink.ack),
348 ]
349
350 # output packets
351 self.fsm = fsm = FSM(reset_state="IDLE")
352 self.submodules += fsm
353 fsm.act("IDLE",
354 If(cmd_fifo.source.stb,
355 NextState("SEEK_SOP")
356 )
357 )
358 fsm.act("SEEK_SOP",
359 If(~data_fifo.source.sop,
360 data_fifo.source.ack.eq(1)
361 ).Else(
362 NextState("OUTPUT")
363 )
364 )
365 if hasattr(source, "error"):
366 source_error = self.source.error
367 else:
368 source_error = Signal()
369
370 fsm.act("OUTPUT",
371 Record.connect(data_fifo.source, self.source),
372 source_error.eq(cmd_fifo.source.error),
373 If(source_status.eop,
374 cmd_fifo.source.ack.eq(1),
375 NextState("IDLE")
376 )
377 )
378
379 # compute almost full
380 if almost_full is not None:
381 self.almost_full = Signal()
382 self.comb += self.almost_full.eq(data_fifo.fifo.level > almost_full)
383
384 # XXX