from nmigen.lib import fifo
-__all__ = ["Endpoint", "SyncFIFO", "AsyncFIFO", "Buffer"]
+__all__ = ["Endpoint", "SyncFIFO", "AsyncFIFO", "Buffer", "StrideConverter"]
def _make_fanout(layout):
class Buffer(PipeValid):
pass # FIXME: Replace Buffer with PipeValid in codebase?
+
+
+class _UpConverter(Elaboratable):
+ def __init__(self, nbits_from, nbits_to, ratio, reverse,
+ report_valid_token_count):
+ self.sink = sink = Endpoint([("data", nbits_from)])
+ source_layout = [("data", nbits_to)]
+ if report_valid_token_count:
+ source_layout.append(("valid_token_count", bits_for(ratio)))
+ self.source = source = Endpoint(source_layout)
+ self.ratio = ratio
+ self._nbits_from = nbits_from
+
+ def elaborate(self, platform):
+ m = Module()
+
+ # control path
+ demux = Signal(max=ratio)
+ load_part = Signal()
+ strobe_all = Signal()
+ m.d.comb += [
+ self.sink.ack.eq(~strobe_all | self.source.ack),
+ self.source.stb.eq(strobe_all),
+ load_part.eq(self.sink.stb & self.sink.ack)
+ ]
+
+ demux_last = ((demux == (ratio - 1)) | self.sink.eop)
+
+ with m.If(source.ack):
+ m.d.sync += strobe_all.eq(0)
+
+ with m.If(load_part):
+ with m.If(demux_last):
+ m.d.sync += [
+ demux.eq(0),
+ strobe_all.eq(1),
+ ]
+ with m.Else():
+ m.d.sync += demux.eq(demux+1)
+
+ with m.If(self.source.stb & self.source.ack):
+ m.d.sync += self.source.eop.eq(self.sink.eop)
+ with m.Elif(self.sink.stb & self.sink.ack):
+ m.d.sync += self.source.eop.eq(self.sink.eop | self.source.eop)
+
+ # data path
+ cases = {}
+ for i in range(ratio):
+ n = ratio-i-1 if reverse else i
+ cases[i] = source.data[n*self._nbits_from:(n+1)*self._nbits_from].eq(sink.data)
+ m.d.sync += If(load_part, Case(demux, cases))
+
+ if report_valid_token_count:
+ with m.If(load_part):
+ m.d.sync += self.source.valid_token_count.eq(demux + 1)
+
+ return m
+
+
+class _DownConverter(Elaboratable):
+ def __init__(self, nbits_from, nbits_to, ratio, reverse,
+ report_valid_token_count):
+ self.sink = sink = Endpoint([("data", nbits_from)])
+ source_layout = [("data", nbits_to)]
+ if report_valid_token_count:
+ source_layout.append(("valid_token_count", 1))
+ self.source = source = Endpoint(source_layout)
+ self.ratio = ratio
+
+ def elaborate(self, platform):
+ m = Module()
+
+ # control path
+ mux = Signal(range(self.ratio))
+ last = Signal()
+ m.d.comb += [
+ last.eq(mux == (self.ratio-1)),
+ self.source.stb.eq(self.sink.stb),
+ self.source.eop.eq(sink.eop & last),
+ self.sink.ack.eq(last & self.source.ack)
+ ]
+ with m.If(self.source.stb & self.source.ack):
+ with m.If(last):
+ m.d.sync += mux.eq(0)
+ with m.Else():
+ m.d.sync += mux.eq(mux+1)
+
+ # data path
+ cases = {}
+ for i in range(ratio):
+ n = ratio-i-1 if reverse else i
+ cases[i] = self.source.data.eq(self.sink.data[n*nbits_to:(n+1)*nbits_to])
+ m.d.comb += Case(mux, cases).makedefault()
+
+ if report_valid_token_count:
+ m.d.comb += self.source.valid_token_count.eq(last)
+
+ return m
+
+
+class _IdentityConverter(Elaboratable):
+ def __init__(self, nbits_from, nbits_to, ratio, reverse,
+ report_valid_token_count):
+ self.sink = Endpoint([("data", nbits_from)])
+ source_layout = [("data", nbits_to)]
+ if report_valid_token_count:
+ source_layout.append(("valid_token_count", 1))
+ self.source = Endpoint(source_layout)
+ assert ratio == 1
+ self.ratio = ratio
+ self._report_valid_token_count = report_valid_token_count
+
+ def elaborate(self, platform):
+ m = Module()
+
+ m.d.comb += self.sink.connect(self.source)
+ if self._report_valid_token_count:
+ m.d.comb += self.source.valid_token_count.eq(1)
+
+ return m
+
+
+def _get_converter_ratio(nbits_from, nbits_to):
+ if nbits_from > nbits_to:
+ specialized_cls = _DownConverter
+ if nbits_from % nbits_to:
+ raise ValueError("Ratio must be an int")
+ ratio = nbits_from//nbits_to
+ elif nbits_from < nbits_to:
+ specialized_cls = _UpConverter
+ if nbits_to % nbits_from:
+ raise ValueError("Ratio must be an int")
+ ratio = nbits_to//nbits_from
+ else:
+ specialized_cls = _IdentityConverter
+ ratio = 1
+
+ return specialized_cls, ratio
+
+
+class Converter(Elaboratable):
+ def __init__(self, nbits_from, nbits_to, reverse=False,
+ report_valid_token_count=False):
+ cls, ratio = _get_converter_ratio(nbits_from, nbits_to)
+ self.specialized = cls(nbits_from, nbits_to, ratio,
+ reverse, report_valid_token_count)
+ self.sink = self.specialized.sink
+ self.source = self.specialized.source
+
+ def elaborate(self, platform):
+ m = Module()
+
+ m.submodules += self.specialized
+
+ return m
+
+
+class StrideConverter(Elaboratable):
+ def __init__(self, layout_from, layout_to, *args, **kwargs):
+ self.sink = sink = Endpoint(layout_from)
+ self.source = source = Endpoint(layout_to)
+
+ nbits_from = len(sink.payload.raw_bits())
+ nbits_to = len(source.payload.raw_bits())
+ self.converter = Converter(nbits_from, nbits_to, *args, **kwargs)
+
+ def elaborate(self, platform):
+ m = Module()
+
+ nbits_from = len(sink.payload.raw_bits())
+ nbits_to = len(source.payload.raw_bits())
+
+ m.d.submodules += self.converter
+
+ # cast sink to converter.sink (user fields --> raw bits)
+ m.d.comb += [
+ self.converter.sink.stb.eq(sink.stb),
+ self.converter.sink.eop.eq(sink.eop),
+ sink.ack.eq(self.converter.sink.ack)
+ ]
+ if isinstance(self.converter.specialized, _DownConverter):
+ ratio = self.converter.specialized.ratio
+ for i in range(ratio):
+ j = 0
+ for name, width in layout_to:
+ src = getattr(sink, name)[i*width:(i+1)*width]
+ dst = self.converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
+ self.comb += dst.eq(src)
+ j += width
+ else:
+ m.d.comb += self.converter.sink.data.eq(sink.payload.raw_bits())
+
+
+ # cast converter.source to source (raw bits --> user fields)
+ m.D.comb += [
+ source.stb.eq(self.converter.source.stb),
+ source.eop.eq(self.converter.source.eop),
+ self.converter.source.ack.eq(source.ack)
+ ]
+ if isinstance(self.converter.specialized, _UpConverter):
+ ratio = self.converter.specialized.ratio
+ for i in range(ratio):
+ j = 0
+ for name, width in layout_from:
+ src = self.converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
+ dst = getattr(source, name)[i*width:(i+1)*width]
+ m.d.comb += dst.eq(src)
+ j += width
+ else:
+ m.d.comb += source.payload.raw_bits().eq(self.converter.source.data)
+
+ return m