Avoid timing violation on ECP5 PHY PAUSE signal
[gram.git] / gram / stream.py
1 # This file is Copyright (c) 2020 LambdaConcept <contact@lambdaconcept.com>
2
3 from nmigen import *
4 from nmigen.hdl.rec import *
5 from nmigen.lib import fifo
6
7
8 __all__ = ["Endpoint", "SyncFIFO", "AsyncFIFO", "Buffer", "StrideConverter"]
9
10
11 def _make_fanout(layout):
12 r = []
13 for f in layout:
14 if isinstance(f[1], (int, tuple)):
15 r.append((f[0], f[1], DIR_FANOUT))
16 else:
17 r.append((f[0], _make_fanout(f[1])))
18 return r
19
20
21 class EndpointDescription:
22 def __init__(self, payload_layout):
23 self.payload_layout = payload_layout
24
25 def get_full_layout(self):
26 reserved = {"valid", "ready", "first", "last", "payload"}
27 attributed = set()
28 for f in self.payload_layout:
29 if f[0] in attributed:
30 raise ValueError(
31 f[0] + " already attributed in payload layout")
32 if f[0] in reserved:
33 raise ValueError(f[0] + " cannot be used in endpoint layout")
34 attributed.add(f[0])
35
36 full_layout = [
37 ("valid", 1, DIR_FANOUT),
38 ("ready", 1, DIR_FANIN),
39 ("first", 1, DIR_FANOUT),
40 ("last", 1, DIR_FANOUT),
41 ("payload", _make_fanout(self.payload_layout))
42 ]
43 return full_layout
44
45
46 class Endpoint(Record):
47 def __init__(self, layout_or_description, **kwargs):
48 if isinstance(layout_or_description, EndpointDescription):
49 self.description = layout_or_description
50 else:
51 self.description = EndpointDescription(layout_or_description)
52 super().__init__(self.description.get_full_layout(), src_loc_at=1, **kwargs)
53
54 def __getattr__(self, name):
55 try:
56 return super().__getattr__(name)
57 except AttributeError:
58 return self.fields["payload"][name]
59
60
61 class _FIFOWrapper:
62 def __init__(self, payload_layout):
63 self.sink = Endpoint(payload_layout)
64 self.source = Endpoint(payload_layout)
65
66 self.layout = Layout([
67 ("payload", self.sink.description.payload_layout),
68 ("first", 1, DIR_FANOUT),
69 ("last", 1, DIR_FANOUT)
70 ])
71
72 def elaborate(self, platform):
73 m = Module()
74
75 fifo = m.submodules.fifo = self.fifo
76 fifo_din = Record(self.layout)
77 fifo_dout = Record(self.layout)
78 m.d.comb += [
79 fifo.w_data.eq(fifo_din),
80 fifo_dout.eq(fifo.r_data),
81
82 self.sink.ready.eq(fifo.w_rdy),
83 fifo.w_en.eq(self.sink.valid),
84 fifo_din.first.eq(self.sink.first),
85 fifo_din.last.eq(self.sink.last),
86 fifo_din.payload.eq(self.sink.payload),
87
88 self.source.valid.eq(fifo.r_rdy),
89 self.source.first.eq(fifo_dout.first),
90 self.source.last.eq(fifo_dout.last),
91 self.source.payload.eq(fifo_dout.payload),
92 fifo.r_en.eq(self.source.ready)
93 ]
94
95 return m
96
97
98 class SyncFIFO(Elaboratable, _FIFOWrapper):
99 def __init__(self, layout, depth, fwft=True, buffered=False):
100 super().__init__(layout)
101 if buffered:
102 self.fifo = fifo.SyncFIFOBuffered(
103 width=len(Record(self.layout)), depth=depth, fwft=fwft)
104 else:
105 self.fifo = fifo.SyncFIFO(
106 width=len(Record(self.layout)), depth=depth, fwft=fwft)
107 self.depth = self.fifo.depth
108 self.level = self.fifo.level
109
110
111 class AsyncFIFO(Elaboratable, _FIFOWrapper):
112 def __init__(self, layout, depth, r_domain="read", w_domain="write"):
113 super().__init__(layout)
114 self.fifo = fifo.AsyncFIFO(width=len(Record(self.layout)), depth=depth,
115 r_domain=r_domain, w_domain=w_domain)
116 self.depth = self.fifo.depth
117
118
119 class PipeValid(Elaboratable):
120 """Pipe valid/payload to cut timing path"""
121
122 def __init__(self, layout):
123 self.sink = Endpoint(layout)
124 self.source = Endpoint(layout)
125
126 def elaborate(self, platform):
127 m = Module()
128
129 # Pipe when source is not valid or is ready.
130 with m.If(~self.source.valid | self.source.ready):
131 m.d.sync += [
132 self.source.valid.eq(self.sink.valid),
133 self.source.first.eq(self.sink.first),
134 self.source.last.eq(self.sink.last),
135 self.source.payload.eq(self.sink.payload),
136 # self.source.param.eq(self.sink.param), # TODO ensure this can be commented
137 ]
138 m.d.comb += self.sink.ready.eq(~self.source.valid | self.source.ready)
139
140 return m
141
142
143 class Buffer(PipeValid):
144 pass # FIXME: Replace Buffer with PipeValid in codebase?
145
146
147 class _UpConverter(Elaboratable):
148 def __init__(self, nbits_from, nbits_to, ratio, reverse,
149 report_valid_token_count):
150 self.sink = sink = Endpoint([("data", nbits_from)])
151 source_layout = [("data", nbits_to)]
152 if report_valid_token_count:
153 source_layout.append(("valid_token_count", bits_for(ratio)))
154 self.source = source = Endpoint(source_layout)
155 self.ratio = ratio
156 self._nbits_from = nbits_from
157 self._reverse = reverse
158 self._report_valid_token_count = report_valid_token_count
159
160 def elaborate(self, platform):
161 m = Module()
162
163 # control path
164 demux = Signal(range(self.ratio))
165 load_part = Signal()
166 strobe_all = Signal()
167 m.d.comb += [
168 self.sink.ready.eq(~strobe_all | self.source.ready),
169 self.source.valid.eq(strobe_all),
170 load_part.eq(self.sink.valid & self.sink.ready)
171 ]
172
173 demux_last = ((demux == (self.ratio - 1)) | self.sink.last)
174
175 with m.If(self.source.ready):
176 m.d.sync += strobe_all.eq(0)
177
178 with m.If(load_part):
179 with m.If(demux_last):
180 m.d.sync += [
181 demux.eq(0),
182 strobe_all.eq(1),
183 ]
184 with m.Else():
185 m.d.sync += demux.eq(demux+1)
186
187 with m.If(self.source.valid & self.source.ready):
188 m.d.sync += self.source.last.eq(self.sink.last)
189 with m.Elif(self.sink.valid & self.sink.ready):
190 m.d.sync += self.source.last.eq(self.sink.last | self.source.last)
191
192 # data path
193 with m.If(load_part):
194 with m.Switch(demux):
195 for i in range(self.ratio):
196 with m.Case(i):
197 n = self.ratio-i-1 if self._reverse else i
198 m.d.sync += self.source.payload.lower()[n*self._nbits_from:(
199 n+1)*self._nbits_from].eq(self.sink.payload)
200
201 if self._report_valid_token_count:
202 with m.If(load_part):
203 m.d.sync += self.source.valid_token_count.eq(demux + 1)
204
205 return m
206
207
208 class _DownConverter(Elaboratable):
209 def __init__(self, nbits_from, nbits_to, ratio, reverse,
210 report_valid_token_count):
211 self.sink = Endpoint([("data", nbits_from)])
212 source_layout = [("data", nbits_to)]
213 if report_valid_token_count:
214 source_layout.append(("valid_token_count", 1))
215 self.source = Endpoint(source_layout)
216 self.ratio = ratio
217 self._reverse = reverse
218 self._nbits_to = nbits_to
219 self._report_valid_token_count = report_valid_token_count
220
221 def elaborate(self, platform):
222 m = Module()
223
224 # control path
225 mux = Signal(range(self.ratio))
226 last = Signal()
227 m.d.comb += [
228 last.eq(mux == (self.ratio-1)),
229 self.source.valid.eq(self.sink.valid),
230 self.source.last.eq(self.sink.last & last),
231 self.sink.ready.eq(last & self.source.ready)
232 ]
233 with m.If(self.source.valid & self.source.ready):
234 with m.If(last):
235 m.d.sync += mux.eq(0)
236 with m.Else():
237 m.d.sync += mux.eq(mux+1)
238
239 # data path
240 # cases = {}
241 # for i in range(self.ratio):
242 # n = self.ratio-i-1 if self._reverse else i
243 # cases[i] = self.source.data.eq(self.sink.data[n*self._nbits_to:(n+1)*self._nbits_to])
244 # m.d.comb += Case(mux, cases).makedefault()
245
246 with m.Switch(mux):
247 for i in range(self.ratio):
248 with m.Case(i):
249 n = self.ratio-i-1 if self._reverse else i
250 m.d.comb += self.source.data.eq(
251 self.sink.data[n*self._nbits_to:(n+1)*self._nbits_to])
252 with m.Case():
253 n = self.ratio-self.ratio-1-1 if self._reverse else self.ratio-1
254 m.d.comb += self.source.data.eq(
255 self.sink.data[n*self._nbits_to:(n+1)*self._nbits_to])
256
257 if self._report_valid_token_count:
258 m.d.comb += self.source.valid_token_count.eq(last)
259
260 return m
261
262
263 class _IdentityConverter(Elaboratable):
264 def __init__(self, nbits_from, nbits_to, ratio, reverse,
265 report_valid_token_count):
266 self.sink = Endpoint([("data", nbits_from)])
267 source_layout = [("data", nbits_to)]
268 if report_valid_token_count:
269 source_layout.append(("valid_token_count", 1))
270 self.source = Endpoint(source_layout)
271 assert ratio == 1
272 self.ratio = ratio
273 self._report_valid_token_count = report_valid_token_count
274
275 def elaborate(self, platform):
276 m = Module()
277
278 m.d.comb += self.sink.connect(self.source)
279 if self._report_valid_token_count:
280 m.d.comb += self.source.valid_token_count.eq(1)
281
282 return m
283
284
285 def _get_converter_ratio(nbits_from, nbits_to):
286 if nbits_from > nbits_to:
287 specialized_cls = _DownConverter
288 if nbits_from % nbits_to:
289 raise ValueError("Ratio must be an int")
290 ratio = nbits_from//nbits_to
291 elif nbits_from < nbits_to:
292 specialized_cls = _UpConverter
293 if nbits_to % nbits_from:
294 raise ValueError("Ratio must be an int")
295 ratio = nbits_to//nbits_from
296 else:
297 specialized_cls = _IdentityConverter
298 ratio = 1
299
300 return specialized_cls, ratio
301
302
303 class Converter(Elaboratable):
304 def __init__(self, nbits_from, nbits_to, reverse=False,
305 report_valid_token_count=False):
306 cls, ratio = _get_converter_ratio(nbits_from, nbits_to)
307 self.specialized = cls(nbits_from, nbits_to, ratio,
308 reverse, report_valid_token_count)
309 self.sink = self.specialized.sink
310 self.source = self.specialized.source
311
312 def elaborate(self, platform):
313 m = Module()
314
315 m.submodules += self.specialized
316
317 return m
318
319
320 class StrideConverter(Elaboratable):
321 def __init__(self, layout_from, layout_to, *args, **kwargs):
322 self.sink = sink = Endpoint(layout_from)
323 self.source = source = Endpoint(layout_to)
324
325 self._layout_to = layout_to
326 self._layout_from = layout_from
327
328 nbits_from = len(sink.payload.lower())
329 nbits_to = len(source.payload.lower())
330 self.converter = Converter(nbits_from, nbits_to, *args, **kwargs)
331
332 def elaborate(self, platform):
333 m = Module()
334
335 nbits_from = len(self.sink.payload.lower())
336 nbits_to = len(self.source.payload.lower())
337
338 m.submodules += self.converter
339
340 # cast sink to converter.sink (user fields --> raw bits)
341 m.d.comb += [
342 self.converter.sink.valid.eq(self.sink.valid),
343 self.converter.sink.last.eq(self.sink.last),
344 self.sink.ready.eq(self.converter.sink.ready)
345 ]
346 if isinstance(self.converter.specialized, _DownConverter):
347 ratio = self.converter.specialized.ratio
348 for i in range(ratio):
349 j = 0
350 for name, width in self._layout_to.payload_layout:
351 src = getattr(self.sink, name)[i*width:(i+1)*width]
352 dst = self.converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
353 m.d.comb += dst.eq(src)
354 j += width
355 else:
356 m.d.comb += self.converter.sink.payload.eq(
357 self.sink.payload.lower())
358
359 # cast converter.source to source (raw bits --> user fields)
360 m.d.comb += [
361 self.source.valid.eq(self.converter.source.valid),
362 self.source.last.eq(self.converter.source.last),
363 self.converter.source.ready.eq(self.source.ready)
364 ]
365 if isinstance(self.converter.specialized, _UpConverter):
366 ratio = self.converter.specialized.ratio
367 for i in range(ratio):
368 j = 0
369 for name, width in self._layout_from.payload_layout:
370 src = self.converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
371 dst = getattr(self.source, name)[i*width:(i+1)*width]
372 m.d.comb += dst.eq(src)
373 j += width
374 else:
375 m.d.comb += self.source.payload.lower().eq(self.converter.source.payload)
376
377 return m
378
379
380 class Pipeline(Elaboratable):
381 def __init__(self, *modules):
382 self._modules = modules
383
384 # expose sink of first module
385 # if available
386 if hasattr(modules[0], "sink"):
387 self.sink = modules[0].sink
388
389 # expose source of last module
390 # if available
391 if hasattr(modules[-1], "source"):
392 self.source = modules[-1].source
393
394 def elaborate(self, platform):
395 m = Module()
396
397 n = len(self._modules)
398 mod = self._modules[0]
399
400 for i in range(1, n):
401 mod_n = self._modules[i]
402
403 if isinstance(mod, Endpoint):
404 source = mod
405 else:
406 source = mod.source
407
408 if isinstance(mod_n, Endpoint):
409 sink = mod_n
410 else:
411 sink = mod_n.sink
412
413 if mod is not mod_n:
414 m.d.comb += source.connect(sink)
415
416 mod = mod_n
417
418 return m