litex_sim: Rework Makefiles to put output files in gateware directory.
[litex.git] / litex / soc / interconnect / stream.py
1 # This file is Copyright (c) 2015 Sebastien Bourdeauducq <sb@m-labs.hk>
2 # This file is Copyright (c) 2015-2020 Florent Kermarrec <florent@enjoy-digital.fr>
3 # This file is Copyright (c) 2018 Tim 'mithro' Ansell <me@mith.ro>
4 # License: BSD
5
6 import math
7
8 from migen import *
9 from migen.util.misc import xdir
10 from migen.genlib import fifo
11 from migen.genlib.cdc import MultiReg, PulseSynchronizer
12
13 from litex.soc.interconnect.csr import *
14
15 # Endpoint -----------------------------------------------------------------------------------------
16
17 (DIR_SINK, DIR_SOURCE) = range(2)
18
19 def _make_m2s(layout):
20 r = []
21 for f in layout:
22 if isinstance(f[1], (int, tuple)):
23 r.append((f[0], f[1], DIR_M_TO_S))
24 else:
25 r.append((f[0], _make_m2s(f[1])))
26 return r
27
28 def set_reset_less(field):
29 if isinstance(field, Signal):
30 field.reset_less = True
31 elif isinstance(field, Record):
32 for s, _ in field.iter_flat():
33 s.reset_less = True
34
35 class EndpointDescription:
36 def __init__(self, payload_layout, param_layout=[]):
37 self.payload_layout = payload_layout
38 self.param_layout = param_layout
39
40 def get_full_layout(self):
41 reserved = {"valid", "ready", "payload", "param", "first", "last", "description"}
42 attributed = set()
43 for f in self.payload_layout + self.param_layout:
44 if f[0] in attributed:
45 raise ValueError(f[0] + " already attributed in payload or param layout")
46 if f[0] in reserved:
47 raise ValueError(f[0] + " cannot be used in endpoint layout")
48 attributed.add(f[0])
49
50 full_layout = [
51 ("valid", 1, DIR_M_TO_S),
52 ("ready", 1, DIR_S_TO_M),
53 ("first", 1, DIR_M_TO_S),
54 ("last", 1, DIR_M_TO_S),
55 ("payload", _make_m2s(self.payload_layout)),
56 ("param", _make_m2s(self.param_layout))
57 ]
58 return full_layout
59
60
61 class Endpoint(Record):
62 def __init__(self, description_or_layout, name=None, **kwargs):
63 if isinstance(description_or_layout, EndpointDescription):
64 self.description = description_or_layout
65 else:
66 self.description = EndpointDescription(description_or_layout)
67 Record.__init__(self, self.description.get_full_layout(), name, **kwargs)
68 set_reset_less(self.first)
69 set_reset_less(self.last)
70 set_reset_less(self.payload)
71 set_reset_less(self.param)
72
73 def __getattr__(self, name):
74 try:
75 return getattr(object.__getattribute__(self, "payload"), name)
76 except:
77 return getattr(object.__getattribute__(self, "param"), name)
78
79 # Actor --------------------------------------------------------------------------------------------
80
81 def _rawbits_layout(l):
82 if isinstance(l, int):
83 return [("rawbits", l)]
84 else:
85 return l
86
87 def pack_layout(l, n):
88 return [("chunk"+str(i), l) for i in range(n)]
89
90 def get_endpoints(obj, filt=Endpoint):
91 if hasattr(obj, "get_endpoints") and callable(obj.get_endpoints):
92 return obj.get_endpoints(filt)
93 r = dict()
94 for k, v in xdir(obj, True):
95 if isinstance(v, filt):
96 r[k] = v
97 return r
98
99 def get_single_ep(obj, filt):
100 eps = get_endpoints(obj, filt)
101 if len(eps) != 1:
102 raise ValueError("More than one endpoint")
103 return list(eps.items())[0]
104
105
106 class BinaryActor(Module):
107 def __init__(self, *args, **kwargs):
108 self.build_binary_control(self.sink, self.source, *args, **kwargs)
109
110 def build_binary_control(self, sink, source):
111 raise NotImplementedError("Binary actor classes must overload build_binary_control_fragment")
112
113
114 class CombinatorialActor(BinaryActor):
115 def build_binary_control(self, sink, source):
116 self.comb += [
117 source.valid.eq(sink.valid),
118 source.first.eq(sink.first),
119 source.last.eq(sink.last),
120 sink.ready.eq(source.ready),
121 ]
122
123
124 class PipelinedActor(BinaryActor):
125 def __init__(self, latency):
126 self.latency = latency
127 self.pipe_ce = Signal()
128 self.busy = Signal()
129 BinaryActor.__init__(self, latency)
130
131 def build_binary_control(self, sink, source, latency):
132 busy = 0
133 valid = sink.valid
134 for i in range(latency):
135 valid_n = Signal()
136 self.sync += If(self.pipe_ce, valid_n.eq(valid))
137 valid = valid_n
138 busy = busy | valid
139
140 self.comb += [
141 self.pipe_ce.eq(source.ready | ~valid),
142 sink.ready.eq(self.pipe_ce),
143 source.valid.eq(valid),
144 self.busy.eq(busy)
145 ]
146 first = sink.valid & sink.first
147 last = sink.valid & sink.last
148 for i in range(latency):
149 first_n = Signal(reset_less=True)
150 last_n = Signal(reset_less=True)
151 self.sync += \
152 If(self.pipe_ce,
153 first_n.eq(first),
154 last_n.eq(last)
155 )
156 first = first_n
157 last = last_n
158 self.comb += [
159 source.first.eq(first),
160 source.last.eq(last)
161 ]
162
163 # FIFO ---------------------------------------------------------------------------------------------
164
165 class _FIFOWrapper(Module):
166 def __init__(self, fifo_class, layout, depth):
167 self.sink = sink = Endpoint(layout)
168 self.source = source = Endpoint(layout)
169
170 # # #
171
172 description = sink.description
173 fifo_layout = [
174 ("payload", description.payload_layout),
175 ("param", description.param_layout),
176 ("first", 1),
177 ("last", 1)
178 ]
179
180 self.submodules.fifo = fifo = fifo_class(layout_len(fifo_layout), depth)
181 fifo_in = Record(fifo_layout)
182 fifo_out = Record(fifo_layout)
183 self.comb += [
184 fifo.din.eq(fifo_in.raw_bits()),
185 fifo_out.raw_bits().eq(fifo.dout)
186 ]
187
188 self.comb += [
189 sink.ready.eq(fifo.writable),
190 fifo.we.eq(sink.valid),
191 fifo_in.first.eq(sink.first),
192 fifo_in.last.eq(sink.last),
193 fifo_in.payload.eq(sink.payload),
194 fifo_in.param.eq(sink.param),
195
196 source.valid.eq(fifo.readable),
197 source.first.eq(fifo_out.first),
198 source.last.eq(fifo_out.last),
199 source.payload.eq(fifo_out.payload),
200 source.param.eq(fifo_out.param),
201 fifo.re.eq(source.ready)
202 ]
203
204
205 class SyncFIFO(_FIFOWrapper):
206 def __init__(self, layout, depth, buffered=False):
207 assert depth >= 0
208 if depth >= 2:
209 _FIFOWrapper.__init__(self,
210 fifo_class = fifo.SyncFIFOBuffered if buffered else fifo.SyncFIFO,
211 layout = layout,
212 depth = depth)
213 self.depth = self.fifo.depth
214 self.level = self.fifo.level
215 elif depth == 1:
216 buf = Buffer(layout)
217 self.submodules += buf
218 self.sink = buf.sink
219 self.source = buf.source
220 self.depth = 1
221 self.level = Signal()
222 elif depth == 0:
223 self.sink = Endpoint(layout)
224 self.source = Endpoint(layout)
225 self.comb += self.sink.connect(self.source)
226 self.depth = 0
227 self.level = Signal()
228
229
230 class AsyncFIFO(_FIFOWrapper):
231 def __init__(self, layout, depth, buffered=False):
232 assert depth >= 4
233 _FIFOWrapper.__init__(self,
234 fifo_class = fifo.AsyncFIFOBuffered if buffered else fifo.AsyncFIFO,
235 layout = layout,
236 depth = depth)
237
238 # Mux/Demux ----------------------------------------------------------------------------------------
239
240 class Multiplexer(Module):
241 def __init__(self, layout, n):
242 self.source = Endpoint(layout)
243 sinks = []
244 for i in range(n):
245 sink = Endpoint(layout)
246 setattr(self, "sink"+str(i), sink)
247 sinks.append(sink)
248 self.sel = Signal(max=n)
249
250 # # #
251
252 cases = {}
253 for i, sink in enumerate(sinks):
254 cases[i] = sink.connect(self.source)
255 self.comb += Case(self.sel, cases)
256
257
258 class Demultiplexer(Module):
259 def __init__(self, layout, n):
260 self.sink = Endpoint(layout)
261 sources = []
262 for i in range(n):
263 source = Endpoint(layout)
264 setattr(self, "source"+str(i), source)
265 sources.append(source)
266 self.sel = Signal(max=n)
267
268 # # #
269
270 cases = {}
271 for i, source in enumerate(sources):
272 cases[i] = self.sink.connect(source)
273 self.comb += Case(self.sel, cases)
274
275 # Converter ----------------------------------------------------------------------------------------
276
277 class _UpConverter(Module):
278 def __init__(self, nbits_from, nbits_to, ratio, reverse):
279 self.sink = sink = Endpoint([("data", nbits_from)])
280 self.source = source = Endpoint([("data", nbits_to), ("valid_token_count", bits_for(ratio))])
281 self.latency = 1
282
283 # # #
284
285 # Control path
286 demux = Signal(max=ratio)
287 load_part = Signal()
288 strobe_all = Signal()
289 self.comb += [
290 sink.ready.eq(~strobe_all | source.ready),
291 source.valid.eq(strobe_all),
292 load_part.eq(sink.valid & sink.ready)
293 ]
294
295 demux_last = ((demux == (ratio - 1)) | sink.last)
296
297 self.sync += [
298 If(source.ready, strobe_all.eq(0)),
299 If(load_part,
300 If(demux_last,
301 demux.eq(0),
302 strobe_all.eq(1)
303 ).Else(
304 demux.eq(demux + 1)
305 )
306 ),
307 If(source.valid & source.ready,
308 If(sink.valid & sink.ready,
309 source.first.eq(sink.first),
310 source.last.eq(sink.last)
311 ).Else(
312 source.first.eq(0),
313 source.last.eq(0)
314 )
315 ).Elif(sink.valid & sink.ready,
316 source.first.eq(sink.first | source.first),
317 source.last.eq(sink.last | source.last)
318 )
319 ]
320
321 # Data path
322 cases = {}
323 for i in range(ratio):
324 n = ratio-i-1 if reverse else i
325 cases[i] = source.data[n*nbits_from:(n+1)*nbits_from].eq(sink.data)
326 self.sync += If(load_part, Case(demux, cases))
327
328 # Valid token count
329 self.sync += If(load_part, source.valid_token_count.eq(demux + 1))
330
331
332 class _DownConverter(Module):
333 def __init__(self, nbits_from, nbits_to, ratio, reverse):
334 self.sink = sink = Endpoint([("data", nbits_from)])
335 self.source = source = Endpoint([("data", nbits_to), ("valid_token_count", 1)])
336 self.latency = 0
337
338 # # #
339
340 # Control path
341 mux = Signal(max=ratio)
342 first = Signal()
343 last = Signal()
344 self.comb += [
345 first.eq(mux == 0),
346 last.eq(mux == (ratio-1)),
347 source.valid.eq(sink.valid),
348 source.first.eq(sink.first & first),
349 source.last.eq(sink.last & last),
350 sink.ready.eq(last & source.ready)
351 ]
352 self.sync += \
353 If(source.valid & source.ready,
354 If(last,
355 mux.eq(0)
356 ).Else(
357 mux.eq(mux + 1)
358 )
359 )
360
361 # Data path
362 cases = {}
363 for i in range(ratio):
364 n = ratio-i-1 if reverse else i
365 cases[i] = source.data.eq(sink.data[n*nbits_to:(n+1)*nbits_to])
366 self.comb += Case(mux, cases).makedefault()
367
368 # Valid token count
369 self.comb += source.valid_token_count.eq(last)
370
371
372 class _IdentityConverter(Module):
373 def __init__(self, nbits_from, nbits_to, ratio, reverse):
374 self.sink = sink = Endpoint([("data", nbits_from)])
375 self.source = source = Endpoint([("data", nbits_to), ("valid_token_count", 1)])
376 self.latency = 0
377
378 # # #
379
380 self.comb += [
381 sink.connect(source),
382 source.valid_token_count.eq(1)
383 ]
384
385
386 def _get_converter_ratio(nbits_from, nbits_to):
387 if nbits_from > nbits_to:
388 converter_cls = _DownConverter
389 if nbits_from % nbits_to:
390 raise ValueError("Ratio must be an int")
391 ratio = nbits_from//nbits_to
392 elif nbits_from < nbits_to:
393 converter_cls = _UpConverter
394 if nbits_to % nbits_from:
395 raise ValueError("Ratio must be an int")
396 ratio = nbits_to//nbits_from
397 else:
398 converter_cls = _IdentityConverter
399 ratio = 1
400 return converter_cls, ratio
401
402
403 class Converter(Module):
404 def __init__(self, nbits_from, nbits_to,
405 reverse = False,
406 report_valid_token_count = False):
407 self.cls, self.ratio = _get_converter_ratio(nbits_from, nbits_to)
408
409 # # #
410
411 converter = self.cls(nbits_from, nbits_to, self.ratio, reverse)
412 self.submodules += converter
413 self.latency = converter.latency
414
415 self.sink = converter.sink
416 if report_valid_token_count:
417 self.source = converter.source
418 else:
419 self.source = Endpoint([("data", nbits_to)])
420 self.comb += converter.source.connect(self.source, omit=set(["valid_token_count"]))
421
422
423 class StrideConverter(Module):
424 def __init__(self, description_from, description_to, reverse=False):
425 self.sink = sink = Endpoint(description_from)
426 self.source = source = Endpoint(description_to)
427
428 # # #
429
430 nbits_from = len(sink.payload.raw_bits())
431 nbits_to = len(source.payload.raw_bits())
432
433 converter = Converter(nbits_from, nbits_to, reverse)
434 self.submodules += converter
435
436 # Cast sink to converter.sink (user fields --> raw bits)
437 self.comb += [
438 converter.sink.valid.eq(sink.valid),
439 converter.sink.first.eq(sink.first),
440 converter.sink.last.eq(sink.last),
441 sink.ready.eq(converter.sink.ready)
442 ]
443 if converter.cls == _DownConverter:
444 ratio = converter.ratio
445 for i in range(ratio):
446 j = 0
447 for name, width in source.description.payload_layout:
448 src = getattr(sink, name)[i*width:(i+1)*width]
449 dst = converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
450 self.comb += dst.eq(src)
451 j += width
452 else:
453 self.comb += converter.sink.data.eq(sink.payload.raw_bits())
454
455
456 # Cast converter.source to source (raw bits --> user fields)
457 self.comb += [
458 source.valid.eq(converter.source.valid),
459 source.first.eq(converter.source.first),
460 source.last.eq(converter.source.last),
461 converter.source.ready.eq(source.ready)
462 ]
463 if converter.cls == _UpConverter:
464 ratio = converter.ratio
465 for i in range(ratio):
466 j = 0
467 for name, width in sink.description.payload_layout:
468 src = converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
469 dst = getattr(source, name)[i*width:(i+1)*width]
470 self.comb += dst.eq(src)
471 j += width
472 else:
473 self.comb += source.payload.raw_bits().eq(converter.source.data)
474
475 # Connect params
476 if converter.latency == 0:
477 self.comb += source.param.eq(sink.param)
478 elif converter.latency == 1:
479 self.sync += source.param.eq(sink.param)
480 else:
481 raise ValueError
482
483 # Gearbox ------------------------------------------------------------------------------------------
484
485 def lcm(a, b):
486 return (a*b)//math.gcd(a, b)
487
488
489 def inc_mod(s, m):
490 return [s.eq(s + 1), If(s == (m -1), s.eq(0))]
491
492
493 class Gearbox(Module):
494 def __init__(self, i_dw, o_dw, msb_first=True):
495 self.sink = sink = Endpoint([("data", i_dw)])
496 self.source = source = Endpoint([("data", o_dw)])
497
498 # # #
499
500 io_lcm = lcm(i_dw, o_dw)
501
502 # Control path
503
504 level = Signal(max=io_lcm)
505 i_inc = Signal()
506 i_count = Signal(max=io_lcm//i_dw)
507 o_inc = Signal()
508 o_count = Signal(max=io_lcm//o_dw)
509
510 self.comb += [
511 sink.ready.eq(level < (io_lcm - i_dw)),
512 source.valid.eq(level >= o_dw),
513 ]
514 self.comb += [
515 i_inc.eq(sink.valid & sink.ready),
516 o_inc.eq(source.valid & source.ready)
517 ]
518 self.sync += [
519 If(i_inc, *inc_mod(i_count, io_lcm//i_dw)),
520 If(o_inc, *inc_mod(o_count, io_lcm//o_dw)),
521 If(i_inc & ~o_inc, level.eq(level + i_dw)),
522 If(~i_inc & o_inc, level.eq(level - o_dw)),
523 If(i_inc & o_inc, level.eq(level + i_dw - o_dw)),
524 ]
525
526 # Data path
527
528 shift_register = Signal(io_lcm, reset_less=True)
529
530 i_cases = {}
531 i_data = Signal(i_dw)
532 if msb_first:
533 self.comb += i_data.eq(sink.data)
534 else:
535 self.comb += i_data.eq(sink.data[::-1])
536 for i in range(io_lcm//i_dw):
537 i_cases[i] = shift_register[io_lcm - i_dw*(i+1):io_lcm - i_dw*i].eq(i_data)
538 self.sync += If(sink.valid & sink.ready, Case(i_count, i_cases))
539
540 o_cases = {}
541 o_data = Signal(o_dw)
542 for i in range(io_lcm//o_dw):
543 o_cases[i] = o_data.eq(shift_register[io_lcm - o_dw*(i+1):io_lcm - o_dw*i])
544 self.comb += Case(o_count, o_cases)
545 if msb_first:
546 self.comb += source.data.eq(o_data)
547 else:
548 self.comb += source.data.eq(o_data[::-1])
549
550 # Monitor ------------------------------------------------------------------------------------------
551
552 class Monitor(Module, AutoCSR):
553 def __init__(self, endpoint, count_width=32, clock_domain="sys",
554 with_tokens = False,
555 with_overflows = False,
556 with_underflows = False):
557
558 self.reset = CSR()
559 self.latch = CSR()
560 if with_tokens:
561 self.tokens = CSRStatus(count_width)
562 if with_overflows:
563 self.overflows = CSRStatus(count_width)
564 if with_underflows:
565 self.underflows = CSRStatus(count_width)
566
567 # # #
568
569 reset = Signal()
570 latch = Signal()
571 if clock_domain == "sys":
572 self.comb += reset.eq(self.reset.re)
573 self.comb += latch.eq(self.latch.re)
574 else:
575 reset_ps = PulseSynchronizer("sys", clock_domain)
576 latch_ps = PulseSynchronizer("sys", clock_domain)
577 self.submodules += reset_ps, latch_ps
578 self.comb += reset_ps.i.eq(self.reset.re)
579 self.comb += reset.eq(reset_ps.o)
580 self.comb += latch_ps.i.eq(self.latch.re)
581 self.comb += latch.eq(latch_ps.o)
582
583 # Generic Monitor Counter ------------------------------------------------------------------
584 class MonitorCounter(Module):
585 def __init__(self, reset, latch, enable, count):
586 _count = Signal.like(count)
587 _count_latched = Signal.like(count)
588 _sync = getattr(self.sync, clock_domain)
589 _sync += [
590 If(reset,
591 _count.eq(0),
592 _count_latched.eq(0),
593 ).Elif(enable,
594 If(_count != (2**len(count)-1),
595 _count.eq(_count + 1)
596 )
597 ),
598 If(latch,
599 _count_latched.eq(_count)
600 )
601 ]
602 self.specials += MultiReg(_count_latched, count)
603
604 # Tokens Count -----------------------------------------------------------------------------
605 if with_tokens:
606 tokens_counter = MonitorCounter(reset, latch, endpoint.valid & endpoint.ready, self.tokens.status)
607 self.submodules += token_counter
608
609 # Overflows Count (only useful when endpoint is expected to always be ready) ---------------
610 if with_overflows:
611 overflow_counter = MonitorCounter(reset, latch, endpoint.valid & ~endpoint.ready, self.overflows.status)
612 self.submodules += overflow_counter
613
614 # Underflows Count (only useful when endpoint is expected to always be valid) --------------
615 if with_underflows:
616 underflow_counter = MonitorCounter(reset, latch, ~endpoint.valid & endpoint.ready, self.underflows.status)
617 self.submodules += underflow_counter
618
619 # Pipe ---------------------------------------------------------------------------------------------
620
621 class PipeValid(Module):
622 """Pipe valid/payload to cut timing path"""
623 def __init__(self, layout):
624 self.sink = sink = Endpoint(layout)
625 self.source = source = Endpoint(layout)
626
627 # # #
628
629 # Pipe when source is not valid or is ready.
630 self.sync += [
631 If(~source.valid | source.ready,
632 source.valid.eq(sink.valid),
633 source.first.eq(sink.first),
634 source.last.eq(sink.last),
635 source.payload.eq(sink.payload),
636 source.param.eq(sink.param),
637 )
638 ]
639 self.comb += sink.ready.eq(~source.valid | source.ready)
640
641
642 class PipeReady(Module):
643 """Pipe ready to cut timing path"""
644 def __init__(self, layout):
645 self.sink = sink = Endpoint(layout)
646 self.source = source = Endpoint(layout)
647
648 # # #
649
650 valid = Signal()
651 sink_d = Endpoint(layout)
652
653 self.sync += [
654 If(sink.valid & ~source.ready,
655 valid.eq(1)
656 ).Elif(source.ready,
657 valid.eq(0)
658 ),
659 If(~source.ready & ~valid,
660 sink_d.eq(sink)
661 )
662 ]
663 self.comb += [
664 sink.ready.eq(~valid),
665 If(valid,
666 sink_d.connect(source, omit={"ready"})
667 ).Else(
668 sink.connect(source, omit={"ready"})
669 )
670 ]
671
672 # Buffer -------------------------------------------------------------------------------------------
673
674 class Buffer(PipeValid): pass # FIXME: Replace Buffer with PipeValid in codebase?
675
676 # Cast ---------------------------------------------------------------------------------------------
677
678 class Cast(CombinatorialActor):
679 def __init__(self, layout_from, layout_to, reverse_from=False, reverse_to=False):
680 self.sink = Endpoint(_rawbits_layout(layout_from))
681 self.source = Endpoint(_rawbits_layout(layout_to))
682 CombinatorialActor.__init__(self)
683
684 # # #
685
686 sigs_from = self.sink.payload.flatten()
687 if reverse_from:
688 sigs_from = list(reversed(sigs_from))
689 sigs_to = self.source.payload.flatten()
690 if reverse_to:
691 sigs_to = list(reversed(sigs_to))
692 if sum(len(s) for s in sigs_from) != sum(len(s) for s in sigs_to):
693 raise TypeError
694 self.comb += Cat(*sigs_to).eq(Cat(*sigs_from))
695
696 # Unpack/Pack --------------------------------------------------------------------------------------
697
698 class Unpack(Module):
699 def __init__(self, n, layout_to, reverse=False):
700 self.source = source = Endpoint(layout_to)
701 description_from = Endpoint(layout_to).description
702 description_from.payload_layout = pack_layout(description_from.payload_layout, n)
703 self.sink = sink = Endpoint(description_from)
704
705 # # #
706
707 mux = Signal(max=n)
708 first = Signal()
709 last = Signal()
710 self.comb += [
711 first.eq(mux == 0),
712 last.eq(mux == (n-1)),
713 source.valid.eq(sink.valid),
714 sink.ready.eq(last & source.ready)
715 ]
716 self.sync += [
717 If(source.valid & source.ready,
718 If(last,
719 mux.eq(0)
720 ).Else(
721 mux.eq(mux + 1)
722 )
723 )
724 ]
725 cases = {}
726 for i in range(n):
727 chunk = n-i-1 if reverse else i
728 cases[i] = [source.payload.raw_bits().eq(getattr(sink.payload, "chunk"+str(chunk)).raw_bits())]
729 self.comb += Case(mux, cases).makedefault()
730
731 for f in description_from.param_layout:
732 src = getattr(self.sink, f[0])
733 dst = getattr(self.source, f[0])
734 self.comb += dst.eq(src)
735
736 self.comb += [
737 source.first.eq(sink.first & first),
738 source.last.eq(sink.last & last)
739 ]
740
741
742 class Pack(Module):
743 def __init__(self, layout_from, n, reverse=False):
744 self.sink = sink = Endpoint(layout_from)
745 description_to = Endpoint(layout_from).description
746 description_to.payload_layout = pack_layout(description_to.payload_layout, n)
747 self.source = source = Endpoint(description_to)
748
749 # # #
750
751 demux = Signal(max=n)
752
753 load_part = Signal()
754 strobe_all = Signal()
755 cases = {}
756 for i in range(n):
757 chunk = n-i-1 if reverse else i
758 cases[i] = [getattr(source.payload, "chunk"+str(chunk)).raw_bits().eq(sink.payload.raw_bits())]
759 self.comb += [
760 sink.ready.eq(~strobe_all | source.ready),
761 source.valid.eq(strobe_all),
762 load_part.eq(sink.valid & sink.ready)
763 ]
764
765 for f in description_to.param_layout:
766 src = getattr(self.sink, f[0])
767 dst = getattr(self.source, f[0])
768 self.sync += If(load_part, dst.eq(src))
769
770 demux_last = ((demux == (n - 1)) | sink.last)
771
772 self.sync += [
773 If(source.ready, strobe_all.eq(0)),
774 If(load_part,
775 Case(demux, cases),
776 If(demux_last,
777 demux.eq(0),
778 strobe_all.eq(1)
779 ).Else(
780 demux.eq(demux + 1)
781 )
782 ),
783 If(source.valid & source.ready,
784 source.first.eq(sink.first),
785 source.last.eq(sink.last),
786 ).Elif(sink.valid & sink.ready,
787 source.first.eq(sink.first | source.first),
788 source.last.eq(sink.last | source.last)
789 )
790 ]
791
792 # Pipeline -----------------------------------------------------------------------------------------
793
794 class Pipeline(Module):
795 def __init__(self, *modules):
796 n = len(modules)
797 m = modules[0]
798 # expose sink of first module
799 # if available
800 if hasattr(m, "sink"):
801 self.sink = m.sink
802 for i in range(1, n):
803 m_n = modules[i]
804 if isinstance(m, Endpoint):
805 source = m
806 else:
807 source = m.source
808 if isinstance(m_n, Endpoint):
809 sink = m_n
810 else:
811 sink = m_n.sink
812 if m is not m_n:
813 self.comb += source.connect(sink)
814 m = m_n
815 # expose source of last module
816 # if available
817 if hasattr(m, "source"):
818 self.source = m.source
819
820 # BufferizeEndpoints -------------------------------------------------------------------------------
821
822 # Add buffers on Endpoints (can be used to improve timings)
823 class BufferizeEndpoints(ModuleTransformer):
824 def __init__(self, endpoint_dict):
825 self.endpoint_dict = endpoint_dict
826
827 def transform_instance(self, submodule):
828 for name, direction in self.endpoint_dict.items():
829 endpoint = getattr(submodule, name)
830 # add buffer on sinks
831 if direction == DIR_SINK:
832 buf = Buffer(endpoint.description)
833 submodule.submodules += buf
834 setattr(submodule, name, buf.sink)
835 submodule.comb += buf.source.connect(endpoint)
836 # add buffer on sources
837 elif direction == DIR_SOURCE:
838 buf = Buffer(endpoint.description)
839 submodule.submodules += buf
840 submodule.comb += endpoint.connect(buf.sink)
841 setattr(submodule, name, buf.source)
842 else:
843 raise ValueError